Improve error message when tf.ragged.map_flat_values is called with a function that doesn't preserve the outer dimension size of ragged inputs.

PiperOrigin-RevId: 325887317
Change-Id: Ibdc85269a9ff8b842844b1f22c4bd005e554bf52
This commit is contained in:
Edward Loper 2020-08-10 14:25:56 -07:00 committed by TensorFlower Gardener
parent 642360f24e
commit 9b4f994681
2 changed files with 60 additions and 15 deletions

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_config
from tensorflow.python.ops.ragged import ragged_tensor
@ -70,10 +71,22 @@ def map_flat_values(op, *args, **kwargs):
# Replace RaggedTensors with their values; and collect the splits tensors
# from each RaggedTensor.
nested_splits_lists = []
inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists)
inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists)
flat_values_nrows = []
inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists,
flat_values_nrows)
inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists,
flat_values_nrows)
if not nested_splits_lists:
return op(*args, **kwargs)
if flat_values_nrows:
flat_values_nrows = set(flat_values_nrows)
if len(flat_values_nrows) != 1:
raise ValueError("Input RaggedTensors' flat_values must all have the "
"same outer-dimension size. Got sizes: %s" %
flat_values_nrows)
flat_values_nrows = flat_values_nrows.pop() # Get the single element
else:
flat_values_nrows = None
split_dtypes = set(splits[0].dtype for splits in nested_splits_lists)
if len(split_dtypes) > 1:
@ -88,13 +101,23 @@ def map_flat_values(op, *args, **kwargs):
with ops.control_dependencies(
ragged_util.assert_splits_match(nested_splits_lists)):
# Delegate to op, and then compose the result from the transformed values
# and the splits.
# Delegate to `op`
op_output = op(*inner_args, **inner_kwargs)
# Check that the result has the expected shape (if known).
if flat_values_nrows is not None:
if not op_output.shape[:1].is_compatible_with([flat_values_nrows]):
raise ValueError(
"tf.ragged.map_flat_values requires that the output of `op` have "
"the same outer-dimension size as flat_values of any ragged "
"inputs. (output shape: %s; expected outer dimension size: %s)" %
(op_output.shape, flat_values_nrows))
# Compose the result from the transformed values and the splits.
return ragged_tensor.RaggedTensor.from_nested_row_splits(
op(*inner_args, **inner_kwargs), nested_splits_lists[0], validate=False)
op_output, nested_splits_lists[0], validate=False)
def _replace_ragged_with_flat_values(value, nested_splits_lists):
def _replace_ragged_with_flat_values(value, nested_splits_lists,
flat_values_nrows):
"""Replace RaggedTensors with their flat_values, and record their splits.
Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their
@ -106,6 +129,9 @@ def _replace_ragged_with_flat_values(value, nested_splits_lists):
value: The value that should be transformed by replacing `RaggedTensors`.
nested_splits_lists: An output parameter used to record the `nested_splits`
for any `RaggedTensors` that were replaced.
flat_values_nrows: An output parameter used to record the outer dimension
size for each replacement `flat_values` (when known). Contains a list of
int.
Returns:
A copy of `value` with nested `RaggedTensors` replaced by their `values`.
@ -114,11 +140,15 @@ def _replace_ragged_with_flat_values(value, nested_splits_lists):
if ragged_tensor.is_ragged(value):
value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
nested_splits_lists.append(value.nested_row_splits)
nrows = tensor_shape.dimension_at_index(value.flat_values.shape, 0).value
if nrows is not None:
flat_values_nrows.append(nrows)
return value.flat_values
# Recursion cases
def recurse(v):
return _replace_ragged_with_flat_values(v, nested_splits_lists)
return _replace_ragged_with_flat_values(v, nested_splits_lists,
flat_values_nrows)
if isinstance(value, list):
return [recurse(v) for v in value]

View File

@ -178,18 +178,33 @@ class RaggedMapInnerValuesOpTest(test_util.TensorFlowTestCase):
def testRaggedTensorSplitsRaggedRankMismatchError(self):
x = ragged_factory_ops.constant([[3, 1, 4], [], [1, 5]])
y = ragged_factory_ops.constant([[[3, 1, 4], []], [], [[1, 5]]])
self.assertRaisesRegex(ValueError,
r'Inputs must have identical ragged splits.*',
ragged_functional_ops.map_flat_values, math_ops.add,
x, y)
with self.assertRaisesRegex(ValueError,
r'Inputs must have identical ragged splits.*'):
ragged_functional_ops.map_flat_values(math_ops.add, x, y)
def testRaggedTensorSplitsValueMismatchError(self):
x = ragged_factory_ops.constant([[3, 1, 4], [], [1, 5]])
y = ragged_factory_ops.constant([[1], [2, 3], [4, 5]])
self.assertRaisesRegex(errors.InvalidArgumentError,
r'Inputs must have identical ragged splits.*',
ragged_functional_ops.map_flat_values, math_ops.add,
x, y)
with self.assertRaisesRegex(errors.InvalidArgumentError,
r'Inputs must have identical ragged splits.*'):
ragged_functional_ops.map_flat_values(math_ops.add, x, y)
z_splits = array_ops.placeholder_with_default(
constant_op.constant([0, 3], dtypes.int64), None)
z = ragged_tensor.RaggedTensor.from_row_splits([0, 1, 2], z_splits)
with self.assertRaisesRegex(
ValueError,
r"Input RaggedTensors' flat_values must all have the same "
r'outer-dimension size. Got sizes: \{3, 5\}'):
ragged_functional_ops.map_flat_values(math_ops.add, x, z)
def testRaggedTensorShapeMismatchError(self):
x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
with self.assertRaisesRegex(
ValueError, r'tf.ragged.map_flat_values requires that the output of '
'`op` have the same outer-dimension size as flat_values of any ragged '
r'inputs. \(output shape: \(\); expected outer dimension size: 5\)'):
ragged_functional_ops.map_flat_values(math_ops.argmax, x)
def testRaggedTensorSplitsMismatchErrorAtRuntime(self):
splits1 = array_ops.placeholder_with_default(