Account for values shape when computing splits in variant -> ragged.

PiperOrigin-RevId: 253614932
This commit is contained in:
Gaurav Mishra 2019-06-17 10:56:45 -07:00 committed by TensorFlower Gardener
parent 93ebb3fb4a
commit 84cf4d436e
2 changed files with 44 additions and 1 deletions

View File

@ -138,7 +138,7 @@ Status NestedStackRaggedTensors(
output_ragged->nested_splits[dims - 1].vec<SPLIT_TYPE>();
dims_splits_vec(0) = 0;
for (int i = 0; i < ragged_components.size(); i++) {
int split_val = ragged_components[i].values.NumElements();
int split_val = ragged_components[i].values.shape().dim_size(0);
if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) {
split_val = ragged_components[i].nested_splits[0].NumElements() - 1;
}

View File

@ -691,5 +691,48 @@ TEST_F(RaggedTensorFromVariantKernelTest, ShapeFnTest) {
INFER_ERROR("Shape must be rank 3 but is rank 2", op, "[?,?]");
INFER_OK(op, "[?,?,?]", "[?];[?];[?];[?];[?];[?];?");
}
TEST_F(RaggedTensorFromVariantKernelTest, 2DValuesTensorIn1DOut) {
// [
// [
// [[x, x], [x, x]],
// [[x, x], [x, x]]
// ],
// [[[x, x], [x, x]]],
// [],
// [
// [[x, x], [x, x]],
// [[x, x], [x, x]]
// ]
// ]
const std::vector<int64> batched_splits_1 = {0, 2, 3, 3, 5};
const std::vector<int> batched_values = {1, 1, 1, 1, 2, 2, 2, 2, 3, 3,
3, 3, 4, 4, 4, 4, 5, 5, 5, 5};
Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
{}, TensorShape({2, 2, 2}), {1, 1, 1, 1, 2, 2, 2, 2});
Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
{}, TensorShape({1, 2, 2}), {3, 3, 3, 3});
Tensor variant_component_3 =
CreateVariantFromRagged<int, int64>({}, TensorShape({0, 2, 2}), {});
Tensor variant_component_4 = CreateVariantFromRagged<int, int64>(
{}, TensorShape({2, 2, 2}), {4, 4, 4, 4, 5, 5, 5, 5});
Tensor expected_splits_1(DT_INT64, TensorShape({5}));
Tensor expected_values(DT_INT32, TensorShape({5, 2, 2}));
test::FillValues<int64>(&expected_splits_1, batched_splits_1);
test::FillValues<int>(&expected_values, batched_values);
int input_ragged_rank = 0;
int output_ragged_rank = 1;
BuildDecodeRaggedTensorGraph<int, int64>(
input_ragged_rank, output_ragged_rank, TensorShape({4}),
{variant_component_1, variant_component_2, variant_component_3,
variant_component_4});
TF_ASSERT_OK(RunOpKernel());
test::ExpectTensorEqual<int64>(*GetOutput(0), expected_splits_1);
test::ExpectTensorEqual<int>(*GetOutput(1), expected_values);
}
} // namespace
} // namespace tensorflow