Update tf.expand_dims to always insert the new dimension as a non-ragged dimension (making use of uniform_row_length where necessary).
Also update RaggedTensor.nrows() to return a constant value when possible (which makes it easier to preserve static shape information). PiperOrigin-RevId: 294426522 Change-Id: I20ed89a8bbb0969ebc5801d22ec08dba89145342
This commit is contained in:
parent
f752cf8155
commit
ca8cf9f8ba
@ -384,15 +384,6 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
||||
Given a potentially ragged tenor `input`, this operation inserts a
|
||||
dimension with size 1 at the dimension `axis` of `input`'s shape.
|
||||
|
||||
* If `input` is a `Tensor`, then this is equivalent to
|
||||
`tf.expand_dims`.
|
||||
* If `input` is ragged, and `axis=0`, then the new dimension will be
|
||||
uniform; but the previously outermost dimension will become ragged.
|
||||
* If `input` is ragged, and `0 < axis < input.ragged_rank`, then the
|
||||
new dimension will be ragged.
|
||||
* If `input` is ragged, and axis >= input.ragged_rank`, then the new
|
||||
dimension will be uniform.
|
||||
|
||||
The following table gives some examples showing how `ragged.expand_dims`
|
||||
impacts the shapes of different input tensors. Ragged dimensions are
|
||||
indicated by enclosing them in parentheses.
|
||||
@ -402,9 +393,9 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
||||
`[D1, D2]` | `0` | `[1, D1, D2]`
|
||||
`[D1, D2]` | `1` | `[D1, 1, D2]`
|
||||
`[D1, D2]` | `2` | `[D1, D2, 1]`
|
||||
`[D1, (D2), (D3), D4]` | `0` | `[1, (D1), (D2), (D3), D4]`
|
||||
`[D1, (D2), (D3), D4]` | `1` | `[D1, (1), (D2), (D3), D4]`
|
||||
`[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), (1), (D3), D4]`
|
||||
`[D1, (D2), (D3), D4]` | `0` | `[1, D1, (D2), (D3), D4]`
|
||||
`[D1, (D2), (D3), D4]` | `1` | `[D1, 1, (D2), (D3), D4]`
|
||||
`[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), 1, (D3), D4]`
|
||||
`[D1, (D2), (D3), D4]` | `3` | `[D1, (D2), (D3), 1, D4]`
|
||||
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
|
||||
|
||||
@ -427,11 +418,11 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
||||
|
||||
>>> expanded = tf.expand_dims(rt, axis=0)
|
||||
>>> print(expanded.shape, expanded)
|
||||
(1, None, None) <tf.RaggedTensor [[[1, 2], [3]]]>
|
||||
(1, 2, None) <tf.RaggedTensor [[[1, 2], [3]]]>
|
||||
|
||||
>>> expanded = tf.expand_dims(rt, axis=1)
|
||||
>>> print(expanded.shape, expanded)
|
||||
(2, None, None) <tf.RaggedTensor [[[1, 2]], [[3]]]>
|
||||
(2, 1, None) <tf.RaggedTensor [[[1, 2]], [[3]]]>
|
||||
|
||||
>>> expanded = tf.expand_dims(rt, axis=2)
|
||||
>>> print(expanded.shape, expanded)
|
||||
@ -446,18 +437,15 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
||||
|
||||
ndims = None if input.shape.ndims is None else input.shape.ndims + 1
|
||||
axis = ragged_util.get_positive_axis(axis, ndims)
|
||||
if axis == 0:
|
||||
values = input
|
||||
splits = array_ops.stack([0, input.nrows()])
|
||||
elif axis == 1:
|
||||
values = input
|
||||
splits = math_ops.range(input.nrows() + 1)
|
||||
else:
|
||||
values = expand_dims(input.values, axis - 1)
|
||||
splits = input.row_splits
|
||||
|
||||
return ragged_tensor.RaggedTensor.from_row_splits(values, splits,
|
||||
validate=False)
|
||||
if axis == 0:
|
||||
return ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
input, uniform_row_length=input.nrows(), nrows=1, validate=False)
|
||||
elif axis == 1:
|
||||
return ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
input, uniform_row_length=1, nrows=input.nrows(), validate=False)
|
||||
else:
|
||||
return input.with_values(expand_dims(input.values, axis - 1))
|
||||
|
||||
|
||||
#===============================================================================
|
||||
|
@ -470,7 +470,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
def testElementwiseOpShapeMismatch(self):
|
||||
x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
|
||||
y = ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
self.evaluate(math_ops.add(x, y))
|
||||
|
||||
def testBinaryOpSparseAndRagged(self):
|
||||
|
@ -50,11 +50,11 @@ class RaggedExpandDimsOpTest(test_util.TensorFlowTestCase,
|
||||
dict(rt_input=[[1, 2], [3]],
|
||||
axis=0,
|
||||
expected=[[[1, 2], [3]]],
|
||||
expected_shape=[1, None, None]),
|
||||
expected_shape=[1, 2, None]),
|
||||
dict(rt_input=[[1, 2], [3]],
|
||||
axis=1,
|
||||
expected=[[[1, 2]], [[3]]],
|
||||
expected_shape=[2, None, None]),
|
||||
expected_shape=[2, 1, None]),
|
||||
dict(rt_input=[[1, 2], [3]],
|
||||
axis=2,
|
||||
expected=[[[1], [2]], [[3]]],
|
||||
@ -85,17 +85,17 @@ class RaggedExpandDimsOpTest(test_util.TensorFlowTestCase,
|
||||
ragged_rank=2,
|
||||
axis=0,
|
||||
expected=EXAMPLE4D_EXPAND_AXIS[0],
|
||||
expected_shape=[1, None, None, None, 2]),
|
||||
expected_shape=[1, 3, None, None, 2]),
|
||||
dict(rt_input=EXAMPLE4D,
|
||||
ragged_rank=2,
|
||||
axis=1,
|
||||
expected=EXAMPLE4D_EXPAND_AXIS[1],
|
||||
expected_shape=[3, None, None, None, 2]),
|
||||
expected_shape=[3, 1, None, None, 2]),
|
||||
dict(rt_input=EXAMPLE4D,
|
||||
ragged_rank=2,
|
||||
axis=2,
|
||||
expected=EXAMPLE4D_EXPAND_AXIS[2],
|
||||
expected_shape=[3, None, None, None, 2]),
|
||||
expected_shape=[3, None, 1, None, 2]),
|
||||
dict(rt_input=EXAMPLE4D,
|
||||
ragged_rank=2,
|
||||
axis=3,
|
||||
|
@ -1224,7 +1224,11 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
if self._cached_nrows is not None:
|
||||
return math_ops.cast(self._cached_nrows, out_type)
|
||||
with ops.name_scope(name, "RaggedNRows", [self]):
|
||||
return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
|
||||
nsplits = tensor_shape.dimension_at_index(self.row_splits.shape, 0)
|
||||
if nsplits.value is None:
|
||||
return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
|
||||
else:
|
||||
return constant_op.constant(nsplits.value - 1, dtype=out_type)
|
||||
|
||||
def row_starts(self, name=None):
|
||||
"""Returns the start indices for rows in this ragged tensor.
|
||||
|
Loading…
Reference in New Issue
Block a user