In RaggedTensor dispatchers for elementwise ops, convert RaggedTensorValue -> RaggedTensor.
PiperOrigin-RevId: 293807568 Change-Id: I839e2d510e84253d2d9a4a5325848d8471a2f74b
This commit is contained in:
parent
99ec314b06
commit
bb275635d1
@ -131,6 +131,10 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
||||
elif not _is_convertible_to_tensor(elt):
|
||||
return self.NOT_SUPPORTED
|
||||
if found_ragged:
|
||||
x = [
|
||||
ragged_tensor.convert_to_tensor_or_ragged_tensor(elt)
|
||||
if ragged_tensor.is_ragged(elt) else elt for elt in x
|
||||
]
|
||||
x = ragged_tensor.match_row_splits_dtypes(*x)
|
||||
nested_splits_lists = [
|
||||
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
|
||||
@ -149,6 +153,7 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
||||
else:
|
||||
found_ragged = ragged_tensor.is_ragged(x)
|
||||
if found_ragged:
|
||||
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name=self._x)
|
||||
mapped_values = self._original_op(x.flat_values, *args, **kwargs)
|
||||
return x.with_flat_values(mapped_values)
|
||||
else:
|
||||
@ -196,10 +201,10 @@ class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
||||
|
||||
# Convert args to tensors. Bail if conversion fails.
|
||||
try:
|
||||
if not x_is_ragged:
|
||||
x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
|
||||
if not y_is_ragged:
|
||||
y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
|
||||
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
x, name=self._x, preferred_dtype=(y.dtype if y_is_ragged else None))
|
||||
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
y, name=self._y, preferred_dtype=(x.dtype if x_is_ragged else None))
|
||||
except (TypeError, ValueError):
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
|
@ -139,14 +139,15 @@ BINARY_INT_OPS = [
|
||||
]
|
||||
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def assertSameShape(self, x, y):
|
||||
"""Checks that x and y have the same shape (including ragged shapes)."""
|
||||
if isinstance(x, ragged_tensor.RaggedTensor):
|
||||
self.assertIsInstance(y, ragged_tensor.RaggedTensor)
|
||||
if ragged_tensor.is_ragged(x):
|
||||
self.assertTrue(ragged_tensor.is_ragged(y))
|
||||
self.assertEqual(x.ragged_rank, y.ragged_rank)
|
||||
for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits):
|
||||
self.assertAllEqual(x_splits, y_splits)
|
||||
@ -234,18 +235,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
]
|
||||
) # pyformat: disable
|
||||
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
|
||||
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
|
||||
result = op(x, **extra_args)
|
||||
|
||||
# Run the wrapped op on the dense values, for comparison.
|
||||
dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
|
||||
dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
|
||||
expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
|
||||
|
||||
# Check that the result has the expected shape.
|
||||
self.assertSameShape(x, result)
|
||||
|
||||
# Check that the result has the expected (flattened) values.
|
||||
if isinstance(result, ragged_tensor.RaggedTensor):
|
||||
if ragged_tensor.is_ragged(result):
|
||||
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
||||
else:
|
||||
result_flat_values = array_ops.reshape(result, [-1])
|
||||
@ -350,8 +350,6 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
) # pyformat: disable
|
||||
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
|
||||
use_kwargs = extra_args.pop('use_kwargs', ())
|
||||
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
|
||||
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y)
|
||||
if 'x' in use_kwargs and 'y' in use_kwargs:
|
||||
result = op(x=x, y=y, **extra_args)
|
||||
elif 'y' in use_kwargs:
|
||||
@ -360,8 +358,8 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
result = op(x, y, **extra_args)
|
||||
|
||||
# Run the wrapped op on the dense values, for comparison.
|
||||
dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
|
||||
dense_y = y.flat_values if isinstance(y, ragged_tensor.RaggedTensor) else y
|
||||
dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
|
||||
dense_y = y.flat_values if ragged_tensor.is_ragged(y) else y
|
||||
expected_flat_values = array_ops.reshape(
|
||||
op(dense_x, dense_y, **extra_args), [-1])
|
||||
|
||||
@ -369,7 +367,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
self.assertSameShape(y, result)
|
||||
|
||||
# Check that the result has the expected (flattened) values.
|
||||
if isinstance(result, ragged_tensor.RaggedTensor):
|
||||
if ragged_tensor.is_ragged(result):
|
||||
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
||||
else:
|
||||
result_flat_values = array_ops.reshape(result, [-1])
|
||||
@ -415,9 +413,6 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
|
||||
**extra_args):
|
||||
use_kwargs = extra_args.pop('use_kwargs', False)
|
||||
inputs = [
|
||||
ragged_tensor.convert_to_tensor_or_ragged_tensor(x) for x in inputs
|
||||
]
|
||||
if use_kwargs:
|
||||
result = op(inputs=inputs, **extra_args)
|
||||
else:
|
||||
@ -425,8 +420,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
|
||||
# Run the wrapped op on the dense values, for comparison.
|
||||
dense_inputs = [
|
||||
x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
|
||||
for x in inputs
|
||||
x.flat_values if ragged_tensor.is_ragged(x) else x for x in inputs
|
||||
]
|
||||
expected_flat_values = array_ops.reshape(
|
||||
op(dense_inputs, **extra_args), [-1])
|
||||
@ -435,7 +429,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
self.assertSameShape(inputs[0], result)
|
||||
|
||||
# Check that the result has the expected (flattened) values.
|
||||
if isinstance(result, ragged_tensor.RaggedTensor):
|
||||
if ragged_tensor.is_ragged(result):
|
||||
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
||||
else:
|
||||
result_flat_values = array_ops.reshape(result, [-1])
|
||||
|
Loading…
x
Reference in New Issue
Block a user