Add broadcasting support for RaggedTensors
RELNOTES: Broadcasting support for Ragged Tensors. PiperOrigin-RevId: 223373179
This commit is contained in:
parent
f1263f34f5
commit
a1f1abbe53
@ -33,6 +33,7 @@ py_library(
|
|||||||
":ragged_math_ops",
|
":ragged_math_ops",
|
||||||
":ragged_operators",
|
":ragged_operators",
|
||||||
":ragged_tensor",
|
":ragged_tensor",
|
||||||
|
":ragged_tensor_shape",
|
||||||
":ragged_tensor_value",
|
":ragged_tensor_value",
|
||||||
":ragged_util",
|
":ragged_util",
|
||||||
":segment_id_ops",
|
":segment_id_ops",
|
||||||
@ -155,6 +156,7 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":ragged_factory_ops",
|
":ragged_factory_ops",
|
||||||
":ragged_tensor",
|
":ragged_tensor",
|
||||||
|
":ragged_tensor_shape",
|
||||||
":ragged_util",
|
":ragged_util",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:clip_ops",
|
"//tensorflow/python:clip_ops",
|
||||||
@ -190,6 +192,25 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "ragged_tensor_shape",
|
||||||
|
srcs = ["ragged_tensor_shape.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":ragged_array_ops",
|
||||||
|
":ragged_conversion_ops",
|
||||||
|
":ragged_factory_ops",
|
||||||
|
":ragged_tensor",
|
||||||
|
":ragged_util",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:tensor_shape",
|
||||||
|
"//tensorflow/python:tensor_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "ragged_tensor_value",
|
name = "ragged_tensor_value",
|
||||||
srcs = ["ragged_tensor_value.py"],
|
srcs = ["ragged_tensor_value.py"],
|
||||||
@ -207,6 +228,7 @@ py_library(
|
|||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:ragged_math_ops_gen",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -690,3 +712,15 @@ py_test(
|
|||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "ragged_tensor_shape_test",
|
||||||
|
srcs = ["ragged_tensor_shape_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":ragged",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@ -143,6 +143,11 @@ The following operations are specific to ragged tensors:
|
|||||||
<!-- Elementwise Ops -->
|
<!-- Elementwise Ops -->
|
||||||
@@make_elementwise_op
|
@@make_elementwise_op
|
||||||
|
|
||||||
|
<!-- Shape & broadcasting -->
|
||||||
|
@@RaggedTensorDynamicShape
|
||||||
|
@@broadcast_to
|
||||||
|
@@broadcast_dynamic_shape
|
||||||
|
|
||||||
<!-- Symbols from ragged_elementwise_ops._symbols_to_export are whitelisted -->
|
<!-- Symbols from ragged_elementwise_ops._symbols_to_export are whitelisted -->
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -214,6 +219,10 @@ from tensorflow.python.ops.ragged.ragged_tensor import is_ragged
|
|||||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
||||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorType
|
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorType
|
||||||
|
|
||||||
|
from tensorflow.python.ops.ragged.ragged_tensor_shape import broadcast_dynamic_shape
|
||||||
|
from tensorflow.python.ops.ragged.ragged_tensor_shape import broadcast_to
|
||||||
|
from tensorflow.python.ops.ragged.ragged_tensor_shape import RaggedTensorDynamicShape
|
||||||
|
|
||||||
from tensorflow.python.ops.ragged.ragged_tensor_value import RaggedTensorValue
|
from tensorflow.python.ops.ragged.ragged_tensor_value import RaggedTensorValue
|
||||||
|
|
||||||
from tensorflow.python.ops.ragged.segment_id_ops import row_splits_to_segment_ids
|
from tensorflow.python.ops.ragged.segment_id_ops import row_splits_to_segment_ids
|
||||||
|
|||||||
@ -225,6 +225,28 @@ def row_lengths(rt_input, axis=1, name=None):
|
|||||||
return array_ops.ones(shape[:axis], dtypes.int64) * shape[axis]
|
return array_ops.ones(shape[:axis], dtypes.int64) * shape[axis]
|
||||||
|
|
||||||
|
|
||||||
|
def nested_row_lengths(rt_input, name=None):
|
||||||
|
"""Returns a tuple containing the row_lengths for all ragged dimensions.
|
||||||
|
|
||||||
|
`nested_row_lengths(rt)` is a tuple containing the `row_lengths` tensors for
|
||||||
|
all ragged dimensions in `rt`, ordered from outermost to innermost.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rt_input: A potentially ragged tensor.
|
||||||
|
name: A name prefix for the returned tensors (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `tuple` of 1-D `int64` `Tensors`. The length of the tuple is equal to
|
||||||
|
`rt_input.ragged_rank`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, 'RaggedNestedRowLengths', [rt_input]):
|
||||||
|
rt_nested_row_lengths = []
|
||||||
|
while isinstance(rt_input, ragged_tensor.RaggedTensor):
|
||||||
|
rt_nested_row_lengths.append(row_lengths(rt_input))
|
||||||
|
rt_input = rt_input.values
|
||||||
|
return tuple(rt_nested_row_lengths)
|
||||||
|
|
||||||
|
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
# Bounding Shape
|
# Bounding Shape
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
@ -451,8 +473,7 @@ def batch_gather(params, indices, name=None):
|
|||||||
adjusted_indices = math_ops.to_int64(indices) + adjustments
|
adjusted_indices = math_ops.to_int64(indices) + adjustments
|
||||||
return gather(params.values, adjusted_indices)
|
return gather(params.values, adjusted_indices)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError('batch shape from indices does not match params shape')
|
||||||
'batch shape from indices does not match params shape')
|
|
||||||
|
|
||||||
|
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
@ -719,7 +740,7 @@ def boolean_mask(data, mask, keepdims=False, name=None):
|
|||||||
int_mask = ragged_functional_ops.map_inner_values(
|
int_mask = ragged_functional_ops.map_inner_values(
|
||||||
math_ops.cast, mask, dtype=dtypes.int64)
|
math_ops.cast, mask, dtype=dtypes.int64)
|
||||||
masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
|
masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
|
||||||
splits.append(_lengths_to_splits(masked_row_lengths))
|
splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
|
||||||
mask = mask.values
|
mask = mask.values
|
||||||
data = data.values
|
data = data.values
|
||||||
|
|
||||||
@ -741,7 +762,7 @@ def boolean_mask(data, mask, keepdims=False, name=None):
|
|||||||
# masks back to a splits tensor.
|
# masks back to a splits tensor.
|
||||||
lengths = row_lengths(data)
|
lengths = row_lengths(data)
|
||||||
masked_lengths = array_ops.boolean_mask(lengths, mask)
|
masked_lengths = array_ops.boolean_mask(lengths, mask)
|
||||||
masked_splits = _lengths_to_splits(masked_lengths)
|
masked_splits = ragged_util.lengths_to_splits(masked_lengths)
|
||||||
|
|
||||||
# Get the masked values: first get row ids corresponding to each
|
# Get the masked values: first get row ids corresponding to each
|
||||||
# value, then use tf.gather to build a boolean mask that's false for
|
# value, then use tf.gather to build a boolean mask that's false for
|
||||||
@ -977,7 +998,7 @@ def _ragged_stack_concat_axis_0(rt_inputs, stack_values):
|
|||||||
# If we are performing a stack operation, then add another splits.
|
# If we are performing a stack operation, then add another splits.
|
||||||
if stack_values:
|
if stack_values:
|
||||||
stack_lengths = array_ops.stack([nrows(rt) for rt in rt_inputs])
|
stack_lengths = array_ops.stack([nrows(rt) for rt in rt_inputs])
|
||||||
stack_splits = _lengths_to_splits(stack_lengths)
|
stack_splits = ragged_util.lengths_to_splits(stack_lengths)
|
||||||
concatenated_nested_splits.insert(0, stack_splits)
|
concatenated_nested_splits.insert(0, stack_splits)
|
||||||
|
|
||||||
return ragged_factory_ops.from_nested_row_splits(concatenated_inner_values,
|
return ragged_factory_ops.from_nested_row_splits(concatenated_inner_values,
|
||||||
@ -1131,7 +1152,8 @@ def _tile_ragged_values(rt_input, multiples, const_multiples=None):
|
|||||||
|
|
||||||
# Repeat each element in this ragged dimension `multiples[axis]` times.
|
# Repeat each element in this ragged dimension `multiples[axis]` times.
|
||||||
if const_multiples is None or const_multiples[axis] != 1:
|
if const_multiples is None or const_multiples[axis] != 1:
|
||||||
inner_value_ids = _repeat_ranges(inner_value_ids, splits, multiples[axis])
|
inner_value_ids = ragged_util.repeat_ranges(inner_value_ids, splits,
|
||||||
|
multiples[axis])
|
||||||
|
|
||||||
prev_splits = splits
|
prev_splits = splits
|
||||||
|
|
||||||
@ -1200,15 +1222,15 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
|
|||||||
for d in range(axis - 1, -1, -1):
|
for d in range(axis - 1, -1, -1):
|
||||||
if const_multiples is None or const_multiples[d + 1] != 1:
|
if const_multiples is None or const_multiples[d + 1] != 1:
|
||||||
splits = projected_splits[d][axis - 1] * repeats
|
splits = projected_splits[d][axis - 1] * repeats
|
||||||
output_lengths = _repeat_ranges(output_lengths, splits,
|
output_lengths = ragged_util.repeat_ranges(output_lengths, splits,
|
||||||
multiples[d + 1])
|
multiples[d + 1])
|
||||||
repeats *= multiples[d + 1]
|
repeats *= multiples[d + 1]
|
||||||
|
|
||||||
# Tile splits for the outermost (uniform) dimension.
|
# Tile splits for the outermost (uniform) dimension.
|
||||||
output_lengths = array_ops.tile(output_lengths, multiples[:1])
|
output_lengths = array_ops.tile(output_lengths, multiples[:1])
|
||||||
|
|
||||||
# Convert to splits.
|
# Convert to splits.
|
||||||
result_splits.append(_lengths_to_splits(output_lengths))
|
result_splits.append(ragged_util.lengths_to_splits(output_lengths))
|
||||||
|
|
||||||
return result_splits
|
return result_splits
|
||||||
|
|
||||||
@ -1436,11 +1458,6 @@ def _coordinate_where(condition):
|
|||||||
#===============================================================================
|
#===============================================================================
|
||||||
|
|
||||||
|
|
||||||
def _lengths_to_splits(lengths):
|
|
||||||
"""Returns splits corresponding to the given lengths."""
|
|
||||||
return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=0)
|
|
||||||
|
|
||||||
|
|
||||||
def _increase_ragged_rank_to(rt_input, ragged_rank):
|
def _increase_ragged_rank_to(rt_input, ragged_rank):
|
||||||
"""Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
|
"""Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
|
||||||
if ragged_rank > 0:
|
if ragged_rank > 0:
|
||||||
@ -1460,45 +1477,3 @@ def _concat_ragged_splits(splits_list):
|
|||||||
pieces.append(splits[1:] + splits_offset)
|
pieces.append(splits[1:] + splits_offset)
|
||||||
splits_offset += splits[-1]
|
splits_offset += splits[-1]
|
||||||
return array_ops.concat(pieces, axis=0)
|
return array_ops.concat(pieces, axis=0)
|
||||||
|
|
||||||
|
|
||||||
def _repeat_ranges(params, splits, multiple):
|
|
||||||
"""Repeats each range of `params` (as specified by `splits`) `multiple` times.
|
|
||||||
|
|
||||||
Let the `i`th range of `params` be defined as
|
|
||||||
`params[splits[i]:splits[i + 1]]`. Then this function returns a tensor
|
|
||||||
containing range 0 repeated `multiple` times, followed by range 1 repeated
|
|
||||||
`multiple`, ..., followed by the last range repeated `multiple` times.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params: The `Tensor` whose values should be repeated.
|
|
||||||
splits: A splits tensor indicating the ranges of `params` that should be
|
|
||||||
repeated.
|
|
||||||
multiple: The number of times each range should be repeated.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `Tensor` with the same rank and type as `params`.
|
|
||||||
|
|
||||||
#### Example:
|
|
||||||
```python
|
|
||||||
>>> _repeat_ranges(['a', 'b', 'c'], [0, 2, 3], 3)
|
|
||||||
['a', 'b', 'a', 'b', 'a', 'b', 'c', 'c', 'c']
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
# Repeat each split value `multiple` times. E.g., if `splits=[0 3 4]` and
|
|
||||||
# `multiples=3`, then `repeated_splits=[0 0 0 3 3 3 4 4 4]`.
|
|
||||||
repeated_splits = array_ops.tile(
|
|
||||||
array_ops.expand_dims(splits, axis=1), array_ops.stack([1, multiple]))
|
|
||||||
repeated_splits = array_ops.reshape(repeated_splits, [-1])
|
|
||||||
|
|
||||||
# Divide the splits into repeated starts & repeated limits. E.g., if
|
|
||||||
# `repeated_splits=[0 0 0 3 3 3 4 4 4]` then `repeated_starts=[0 0 0 3 3 3]`
|
|
||||||
# and `repeated_limits=[3 3 3 4 4 4]`.
|
|
||||||
n_splits = array_ops.shape(repeated_splits, out_type=dtypes.int64)[0]
|
|
||||||
repeated_starts = repeated_splits[:n_splits - multiple]
|
|
||||||
repeated_limits = repeated_splits[multiple:]
|
|
||||||
|
|
||||||
# Get indices for each range from starts to limits, and use those to gather
|
|
||||||
# the values in the desired repetition pattern.
|
|
||||||
offsets = ragged_math_ops.range(repeated_starts, repeated_limits).values
|
|
||||||
return array_ops.gather(params, offsets)
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from tensorflow.python.ops import parsing_ops
|
|||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.ops.ragged import ragged_util
|
from tensorflow.python.ops.ragged import ragged_tensor_shape
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
from tensorflow.python.util import tf_export
|
from tensorflow.python.util import tf_export
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
@ -209,28 +209,45 @@ def _broadcast_elementwise_args(elementwise_args):
|
|||||||
if not any(is_ragged):
|
if not any(is_ragged):
|
||||||
return elementwise_args, (), ()
|
return elementwise_args, (), ()
|
||||||
|
|
||||||
# Support limited broadcasting (namely, scalar + ragged). Full
|
# If we have a single ragged tensor plus a set of scalars, then we can
|
||||||
# broadcasting support will be added later.
|
# rely on the underlying elementwise op to do broadcasting.
|
||||||
if all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
|
if (sum(is_ragged) == 1 and
|
||||||
for t in elementwise_args.values()):
|
all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
|
||||||
|
for t in elementwise_args.values())):
|
||||||
nested_splits_lists = [
|
nested_splits_lists = [
|
||||||
t.nested_row_splits
|
t.nested_row_splits
|
||||||
for t in elementwise_args.values()
|
for t in elementwise_args.values()
|
||||||
if ragged_tensor.is_ragged(t)
|
if ragged_tensor.is_ragged(t)][0]
|
||||||
]
|
return elementwise_args, nested_splits_lists, ()
|
||||||
if len(nested_splits_lists) == 1:
|
|
||||||
checks = ()
|
|
||||||
else:
|
|
||||||
if any(t.shape.ndims is None for t in elementwise_args.values()):
|
|
||||||
raise ValueError('Ragged elementwise ops require that rank (number '
|
|
||||||
'of dimensions) be statically known.')
|
|
||||||
if len(set(t.shape.ndims for t in elementwise_args.values())) != 1:
|
|
||||||
raise ValueError('Ragged elementwise ops do not support '
|
|
||||||
'broadcasting yet')
|
|
||||||
checks = ragged_util.assert_splits_match(nested_splits_lists)
|
|
||||||
return (elementwise_args, nested_splits_lists[0], checks)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('Ragged elementwise ops do not support broadcasting yet')
|
# Get the shapes of all the elementwise arguments.
|
||||||
|
shapes = [ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(t)
|
||||||
|
for t in elementwise_args.values()]
|
||||||
|
|
||||||
|
# Broadcast the shapes to all have the same rank (the max rank).
|
||||||
|
ranks = [t.shape.ndims for t in elementwise_args.values()]
|
||||||
|
if any(rank is None for rank in ranks):
|
||||||
|
raise ValueError('Unable to broadcast: unknown rank')
|
||||||
|
broadcast_rank = max(ranks)
|
||||||
|
shapes = [shape.broadcast_to_rank(broadcast_rank) for shape in shapes]
|
||||||
|
|
||||||
|
# For each dimension, broadcast the shapes to be compatible.
|
||||||
|
for axis in range(broadcast_rank):
|
||||||
|
# For each i, broadcast shape[i+1] to be compatible with shape[i]; and
|
||||||
|
# then finally broadcast shape[0] to be compatible with shape[-1].
|
||||||
|
for i in range(len(shapes)):
|
||||||
|
j = (i + 1) % len(shapes)
|
||||||
|
dim_size = shapes[i].dimension_size(axis)
|
||||||
|
shapes[j] = shapes[j].broadcast_dimension(axis, dim_size)
|
||||||
|
broadcast_shape = shapes[0]
|
||||||
|
|
||||||
|
# Broadcast every elementwise arg to the shape that we calculated.
|
||||||
|
elementwise_args = dict([
|
||||||
|
(key, ragged_tensor_shape.broadcast_to(t, broadcast_shape, False))
|
||||||
|
for (key, t) in elementwise_args.items()])
|
||||||
|
nested_splits_lists = list(elementwise_args.values())[0].nested_row_splits
|
||||||
|
return elementwise_args, nested_splits_lists, ()
|
||||||
|
|
||||||
|
|
||||||
# A list of symbols that should be exported in the "ragged" package.
|
# A list of symbols that should be exported in the "ragged" package.
|
||||||
|
|||||||
@ -399,44 +399,37 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
y = ragged.from_row_splits(
|
y = ragged.from_row_splits(
|
||||||
array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
|
array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, r'Ragged elementwise ops require that rank \(number '
|
ValueError, r'Unable to broadcast: unknown rank'):
|
||||||
r'of dimensions\) be statically known.'):
|
|
||||||
ragged.add(x, y)
|
ragged.add(x, y)
|
||||||
|
|
||||||
def testBroadcastError1(self):
|
@parameterized.parameters([
|
||||||
x = ragged.constant([[1, 2], [3]])
|
dict(
|
||||||
y = [[12]]
|
x=ragged.constant_value([[1, 2], [3]]),
|
||||||
with self.assertRaisesRegexp(
|
y=[[10]],
|
||||||
ValueError, 'Ragged elementwise ops do not support broadcasting yet'):
|
expected=[[11, 12], [13]]),
|
||||||
ragged.add(x, y)
|
dict(
|
||||||
|
x=ragged.constant_value([[[1, 2], [3, 4]], [[5]]], ragged_rank=2),
|
||||||
def testBroadcastError2(self):
|
y=ragged.constant_value([[[10], [20]], [[30]]], ragged_rank=1),
|
||||||
x = ragged.constant([[[1, 2], [3, 4]], [[5]]], ragged_rank=2)
|
expected=[[[11, 12], [23, 24]], [[35]]]),
|
||||||
y = ragged.constant([[[8], [3]], [[2]]], ragged_rank=1)
|
dict(
|
||||||
with self.assertRaisesRegexp(ValueError,
|
x=ragged.constant_value([[[1]]]),
|
||||||
'Inputs must have identical ragged splits'):
|
y=ragged.constant_value([[1]]),
|
||||||
ragged.add(x, y)
|
expected=[[[2]]]),
|
||||||
|
])
|
||||||
def testBroadcastError3(self):
|
def testBroadcastAdd(self, x, y, expected):
|
||||||
x = ragged.constant([[[1, 2], [3]], [[4, 5], [6]]], ragged_rank=2)
|
x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
|
||||||
y = ragged.constant([[7, 8], [9]], ragged_rank=1)
|
y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
|
||||||
with self.assertRaisesRegexp(
|
result = x + y
|
||||||
ValueError, 'Ragged elementwise ops do not support broadcasting yet'):
|
with self.cached_session():
|
||||||
ragged.add(x, y)
|
self.assertEqual(result.eval().tolist(), expected)
|
||||||
|
|
||||||
def testBroadcastError4(self):
|
|
||||||
x = ragged.constant([[[1]]])
|
|
||||||
y = ragged.constant([[1]])
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, 'Ragged elementwise ops do not support broadcasting yet'):
|
|
||||||
ragged.add(x, y)
|
|
||||||
|
|
||||||
def testShapeMismatch(self):
|
def testShapeMismatch(self):
|
||||||
x = ragged.constant([[1, 2, 3], [4, 5]])
|
x = ragged.constant([[1, 2, 3], [4, 5]])
|
||||||
y = ragged.constant([[1, 2, 3], [4, 5, 6]])
|
y = ragged.constant([[1, 2, 3], [4, 5, 6]])
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
'Inputs must have identical ragged splits'):
|
'Incompatible shapes'):
|
||||||
ragged.add(x, y)
|
with self.cached_session():
|
||||||
|
ragged.add(x, y).eval()
|
||||||
|
|
||||||
def testDocstring(self):
|
def testDocstring(self):
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
|
|||||||
@ -676,3 +676,33 @@ def from_nested_row_splits(inner_values, nested_row_splits, name=None):
|
|||||||
for splits in reversed(nested_row_splits):
|
for splits in reversed(nested_row_splits):
|
||||||
result = from_row_splits(result, splits)
|
result = from_row_splits(result, splits)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def from_nested_row_lengths(inner_values, nested_row_lengths, name=None):
|
||||||
|
"""Creates a `RaggedTensor` from a nested list of `row_lengths` tensors.
|
||||||
|
|
||||||
|
Equivalent to:
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = inner_values
|
||||||
|
for row_lengths in reversed(nested_row_lengths):
|
||||||
|
result = from_row_lengths(result, row_lengths)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inner_values: A potentially ragged tensor.
|
||||||
|
nested_row_lengths: A list of 1-D int64 tensors. The `i`th tensor is used
|
||||||
|
as the `row_lengths` for the `i`th ragged dimension.
|
||||||
|
name: A name prefix for the RaggedTensor (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `RaggedTensor` (or `inner_values` if `nested_row_lengths` is empty).
|
||||||
|
"""
|
||||||
|
if isinstance(nested_row_lengths, ops.Tensor):
|
||||||
|
raise TypeError('nested_row_lengths must be a list of Tensors')
|
||||||
|
with ops.name_scope(name, 'RaggedFromNestedRowlengths',
|
||||||
|
[inner_values] + list(nested_row_lengths)):
|
||||||
|
result = inner_values
|
||||||
|
for lengths in reversed(nested_row_lengths):
|
||||||
|
result = from_row_lengths(result, lengths)
|
||||||
|
return result
|
||||||
|
|||||||
@ -257,6 +257,7 @@ class RaggedTensor(object):
|
|||||||
raise TypeError("Row-partitioning argument must be a Tensor.")
|
raise TypeError("Row-partitioning argument must be a Tensor.")
|
||||||
values.shape.with_rank_at_least(1)
|
values.shape.with_rank_at_least(1)
|
||||||
row_splits.shape.assert_has_rank(1)
|
row_splits.shape.assert_has_rank(1)
|
||||||
|
row_splits.set_shape([None])
|
||||||
|
|
||||||
self._values = values
|
self._values = values
|
||||||
self._row_splits = row_splits
|
self._row_splits = row_splits
|
||||||
|
|||||||
570
tensorflow/python/ops/ragged/ragged_tensor_shape.py
Normal file
570
tensorflow/python/ops/ragged/ragged_tensor_shape.py
Normal file
@ -0,0 +1,570 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Shapes & broadcasting for RaggedTensors."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_conversion_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
from tensorflow.python.ops.ragged import ragged_util
|
||||||
|
|
||||||
|
|
||||||
|
class RaggedTensorDynamicShape(object):
|
||||||
|
"""A collection of tensors encoding the shape of a potentially ragged tensor.
|
||||||
|
|
||||||
|
Each `RaggedTensorDynamicShape` consists of an ordered list of dimension
|
||||||
|
sizes. There are two dimension types:
|
||||||
|
|
||||||
|
* "Uniform dimensions" are dimenisons where all slices have the same
|
||||||
|
length. `RaggedTensorDynamicShape` records the size of each uniform
|
||||||
|
dimension using a single scalar integer.
|
||||||
|
|
||||||
|
* "Ragged dimensions" are dimensions whose slices may have different
|
||||||
|
lengths. `RaggedTensorDynamicShape` records the size of each ragged
|
||||||
|
dimension using an integer vector containing the slice lengths for all
|
||||||
|
the slices across that dimension.
|
||||||
|
|
||||||
|
Furthermore, there are two ways a dimension might be encoded:
|
||||||
|
|
||||||
|
* "Partitioned dimensions" are dimensions that are encoded using a
|
||||||
|
`RaggedTensor`'s `nested_row_splits`. The outermostmost partitioned
|
||||||
|
dimension must be uniform, and the innermost partitioned dimension must
|
||||||
|
be ragged.
|
||||||
|
|
||||||
|
* "Inner dimensions" are dimensions that are encoded using a
|
||||||
|
`RaggedTensor`'s `inner_values`. Inner dimensions are always uniform.
|
||||||
|
|
||||||
|
The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes`
|
||||||
|
and `inner_dim_sizes`:
|
||||||
|
|
||||||
|
* `paritioned_dim_sizes` is a list of tensors (one for each partitioned
|
||||||
|
dimension).
|
||||||
|
|
||||||
|
* For uniform dimensions, the tensor is an integer scalar specifying the
|
||||||
|
size of all slices across that dimension.
|
||||||
|
* For ragged dimensions, the tensor is an integer vector specifying the
|
||||||
|
size of each slice across that dimension.
|
||||||
|
|
||||||
|
* `inner_dim_sizes` is a single integer vector, where each element
|
||||||
|
specifies the size of a single inner dimension.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Tensor | Ragged | Partitioned Dim Sizes | Inner Dim
|
||||||
|
: Rank : : Sizes
|
||||||
|
------------------------------ | ------ | ---------------------- | ----------
|
||||||
|
`[[1, 2, 3], [4, 5, 6]]` | 0 | | `2, 3`
|
||||||
|
`[[1, 2], [], [3, 4, 5]]` | 1 | `3, (2, 0, 3)` |
|
||||||
|
`[[[1, 2], [3, 4]], [[5, 6]]]` | 1 | `2, (2, 1)` | 2
|
||||||
|
`[[[1, 2], [3]], [[4, 5]]]` | 2 | `2, (2, 1), (2, 1, 2)` |
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, partitioned_dim_sizes, inner_dim_sizes):
|
||||||
|
"""Creates a RaggedTensorDynamicShape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for
|
||||||
|
each partitioned dimension. If dimension `d` is uniform, then
|
||||||
|
`partitioned_dim_sizes[d]` must be an integer scalar, specifying the
|
||||||
|
size of all slices across dimension `d`. If dimension `d` is ragged,
|
||||||
|
then `partitioned_dim_sizes[d]` must be an integer vector, specifying
|
||||||
|
the size of each slice across dimension `d`.
|
||||||
|
inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the
|
||||||
|
number of inner dimensions. `inner_dim_sizes[n]` is the size of all
|
||||||
|
slices across the `n`th inner dimension (which is the
|
||||||
|
`(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor.
|
||||||
|
"""
|
||||||
|
assert isinstance(partitioned_dim_sizes, (list, tuple))
|
||||||
|
with ops.name_scope(None, 'RaggedTensorDynamicShape',
|
||||||
|
(partitioned_dim_sizes, inner_dim_sizes)):
|
||||||
|
partitioned_dim_sizes = tuple(
|
||||||
|
ragged_util.convert_to_int_tensor(
|
||||||
|
size, dtype=dtypes.int64, name='partitioned_dimension_size')
|
||||||
|
for size in partitioned_dim_sizes)
|
||||||
|
inner_dim_sizes = ragged_util.convert_to_int_tensor(
|
||||||
|
inner_dim_sizes, dtype=dtypes.int64, name='inner_dim_sizes')
|
||||||
|
|
||||||
|
# Validate shapes.
|
||||||
|
if partitioned_dim_sizes:
|
||||||
|
for axis, dimension_size in enumerate(partitioned_dim_sizes):
|
||||||
|
if dimension_size.shape.ndims is None:
|
||||||
|
raise ValueError(
|
||||||
|
'rank of partitioned_dim_sizes[%d] is unknown' % axis)
|
||||||
|
dimension_size.shape.with_rank_at_most(1)
|
||||||
|
if partitioned_dim_sizes[0].shape.ndims == 1:
|
||||||
|
raise ValueError('outermost partitioned dimension must be uniform')
|
||||||
|
if partitioned_dim_sizes[-1].shape.ndims == 0:
|
||||||
|
raise ValueError('innermost partitioned dimension must be ragged')
|
||||||
|
inner_dim_sizes.shape.assert_has_rank(1)
|
||||||
|
|
||||||
|
self._partitioned_dim_sizes = partitioned_dim_sizes
|
||||||
|
self._inner_dim_sizes = inner_dim_sizes
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ('RaggedTensorDynamicShape'
|
||||||
|
'(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' %
|
||||||
|
(self._partitioned_dim_sizes, self._inner_dim_sizes))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dim_sizes(dim_sizes):
|
||||||
|
"""Constructs a ragged shape from a list of dimension sizes.
|
||||||
|
|
||||||
|
This list contains a single tensor for each dimension, where the tensor
|
||||||
|
is a scalar if the dimension is uniform, or a vector if the dimension is
|
||||||
|
ragged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim_sizes: List of int64 scalars or vectors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A RaggedTensorDynamicShape.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes',
|
||||||
|
[dim_sizes]):
|
||||||
|
dim_sizes = tuple(
|
||||||
|
ragged_util.convert_to_int_tensor(
|
||||||
|
size, dtype=dtypes.int64, name='dim_sizes') for size in dim_sizes)
|
||||||
|
# Split the dimensions into partitioned & inner dimensions.
|
||||||
|
inner_split = 0
|
||||||
|
for dim, dim_size in enumerate(dim_sizes):
|
||||||
|
if dim_size.shape.ndims == 1:
|
||||||
|
inner_split = dim + 1
|
||||||
|
elif dim_size.shape.ndims != 0:
|
||||||
|
raise ValueError('Each dim_size must be a scalar or a vector')
|
||||||
|
return RaggedTensorDynamicShape(dim_sizes[:inner_split],
|
||||||
|
dim_sizes[inner_split:])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tensor(cls, rt_input):
|
||||||
|
"""Constructs a ragged shape for a potentially ragged tensor."""
|
||||||
|
with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
|
||||||
|
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(rt_input)
|
||||||
|
if not ragged_tensor.is_ragged(rt_input):
|
||||||
|
return cls([], array_ops.shape(rt_input))
|
||||||
|
else:
|
||||||
|
partitioned_dim_sizes = ((ragged_array_ops.nrows(rt_input),) +
|
||||||
|
ragged_array_ops.nested_row_lengths(rt_input))
|
||||||
|
return RaggedTensorDynamicShape(
|
||||||
|
partitioned_dim_sizes,
|
||||||
|
array_ops.shape(rt_input.inner_values)[1:])
|
||||||
|
|
||||||
|
def dimension_size(self, axis):
|
||||||
|
"""Returns the size of slices across the specified dimension."""
|
||||||
|
if not isinstance(axis, int):
|
||||||
|
raise TypeError('axis must be an integer')
|
||||||
|
partitioned_ndims = len(self._partitioned_dim_sizes)
|
||||||
|
if axis < partitioned_ndims:
|
||||||
|
return self._partitioned_dim_sizes[axis]
|
||||||
|
else:
|
||||||
|
return self._inner_dim_sizes[axis - partitioned_ndims]
|
||||||
|
|
||||||
|
def is_ragged(self, axis):
|
||||||
|
"""Returns true if the indicated dimension is ragged."""
|
||||||
|
if not isinstance(axis, int):
|
||||||
|
raise TypeError('axis must be an integer')
|
||||||
|
rank = self.rank
|
||||||
|
if axis < 0:
|
||||||
|
raise ValueError('Negative axis values are not supported')
|
||||||
|
elif rank is not None and axis >= rank:
|
||||||
|
raise ValueError('Expected axis=%s < rank=%s' % (axis, rank))
|
||||||
|
else:
|
||||||
|
return (axis > 0 and axis < len(self._partitioned_dim_sizes) and
|
||||||
|
self._partitioned_dim_sizes[axis].shape.ndims == 1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rank(self):
|
||||||
|
"""The number of dimensions in this shape, or None if unknown."""
|
||||||
|
inner_ndims = self._inner_dim_sizes.shape[0].value
|
||||||
|
if inner_ndims is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return len(self._partitioned_dim_sizes) + inner_ndims
|
||||||
|
|
||||||
|
@property
|
||||||
|
def partitioned_dim_sizes(self):
|
||||||
|
"""The partitioned dimension sizes for this shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `list` of 0-D or 1-D integer `Tensor`.
|
||||||
|
"""
|
||||||
|
return self._partitioned_dim_sizes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inner_dim_sizes(self):
|
||||||
|
"""The inner dimension sizes for this shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 1-D integer `Tensor`.
|
||||||
|
"""
|
||||||
|
return self._inner_dim_sizes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_partitioned_dimensions(self):
|
||||||
|
"""The number of partitioned dimensions in this shape."""
|
||||||
|
return len(self._partitioned_dim_sizes)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_inner_dimensions(self):
|
||||||
|
"""The number of inner dimensions, or `None` if not statically known."""
|
||||||
|
return self._inner_dim_sizes.shape[0].value
|
||||||
|
|
||||||
|
def broadcast_to_rank(self, rank):
|
||||||
|
"""Adds leading size-1 dimensions to broadcast `self` to the given rank.
|
||||||
|
|
||||||
|
E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)`
|
||||||
|
is `[1, 1, 3, (D2), 4]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rank: The rank for the returned shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions
|
||||||
|
have the same size as `self` and whose outer dimensions have size `1`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `self.rank` is unknown or greater than `rank`.
|
||||||
|
"""
|
||||||
|
if self.rank is None:
|
||||||
|
raise ValueError('Unable to broadcast: self.rank is unknown')
|
||||||
|
dims_to_add = rank - self.rank
|
||||||
|
if dims_to_add < 0:
|
||||||
|
raise ValueError('Unable to broadcast: rank=%d must be greater than '
|
||||||
|
'self.rank=%d.' % (rank, self.rank))
|
||||||
|
elif dims_to_add == 0:
|
||||||
|
return self
|
||||||
|
elif self._partitioned_dim_sizes:
|
||||||
|
partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes
|
||||||
|
return RaggedTensorDynamicShape(partitioned_dims, self._inner_dim_sizes)
|
||||||
|
else:
|
||||||
|
inner_dims = array_ops.concat(
|
||||||
|
[array_ops.ones([dims_to_add], dtypes.int64), self.inner_dim_sizes],
|
||||||
|
axis=0)
|
||||||
|
return RaggedTensorDynamicShape([], inner_dims)
|
||||||
|
|
||||||
|
def broadcast_dimension(self, axis, lengths):
|
||||||
|
"""Returns a shape that is broadcast-compatible with self & lengths.
|
||||||
|
|
||||||
|
* If dimension[axis] is uniform and lengths is a scalar, the check
|
||||||
|
that either lengths==1 or axis==1 or lengths==axis, and tile
|
||||||
|
dimension[axis] with tf.where(lengths==axis, 1, axis) repeats.
|
||||||
|
|
||||||
|
* If dimension[axis] is uniform and lengths is a vector, then check
|
||||||
|
that dimension[axis]==1, and raggedly tile dimension[axis] with
|
||||||
|
lengths repeats. (we can skip tiling if we statically know that
|
||||||
|
slice_lengths == 1??)
|
||||||
|
|
||||||
|
* If dimension[axis] is ragged and lengths is a scalar, then check
|
||||||
|
that lengths==1.
|
||||||
|
|
||||||
|
* If dimension[axis] is ragged and lengths is a vector, then check
|
||||||
|
that self.dimension_size(axis) == lengths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis: `int`. The dimension to broadcast.
|
||||||
|
lengths: 0-D or 1-D integer `Tensor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `RaggedTensorDynamicShape`.
|
||||||
|
"""
|
||||||
|
lengths = ragged_util.convert_to_int_tensor(
|
||||||
|
lengths, name='lengths', dtype=dtypes.int64)
|
||||||
|
# Check whether lengths is a scalar (for uniform dimensions) or
|
||||||
|
# vector (for ragged dimensions).
|
||||||
|
if lengths.shape.ndims is None:
|
||||||
|
raise ValueError('lengths must have a known rank.')
|
||||||
|
elif lengths.shape.ndims > 1:
|
||||||
|
raise ValueError('lengths must be a scalar or vector')
|
||||||
|
else:
|
||||||
|
lengths_is_scalar = (lengths.shape.ndims == 0)
|
||||||
|
|
||||||
|
# Verify that the shapes are compatible.
|
||||||
|
if self.is_ragged(axis):
|
||||||
|
if lengths_is_scalar:
|
||||||
|
condition = math_ops.equal(lengths, 1)
|
||||||
|
else:
|
||||||
|
condition = math_ops.reduce_all(
|
||||||
|
math_ops.equal(lengths, self.dimension_size(axis)))
|
||||||
|
else:
|
||||||
|
axis_dim_size = self.dimension_size(axis)
|
||||||
|
if lengths_is_scalar:
|
||||||
|
condition = (
|
||||||
|
math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1)
|
||||||
|
| math_ops.equal(axis_dim_size, lengths))
|
||||||
|
else:
|
||||||
|
condition = math_ops.equal(axis_dim_size, 1)
|
||||||
|
broadcast_err = [
|
||||||
|
'Unable to broadcast: dimension size mismatch in dimension', axis,
|
||||||
|
'lengths=', lengths, 'dim_size=',
|
||||||
|
self.dimension_size(axis)
|
||||||
|
]
|
||||||
|
broadcast_check = control_flow_ops.Assert(
|
||||||
|
condition, data=broadcast_err, summarize=10)
|
||||||
|
|
||||||
|
with ops.control_dependencies([broadcast_check]):
|
||||||
|
# Partitioned dimensions:
|
||||||
|
if axis < self.num_partitioned_dimensions:
|
||||||
|
if self.is_ragged(axis):
|
||||||
|
# Use an identity op to make sure the check actually gets run.
|
||||||
|
return RaggedTensorDynamicShape(
|
||||||
|
self._partitioned_dim_sizes,
|
||||||
|
array_ops.identity(self.inner_dim_sizes))
|
||||||
|
else:
|
||||||
|
return self._broadcast_uniform_partitioned_dimension(axis, lengths)
|
||||||
|
|
||||||
|
# Inner dimensions:
|
||||||
|
else:
|
||||||
|
if lengths_is_scalar:
|
||||||
|
return self._broadcast_inner_dimension_to_uniform(axis, lengths)
|
||||||
|
else:
|
||||||
|
if axis == 0:
|
||||||
|
raise ValueError('Unable to broadcast: '
|
||||||
|
'outermost dimension must be uniform.')
|
||||||
|
return self._broadcast_inner_dimension_to_ragged(axis, lengths)
|
||||||
|
|
||||||
|
def num_slices_in_dimension(self, axis):
|
||||||
|
"""Returns the total number of slices across the indicated dimension."""
|
||||||
|
if axis < 0:
|
||||||
|
return constant_op.constant(1, dtype=dtypes.int64)
|
||||||
|
elif self.is_ragged(axis):
|
||||||
|
return math_ops.reduce_sum(self._partitioned_dim_sizes[axis])
|
||||||
|
else:
|
||||||
|
return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1)
|
||||||
|
|
||||||
|
def _broadcast_uniform_partitioned_dimension(self, axis, lengths):
|
||||||
|
"""Broadcasts the partitioned dimension `axis` to match `lengths`."""
|
||||||
|
axis_dim_size = self.dimension_size(axis)
|
||||||
|
partitioned_sizes = list(self._partitioned_dim_sizes[:axis])
|
||||||
|
|
||||||
|
if lengths.shape.ndims == 0:
|
||||||
|
lengths = array_ops.where(
|
||||||
|
math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size)
|
||||||
|
repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1)
|
||||||
|
splits = array_ops.stack([0, self.num_slices_in_dimension(axis)])
|
||||||
|
else:
|
||||||
|
splits = math_ops.range(
|
||||||
|
array_ops.size(lengths, out_type=dtypes.int64) + 1)
|
||||||
|
repeats = lengths
|
||||||
|
|
||||||
|
partitioned_sizes.append(lengths)
|
||||||
|
|
||||||
|
for dim_size in self._partitioned_dim_sizes[axis + 1:]:
|
||||||
|
if dim_size.shape.ndims == 0:
|
||||||
|
partitioned_sizes.append(dim_size)
|
||||||
|
splits *= dim_size
|
||||||
|
else:
|
||||||
|
partitioned_sizes.append(
|
||||||
|
ragged_util.repeat_ranges(dim_size, splits, repeats))
|
||||||
|
splits = array_ops.gather(
|
||||||
|
ragged_util.lengths_to_splits(dim_size), splits)
|
||||||
|
inner_sizes = self._inner_dim_sizes
|
||||||
|
return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
|
||||||
|
|
||||||
|
def _broadcast_inner_dimension_to_uniform(self, axis, length):
|
||||||
|
"""Broadcasts the inner dimension `axis` to match `lengths`."""
|
||||||
|
dim_size = self.dimension_size(axis)
|
||||||
|
axis_in_inner_dims = axis - self.num_partitioned_dimensions
|
||||||
|
partitioned_sizes = self._partitioned_dim_sizes
|
||||||
|
inner_sizes = array_ops.concat([
|
||||||
|
self._inner_dim_sizes[:axis_in_inner_dims],
|
||||||
|
[array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)],
|
||||||
|
self._inner_dim_sizes[axis_in_inner_dims + 1:]
|
||||||
|
],
|
||||||
|
axis=0)
|
||||||
|
return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
|
||||||
|
|
||||||
|
def _broadcast_inner_dimension_to_ragged(self, axis, lengths):
|
||||||
|
axis_in_inner_dims = axis - self.num_partitioned_dimensions
|
||||||
|
partitioned_sizes = (
|
||||||
|
self._partitioned_dim_sizes + tuple([
|
||||||
|
self._inner_dim_sizes[i] for i in range(axis_in_inner_dims)
|
||||||
|
]) + (lengths,))
|
||||||
|
inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:]
|
||||||
|
return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_dynamic_shape(shape_x, shape_y):
|
||||||
|
"""Returns the shape formed by broadcasting two shapes to be compatible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape_x: A `RaggedTensorDynamicShape`
|
||||||
|
shape_y: A `RaggedTensorDynamicShape`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `RaggedTensorDynamicShape`.
|
||||||
|
Raises:
|
||||||
|
ValueError: If `shape_x` and `shape_y` are not broadcast-compatible.
|
||||||
|
"""
|
||||||
|
if not isinstance(shape_x, RaggedTensorDynamicShape):
|
||||||
|
raise TypeError('shape_x must be a RaggedTensorDynamicShape')
|
||||||
|
if not isinstance(shape_y, RaggedTensorDynamicShape):
|
||||||
|
raise TypeError('shape_y must be a RaggedTensorDynamicShape')
|
||||||
|
|
||||||
|
# Broadcast both shapes to have the same rank.
|
||||||
|
if shape_x.rank is None or shape_y.rank is None:
|
||||||
|
raise ValueError('Unable to broadcast: unknown rank')
|
||||||
|
broadcast_rank = max(shape_x.rank, shape_y.rank)
|
||||||
|
shape_x = shape_x.broadcast_to_rank(broadcast_rank)
|
||||||
|
shape_y = shape_y.broadcast_to_rank(broadcast_rank)
|
||||||
|
|
||||||
|
# Broadcast dimensions one at a time, starting from the outermost dimension.
|
||||||
|
for axis in range(broadcast_rank):
|
||||||
|
shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis))
|
||||||
|
shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis))
|
||||||
|
|
||||||
|
return shape_x
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True):
|
||||||
|
"""Broadcasts a potentially ragged tensor to a ragged shape.
|
||||||
|
|
||||||
|
Tiles `rt_input` as necessary to match the given shape.
|
||||||
|
|
||||||
|
Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rt_input: The potentially ragged tensor to broadcast.
|
||||||
|
shape: A `RaggedTensorDynamicShape`
|
||||||
|
broadcast_inner_dimensions: If false, then inner dimensions will not be
|
||||||
|
tiled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A potentially ragged tensor whose values are taken from
|
||||||
|
`rt_input`, and whose shape matches `shape`.
|
||||||
|
"""
|
||||||
|
if not isinstance(shape, RaggedTensorDynamicShape):
|
||||||
|
raise TypeError('shape must be a RaggedTensorDynamicShape')
|
||||||
|
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(rt_input)
|
||||||
|
|
||||||
|
# Broadcasting to a uniform shape.
|
||||||
|
if shape.num_partitioned_dimensions == 0:
|
||||||
|
return _broadcast_to_uniform_shape(rt_input, shape,
|
||||||
|
broadcast_inner_dimensions)
|
||||||
|
else:
|
||||||
|
return _broadcast_to_ragged_shape(rt_input, shape,
|
||||||
|
broadcast_inner_dimensions)
|
||||||
|
|
||||||
|
|
||||||
|
def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions):
|
||||||
|
"""Broadcasts rt_input to the uniform shape `shape`."""
|
||||||
|
if isinstance(rt_input, ragged_tensor.RaggedTensor):
|
||||||
|
raise ValueError('Incompatible with shape: ragged rank mismatch')
|
||||||
|
if broadcast_inner_dimensions:
|
||||||
|
return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes)
|
||||||
|
else:
|
||||||
|
return rt_input
|
||||||
|
|
||||||
|
|
||||||
|
def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
|
||||||
|
"""Broadcasts rt_input to the ragged shape `dst_shape`."""
|
||||||
|
# dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
|
||||||
|
if rt_input.shape.ndims is None or dst_shape.rank is None:
|
||||||
|
raise ValueError('Unable to broadcast: unknown rank')
|
||||||
|
if rt_input.shape.ndims > dst_shape.rank:
|
||||||
|
raise ValueError('Incompatible with shape: rank mismatch')
|
||||||
|
if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
|
||||||
|
rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
|
||||||
|
raise ValueError('Incompatible with shape: ragged rank mismatch')
|
||||||
|
|
||||||
|
src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
|
||||||
|
src_shape = src_shape.broadcast_to_rank(dst_shape.rank)
|
||||||
|
|
||||||
|
# Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
|
||||||
|
if dst_shape.rank > rt_input.shape.ndims:
|
||||||
|
if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
|
||||||
|
rt_input = array_ops.reshape(
|
||||||
|
rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
|
||||||
|
for _ in range(dst_shape.rank - rt_input.shape.ndims):
|
||||||
|
rt_input = ragged_factory_ops.from_row_lengths(
|
||||||
|
rt_input, [ragged_array_ops.nrows(rt_input)])
|
||||||
|
|
||||||
|
# Add ragged dimensions to match dst_shape.
|
||||||
|
if ragged_tensor.is_ragged(rt_input):
|
||||||
|
inner_rank_diff = (
|
||||||
|
rt_input.inner_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
|
||||||
|
if inner_rank_diff > 0:
|
||||||
|
rt_input = rt_input.with_inner_values(
|
||||||
|
ragged_conversion_ops.from_tensor(
|
||||||
|
rt_input.inner_values, ragged_rank=inner_rank_diff))
|
||||||
|
else:
|
||||||
|
rt_input = ragged_conversion_ops.from_tensor(
|
||||||
|
rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1)
|
||||||
|
|
||||||
|
# Do broadcasting for any dimensions that will remain uniform. We can do
|
||||||
|
# these all at once, since they're independent of one another.
|
||||||
|
multiples = [1] * dst_shape.rank
|
||||||
|
for axis in range(dst_shape.num_partitioned_dimensions):
|
||||||
|
if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
|
||||||
|
src_size = src_shape.dimension_size(axis)
|
||||||
|
dst_size = dst_shape.dimension_size(axis)
|
||||||
|
if ((tensor_util.constant_value(src_size) in (1, None)) and
|
||||||
|
(tensor_util.constant_value(dst_size) != 1)):
|
||||||
|
multiples[axis] = array_ops.where(
|
||||||
|
math_ops.equal(src_size, 1), dst_size, 1)
|
||||||
|
if not all(isinstance(v, int) and v == 1 for v in multiples):
|
||||||
|
multiples = array_ops.stack(multiples, axis=0)
|
||||||
|
rt_input = ragged_array_ops.tile(rt_input, multiples)
|
||||||
|
|
||||||
|
if broadcast_inner_dimensions:
|
||||||
|
rt_input = rt_input.with_inner_values(
|
||||||
|
array_ops.reshape(
|
||||||
|
rt_input.inner_values,
|
||||||
|
array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)))
|
||||||
|
|
||||||
|
# Do broadcasting for dimensions that become ragged. We must do these from
|
||||||
|
# outermost to innermost.
|
||||||
|
for axis in range(dst_shape.num_partitioned_dimensions):
|
||||||
|
if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
|
||||||
|
dst_size = dst_shape.dimension_size(axis)
|
||||||
|
rt_input = _ragged_tile_axis(rt_input, axis, dst_size)
|
||||||
|
|
||||||
|
return rt_input
|
||||||
|
|
||||||
|
|
||||||
|
def _ragged_tile_axis(rt_input, axis, repeats):
|
||||||
|
"""Tile a dimension of a RaggedTensor to match a ragged shape."""
|
||||||
|
assert axis > 0 # Outermost dimension may not be ragged.
|
||||||
|
|
||||||
|
if not ragged_tensor.is_ragged(rt_input):
|
||||||
|
rt_input = ragged_conversion_ops.from_tensor(rt_input, ragged_rank=1)
|
||||||
|
|
||||||
|
if axis > 1:
|
||||||
|
return rt_input.with_values(
|
||||||
|
_ragged_tile_axis(rt_input.values, axis - 1, repeats))
|
||||||
|
else:
|
||||||
|
src_row_splits = rt_input.nested_row_splits
|
||||||
|
src_row_lengths = ragged_array_ops.nested_row_lengths(rt_input)
|
||||||
|
splits = src_row_splits[0]
|
||||||
|
|
||||||
|
dst_row_lengths = [repeats]
|
||||||
|
for i in range(1, len(src_row_lengths)):
|
||||||
|
dst_row_lengths.append(
|
||||||
|
ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats))
|
||||||
|
splits = array_ops.gather(src_row_splits[i], splits)
|
||||||
|
dst_values = ragged_util.repeat_ranges(rt_input.inner_values, splits,
|
||||||
|
repeats)
|
||||||
|
return ragged_factory_ops.from_nested_row_lengths(dst_values,
|
||||||
|
dst_row_lengths)
|
||||||
|
|
||||||
487
tensorflow/python/ops/ragged/ragged_tensor_shape_test.py
Normal file
487
tensorflow/python/ops/ragged/ragged_tensor_shape_test.py
Normal file
@ -0,0 +1,487 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tf.ragged.ragged_tensor_shape."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import ragged
|
||||||
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
|
||||||
|
class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
def assertShapeEq(self, x, y):
|
||||||
|
assert isinstance(x, ragged.RaggedTensorDynamicShape)
|
||||||
|
assert isinstance(y, ragged.RaggedTensorDynamicShape)
|
||||||
|
x_partitioned_dim_sizes = [
|
||||||
|
splits.eval().tolist() #
|
||||||
|
for splits in x.partitioned_dim_sizes
|
||||||
|
]
|
||||||
|
y_partitioned_dim_sizes = [
|
||||||
|
splits.eval().tolist() #
|
||||||
|
for splits in y.partitioned_dim_sizes
|
||||||
|
]
|
||||||
|
self.assertEqual(x_partitioned_dim_sizes, y_partitioned_dim_sizes)
|
||||||
|
self.assertEqual(x.inner_dim_sizes.eval().tolist(),
|
||||||
|
y.inner_dim_sizes.eval().tolist())
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
dict(value='x', expected_dim_sizes=[]),
|
||||||
|
dict(value=['a', 'b', 'c'], expected_dim_sizes=[3]),
|
||||||
|
dict(value=[['a', 'b', 'c'], ['d', 'e', 'f']], expected_dim_sizes=[2, 3]),
|
||||||
|
dict(
|
||||||
|
value=[[['a', 'b', 'c'], ['d', 'e', 'f']]],
|
||||||
|
expected_dim_sizes=[1, 2, 3]),
|
||||||
|
dict(
|
||||||
|
value=ragged.constant_value([['a', 'b', 'c'], ['d', 'e']]),
|
||||||
|
expected_dim_sizes=[2, [3, 2]]),
|
||||||
|
dict(
|
||||||
|
value=ragged.constant_value([[['a', 'b', 'c'], ['d', 'e']]]),
|
||||||
|
expected_dim_sizes=[1, [2], [3, 2]]),
|
||||||
|
dict(
|
||||||
|
value=ragged.constant_value([[['a', 'b', 'c'], ['d', 'e', 'f']]],
|
||||||
|
ragged_rank=1),
|
||||||
|
expected_dim_sizes=[1, [2], 3]),
|
||||||
|
dict(
|
||||||
|
value=ragged.constant_value([[[[1], [2]], [[3], [4]]],
|
||||||
|
[[[5], [6]]]], ragged_rank=1),
|
||||||
|
expected_dim_sizes=[2, [2, 1], 2, 1]),
|
||||||
|
dict(
|
||||||
|
value=ragged.constant_value([[10, 20], [30]]),
|
||||||
|
expected_dim_sizes=[2, [2, 1]]),
|
||||||
|
# Docstring examples:
|
||||||
|
dict(value=[[1, 2, 3], [4, 5, 6]], expected_dim_sizes=[2, 3]),
|
||||||
|
dict(
|
||||||
|
value=ragged.constant_value([[1, 2], [], [3, 4, 5]]),
|
||||||
|
expected_dim_sizes=[3, [2, 0, 3]]),
|
||||||
|
dict(
|
||||||
|
value=ragged.constant_value([[[1, 2], [3, 4]], [[5, 6]]],
|
||||||
|
ragged_rank=1),
|
||||||
|
expected_dim_sizes=[2, [2, 1], 2]),
|
||||||
|
dict(
|
||||||
|
value=ragged.constant_value([[[1, 2], [3]], [[4, 5]]]),
|
||||||
|
expected_dim_sizes=[2, [2, 1], [2, 1, 2]]),
|
||||||
|
])
|
||||||
|
def testFromTensor(self, value, expected_dim_sizes):
|
||||||
|
shape = ragged.RaggedTensorDynamicShape.from_tensor(value)
|
||||||
|
expected = ragged.RaggedTensorDynamicShape.from_dim_sizes(
|
||||||
|
expected_dim_sizes)
|
||||||
|
with self.cached_session():
|
||||||
|
self.assertShapeEq(shape, expected)
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
dict(dim_sizes=[], rank=0, expected_dim_sizes=[]),
|
||||||
|
dict(dim_sizes=[], rank=3, expected_dim_sizes=[1, 1, 1]),
|
||||||
|
dict(dim_sizes=[3], rank=1, expected_dim_sizes=[3]),
|
||||||
|
dict(dim_sizes=[3], rank=3, expected_dim_sizes=[1, 1, 3]),
|
||||||
|
dict(dim_sizes=[2, 3], rank=3, expected_dim_sizes=[1, 2, 3]),
|
||||||
|
dict(dim_sizes=[3, [3, 2, 4]], rank=2, expected_dim_sizes=[3, [3, 2, 4]]),
|
||||||
|
dict(
|
||||||
|
dim_sizes=[3, [3, 2, 4]],
|
||||||
|
rank=4,
|
||||||
|
expected_dim_sizes=[1, 1, 3, [3, 2, 4]]),
|
||||||
|
dict(
|
||||||
|
dim_sizes=[3, [3, 2, 4], 2, 3],
|
||||||
|
rank=5,
|
||||||
|
expected_dim_sizes=[1, 3, [3, 2, 4], 2, 3]),
|
||||||
|
])
|
||||||
|
def testBroadcastToRank(self, dim_sizes, rank, expected_dim_sizes):
|
||||||
|
shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
|
||||||
|
expected = ragged.RaggedTensorDynamicShape.from_dim_sizes(
|
||||||
|
expected_dim_sizes)
|
||||||
|
broadcasted_shape = shape.broadcast_to_rank(rank)
|
||||||
|
with self.cached_session():
|
||||||
|
self.assertShapeEq(broadcasted_shape, expected)
|
||||||
|
self.assertEqual(broadcasted_shape.rank, rank)
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
#=========================================================================
|
||||||
|
# dimension[axis] is uniform inner; and row_lengths is a scalar
|
||||||
|
#=========================================================================
|
||||||
|
# shape: [BROADCAST(UNIFORM), UNIFORM, UNIFORM]
|
||||||
|
dict(axis=0,
|
||||||
|
row_length=3,
|
||||||
|
original_dim_sizes=[1, 4, 5],
|
||||||
|
broadcast_dim_sizes=[3, 4, 5]),
|
||||||
|
|
||||||
|
# shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
|
||||||
|
dict(axis=2,
|
||||||
|
row_length=5,
|
||||||
|
original_dim_sizes=[3, 4, 1],
|
||||||
|
broadcast_dim_sizes=[3, 4, 5]),
|
||||||
|
|
||||||
|
# shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)]
|
||||||
|
dict(axis=2,
|
||||||
|
row_length=5,
|
||||||
|
original_dim_sizes=[3, [3, 2, 8], 1],
|
||||||
|
broadcast_dim_sizes=[3, [3, 2, 8], 5]),
|
||||||
|
|
||||||
|
# shape: [UNIFORM, RAGGED, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
|
||||||
|
dict(axis=5,
|
||||||
|
row_length=5,
|
||||||
|
original_dim_sizes=[2, [2, 1], [3, 2, 8], 3, 4, 1],
|
||||||
|
broadcast_dim_sizes=[2, [2, 1], [3, 2, 8], 3, 4, 5]),
|
||||||
|
|
||||||
|
#=========================================================================
|
||||||
|
# dimension[axis] is uniform inner; and row_lengths is a vector
|
||||||
|
#=========================================================================
|
||||||
|
# shape: [UNIFORM, BROADCAST(UNIFORM)]
|
||||||
|
dict(axis=1,
|
||||||
|
row_length=[2, 0, 1],
|
||||||
|
original_dim_sizes=[3, 1],
|
||||||
|
broadcast_dim_sizes=[3, [2, 0, 1]]),
|
||||||
|
# shape: [UNIFORM, BROADCAST(UNIFORM), UNIFORM]
|
||||||
|
dict(axis=1,
|
||||||
|
row_length=[2, 0, 1],
|
||||||
|
original_dim_sizes=[3, 1, 5],
|
||||||
|
broadcast_dim_sizes=[3, [2, 0, 1], 5]),
|
||||||
|
|
||||||
|
# shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
|
||||||
|
dict(axis=2,
|
||||||
|
row_length=[2, 0, 1, 3, 8, 2, 3, 4, 1, 8, 7, 0],
|
||||||
|
original_dim_sizes=[4, 3, 1],
|
||||||
|
broadcast_dim_sizes=[4, 3, [2, 0, 1, 3, 8, 2, 3, 4, 1, 8, 7, 0]]),
|
||||||
|
|
||||||
|
# shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)]
|
||||||
|
dict(axis=2,
|
||||||
|
row_length=[2, 5, 3],
|
||||||
|
original_dim_sizes=[2, [2, 1], 1],
|
||||||
|
broadcast_dim_sizes=[2, [2, 1], [2, 5, 3]]),
|
||||||
|
|
||||||
|
# shape: [UNIFORM, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM), UNIFORM]
|
||||||
|
dict(axis=4,
|
||||||
|
row_length=list(range(18)),
|
||||||
|
original_dim_sizes=[2, [2, 1], 3, 2, 1, 8],
|
||||||
|
broadcast_dim_sizes=[2, [2, 1], 3, 2, list(range(18)), 8]),
|
||||||
|
|
||||||
|
#=========================================================================
|
||||||
|
# dimension[axis] is uniform partitioned; and row_lengths is a scalar
|
||||||
|
#=========================================================================
|
||||||
|
# shape: [BROADCAST(UNIFORM), RAGGED]
|
||||||
|
dict(axis=0,
|
||||||
|
row_length=3,
|
||||||
|
original_dim_sizes=[1, [5]],
|
||||||
|
broadcast_dim_sizes=[3, [5, 5, 5]]),
|
||||||
|
|
||||||
|
# shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED]
|
||||||
|
dict(axis=0,
|
||||||
|
row_length=2,
|
||||||
|
original_dim_sizes=[1, 3, [3, 0, 2]],
|
||||||
|
broadcast_dim_sizes=[2, 3, [3, 0, 2, 3, 0, 2]]),
|
||||||
|
|
||||||
|
# shape: [BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM, UNIFORM]
|
||||||
|
dict(axis=0,
|
||||||
|
row_length=3,
|
||||||
|
original_dim_sizes=[1, [3], [3, 5, 2], 9, 4, 5],
|
||||||
|
broadcast_dim_sizes=[3, [3, 3, 3], [3, 5, 2, 3, 5, 2, 3, 5, 2],
|
||||||
|
9, 4, 5]),
|
||||||
|
|
||||||
|
# shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED, UNIFORM]
|
||||||
|
dict(axis=0,
|
||||||
|
row_length=2,
|
||||||
|
original_dim_sizes=[1, 2, [2, 1], [3, 5, 2], 2],
|
||||||
|
broadcast_dim_sizes=[2, 2, [2, 1, 2, 1], [3, 5, 2, 3, 5, 2], 2]),
|
||||||
|
|
||||||
|
# shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM]
|
||||||
|
dict(axis=1,
|
||||||
|
row_length=2,
|
||||||
|
original_dim_sizes=[3, 1, [4, 0, 2], 5],
|
||||||
|
broadcast_dim_sizes=[3, 2, [4, 0, 2, 4, 0, 2], 5]),
|
||||||
|
|
||||||
|
# shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED]
|
||||||
|
dict(axis=1,
|
||||||
|
row_length=1,
|
||||||
|
original_dim_sizes=[2, 3, (1, 2, 3, 4, 5, 6)],
|
||||||
|
broadcast_dim_sizes=[2, 3, (1, 2, 3, 4, 5, 6)]),
|
||||||
|
|
||||||
|
#=========================================================================
|
||||||
|
# dimension[axis] is uniform partitioned; and row_lengths is a vector
|
||||||
|
#=========================================================================
|
||||||
|
# shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM]
|
||||||
|
dict(axis=1,
|
||||||
|
row_length=[4, 1, 2],
|
||||||
|
original_dim_sizes=[
|
||||||
|
3, # axis=0
|
||||||
|
1, # axis=1 (broadcast)
|
||||||
|
[3, 1, 2], # axis=2
|
||||||
|
5], # axis=3
|
||||||
|
broadcast_dim_sizes=[
|
||||||
|
3, # axis=0
|
||||||
|
[4, 1, 2], # axis=1 (broadcast)
|
||||||
|
[3, 3, 3, 3, 1, 2, 2], # axis=2
|
||||||
|
5]), # axis=3
|
||||||
|
|
||||||
|
# shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, RAGGED]
|
||||||
|
dict(axis=1,
|
||||||
|
row_length=[2, 0, 3],
|
||||||
|
original_dim_sizes=[
|
||||||
|
3, # axis=0
|
||||||
|
1, # axis=1 (broadcast)
|
||||||
|
[3, 1, 2], # axis=2
|
||||||
|
[3, 1, 4, 1, 5, 9]], # axis=3
|
||||||
|
broadcast_dim_sizes=[
|
||||||
|
3, # axis=0
|
||||||
|
[2, 0, 3], # axis=1 (broadcast)
|
||||||
|
[3, 3, 2, 2, 2], # axis=2
|
||||||
|
[3, 1, 4, 3, 1, 4, 5, 9, 5, 9, 5, 9]]), # axis=3
|
||||||
|
|
||||||
|
# shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM]
|
||||||
|
dict(axis=2,
|
||||||
|
row_length=[4, 1, 2],
|
||||||
|
original_dim_sizes=[
|
||||||
|
3, # axis=0
|
||||||
|
[2, 0, 1], # axis=1
|
||||||
|
1, # axis=2 (broadcast)
|
||||||
|
[3, 2, 1], # axis=3
|
||||||
|
[1, 0, 1, 0, 2, 3], # axis=4
|
||||||
|
5], # axis=5
|
||||||
|
broadcast_dim_sizes=[
|
||||||
|
3, # axis=0
|
||||||
|
[2, 0, 1], # axis=2
|
||||||
|
[4, 1, 2], # axis=2 (broadcast)
|
||||||
|
[3, 3, 3, 3, 2, 1, 1], # axis=3
|
||||||
|
[1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, # axis=4
|
||||||
|
2, 3, 3],
|
||||||
|
5]), # axis=5
|
||||||
|
|
||||||
|
dict(axis=0,
|
||||||
|
row_length=2,
|
||||||
|
original_dim_sizes=[1, 1, 2, (2, 1)],
|
||||||
|
broadcast_dim_sizes=[2, 1, 2, (2, 1, 2, 1)]),
|
||||||
|
dict(axis=1,
|
||||||
|
row_length=(2, 1),
|
||||||
|
original_dim_sizes=[2, 1, 2, (2, 1, 2, 1)],
|
||||||
|
broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
|
||||||
|
dict(axis=2,
|
||||||
|
row_length=2,
|
||||||
|
original_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)],
|
||||||
|
broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
|
||||||
|
dict(axis=3,
|
||||||
|
row_length=(2, 1, 2, 1, 2, 1),
|
||||||
|
original_dim_sizes=[2, (2, 1), 2, 1],
|
||||||
|
broadcast_dim_sizes=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
|
||||||
|
]) # pyformat: disable
|
||||||
|
def testBroadcastDimension(self, axis, row_length, original_dim_sizes,
|
||||||
|
broadcast_dim_sizes):
|
||||||
|
"""Tests for the broadcast_dimension method.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
|
||||||
|
* `original.broadcast_dimension(axis, row_length) == broadcast`
|
||||||
|
* `broadcast.broadcast_dimension(axis, row_length) == broadcast`
|
||||||
|
* `broadcast.broadcast_dimension(axis, 1) == broadcast`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis: The axis to broadcast
|
||||||
|
row_length: The slice lengths to broadcast to.
|
||||||
|
original_dim_sizes: The dimension sizes before broadcasting.
|
||||||
|
original_dim_sizes[axis] should be equal to `1` or `row_length`.
|
||||||
|
broadcast_dim_sizes: THe dimension sizes after broadcasting.
|
||||||
|
"""
|
||||||
|
original_shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(
|
||||||
|
original_dim_sizes)
|
||||||
|
broadcast_shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(
|
||||||
|
broadcast_dim_sizes)
|
||||||
|
self.assertEqual(original_shape.rank, broadcast_shape.rank)
|
||||||
|
with self.cached_session():
|
||||||
|
# shape[axis].value == 1 and row_length > 1:
|
||||||
|
bcast1 = original_shape.broadcast_dimension(axis, row_length)
|
||||||
|
# shape[axis].value > 1 and row_length == shape[axis].value:
|
||||||
|
bcast2 = broadcast_shape.broadcast_dimension(axis, row_length)
|
||||||
|
# shape[axis].value > 1 and row_length == 1:
|
||||||
|
bcast3 = broadcast_shape.broadcast_dimension(axis, 1)
|
||||||
|
|
||||||
|
self.assertShapeEq(bcast1, broadcast_shape)
|
||||||
|
self.assertShapeEq(bcast2, broadcast_shape)
|
||||||
|
self.assertShapeEq(bcast3, broadcast_shape)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
[
|
||||||
|
# Broadcast scalar
|
||||||
|
dict(x_dims=[], y_dims=[], expected_dims=[]),
|
||||||
|
dict(x_dims=[], y_dims=[2], expected_dims=[2]),
|
||||||
|
dict(x_dims=[], y_dims=[2, 3], expected_dims=[2, 3]),
|
||||||
|
dict(
|
||||||
|
x_dims=[],
|
||||||
|
y_dims=[2, (2, 3), (5, 7, 2, 0, 9)],
|
||||||
|
expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]),
|
||||||
|
# Broadcast vector
|
||||||
|
dict(x_dims=[3], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]),
|
||||||
|
dict(x_dims=[1], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]),
|
||||||
|
dict(x_dims=[3], y_dims=[4, 2, 1], expected_dims=[4, 2, 3]),
|
||||||
|
dict(
|
||||||
|
x_dims=[3],
|
||||||
|
y_dims=[3, (2, 3, 1), 1],
|
||||||
|
expected_dims=[3, (2, 3, 1), 3]),
|
||||||
|
dict(x_dims=[1], y_dims=[3, (2, 1, 3)], expected_dims=[3, (2, 1, 3)]),
|
||||||
|
dict(
|
||||||
|
x_dims=[1],
|
||||||
|
y_dims=[3, (2, 1, 3), 8],
|
||||||
|
expected_dims=[3, (2, 1, 3), 8]),
|
||||||
|
dict(
|
||||||
|
x_dims=[1],
|
||||||
|
y_dims=[2, (2, 3), (5, 7, 2, 0, 9)],
|
||||||
|
expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]),
|
||||||
|
# Mixed broadcasting
|
||||||
|
dict(
|
||||||
|
x_dims=[
|
||||||
|
1, # axis=0
|
||||||
|
3, # axis=1
|
||||||
|
(3, 0, 2), # axis=2
|
||||||
|
1, # axis=3
|
||||||
|
2, # axis=4
|
||||||
|
],
|
||||||
|
y_dims=[
|
||||||
|
2, # axis=0
|
||||||
|
1, # axis=1
|
||||||
|
1, # axis=2
|
||||||
|
(7, 2), # axis=3
|
||||||
|
1, # axis=4
|
||||||
|
],
|
||||||
|
expected_dims=[
|
||||||
|
2, # axis=0
|
||||||
|
3, # axis=1
|
||||||
|
(3, 0, 2, 3, 0, 2), # axis=2
|
||||||
|
(7, 7, 7, 7, 7, 2, 2, 2, 2, 2), # axis=3
|
||||||
|
2, # axis=4
|
||||||
|
]),
|
||||||
|
dict(
|
||||||
|
x_dims=[2, (2, 1), 2, 1],
|
||||||
|
y_dims=[1, 1, 2, (2, 1)],
|
||||||
|
expected_dims=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
|
||||||
|
])
|
||||||
|
def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims):
|
||||||
|
x_shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(x_dims)
|
||||||
|
y_shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(y_dims)
|
||||||
|
expected = ragged.RaggedTensorDynamicShape.from_dim_sizes(expected_dims)
|
||||||
|
result1 = ragged.broadcast_dynamic_shape(x_shape, y_shape)
|
||||||
|
result2 = ragged.broadcast_dynamic_shape(y_shape, x_shape)
|
||||||
|
with self.cached_session():
|
||||||
|
self.assertShapeEq(expected, result1)
|
||||||
|
self.assertShapeEq(expected, result2)
|
||||||
|
|
||||||
|
def testRepr(self):
|
||||||
|
shape = ragged.RaggedTensorDynamicShape.from_dim_sizes([2, (2, 1), 2, 1])
|
||||||
|
self.assertRegexpMatches(
|
||||||
|
repr(shape),
|
||||||
|
r'RaggedTensorDynamicShape\('
|
||||||
|
r'partitioned_dim_sizes=\(<[^>]+>, <[^>]+>\), '
|
||||||
|
r'inner_dim_sizes=<[^>]+>\)')
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
dict(
|
||||||
|
x=[[10], [20], [30]], # shape=[3, 1]
|
||||||
|
dim_sizes=[3, 2],
|
||||||
|
expected=[[10, 10], [20, 20], [30, 30]]),
|
||||||
|
dict(
|
||||||
|
x=[[10], [20], [30]], # shape=[3, 1]
|
||||||
|
dim_sizes=[3, [3, 0, 2]],
|
||||||
|
expected=ragged.constant_value([[10, 10, 10], [], [30, 30]],
|
||||||
|
dtype=np.int32)),
|
||||||
|
dict(
|
||||||
|
x=[[[1, 2, 3]], [[4, 5, 6]]], # shape = [2, 1, 3]
|
||||||
|
dim_sizes=[2, [2, 3], 3],
|
||||||
|
expected=ragged.constant_value(
|
||||||
|
[[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6], [4, 5, 6]]],
|
||||||
|
dtype=np.int32,
|
||||||
|
ragged_rank=1)),
|
||||||
|
dict(
|
||||||
|
x=[[[1]], [[2]]], # shape = [2, 1, 1]
|
||||||
|
dim_sizes=[2, [2, 3], [0, 2, 1, 2, 0]],
|
||||||
|
expected=ragged.constant_value([[[], [1, 1]], [[2], [2, 2], []]],
|
||||||
|
dtype=np.int32,
|
||||||
|
ragged_rank=2)),
|
||||||
|
dict(
|
||||||
|
x=10,
|
||||||
|
dim_sizes=[3, [3, 0, 2]],
|
||||||
|
expected=ragged.constant_value([[10, 10, 10], [], [10, 10]])),
|
||||||
|
])
|
||||||
|
def testRaggedBroadcastTo(self, x, dim_sizes, expected):
|
||||||
|
shape = ragged.RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
|
||||||
|
result = ragged.broadcast_to(x, shape)
|
||||||
|
with self.cached_session():
|
||||||
|
self.assertEqual(
|
||||||
|
getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank',
|
||||||
|
0))
|
||||||
|
if hasattr(expected, 'tolist'):
|
||||||
|
expected = expected.tolist()
|
||||||
|
self.assertEqual(result.eval().tolist(), expected)
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
dict(
|
||||||
|
doc='x.shape=[3, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
|
||||||
|
x=ragged.constant_value([[1, 2, 3], [], [4, 5]], dtype=np.int32),
|
||||||
|
y=[[10], [20], [30]],
|
||||||
|
expected=ragged.constant_value([[11, 12, 13], [], [34, 35]])),
|
||||||
|
dict(
|
||||||
|
doc='x.shape=[3, (D1)]; y.shape=[]; bcast.shape=[3, (D1)]',
|
||||||
|
x=ragged.constant_value([[1, 2, 3], [], [4, 5]], dtype=np.int32),
|
||||||
|
y=10,
|
||||||
|
expected=ragged.constant_value([[11, 12, 13], [], [14, 15]])),
|
||||||
|
dict(
|
||||||
|
doc='x.shape=[1, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
|
||||||
|
x=ragged.constant_value([[1, 2, 3]], dtype=np.int32),
|
||||||
|
y=[[10], [20], [30]],
|
||||||
|
expected=ragged.constant_value(
|
||||||
|
[[11, 12, 13], [21, 22, 23], [31, 32, 33]], dtype=np.int32)),
|
||||||
|
dict(
|
||||||
|
doc=('x.shape=[2, (D1), 1]; y.shape=[1, (D2)]; '
|
||||||
|
'bcast.shape=[2, (D1), (D2)]'),
|
||||||
|
x=ragged.constant_value([[[1], [2], [3]], [[4]]], ragged_rank=1),
|
||||||
|
y=ragged.constant_value([[10, 20, 30]]),
|
||||||
|
expected=ragged.constant_value([[[11, 21, 31], [12, 22, 32],
|
||||||
|
[13, 23, 33]], [[14, 24, 34]]])),
|
||||||
|
dict(
|
||||||
|
doc=('x.shape=[2, (D1), 1]; y.shape=[1, 1, 4]; '
|
||||||
|
'bcast.shape=[2, (D1), 4]'),
|
||||||
|
x=ragged.constant_value([[[10], [20]], [[30]]], ragged_rank=1),
|
||||||
|
y=[[[1, 2, 3, 4]]],
|
||||||
|
expected=ragged.constant_value(
|
||||||
|
[[[11, 12, 13, 14], [21, 22, 23, 24]], [[31, 32, 33, 34]]],
|
||||||
|
ragged_rank=1)),
|
||||||
|
dict(
|
||||||
|
doc=('x.shape=[2, (D1), 2, 1]; y.shape=[2, (D2)]; '
|
||||||
|
'bcast.shape=[2, (D1), (2), (D2)'),
|
||||||
|
x=ragged.constant_value([[[[1], [2]], [[3], [4]]],
|
||||||
|
[[[5], [6]]]],
|
||||||
|
ragged_rank=1),
|
||||||
|
y=ragged.constant_value([[10, 20], [30]]),
|
||||||
|
expected=ragged.constant_value(
|
||||||
|
[[[[11, 21], [32]], [[13, 23], [34]]],
|
||||||
|
[[[15, 25], [36]]]])),
|
||||||
|
])
|
||||||
|
def testRaggedAddWithBroadcasting(self, x, y, expected, doc):
|
||||||
|
expected_rrank = getattr(expected, 'ragged_rank', 0)
|
||||||
|
x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
|
||||||
|
y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
|
||||||
|
result = x + y
|
||||||
|
result_rrank = getattr(result, 'ragged_rank', 0)
|
||||||
|
self.assertEqual(expected_rrank, result_rrank)
|
||||||
|
if hasattr(expected, 'tolist'):
|
||||||
|
expected = expected.tolist()
|
||||||
|
with self.cached_session():
|
||||||
|
self.assertEqual(result.eval().tolist(), expected)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
googletest.main()
|
||||||
@ -23,6 +23,7 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import ragged
|
from tensorflow.python.ops import ragged
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
@ -178,7 +179,7 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase):
|
|||||||
ragged_rank=2)
|
ragged_rank=2)
|
||||||
rt2 = ragged.constant([[[[9.0, 8.0], [7.0, 6.0]], [[5.0, 4.0]]]],
|
rt2 = ragged.constant([[[[9.0, 8.0], [7.0, 6.0]], [[5.0, 4.0]]]],
|
||||||
ragged_rank=2)
|
ragged_rank=2)
|
||||||
rt = rt1 + rt2 * 2.0
|
rt = ragged.map_inner_values(math_ops.add, rt1, rt2 * 2.0)
|
||||||
st = ragged.to_sparse(rt)
|
st = ragged.to_sparse(rt)
|
||||||
|
|
||||||
g1, g2 = gradients_impl.gradients(st.values, [rt1.inner_values,
|
g1, g2 = gradients_impl.gradients(st.values, [rt1.inner_values,
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import gen_ragged_math_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
@ -229,3 +230,51 @@ def _with_nonzero_rank(data):
|
|||||||
return array_ops.reshape(
|
return array_ops.reshape(
|
||||||
data,
|
data,
|
||||||
array_ops.concat([[1], data_shape], axis=0)[-data_ndims:])
|
array_ops.concat([[1], data_shape], axis=0)[-data_ndims:])
|
||||||
|
|
||||||
|
|
||||||
|
def lengths_to_splits(lengths):
|
||||||
|
"""Returns splits corresponding to the given lengths."""
|
||||||
|
return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_ranges(params, splits, repeats):
|
||||||
|
"""Repeats each range of `params` (as specified by `splits`) `repeats` times.
|
||||||
|
|
||||||
|
Let the `i`th range of `params` be defined as
|
||||||
|
`params[splits[i]:splits[i + 1]]`. Then this function returns a tensor
|
||||||
|
containing range 0 repeated `repeats[0]` times, followed by range 1 repeated
|
||||||
|
`repeats[1]`, ..., followed by the last range repeated `repeats[-1]` times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: The `Tensor` whose values should be repeated.
|
||||||
|
splits: A splits tensor indicating the ranges of `params` that should be
|
||||||
|
repeated.
|
||||||
|
repeats: The number of times each range should be repeated. Supports
|
||||||
|
broadcasting from a scalar value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` with the same rank and type as `params`.
|
||||||
|
|
||||||
|
#### Example:
|
||||||
|
```python
|
||||||
|
>>> repeat_ranges(['a', 'b', 'c'], [0, 2, 3], 3)
|
||||||
|
['a', 'b', 'a', 'b', 'a', 'b', 'c', 'c', 'c']
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Divide `splits` into starts and limits, and repeat them `repeats` times.
|
||||||
|
if repeats.shape.ndims != 0:
|
||||||
|
repeated_starts = repeat(splits[:-1], repeats, axis=0)
|
||||||
|
repeated_limits = repeat(splits[1:], repeats, axis=0)
|
||||||
|
else:
|
||||||
|
# Optimization: we can just call repeat once, and then slice the result.
|
||||||
|
repeated_splits = repeat(splits, repeats, axis=0)
|
||||||
|
n_splits = array_ops.shape(repeated_splits, out_type=dtypes.int64)[0]
|
||||||
|
repeated_starts = repeated_splits[:n_splits - repeats]
|
||||||
|
repeated_limits = repeated_splits[repeats:]
|
||||||
|
|
||||||
|
# Get indices for each range from starts to limits, and use those to gather
|
||||||
|
# the values in the desired repetition pattern.
|
||||||
|
one = array_ops.ones((), repeated_starts.dtype)
|
||||||
|
offsets = gen_ragged_math_ops.ragged_range(
|
||||||
|
repeated_starts, repeated_limits, one)
|
||||||
|
return array_ops.gather(params, offsets.rt_dense_values)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user