Account for values shape when computing splits in variant -> ragged.
PiperOrigin-RevId: 253614932
This commit is contained in:
parent
93ebb3fb4a
commit
84cf4d436e
@ -138,7 +138,7 @@ Status NestedStackRaggedTensors(
|
||||
output_ragged->nested_splits[dims - 1].vec<SPLIT_TYPE>();
|
||||
dims_splits_vec(0) = 0;
|
||||
for (int i = 0; i < ragged_components.size(); i++) {
|
||||
int split_val = ragged_components[i].values.NumElements();
|
||||
int split_val = ragged_components[i].values.shape().dim_size(0);
|
||||
if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) {
|
||||
split_val = ragged_components[i].nested_splits[0].NumElements() - 1;
|
||||
}
|
||||
|
@ -691,5 +691,48 @@ TEST_F(RaggedTensorFromVariantKernelTest, ShapeFnTest) {
|
||||
INFER_ERROR("Shape must be rank 3 but is rank 2", op, "[?,?]");
|
||||
INFER_OK(op, "[?,?,?]", "[?];[?];[?];[?];[?];[?];?");
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorFromVariantKernelTest, 2DValuesTensorIn1DOut) {
|
||||
// [
|
||||
// [
|
||||
// [[x, x], [x, x]],
|
||||
// [[x, x], [x, x]]
|
||||
// ],
|
||||
// [[[x, x], [x, x]]],
|
||||
// [],
|
||||
// [
|
||||
// [[x, x], [x, x]],
|
||||
// [[x, x], [x, x]]
|
||||
// ]
|
||||
// ]
|
||||
const std::vector<int64> batched_splits_1 = {0, 2, 3, 3, 5};
|
||||
const std::vector<int> batched_values = {1, 1, 1, 1, 2, 2, 2, 2, 3, 3,
|
||||
3, 3, 4, 4, 4, 4, 5, 5, 5, 5};
|
||||
|
||||
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
|
||||
{}, TensorShape({2, 2, 2}), {1, 1, 1, 1, 2, 2, 2, 2});
|
||||
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
|
||||
{}, TensorShape({1, 2, 2}), {3, 3, 3, 3});
|
||||
Tensor variant_component_3 =
|
||||
CreateVariantFromRagged<int, int64>({}, TensorShape({0, 2, 2}), {});
|
||||
Tensor variant_component_4 = CreateVariantFromRagged<int, int64>(
|
||||
{}, TensorShape({2, 2, 2}), {4, 4, 4, 4, 5, 5, 5, 5});
|
||||
|
||||
Tensor expected_splits_1(DT_INT64, TensorShape({5}));
|
||||
Tensor expected_values(DT_INT32, TensorShape({5, 2, 2}));
|
||||
test::FillValues<int64>(&expected_splits_1, batched_splits_1);
|
||||
test::FillValues<int>(&expected_values, batched_values);
|
||||
|
||||
int input_ragged_rank = 0;
|
||||
int output_ragged_rank = 1;
|
||||
BuildDecodeRaggedTensorGraph<int, int64>(
|
||||
input_ragged_rank, output_ragged_rank, TensorShape({4}),
|
||||
{variant_component_1, variant_component_2, variant_component_3,
|
||||
variant_component_4});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
test::ExpectTensorEqual<int64>(*GetOutput(0), expected_splits_1);
|
||||
test::ExpectTensorEqual<int>(*GetOutput(1), expected_values);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user