In RaggedTensor dispatchers for elementwise ops, convert RaggedTensorValue -> RaggedTensor.

PiperOrigin-RevId: 293807568
Change-Id: I839e2d510e84253d2d9a4a5325848d8471a2f74b
This commit is contained in:
Edward Loper 2020-02-07 07:08:30 -08:00 committed by TensorFlower Gardener
parent 99ec314b06
commit bb275635d1
2 changed files with 19 additions and 20 deletions

View File

@ -131,6 +131,10 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
elif not _is_convertible_to_tensor(elt):
return self.NOT_SUPPORTED
if found_ragged:
x = [
ragged_tensor.convert_to_tensor_or_ragged_tensor(elt)
if ragged_tensor.is_ragged(elt) else elt for elt in x
]
x = ragged_tensor.match_row_splits_dtypes(*x)
nested_splits_lists = [
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
@ -149,6 +153,7 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
else:
found_ragged = ragged_tensor.is_ragged(x)
if found_ragged:
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name=self._x)
mapped_values = self._original_op(x.flat_values, *args, **kwargs)
return x.with_flat_values(mapped_values)
else:
@ -196,10 +201,10 @@ class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
# Convert args to tensors. Bail if conversion fails.
try:
if not x_is_ragged:
x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
if not y_is_ragged:
y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
x, name=self._x, preferred_dtype=(y.dtype if y_is_ragged else None))
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
y, name=self._y, preferred_dtype=(x.dtype if x_is_ragged else None))
except (TypeError, ValueError):
return self.NOT_SUPPORTED

View File

@ -139,14 +139,15 @@ BINARY_INT_OPS = [
]
# pylint: disable=g-complex-comprehension
@test_util.run_all_in_graph_and_eager_modes
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def assertSameShape(self, x, y):
"""Checks that x and y have the same shape (including ragged shapes)."""
if isinstance(x, ragged_tensor.RaggedTensor):
self.assertIsInstance(y, ragged_tensor.RaggedTensor)
if ragged_tensor.is_ragged(x):
self.assertTrue(ragged_tensor.is_ragged(y))
self.assertEqual(x.ragged_rank, y.ragged_rank)
for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits):
self.assertAllEqual(x_splits, y_splits)
@ -234,18 +235,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
]
) # pyformat: disable
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
result = op(x, **extra_args)
# Run the wrapped op on the dense values, for comparison.
dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
# Check that the result has the expected shape.
self.assertSameShape(x, result)
# Check that the result has the expected (flattened) values.
if isinstance(result, ragged_tensor.RaggedTensor):
if ragged_tensor.is_ragged(result):
result_flat_values = array_ops.reshape(result.flat_values, [-1])
else:
result_flat_values = array_ops.reshape(result, [-1])
@ -350,8 +350,6 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
) # pyformat: disable
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
use_kwargs = extra_args.pop('use_kwargs', ())
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y)
if 'x' in use_kwargs and 'y' in use_kwargs:
result = op(x=x, y=y, **extra_args)
elif 'y' in use_kwargs:
@ -360,8 +358,8 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
result = op(x, y, **extra_args)
# Run the wrapped op on the dense values, for comparison.
dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
dense_y = y.flat_values if isinstance(y, ragged_tensor.RaggedTensor) else y
dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
dense_y = y.flat_values if ragged_tensor.is_ragged(y) else y
expected_flat_values = array_ops.reshape(
op(dense_x, dense_y, **extra_args), [-1])
@ -369,7 +367,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
self.assertSameShape(y, result)
# Check that the result has the expected (flattened) values.
if isinstance(result, ragged_tensor.RaggedTensor):
if ragged_tensor.is_ragged(result):
result_flat_values = array_ops.reshape(result.flat_values, [-1])
else:
result_flat_values = array_ops.reshape(result, [-1])
@ -415,9 +413,6 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
**extra_args):
use_kwargs = extra_args.pop('use_kwargs', False)
inputs = [
ragged_tensor.convert_to_tensor_or_ragged_tensor(x) for x in inputs
]
if use_kwargs:
result = op(inputs=inputs, **extra_args)
else:
@ -425,8 +420,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
# Run the wrapped op on the dense values, for comparison.
dense_inputs = [
x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
for x in inputs
x.flat_values if ragged_tensor.is_ragged(x) else x for x in inputs
]
expected_flat_values = array_ops.reshape(
op(dense_inputs, **extra_args), [-1])
@ -435,7 +429,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
self.assertSameShape(inputs[0], result)
# Check that the result has the expected (flattened) values.
if isinstance(result, ragged_tensor.RaggedTensor):
if ragged_tensor.is_ragged(result):
result_flat_values = array_ops.reshape(result.flat_values, [-1])
else:
result_flat_values = array_ops.reshape(result, [-1])