In RaggedTensor dispatchers for elementwise ops, convert RaggedTensorValue -> RaggedTensor.

PiperOrigin-RevId: 293807568
Change-Id: I839e2d510e84253d2d9a4a5325848d8471a2f74b
This commit is contained in:
Edward Loper 2020-02-07 07:08:30 -08:00 committed by TensorFlower Gardener
parent 99ec314b06
commit bb275635d1
2 changed files with 19 additions and 20 deletions

View File

@ -131,6 +131,10 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
elif not _is_convertible_to_tensor(elt): elif not _is_convertible_to_tensor(elt):
return self.NOT_SUPPORTED return self.NOT_SUPPORTED
if found_ragged: if found_ragged:
x = [
ragged_tensor.convert_to_tensor_or_ragged_tensor(elt)
if ragged_tensor.is_ragged(elt) else elt for elt in x
]
x = ragged_tensor.match_row_splits_dtypes(*x) x = ragged_tensor.match_row_splits_dtypes(*x)
nested_splits_lists = [ nested_splits_lists = [
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt) elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
@ -149,6 +153,7 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
else: else:
found_ragged = ragged_tensor.is_ragged(x) found_ragged = ragged_tensor.is_ragged(x)
if found_ragged: if found_ragged:
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name=self._x)
mapped_values = self._original_op(x.flat_values, *args, **kwargs) mapped_values = self._original_op(x.flat_values, *args, **kwargs)
return x.with_flat_values(mapped_values) return x.with_flat_values(mapped_values)
else: else:
@ -196,10 +201,10 @@ class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
# Convert args to tensors. Bail if conversion fails. # Convert args to tensors. Bail if conversion fails.
try: try:
if not x_is_ragged: x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype) x, name=self._x, preferred_dtype=(y.dtype if y_is_ragged else None))
if not y_is_ragged: y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype) y, name=self._y, preferred_dtype=(x.dtype if x_is_ragged else None))
except (TypeError, ValueError): except (TypeError, ValueError):
return self.NOT_SUPPORTED return self.NOT_SUPPORTED

View File

@ -139,14 +139,15 @@ BINARY_INT_OPS = [
] ]
# pylint: disable=g-complex-comprehension
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
parameterized.TestCase): parameterized.TestCase):
def assertSameShape(self, x, y): def assertSameShape(self, x, y):
"""Checks that x and y have the same shape (including ragged shapes).""" """Checks that x and y have the same shape (including ragged shapes)."""
if isinstance(x, ragged_tensor.RaggedTensor): if ragged_tensor.is_ragged(x):
self.assertIsInstance(y, ragged_tensor.RaggedTensor) self.assertTrue(ragged_tensor.is_ragged(y))
self.assertEqual(x.ragged_rank, y.ragged_rank) self.assertEqual(x.ragged_rank, y.ragged_rank)
for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits): for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits):
self.assertAllEqual(x_splits, y_splits) self.assertAllEqual(x_splits, y_splits)
@ -234,18 +235,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
] ]
) # pyformat: disable ) # pyformat: disable
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args): def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
result = op(x, **extra_args) result = op(x, **extra_args)
# Run the wrapped op on the dense values, for comparison. # Run the wrapped op on the dense values, for comparison.
dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1]) expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
# Check that the result has the expected shape. # Check that the result has the expected shape.
self.assertSameShape(x, result) self.assertSameShape(x, result)
# Check that the result has the expected (flattened) values. # Check that the result has the expected (flattened) values.
if isinstance(result, ragged_tensor.RaggedTensor): if ragged_tensor.is_ragged(result):
result_flat_values = array_ops.reshape(result.flat_values, [-1]) result_flat_values = array_ops.reshape(result.flat_values, [-1])
else: else:
result_flat_values = array_ops.reshape(result, [-1]) result_flat_values = array_ops.reshape(result, [-1])
@ -350,8 +350,6 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
) # pyformat: disable ) # pyformat: disable
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args): def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
use_kwargs = extra_args.pop('use_kwargs', ()) use_kwargs = extra_args.pop('use_kwargs', ())
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y)
if 'x' in use_kwargs and 'y' in use_kwargs: if 'x' in use_kwargs and 'y' in use_kwargs:
result = op(x=x, y=y, **extra_args) result = op(x=x, y=y, **extra_args)
elif 'y' in use_kwargs: elif 'y' in use_kwargs:
@ -360,8 +358,8 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
result = op(x, y, **extra_args) result = op(x, y, **extra_args)
# Run the wrapped op on the dense values, for comparison. # Run the wrapped op on the dense values, for comparison.
dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
dense_y = y.flat_values if isinstance(y, ragged_tensor.RaggedTensor) else y dense_y = y.flat_values if ragged_tensor.is_ragged(y) else y
expected_flat_values = array_ops.reshape( expected_flat_values = array_ops.reshape(
op(dense_x, dense_y, **extra_args), [-1]) op(dense_x, dense_y, **extra_args), [-1])
@ -369,7 +367,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
self.assertSameShape(y, result) self.assertSameShape(y, result)
# Check that the result has the expected (flattened) values. # Check that the result has the expected (flattened) values.
if isinstance(result, ragged_tensor.RaggedTensor): if ragged_tensor.is_ragged(result):
result_flat_values = array_ops.reshape(result.flat_values, [-1]) result_flat_values = array_ops.reshape(result.flat_values, [-1])
else: else:
result_flat_values = array_ops.reshape(result, [-1]) result_flat_values = array_ops.reshape(result, [-1])
@ -415,9 +413,6 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n, def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
**extra_args): **extra_args):
use_kwargs = extra_args.pop('use_kwargs', False) use_kwargs = extra_args.pop('use_kwargs', False)
inputs = [
ragged_tensor.convert_to_tensor_or_ragged_tensor(x) for x in inputs
]
if use_kwargs: if use_kwargs:
result = op(inputs=inputs, **extra_args) result = op(inputs=inputs, **extra_args)
else: else:
@ -425,8 +420,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
# Run the wrapped op on the dense values, for comparison. # Run the wrapped op on the dense values, for comparison.
dense_inputs = [ dense_inputs = [
x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x x.flat_values if ragged_tensor.is_ragged(x) else x for x in inputs
for x in inputs
] ]
expected_flat_values = array_ops.reshape( expected_flat_values = array_ops.reshape(
op(dense_inputs, **extra_args), [-1]) op(dense_inputs, **extra_args), [-1])
@ -435,7 +429,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
self.assertSameShape(inputs[0], result) self.assertSameShape(inputs[0], result)
# Check that the result has the expected (flattened) values. # Check that the result has the expected (flattened) values.
if isinstance(result, ragged_tensor.RaggedTensor): if ragged_tensor.is_ragged(result):
result_flat_values = array_ops.reshape(result.flat_values, [-1]) result_flat_values = array_ops.reshape(result.flat_values, [-1])
else: else:
result_flat_values = array_ops.reshape(result, [-1]) result_flat_values = array_ops.reshape(result, [-1])