Update tf.expand_dims to always insert the new dimension as a non-ragged dimension (making use of uniform_row_length where necessary).

Also update RaggedTensor.nrows() to return a constant value when possible (which makes it easier to preserve static shape information).

PiperOrigin-RevId: 294426522
Change-Id: I20ed89a8bbb0969ebc5801d22ec08dba89145342
This commit is contained in:
Edward Loper 2020-02-11 06:33:29 -08:00 committed by TensorFlower Gardener
parent f752cf8155
commit ca8cf9f8ba
4 changed files with 24 additions and 32 deletions

View File

@ -384,15 +384,6 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
Given a potentially ragged tenor `input`, this operation inserts a
dimension with size 1 at the dimension `axis` of `input`'s shape.
* If `input` is a `Tensor`, then this is equivalent to
`tf.expand_dims`.
* If `input` is ragged, and `axis=0`, then the new dimension will be
uniform; but the previously outermost dimension will become ragged.
* If `input` is ragged, and `0 < axis < input.ragged_rank`, then the
new dimension will be ragged.
* If `input` is ragged, and axis >= input.ragged_rank`, then the new
dimension will be uniform.
The following table gives some examples showing how `ragged.expand_dims`
impacts the shapes of different input tensors. Ragged dimensions are
indicated by enclosing them in parentheses.
@ -402,9 +393,9 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
`[D1, D2]` | `0` | `[1, D1, D2]`
`[D1, D2]` | `1` | `[D1, 1, D2]`
`[D1, D2]` | `2` | `[D1, D2, 1]`
`[D1, (D2), (D3), D4]` | `0` | `[1, (D1), (D2), (D3), D4]`
`[D1, (D2), (D3), D4]` | `1` | `[D1, (1), (D2), (D3), D4]`
`[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), (1), (D3), D4]`
`[D1, (D2), (D3), D4]` | `0` | `[1, D1, (D2), (D3), D4]`
`[D1, (D2), (D3), D4]` | `1` | `[D1, 1, (D2), (D3), D4]`
`[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), 1, (D3), D4]`
`[D1, (D2), (D3), D4]` | `3` | `[D1, (D2), (D3), 1, D4]`
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
@ -427,11 +418,11 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
>>> expanded = tf.expand_dims(rt, axis=0)
>>> print(expanded.shape, expanded)
(1, None, None) <tf.RaggedTensor [[[1, 2], [3]]]>
(1, 2, None) <tf.RaggedTensor [[[1, 2], [3]]]>
>>> expanded = tf.expand_dims(rt, axis=1)
>>> print(expanded.shape, expanded)
(2, None, None) <tf.RaggedTensor [[[1, 2]], [[3]]]>
(2, 1, None) <tf.RaggedTensor [[[1, 2]], [[3]]]>
>>> expanded = tf.expand_dims(rt, axis=2)
>>> print(expanded.shape, expanded)
@ -446,18 +437,15 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
ndims = None if input.shape.ndims is None else input.shape.ndims + 1
axis = ragged_util.get_positive_axis(axis, ndims)
if axis == 0:
values = input
splits = array_ops.stack([0, input.nrows()])
elif axis == 1:
values = input
splits = math_ops.range(input.nrows() + 1)
else:
values = expand_dims(input.values, axis - 1)
splits = input.row_splits
return ragged_tensor.RaggedTensor.from_row_splits(values, splits,
validate=False)
if axis == 0:
return ragged_tensor.RaggedTensor.from_uniform_row_length(
input, uniform_row_length=input.nrows(), nrows=1, validate=False)
elif axis == 1:
return ragged_tensor.RaggedTensor.from_uniform_row_length(
input, uniform_row_length=1, nrows=input.nrows(), validate=False)
else:
return input.with_values(expand_dims(input.values, axis - 1))
#===============================================================================

View File

@ -470,7 +470,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
def testElementwiseOpShapeMismatch(self):
x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
y = ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])
with self.assertRaises(errors.InvalidArgumentError):
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
self.evaluate(math_ops.add(x, y))
def testBinaryOpSparseAndRagged(self):

View File

@ -50,11 +50,11 @@ class RaggedExpandDimsOpTest(test_util.TensorFlowTestCase,
dict(rt_input=[[1, 2], [3]],
axis=0,
expected=[[[1, 2], [3]]],
expected_shape=[1, None, None]),
expected_shape=[1, 2, None]),
dict(rt_input=[[1, 2], [3]],
axis=1,
expected=[[[1, 2]], [[3]]],
expected_shape=[2, None, None]),
expected_shape=[2, 1, None]),
dict(rt_input=[[1, 2], [3]],
axis=2,
expected=[[[1], [2]], [[3]]],
@ -85,17 +85,17 @@ class RaggedExpandDimsOpTest(test_util.TensorFlowTestCase,
ragged_rank=2,
axis=0,
expected=EXAMPLE4D_EXPAND_AXIS[0],
expected_shape=[1, None, None, None, 2]),
expected_shape=[1, 3, None, None, 2]),
dict(rt_input=EXAMPLE4D,
ragged_rank=2,
axis=1,
expected=EXAMPLE4D_EXPAND_AXIS[1],
expected_shape=[3, None, None, None, 2]),
expected_shape=[3, 1, None, None, 2]),
dict(rt_input=EXAMPLE4D,
ragged_rank=2,
axis=2,
expected=EXAMPLE4D_EXPAND_AXIS[2],
expected_shape=[3, None, None, None, 2]),
expected_shape=[3, None, 1, None, 2]),
dict(rt_input=EXAMPLE4D,
ragged_rank=2,
axis=3,

View File

@ -1224,7 +1224,11 @@ class RaggedTensor(composite_tensor.CompositeTensor):
if self._cached_nrows is not None:
return math_ops.cast(self._cached_nrows, out_type)
with ops.name_scope(name, "RaggedNRows", [self]):
return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
nsplits = tensor_shape.dimension_at_index(self.row_splits.shape, 0)
if nsplits.value is None:
return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
else:
return constant_op.constant(nsplits.value - 1, dtype=out_type)
def row_starts(self, name=None):
"""Returns the start indices for rows in this ragged tensor.