Improve error message when tf.ragged.map_flat_values is called with a function that doesn't preserve the outer dimension size of ragged inputs.
PiperOrigin-RevId: 325887317 Change-Id: Ibdc85269a9ff8b842844b1f22c4bd005e554bf52
This commit is contained in:
parent
642360f24e
commit
9b4f994681
tensorflow/python/ops/ragged
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_config
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -70,10 +71,22 @@ def map_flat_values(op, *args, **kwargs):
|
||||
# Replace RaggedTensors with their values; and collect the splits tensors
|
||||
# from each RaggedTensor.
|
||||
nested_splits_lists = []
|
||||
inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists)
|
||||
inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists)
|
||||
flat_values_nrows = []
|
||||
inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists,
|
||||
flat_values_nrows)
|
||||
inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists,
|
||||
flat_values_nrows)
|
||||
if not nested_splits_lists:
|
||||
return op(*args, **kwargs)
|
||||
if flat_values_nrows:
|
||||
flat_values_nrows = set(flat_values_nrows)
|
||||
if len(flat_values_nrows) != 1:
|
||||
raise ValueError("Input RaggedTensors' flat_values must all have the "
|
||||
"same outer-dimension size. Got sizes: %s" %
|
||||
flat_values_nrows)
|
||||
flat_values_nrows = flat_values_nrows.pop() # Get the single element
|
||||
else:
|
||||
flat_values_nrows = None
|
||||
|
||||
split_dtypes = set(splits[0].dtype for splits in nested_splits_lists)
|
||||
if len(split_dtypes) > 1:
|
||||
@ -88,13 +101,23 @@ def map_flat_values(op, *args, **kwargs):
|
||||
|
||||
with ops.control_dependencies(
|
||||
ragged_util.assert_splits_match(nested_splits_lists)):
|
||||
# Delegate to op, and then compose the result from the transformed values
|
||||
# and the splits.
|
||||
# Delegate to `op`
|
||||
op_output = op(*inner_args, **inner_kwargs)
|
||||
# Check that the result has the expected shape (if known).
|
||||
if flat_values_nrows is not None:
|
||||
if not op_output.shape[:1].is_compatible_with([flat_values_nrows]):
|
||||
raise ValueError(
|
||||
"tf.ragged.map_flat_values requires that the output of `op` have "
|
||||
"the same outer-dimension size as flat_values of any ragged "
|
||||
"inputs. (output shape: %s; expected outer dimension size: %s)" %
|
||||
(op_output.shape, flat_values_nrows))
|
||||
# Compose the result from the transformed values and the splits.
|
||||
return ragged_tensor.RaggedTensor.from_nested_row_splits(
|
||||
op(*inner_args, **inner_kwargs), nested_splits_lists[0], validate=False)
|
||||
op_output, nested_splits_lists[0], validate=False)
|
||||
|
||||
|
||||
def _replace_ragged_with_flat_values(value, nested_splits_lists):
|
||||
def _replace_ragged_with_flat_values(value, nested_splits_lists,
|
||||
flat_values_nrows):
|
||||
"""Replace RaggedTensors with their flat_values, and record their splits.
|
||||
|
||||
Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their
|
||||
@ -106,6 +129,9 @@ def _replace_ragged_with_flat_values(value, nested_splits_lists):
|
||||
value: The value that should be transformed by replacing `RaggedTensors`.
|
||||
nested_splits_lists: An output parameter used to record the `nested_splits`
|
||||
for any `RaggedTensors` that were replaced.
|
||||
flat_values_nrows: An output parameter used to record the outer dimension
|
||||
size for each replacement `flat_values` (when known). Contains a list of
|
||||
int.
|
||||
|
||||
Returns:
|
||||
A copy of `value` with nested `RaggedTensors` replaced by their `values`.
|
||||
@ -114,11 +140,15 @@ def _replace_ragged_with_flat_values(value, nested_splits_lists):
|
||||
if ragged_tensor.is_ragged(value):
|
||||
value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
|
||||
nested_splits_lists.append(value.nested_row_splits)
|
||||
nrows = tensor_shape.dimension_at_index(value.flat_values.shape, 0).value
|
||||
if nrows is not None:
|
||||
flat_values_nrows.append(nrows)
|
||||
return value.flat_values
|
||||
|
||||
# Recursion cases
|
||||
def recurse(v):
|
||||
return _replace_ragged_with_flat_values(v, nested_splits_lists)
|
||||
return _replace_ragged_with_flat_values(v, nested_splits_lists,
|
||||
flat_values_nrows)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [recurse(v) for v in value]
|
||||
|
@ -178,18 +178,33 @@ class RaggedMapInnerValuesOpTest(test_util.TensorFlowTestCase):
|
||||
def testRaggedTensorSplitsRaggedRankMismatchError(self):
|
||||
x = ragged_factory_ops.constant([[3, 1, 4], [], [1, 5]])
|
||||
y = ragged_factory_ops.constant([[[3, 1, 4], []], [], [[1, 5]]])
|
||||
self.assertRaisesRegex(ValueError,
|
||||
r'Inputs must have identical ragged splits.*',
|
||||
ragged_functional_ops.map_flat_values, math_ops.add,
|
||||
x, y)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'Inputs must have identical ragged splits.*'):
|
||||
ragged_functional_ops.map_flat_values(math_ops.add, x, y)
|
||||
|
||||
def testRaggedTensorSplitsValueMismatchError(self):
|
||||
x = ragged_factory_ops.constant([[3, 1, 4], [], [1, 5]])
|
||||
y = ragged_factory_ops.constant([[1], [2, 3], [4, 5]])
|
||||
self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
r'Inputs must have identical ragged splits.*',
|
||||
ragged_functional_ops.map_flat_values, math_ops.add,
|
||||
x, y)
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
r'Inputs must have identical ragged splits.*'):
|
||||
ragged_functional_ops.map_flat_values(math_ops.add, x, y)
|
||||
|
||||
z_splits = array_ops.placeholder_with_default(
|
||||
constant_op.constant([0, 3], dtypes.int64), None)
|
||||
z = ragged_tensor.RaggedTensor.from_row_splits([0, 1, 2], z_splits)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Input RaggedTensors' flat_values must all have the same "
|
||||
r'outer-dimension size. Got sizes: \{3, 5\}'):
|
||||
ragged_functional_ops.map_flat_values(math_ops.add, x, z)
|
||||
|
||||
def testRaggedTensorShapeMismatchError(self):
|
||||
x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'tf.ragged.map_flat_values requires that the output of '
|
||||
'`op` have the same outer-dimension size as flat_values of any ragged '
|
||||
r'inputs. \(output shape: \(\); expected outer dimension size: 5\)'):
|
||||
ragged_functional_ops.map_flat_values(math_ops.argmax, x)
|
||||
|
||||
def testRaggedTensorSplitsMismatchErrorAtRuntime(self):
|
||||
splits1 = array_ops.placeholder_with_default(
|
||||
|
Loading…
Reference in New Issue
Block a user