In RaggedTensor dispatchers for elementwise ops, convert RaggedTensorValue -> RaggedTensor.
PiperOrigin-RevId: 293807568 Change-Id: I839e2d510e84253d2d9a4a5325848d8471a2f74b
This commit is contained in:
parent
99ec314b06
commit
bb275635d1
@ -131,6 +131,10 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
|||||||
elif not _is_convertible_to_tensor(elt):
|
elif not _is_convertible_to_tensor(elt):
|
||||||
return self.NOT_SUPPORTED
|
return self.NOT_SUPPORTED
|
||||||
if found_ragged:
|
if found_ragged:
|
||||||
|
x = [
|
||||||
|
ragged_tensor.convert_to_tensor_or_ragged_tensor(elt)
|
||||||
|
if ragged_tensor.is_ragged(elt) else elt for elt in x
|
||||||
|
]
|
||||||
x = ragged_tensor.match_row_splits_dtypes(*x)
|
x = ragged_tensor.match_row_splits_dtypes(*x)
|
||||||
nested_splits_lists = [
|
nested_splits_lists = [
|
||||||
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
|
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
|
||||||
@ -149,6 +153,7 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
|||||||
else:
|
else:
|
||||||
found_ragged = ragged_tensor.is_ragged(x)
|
found_ragged = ragged_tensor.is_ragged(x)
|
||||||
if found_ragged:
|
if found_ragged:
|
||||||
|
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name=self._x)
|
||||||
mapped_values = self._original_op(x.flat_values, *args, **kwargs)
|
mapped_values = self._original_op(x.flat_values, *args, **kwargs)
|
||||||
return x.with_flat_values(mapped_values)
|
return x.with_flat_values(mapped_values)
|
||||||
else:
|
else:
|
||||||
@ -196,10 +201,10 @@ class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
|||||||
|
|
||||||
# Convert args to tensors. Bail if conversion fails.
|
# Convert args to tensors. Bail if conversion fails.
|
||||||
try:
|
try:
|
||||||
if not x_is_ragged:
|
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||||
x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
|
x, name=self._x, preferred_dtype=(y.dtype if y_is_ragged else None))
|
||||||
if not y_is_ragged:
|
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||||
y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
|
y, name=self._y, preferred_dtype=(x.dtype if x_is_ragged else None))
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return self.NOT_SUPPORTED
|
return self.NOT_SUPPORTED
|
||||||
|
|
||||||
|
@ -139,14 +139,15 @@ BINARY_INT_OPS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=g-complex-comprehension
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
|
||||||
def assertSameShape(self, x, y):
|
def assertSameShape(self, x, y):
|
||||||
"""Checks that x and y have the same shape (including ragged shapes)."""
|
"""Checks that x and y have the same shape (including ragged shapes)."""
|
||||||
if isinstance(x, ragged_tensor.RaggedTensor):
|
if ragged_tensor.is_ragged(x):
|
||||||
self.assertIsInstance(y, ragged_tensor.RaggedTensor)
|
self.assertTrue(ragged_tensor.is_ragged(y))
|
||||||
self.assertEqual(x.ragged_rank, y.ragged_rank)
|
self.assertEqual(x.ragged_rank, y.ragged_rank)
|
||||||
for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits):
|
for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits):
|
||||||
self.assertAllEqual(x_splits, y_splits)
|
self.assertAllEqual(x_splits, y_splits)
|
||||||
@ -234,18 +235,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
]
|
]
|
||||||
) # pyformat: disable
|
) # pyformat: disable
|
||||||
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
|
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
|
||||||
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
|
|
||||||
result = op(x, **extra_args)
|
result = op(x, **extra_args)
|
||||||
|
|
||||||
# Run the wrapped op on the dense values, for comparison.
|
# Run the wrapped op on the dense values, for comparison.
|
||||||
dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
|
dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
|
||||||
expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
|
expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
|
||||||
|
|
||||||
# Check that the result has the expected shape.
|
# Check that the result has the expected shape.
|
||||||
self.assertSameShape(x, result)
|
self.assertSameShape(x, result)
|
||||||
|
|
||||||
# Check that the result has the expected (flattened) values.
|
# Check that the result has the expected (flattened) values.
|
||||||
if isinstance(result, ragged_tensor.RaggedTensor):
|
if ragged_tensor.is_ragged(result):
|
||||||
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
||||||
else:
|
else:
|
||||||
result_flat_values = array_ops.reshape(result, [-1])
|
result_flat_values = array_ops.reshape(result, [-1])
|
||||||
@ -350,8 +350,6 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
) # pyformat: disable
|
) # pyformat: disable
|
||||||
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
|
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
|
||||||
use_kwargs = extra_args.pop('use_kwargs', ())
|
use_kwargs = extra_args.pop('use_kwargs', ())
|
||||||
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
|
|
||||||
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y)
|
|
||||||
if 'x' in use_kwargs and 'y' in use_kwargs:
|
if 'x' in use_kwargs and 'y' in use_kwargs:
|
||||||
result = op(x=x, y=y, **extra_args)
|
result = op(x=x, y=y, **extra_args)
|
||||||
elif 'y' in use_kwargs:
|
elif 'y' in use_kwargs:
|
||||||
@ -360,8 +358,8 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
result = op(x, y, **extra_args)
|
result = op(x, y, **extra_args)
|
||||||
|
|
||||||
# Run the wrapped op on the dense values, for comparison.
|
# Run the wrapped op on the dense values, for comparison.
|
||||||
dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
|
dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
|
||||||
dense_y = y.flat_values if isinstance(y, ragged_tensor.RaggedTensor) else y
|
dense_y = y.flat_values if ragged_tensor.is_ragged(y) else y
|
||||||
expected_flat_values = array_ops.reshape(
|
expected_flat_values = array_ops.reshape(
|
||||||
op(dense_x, dense_y, **extra_args), [-1])
|
op(dense_x, dense_y, **extra_args), [-1])
|
||||||
|
|
||||||
@ -369,7 +367,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
self.assertSameShape(y, result)
|
self.assertSameShape(y, result)
|
||||||
|
|
||||||
# Check that the result has the expected (flattened) values.
|
# Check that the result has the expected (flattened) values.
|
||||||
if isinstance(result, ragged_tensor.RaggedTensor):
|
if ragged_tensor.is_ragged(result):
|
||||||
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
||||||
else:
|
else:
|
||||||
result_flat_values = array_ops.reshape(result, [-1])
|
result_flat_values = array_ops.reshape(result, [-1])
|
||||||
@ -415,9 +413,6 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
|
def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
|
||||||
**extra_args):
|
**extra_args):
|
||||||
use_kwargs = extra_args.pop('use_kwargs', False)
|
use_kwargs = extra_args.pop('use_kwargs', False)
|
||||||
inputs = [
|
|
||||||
ragged_tensor.convert_to_tensor_or_ragged_tensor(x) for x in inputs
|
|
||||||
]
|
|
||||||
if use_kwargs:
|
if use_kwargs:
|
||||||
result = op(inputs=inputs, **extra_args)
|
result = op(inputs=inputs, **extra_args)
|
||||||
else:
|
else:
|
||||||
@ -425,8 +420,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
|
|
||||||
# Run the wrapped op on the dense values, for comparison.
|
# Run the wrapped op on the dense values, for comparison.
|
||||||
dense_inputs = [
|
dense_inputs = [
|
||||||
x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
|
x.flat_values if ragged_tensor.is_ragged(x) else x for x in inputs
|
||||||
for x in inputs
|
|
||||||
]
|
]
|
||||||
expected_flat_values = array_ops.reshape(
|
expected_flat_values = array_ops.reshape(
|
||||||
op(dense_inputs, **extra_args), [-1])
|
op(dense_inputs, **extra_args), [-1])
|
||||||
@ -435,7 +429,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
self.assertSameShape(inputs[0], result)
|
self.assertSameShape(inputs[0], result)
|
||||||
|
|
||||||
# Check that the result has the expected (flattened) values.
|
# Check that the result has the expected (flattened) values.
|
||||||
if isinstance(result, ragged_tensor.RaggedTensor):
|
if ragged_tensor.is_ragged(result):
|
||||||
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
||||||
else:
|
else:
|
||||||
result_flat_values = array_ops.reshape(result, [-1])
|
result_flat_values = array_ops.reshape(result, [-1])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user