Update tf.expand_dims to always insert the new dimension as a non-ragged dimension (making use of uniform_row_length where necessary).

Also update RaggedTensor.nrows() to return a constant value when possible (which makes it easier to preserve static shape information).

PiperOrigin-RevId: 294426522
Change-Id: I20ed89a8bbb0969ebc5801d22ec08dba89145342
This commit is contained in:
Edward Loper 2020-02-11 06:33:29 -08:00 committed by TensorFlower Gardener
parent f752cf8155
commit ca8cf9f8ba
4 changed files with 24 additions and 32 deletions

View File

@ -384,15 +384,6 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
Given a potentially ragged tenor `input`, this operation inserts a Given a potentially ragged tenor `input`, this operation inserts a
dimension with size 1 at the dimension `axis` of `input`'s shape. dimension with size 1 at the dimension `axis` of `input`'s shape.
* If `input` is a `Tensor`, then this is equivalent to
`tf.expand_dims`.
* If `input` is ragged, and `axis=0`, then the new dimension will be
uniform; but the previously outermost dimension will become ragged.
* If `input` is ragged, and `0 < axis < input.ragged_rank`, then the
new dimension will be ragged.
* If `input` is ragged, and axis >= input.ragged_rank`, then the new
dimension will be uniform.
The following table gives some examples showing how `ragged.expand_dims` The following table gives some examples showing how `ragged.expand_dims`
impacts the shapes of different input tensors. Ragged dimensions are impacts the shapes of different input tensors. Ragged dimensions are
indicated by enclosing them in parentheses. indicated by enclosing them in parentheses.
@ -402,9 +393,9 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
`[D1, D2]` | `0` | `[1, D1, D2]` `[D1, D2]` | `0` | `[1, D1, D2]`
`[D1, D2]` | `1` | `[D1, 1, D2]` `[D1, D2]` | `1` | `[D1, 1, D2]`
`[D1, D2]` | `2` | `[D1, D2, 1]` `[D1, D2]` | `2` | `[D1, D2, 1]`
`[D1, (D2), (D3), D4]` | `0` | `[1, (D1), (D2), (D3), D4]` `[D1, (D2), (D3), D4]` | `0` | `[1, D1, (D2), (D3), D4]`
`[D1, (D2), (D3), D4]` | `1` | `[D1, (1), (D2), (D3), D4]` `[D1, (D2), (D3), D4]` | `1` | `[D1, 1, (D2), (D3), D4]`
`[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), (1), (D3), D4]` `[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), 1, (D3), D4]`
`[D1, (D2), (D3), D4]` | `3` | `[D1, (D2), (D3), 1, D4]` `[D1, (D2), (D3), D4]` | `3` | `[D1, (D2), (D3), 1, D4]`
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]` `[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
@ -427,11 +418,11 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
>>> expanded = tf.expand_dims(rt, axis=0) >>> expanded = tf.expand_dims(rt, axis=0)
>>> print(expanded.shape, expanded) >>> print(expanded.shape, expanded)
(1, None, None) <tf.RaggedTensor [[[1, 2], [3]]]> (1, 2, None) <tf.RaggedTensor [[[1, 2], [3]]]>
>>> expanded = tf.expand_dims(rt, axis=1) >>> expanded = tf.expand_dims(rt, axis=1)
>>> print(expanded.shape, expanded) >>> print(expanded.shape, expanded)
(2, None, None) <tf.RaggedTensor [[[1, 2]], [[3]]]> (2, 1, None) <tf.RaggedTensor [[[1, 2]], [[3]]]>
>>> expanded = tf.expand_dims(rt, axis=2) >>> expanded = tf.expand_dims(rt, axis=2)
>>> print(expanded.shape, expanded) >>> print(expanded.shape, expanded)
@ -446,18 +437,15 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
ndims = None if input.shape.ndims is None else input.shape.ndims + 1 ndims = None if input.shape.ndims is None else input.shape.ndims + 1
axis = ragged_util.get_positive_axis(axis, ndims) axis = ragged_util.get_positive_axis(axis, ndims)
if axis == 0:
values = input
splits = array_ops.stack([0, input.nrows()])
elif axis == 1:
values = input
splits = math_ops.range(input.nrows() + 1)
else:
values = expand_dims(input.values, axis - 1)
splits = input.row_splits
return ragged_tensor.RaggedTensor.from_row_splits(values, splits, if axis == 0:
validate=False) return ragged_tensor.RaggedTensor.from_uniform_row_length(
input, uniform_row_length=input.nrows(), nrows=1, validate=False)
elif axis == 1:
return ragged_tensor.RaggedTensor.from_uniform_row_length(
input, uniform_row_length=1, nrows=input.nrows(), validate=False)
else:
return input.with_values(expand_dims(input.values, axis - 1))
#=============================================================================== #===============================================================================

View File

@ -470,7 +470,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
def testElementwiseOpShapeMismatch(self): def testElementwiseOpShapeMismatch(self):
x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]]) x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
y = ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]]) y = ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises((ValueError, errors.InvalidArgumentError)):
self.evaluate(math_ops.add(x, y)) self.evaluate(math_ops.add(x, y))
def testBinaryOpSparseAndRagged(self): def testBinaryOpSparseAndRagged(self):

View File

@ -50,11 +50,11 @@ class RaggedExpandDimsOpTest(test_util.TensorFlowTestCase,
dict(rt_input=[[1, 2], [3]], dict(rt_input=[[1, 2], [3]],
axis=0, axis=0,
expected=[[[1, 2], [3]]], expected=[[[1, 2], [3]]],
expected_shape=[1, None, None]), expected_shape=[1, 2, None]),
dict(rt_input=[[1, 2], [3]], dict(rt_input=[[1, 2], [3]],
axis=1, axis=1,
expected=[[[1, 2]], [[3]]], expected=[[[1, 2]], [[3]]],
expected_shape=[2, None, None]), expected_shape=[2, 1, None]),
dict(rt_input=[[1, 2], [3]], dict(rt_input=[[1, 2], [3]],
axis=2, axis=2,
expected=[[[1], [2]], [[3]]], expected=[[[1], [2]], [[3]]],
@ -85,17 +85,17 @@ class RaggedExpandDimsOpTest(test_util.TensorFlowTestCase,
ragged_rank=2, ragged_rank=2,
axis=0, axis=0,
expected=EXAMPLE4D_EXPAND_AXIS[0], expected=EXAMPLE4D_EXPAND_AXIS[0],
expected_shape=[1, None, None, None, 2]), expected_shape=[1, 3, None, None, 2]),
dict(rt_input=EXAMPLE4D, dict(rt_input=EXAMPLE4D,
ragged_rank=2, ragged_rank=2,
axis=1, axis=1,
expected=EXAMPLE4D_EXPAND_AXIS[1], expected=EXAMPLE4D_EXPAND_AXIS[1],
expected_shape=[3, None, None, None, 2]), expected_shape=[3, 1, None, None, 2]),
dict(rt_input=EXAMPLE4D, dict(rt_input=EXAMPLE4D,
ragged_rank=2, ragged_rank=2,
axis=2, axis=2,
expected=EXAMPLE4D_EXPAND_AXIS[2], expected=EXAMPLE4D_EXPAND_AXIS[2],
expected_shape=[3, None, None, None, 2]), expected_shape=[3, None, 1, None, 2]),
dict(rt_input=EXAMPLE4D, dict(rt_input=EXAMPLE4D,
ragged_rank=2, ragged_rank=2,
axis=3, axis=3,

View File

@ -1224,7 +1224,11 @@ class RaggedTensor(composite_tensor.CompositeTensor):
if self._cached_nrows is not None: if self._cached_nrows is not None:
return math_ops.cast(self._cached_nrows, out_type) return math_ops.cast(self._cached_nrows, out_type)
with ops.name_scope(name, "RaggedNRows", [self]): with ops.name_scope(name, "RaggedNRows", [self]):
return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1 nsplits = tensor_shape.dimension_at_index(self.row_splits.shape, 0)
if nsplits.value is None:
return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
else:
return constant_op.constant(nsplits.value - 1, dtype=out_type)
def row_starts(self, name=None): def row_starts(self, name=None):
"""Returns the start indices for rows in this ragged tensor. """Returns the start indices for rows in this ragged tensor.