From 4b28727644fa4265a58655ecac5b4035f069ab38 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Mon, 14 Sep 2020 17:51:17 -0700 Subject: [PATCH] Do not add validation ops in tf.ragged.placeholder. PiperOrigin-RevId: 331668447 Change-Id: Ia4e3ee27ff7cffd2ba5743468fe0204f2a7b3d2d --- .../python/ops/ragged/ragged_factory_ops.py | 3 +- .../ops/ragged/ragged_placeholder_op_test.py | 41 ++++++++++--------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/tensorflow/python/ops/ragged/ragged_factory_ops.py b/tensorflow/python/ops/ragged/ragged_factory_ops.py index 0513c6b690b..1d57187e518 100644 --- a/tensorflow/python/ops/ragged/ragged_factory_ops.py +++ b/tensorflow/python/ops/ragged/ragged_factory_ops.py @@ -346,5 +346,6 @@ def placeholder(dtype, ragged_rank, value_shape=None, name=None): for i in reversed(range(ragged_rank)): row_splits = array_ops.placeholder(dtypes.int64, [None], "row_splits_%d" % i) - result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits) + result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits, + validate=False) return result diff --git a/tensorflow/python/ops/ragged/ragged_placeholder_op_test.py b/tensorflow/python/ops/ragged/ragged_placeholder_op_test.py index d2261d408b3..cdad7d49205 100644 --- a/tensorflow/python/ops/ragged/ragged_placeholder_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_placeholder_op_test.py @@ -21,6 +21,7 @@ from absl.testing import parameterized from tensorflow.python.eager import context from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import googletest @@ -34,30 +35,20 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase, # dtype, ragged_rank, value_shape, name -> expected (dtypes.int32, 0, [5], None, 'Tensor("Placeholder:0", shape=(5,), dtype=int32)'), - (dtypes.int32, 1, [], 'ph', - 'tf.RaggedTensor(' + (dtypes.int32, 1, [], 'ph', 'tf.RaggedTensor(' 'values=Tensor("ph/flat_values:0", shape=(None,), dtype=int32), ' - 'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", ' - 'shape=(None,), dtype=int64))'), - (dtypes.string, 1, [5], 'ph', - 'tf.RaggedTensor(' + 'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'), + (dtypes.string, 1, [5], 'ph', 'tf.RaggedTensor(' 'values=Tensor("ph/flat_values:0", shape=(None, 5), dtype=string), ' - 'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", ' - 'shape=(None,), dtype=int64))'), - (dtypes.float32, 2, [], 'ph', - 'tf.RaggedTensor(values=tf.RaggedTensor(' + 'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'), + (dtypes.float32, 2, [], 'ph', 'tf.RaggedTensor(values=tf.RaggedTensor(' 'values=Tensor("ph/flat_values:0", shape=(None,), dtype=float32), ' - 'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", ' - 'shape=(None,), dtype=int64)), ' - 'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", ' - 'shape=(None,), dtype=int64))'), - (dtypes.int32, 2, [3, 5], 'ph', - 'tf.RaggedTensor(values=tf.RaggedTensor(' + 'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), ' + 'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'), + (dtypes.int32, 2, [3, 5], 'ph', 'tf.RaggedTensor(values=tf.RaggedTensor(' 'values=Tensor("ph/flat_values:0", shape=(None, 3, 5), dtype=int32), ' - 'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", ' - 'shape=(None,), dtype=int64)), ' - 'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", ' - 'shape=(None,), dtype=int64))'), + 'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), ' + 'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'), ]) def testRaggedPlaceholder(self, dtype, ragged_rank, value_shape, name, expected): @@ -72,6 +63,16 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase, with self.assertRaises(RuntimeError): ragged_factory_ops.placeholder(dtypes.int32, 1, []) + def testRaggedPlaceholderDoesNotIncludeValidationOps(self): + if context.executing_eagerly(): + return + graph = ops.Graph() + with graph.as_default(): + ragged_factory_ops.placeholder( + dtypes.float32, ragged_rank=1, value_shape=[]) + self.assertEqual([op.type for op in graph.get_operations()], + ['Placeholder', 'Placeholder']) + if __name__ == '__main__': googletest.main()