Do not add validation ops in tf.ragged.placeholder.
PiperOrigin-RevId: 331668447 Change-Id: Ia4e3ee27ff7cffd2ba5743468fe0204f2a7b3d2d
This commit is contained in:
parent
40fb595cc2
commit
4b28727644
tensorflow/python/ops/ragged
@ -346,5 +346,6 @@ def placeholder(dtype, ragged_rank, value_shape=None, name=None):
|
||||
for i in reversed(range(ragged_rank)):
|
||||
row_splits = array_ops.placeholder(dtypes.int64, [None],
|
||||
"row_splits_%d" % i)
|
||||
result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits)
|
||||
result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits,
|
||||
validate=False)
|
||||
return result
|
||||
|
@ -21,6 +21,7 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
@ -34,30 +35,20 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase,
|
||||
# dtype, ragged_rank, value_shape, name -> expected
|
||||
(dtypes.int32, 0, [5], None,
|
||||
'Tensor("Placeholder:0", shape=(5,), dtype=int32)'),
|
||||
(dtypes.int32, 1, [], 'ph',
|
||||
'tf.RaggedTensor('
|
||||
(dtypes.int32, 1, [], 'ph', 'tf.RaggedTensor('
|
||||
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=int32), '
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64))'),
|
||||
(dtypes.string, 1, [5], 'ph',
|
||||
'tf.RaggedTensor('
|
||||
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||
(dtypes.string, 1, [5], 'ph', 'tf.RaggedTensor('
|
||||
'values=Tensor("ph/flat_values:0", shape=(None, 5), dtype=string), '
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64))'),
|
||||
(dtypes.float32, 2, [], 'ph',
|
||||
'tf.RaggedTensor(values=tf.RaggedTensor('
|
||||
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||
(dtypes.float32, 2, [], 'ph', 'tf.RaggedTensor(values=tf.RaggedTensor('
|
||||
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=float32), '
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64)), '
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64))'),
|
||||
(dtypes.int32, 2, [3, 5], 'ph',
|
||||
'tf.RaggedTensor(values=tf.RaggedTensor('
|
||||
'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), '
|
||||
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||
(dtypes.int32, 2, [3, 5], 'ph', 'tf.RaggedTensor(values=tf.RaggedTensor('
|
||||
'values=Tensor("ph/flat_values:0", shape=(None, 3, 5), dtype=int32), '
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64)), '
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64))'),
|
||||
'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), '
|
||||
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||
])
|
||||
def testRaggedPlaceholder(self, dtype, ragged_rank, value_shape, name,
|
||||
expected):
|
||||
@ -72,6 +63,16 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase,
|
||||
with self.assertRaises(RuntimeError):
|
||||
ragged_factory_ops.placeholder(dtypes.int32, 1, [])
|
||||
|
||||
def testRaggedPlaceholderDoesNotIncludeValidationOps(self):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
ragged_factory_ops.placeholder(
|
||||
dtypes.float32, ragged_rank=1, value_shape=[])
|
||||
self.assertEqual([op.type for op in graph.get_operations()],
|
||||
['Placeholder', 'Placeholder'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user