diff --git a/tensorflow/python/ops/ragged/ragged_dispatch.py b/tensorflow/python/ops/ragged/ragged_dispatch.py index 7d024e53299..5b0f19358fa 100644 --- a/tensorflow/python/ops/ragged/ragged_dispatch.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch.py @@ -131,6 +131,10 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): elif not _is_convertible_to_tensor(elt): return self.NOT_SUPPORTED if found_ragged: + x = [ + ragged_tensor.convert_to_tensor_or_ragged_tensor(elt) + if ragged_tensor.is_ragged(elt) else elt for elt in x + ] x = ragged_tensor.match_row_splits_dtypes(*x) nested_splits_lists = [ elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt) @@ -149,6 +153,7 @@ class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): else: found_ragged = ragged_tensor.is_ragged(x) if found_ragged: + x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name=self._x) mapped_values = self._original_op(x.flat_values, *args, **kwargs) return x.with_flat_values(mapped_values) else: @@ -196,10 +201,10 @@ class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): # Convert args to tensors. Bail if conversion fails. try: - if not x_is_ragged: - x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype) - if not y_is_ragged: - y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype) + x = ragged_tensor.convert_to_tensor_or_ragged_tensor( + x, name=self._x, preferred_dtype=(y.dtype if y_is_ragged else None)) + y = ragged_tensor.convert_to_tensor_or_ragged_tensor( + y, name=self._y, preferred_dtype=(x.dtype if x_is_ragged else None)) except (TypeError, ValueError): return self.NOT_SUPPORTED diff --git a/tensorflow/python/ops/ragged/ragged_dispatch_test.py b/tensorflow/python/ops/ragged/ragged_dispatch_test.py index b7e15a40889..d3c2cd23fa4 100644 --- a/tensorflow/python/ops/ragged/ragged_dispatch_test.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch_test.py @@ -139,14 +139,15 @@ BINARY_INT_OPS = [ ] +# pylint: disable=g-complex-comprehension @test_util.run_all_in_graph_and_eager_modes class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): def assertSameShape(self, x, y): """Checks that x and y have the same shape (including ragged shapes).""" - if isinstance(x, ragged_tensor.RaggedTensor): - self.assertIsInstance(y, ragged_tensor.RaggedTensor) + if ragged_tensor.is_ragged(x): + self.assertTrue(ragged_tensor.is_ragged(y)) self.assertEqual(x.ragged_rank, y.ragged_rank) for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits): self.assertAllEqual(x_splits, y_splits) @@ -234,18 +235,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, ] ) # pyformat: disable def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args): - x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x) result = op(x, **extra_args) # Run the wrapped op on the dense values, for comparison. - dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x + dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1]) # Check that the result has the expected shape. self.assertSameShape(x, result) # Check that the result has the expected (flattened) values. - if isinstance(result, ragged_tensor.RaggedTensor): + if ragged_tensor.is_ragged(result): result_flat_values = array_ops.reshape(result.flat_values, [-1]) else: result_flat_values = array_ops.reshape(result, [-1]) @@ -350,8 +350,6 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, ) # pyformat: disable def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args): use_kwargs = extra_args.pop('use_kwargs', ()) - x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x) - y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y) if 'x' in use_kwargs and 'y' in use_kwargs: result = op(x=x, y=y, **extra_args) elif 'y' in use_kwargs: @@ -360,8 +358,8 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, result = op(x, y, **extra_args) # Run the wrapped op on the dense values, for comparison. - dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x - dense_y = y.flat_values if isinstance(y, ragged_tensor.RaggedTensor) else y + dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x + dense_y = y.flat_values if ragged_tensor.is_ragged(y) else y expected_flat_values = array_ops.reshape( op(dense_x, dense_y, **extra_args), [-1]) @@ -369,7 +367,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, self.assertSameShape(y, result) # Check that the result has the expected (flattened) values. - if isinstance(result, ragged_tensor.RaggedTensor): + if ragged_tensor.is_ragged(result): result_flat_values = array_ops.reshape(result.flat_values, [-1]) else: result_flat_values = array_ops.reshape(result, [-1]) @@ -415,9 +413,6 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n, **extra_args): use_kwargs = extra_args.pop('use_kwargs', False) - inputs = [ - ragged_tensor.convert_to_tensor_or_ragged_tensor(x) for x in inputs - ] if use_kwargs: result = op(inputs=inputs, **extra_args) else: @@ -425,8 +420,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, # Run the wrapped op on the dense values, for comparison. dense_inputs = [ - x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x - for x in inputs + x.flat_values if ragged_tensor.is_ragged(x) else x for x in inputs ] expected_flat_values = array_ops.reshape( op(dense_inputs, **extra_args), [-1]) @@ -435,7 +429,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, self.assertSameShape(inputs[0], result) # Check that the result has the expected (flattened) values. - if isinstance(result, ragged_tensor.RaggedTensor): + if ragged_tensor.is_ragged(result): result_flat_values = array_ops.reshape(result.flat_values, [-1]) else: result_flat_values = array_ops.reshape(result, [-1])