From 261fef924c12b8b0e7e95d1197cd8d50f4c641ec Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Wed, 2 Oct 2019 13:06:18 -0700 Subject: [PATCH] Improve handling of variant-encoded ragged tensors with ragged_rank=0. This fixes a bug that was preventing tf.data.Dataset.from_tensor_slices from being used with ragged tensors that had ragged_rank=1. PiperOrigin-RevId: 272502146 --- .../kernels/ragged_tensor_to_variant_op.cc | 1 - .../ragged_tensor_to_variant_op_test.cc | 34 ++++++------ tensorflow/core/ops/ragged_conversion_ops.cc | 12 +++-- tensorflow/python/ops/ragged/ragged_tensor.py | 34 +++++++++--- .../python/ops/ragged/ragged_tensor_test.py | 53 ++++++++++++++++++- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 2 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 2 +- 7 files changed, 108 insertions(+), 30 deletions(-) diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc index c9f09796239..7a5ae1c6240 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc @@ -152,7 +152,6 @@ class RaggedTensorToVariantOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list("rt_nested_splits", &ragged_nested_splits_in)); const int ragged_nested_splits_len = ragged_nested_splits_in.size(); - DCHECK_GT(ragged_nested_splits_len, 0); // Enforced by REGISTER_OP. RaggedTensor batched_ragged_input; // Read ragged_values input. batched_ragged_input.values = context->input(ragged_nested_splits_len); diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc index 2854044d19a..1cc2353d50a 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc @@ -525,7 +525,10 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestBatched) { // Tests with len(ragged_splits)==0. (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0); - INFER_ERROR("Shape inference should have returned error", op, "?"); + INFER_ERROR( + "ragged_rank=0 is not currently supported " + "when batched_input=true.", + op, "?"); // Tests with len(ragged_splits)==1. (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1); @@ -563,7 +566,7 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) { // Tests with len(ragged_splits)==0. (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0); - INFER_ERROR("Shape inference should have returned error", op, "?"); + INFER_OK(op, "?", "[]"); // Tests with len(ragged_splits)==1. (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1); @@ -592,19 +595,20 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) { INFER_OK(op, "?;?;?;[5]", "[]"); } -TEST_F(RaggedTensorToVariantKernelTest, NoSplits) { - const auto dtype = DataTypeToEnum::v(); - TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedTensorToVariant") - .Input(FakeInput(0)) - .Input(FakeInput(dtype)) - .Attr("RAGGED_RANK", 0) - .Attr("Tvalues", dtype) - .Attr("Tsplits", DT_INT64) - .Attr("batched_input", true) - .Finalize(node_def())); - EXPECT_TRUE(absl::StartsWith( - InitOp().error_message(), - "Value for attr 'RAGGED_RANK' of 0 must be at least minimum 1")); +TEST_F(RaggedTensorToVariantKernelTest, NonRaggedInput) { + const std::vector values = {1, 2, 3, 4, 5, 6}; + Tensor expected_values(DT_INT32, TensorShape({6})); + test::FillValues(&expected_values, values); + + BuildEncodeRaggedTensorGraph({}, TensorShape({6}), values, false); + TF_ASSERT_OK(RunOpKernel()); + + const auto& encoded_scalar = GetOutput(0)->scalar()(); + const Variant& encoded_values = + encoded_scalar.get()->vec()(0); + + test::ExpectTensorEqual(*encoded_values.get(), expected_values); } + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc index 78fa5db34b2..6bee189c85e 100644 --- a/tensorflow/core/ops/ragged_conversion_ops.cc +++ b/tensorflow/core/ops/ragged_conversion_ops.cc @@ -113,9 +113,9 @@ REGISTER_OP("RaggedTensorToVariant") .Input("rt_nested_splits: RAGGED_RANK * Tsplits") .Input("rt_dense_values: Tvalues") .Output("encoded_ragged: variant") - .Attr("RAGGED_RANK: int >= 1") + .Attr("RAGGED_RANK: int >= 0") .Attr("Tvalues: type") - .Attr("Tsplits: {int32, int64}") + .Attr("Tsplits: {int32, int64} = DT_INT64") .Attr("batched_input: bool") .SetShapeFn(RaggedTensorToVariantShapeFn); @@ -124,9 +124,9 @@ REGISTER_OP("RaggedTensorFromVariant") .Output("output_nested_splits: output_ragged_rank * Tsplits") .Output("output_dense_values: Tvalues") .Attr("input_ragged_rank: int >= -1") - .Attr("output_ragged_rank: int >= 1") + .Attr("output_ragged_rank: int >= 0") .Attr("Tvalues: type") - .Attr("Tsplits: {int32, int64}") + .Attr("Tsplits: {int32, int64} = DT_INT64") .SetShapeFn(RaggedTensorFromVariantShapeFn); REGISTER_OP("RaggedTensorToTensor") @@ -194,6 +194,10 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) { } else { c->set_output(0, c->Scalar()); } + if (batched && num_splits == 0) { + return errors::InvalidArgument( + "ragged_rank=0 is not currently supported when batched_input=true."); + } return Status::OK(); } diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 06e2590e408..da496eb70c8 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -2201,15 +2201,32 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): return [tensor_spec.TensorSpec(None, dtypes.variant)] def _to_tensor_list(self, value): + ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0 + if ragged_rank != self._ragged_rank: + raise ValueError("Ragged rank of value (%d) does not match ragged " + "rank of type (%d)" % (ragged_rank, self._ragged_rank)) + if ragged_rank == 0: + return [ + gen_ragged_conversion_ops.ragged_tensor_to_variant( + (), value, batched_input=False) + ] # pylint: disable=protected-access return [value._to_variant(batched_input=False)] def _to_batched_tensor_list(self, value): + ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0 + if ragged_rank != self._ragged_rank: + raise ValueError("Ragged rank of value (%d) does not match ragged " + "rank of type (%d)" % (ragged_rank, self._ragged_rank)) + if ragged_rank == 0: + # TODO(b/141789000) Update this to handle ragged_rank=0. + raise ValueError( + "_to_batched_tensor_list doesn't support ragged_rank=0 yet") # pylint: disable=protected-access return [value._to_variant(batched_input=True)] def _from_compatible_tensor_list(self, tensor_list): - if self._ragged_rank <= 0: + if self._ragged_rank < 0: raise ValueError( "ragged_rank must be non-negative; got %s." % self._ragged_rank) result = RaggedTensor._from_variant( # pylint: disable=protected-access @@ -2217,12 +2234,15 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): row_splits_dtype=self._row_splits_dtype, output_ragged_rank=self._ragged_rank) if self._shape.ndims is not None: - outer_dim = tensor_shape.dimension_value(self._shape[0]) - if outer_dim is not None: - result.row_splits.set_shape([outer_dim + 1]) - result.flat_values.set_shape( - tensor_shape.TensorShape([None]).concatenate( - self._shape[1 + self._ragged_rank:])) + if isinstance(result, RaggedTensor): + outer_dim = tensor_shape.dimension_value(self._shape[0]) + if outer_dim is not None: + result.row_splits.set_shape([outer_dim + 1]) + result.flat_values.set_shape( + tensor_shape.TensorShape([None]).concatenate( + self._shape[1 + self._ragged_rank:])) + else: + result.set_shape(self._shape) return result def _batch(self, batch_size): diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py index 892735fc156..1d8b71dc18c 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor_test.py +++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py @@ -23,6 +23,7 @@ import re from absl.testing import parameterized import numpy as np +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -1700,6 +1701,27 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, [[6], [7]], [[8], [9]]]) self.assertAllEqual(decoded_rt, expected_rt) + def testUnbatchVariant(self): # b/141789000 + rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]]) + batched = rt._to_variant(batched_input=True) + for i in range(4): + row = RaggedTensor._from_variant( + batched[i], dtype=dtypes.int32, output_ragged_rank=0) + self.assertAllEqual(rt[i], row) + + def testUnbatchVariantInDataset(self): + rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]]) + ds = dataset_ops.Dataset.from_tensor_slices(rt) + if context.executing_eagerly(): + for i, value in enumerate(ds): + self.assertAllEqual(rt[i], value) + else: + it = dataset_ops.make_one_shot_iterator(ds) + out = it.get_next() + with self.cached_session() as sess: + for i in range(3): + self.assertAllEqual(sess.run(rt[i]), out) + def testFromVariantInvalidParams(self): rt = ragged_factory_ops.constant([[0], [1], [2], [3]]) batched_variant = rt._to_variant(batched_input=True) @@ -1847,12 +1869,19 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase, self.assertEqual(rt_spec._flat_tensor_specs, [tensor_spec.TensorSpec(None, dtypes.variant)]) - @parameterized.parameters([ + @parameterized.named_parameters([ { + 'testcase_name': 'RaggedRank0', + 'rt_spec': RaggedTensorSpec(ragged_rank=0), + 'rt': [1.0, 2.0, 3.0], + }, + { + 'testcase_name': 'RaggedRank1', 'rt_spec': RaggedTensorSpec(ragged_rank=1), 'rt': [[1.0, 2.0], [3.0]] }, { + 'testcase_name': 'RaggedRank2', 'rt_spec': RaggedTensorSpec(shape=[2, None, None]), 'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]] }, @@ -1863,6 +1892,28 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase, rt_reconstructed = rt_spec._from_tensor_list(tensor_list) self.assertAllEqual(rt, rt_reconstructed) + @parameterized.named_parameters([ + # TODO(b/141789000) Test ragged_rank=0 when support is added. + { + 'testcase_name': 'RaggedRank1', + 'rt_spec': RaggedTensorSpec(ragged_rank=1), + 'rt': [[1.0, 2.0], [3.0]] + }, + { + 'testcase_name': 'RaggedRank2', + 'rt_spec': RaggedTensorSpec(shape=[2, None, None]), + 'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]] + }, + ]) + def testToFromBatchedTensorList(self, rt_spec, rt): + rt = ragged_factory_ops.constant(rt) + tensor_list = rt_spec._to_batched_tensor_list(rt) + rt_reconstructed = rt_spec._from_tensor_list(tensor_list) + self.assertAllEqual(rt, rt_reconstructed) + first_row = rt_spec._unbatch()._from_tensor_list( + [t[0] for t in tensor_list]) + self.assertAllEqual(rt[0], first_row) + @parameterized.parameters([ (RaggedTensorSpec([2, None], dtypes.float32, 1), 32, RaggedTensorSpec([32, 2, None], dtypes.float32, 2)), diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 71422248dac..9d5cb659208 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -2882,7 +2882,7 @@ tf_module { } member_method { name: "RaggedTensorFromVariant" - argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "RaggedTensorToSparse" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 71422248dac..9d5cb659208 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -2882,7 +2882,7 @@ tf_module { } member_method { name: "RaggedTensorFromVariant" - argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "RaggedTensorToSparse"