Do not add validation ops in tf.ragged.placeholder.
PiperOrigin-RevId: 331668447 Change-Id: Ia4e3ee27ff7cffd2ba5743468fe0204f2a7b3d2d
This commit is contained in:
parent
40fb595cc2
commit
4b28727644
@ -346,5 +346,6 @@ def placeholder(dtype, ragged_rank, value_shape=None, name=None):
|
|||||||
for i in reversed(range(ragged_rank)):
|
for i in reversed(range(ragged_rank)):
|
||||||
row_splits = array_ops.placeholder(dtypes.int64, [None],
|
row_splits = array_ops.placeholder(dtypes.int64, [None],
|
||||||
"row_splits_%d" % i)
|
"row_splits_%d" % i)
|
||||||
result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits)
|
result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits,
|
||||||
|
validate=False)
|
||||||
return result
|
return result
|
||||||
|
@ -21,6 +21,7 @@ from absl.testing import parameterized
|
|||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
@ -34,30 +35,20 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase,
|
|||||||
# dtype, ragged_rank, value_shape, name -> expected
|
# dtype, ragged_rank, value_shape, name -> expected
|
||||||
(dtypes.int32, 0, [5], None,
|
(dtypes.int32, 0, [5], None,
|
||||||
'Tensor("Placeholder:0", shape=(5,), dtype=int32)'),
|
'Tensor("Placeholder:0", shape=(5,), dtype=int32)'),
|
||||||
(dtypes.int32, 1, [], 'ph',
|
(dtypes.int32, 1, [], 'ph', 'tf.RaggedTensor('
|
||||||
'tf.RaggedTensor('
|
|
||||||
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=int32), '
|
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=int32), '
|
||||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||||
'shape=(None,), dtype=int64))'),
|
(dtypes.string, 1, [5], 'ph', 'tf.RaggedTensor('
|
||||||
(dtypes.string, 1, [5], 'ph',
|
|
||||||
'tf.RaggedTensor('
|
|
||||||
'values=Tensor("ph/flat_values:0", shape=(None, 5), dtype=string), '
|
'values=Tensor("ph/flat_values:0", shape=(None, 5), dtype=string), '
|
||||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||||
'shape=(None,), dtype=int64))'),
|
(dtypes.float32, 2, [], 'ph', 'tf.RaggedTensor(values=tf.RaggedTensor('
|
||||||
(dtypes.float32, 2, [], 'ph',
|
|
||||||
'tf.RaggedTensor(values=tf.RaggedTensor('
|
|
||||||
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=float32), '
|
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=float32), '
|
||||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), '
|
||||||
'shape=(None,), dtype=int64)), '
|
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||||
'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", '
|
(dtypes.int32, 2, [3, 5], 'ph', 'tf.RaggedTensor(values=tf.RaggedTensor('
|
||||||
'shape=(None,), dtype=int64))'),
|
|
||||||
(dtypes.int32, 2, [3, 5], 'ph',
|
|
||||||
'tf.RaggedTensor(values=tf.RaggedTensor('
|
|
||||||
'values=Tensor("ph/flat_values:0", shape=(None, 3, 5), dtype=int32), '
|
'values=Tensor("ph/flat_values:0", shape=(None, 3, 5), dtype=int32), '
|
||||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), '
|
||||||
'shape=(None,), dtype=int64)), '
|
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||||
'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", '
|
|
||||||
'shape=(None,), dtype=int64))'),
|
|
||||||
])
|
])
|
||||||
def testRaggedPlaceholder(self, dtype, ragged_rank, value_shape, name,
|
def testRaggedPlaceholder(self, dtype, ragged_rank, value_shape, name,
|
||||||
expected):
|
expected):
|
||||||
@ -72,6 +63,16 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase,
|
|||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
ragged_factory_ops.placeholder(dtypes.int32, 1, [])
|
ragged_factory_ops.placeholder(dtypes.int32, 1, [])
|
||||||
|
|
||||||
|
def testRaggedPlaceholderDoesNotIncludeValidationOps(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return
|
||||||
|
graph = ops.Graph()
|
||||||
|
with graph.as_default():
|
||||||
|
ragged_factory_ops.placeholder(
|
||||||
|
dtypes.float32, ragged_rank=1, value_shape=[])
|
||||||
|
self.assertEqual([op.type for op in graph.get_operations()],
|
||||||
|
['Placeholder', 'Placeholder'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user