Do not add validation ops in tf.ragged.placeholder.

PiperOrigin-RevId: 331668447
Change-Id: Ia4e3ee27ff7cffd2ba5743468fe0204f2a7b3d2d
This commit is contained in:
Edward Loper 2020-09-14 17:51:17 -07:00 committed by TensorFlower Gardener
parent 40fb595cc2
commit 4b28727644
2 changed files with 23 additions and 21 deletions

View File

@ -346,5 +346,6 @@ def placeholder(dtype, ragged_rank, value_shape=None, name=None):
for i in reversed(range(ragged_rank)):
row_splits = array_ops.placeholder(dtypes.int64, [None],
"row_splits_%d" % i)
result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits)
result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits,
validate=False)
return result

View File

@ -21,6 +21,7 @@ from absl.testing import parameterized
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import googletest
@ -34,30 +35,20 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase,
# dtype, ragged_rank, value_shape, name -> expected
(dtypes.int32, 0, [5], None,
'Tensor("Placeholder:0", shape=(5,), dtype=int32)'),
(dtypes.int32, 1, [], 'ph',
'tf.RaggedTensor('
(dtypes.int32, 1, [], 'ph', 'tf.RaggedTensor('
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=int32), '
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
'shape=(None,), dtype=int64))'),
(dtypes.string, 1, [5], 'ph',
'tf.RaggedTensor('
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
(dtypes.string, 1, [5], 'ph', 'tf.RaggedTensor('
'values=Tensor("ph/flat_values:0", shape=(None, 5), dtype=string), '
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
'shape=(None,), dtype=int64))'),
(dtypes.float32, 2, [], 'ph',
'tf.RaggedTensor(values=tf.RaggedTensor('
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
(dtypes.float32, 2, [], 'ph', 'tf.RaggedTensor(values=tf.RaggedTensor('
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=float32), '
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
'shape=(None,), dtype=int64)), '
'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", '
'shape=(None,), dtype=int64))'),
(dtypes.int32, 2, [3, 5], 'ph',
'tf.RaggedTensor(values=tf.RaggedTensor('
'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), '
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
(dtypes.int32, 2, [3, 5], 'ph', 'tf.RaggedTensor(values=tf.RaggedTensor('
'values=Tensor("ph/flat_values:0", shape=(None, 3, 5), dtype=int32), '
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
'shape=(None,), dtype=int64)), '
'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", '
'shape=(None,), dtype=int64))'),
'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), '
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
])
def testRaggedPlaceholder(self, dtype, ragged_rank, value_shape, name,
expected):
@ -72,6 +63,16 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase,
with self.assertRaises(RuntimeError):
ragged_factory_ops.placeholder(dtypes.int32, 1, [])
def testRaggedPlaceholderDoesNotIncludeValidationOps(self):
if context.executing_eagerly():
return
graph = ops.Graph()
with graph.as_default():
ragged_factory_ops.placeholder(
dtypes.float32, ragged_rank=1, value_shape=[])
self.assertEqual([op.type for op in graph.get_operations()],
['Placeholder', 'Placeholder'])
if __name__ == '__main__':
googletest.main()