Fix bug in RaggedTensor broadcasting that impacted broadcasting inner dense dimensions.
PiperOrigin-RevId: 330564050 Change-Id: I2226c49996cd386260c999906ed6fdf5cf8da417
This commit is contained in:
parent
f67204eea0
commit
2db7fda9dd
@ -582,10 +582,12 @@ def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
|
||||
rt_input = ragged_array_ops.tile(rt_input, multiples)
|
||||
|
||||
if broadcast_inner_dimensions:
|
||||
new_shape = array_ops.broadcast_dynamic_shape(
|
||||
array_ops.shape(
|
||||
rt_input.flat_values, out_type=dst_shape.dim_size_dtype),
|
||||
array_ops.concat([[1], dst_shape.inner_dim_sizes], axis=0))
|
||||
rt_input = rt_input.with_flat_values(
|
||||
array_ops.reshape(
|
||||
rt_input.flat_values,
|
||||
array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)))
|
||||
array_ops.broadcast_to(rt_input.flat_values, new_shape))
|
||||
|
||||
# Do broadcasting for dimensions that become ragged. We must do these from
|
||||
# outermost to innermost.
|
||||
|
@ -378,37 +378,41 @@ class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
|
||||
r'partitioned_dim_sizes=\(<[^>]+>, <[^>]+>\), '
|
||||
r'inner_dim_sizes=<[^>]+>\)')
|
||||
|
||||
@parameterized.parameters(
|
||||
[
|
||||
dict(
|
||||
x=[[10], [20], [30]], # shape=[3, 1]
|
||||
dim_sizes=[3, 2],
|
||||
expected=[[10, 10], [20, 20], [30, 30]]),
|
||||
dict(
|
||||
x=[[10], [20], [30]], # shape=[3, 1]
|
||||
dim_sizes=[3, [3, 0, 2]],
|
||||
expected=ragged_factory_ops.constant_value(
|
||||
[[10, 10, 10], [], [30, 30]], dtype=np.int32)),
|
||||
dict(
|
||||
x=[[[1, 2, 3]], [[4, 5, 6]]], # shape = [2, 1, 3]
|
||||
dim_sizes=[2, [2, 3], 3],
|
||||
expected=ragged_factory_ops.constant_value(
|
||||
[[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6], [4, 5, 6]]],
|
||||
dtype=np.int32,
|
||||
ragged_rank=1)),
|
||||
dict(
|
||||
x=[[[1]], [[2]]], # shape = [2, 1, 1]
|
||||
dim_sizes=[2, [2, 3], [0, 2, 1, 2, 0]],
|
||||
expected=ragged_factory_ops.constant_value(
|
||||
[[[], [1, 1]], [[2], [2, 2], []]],
|
||||
dtype=np.int32,
|
||||
ragged_rank=2)),
|
||||
dict(
|
||||
x=10,
|
||||
dim_sizes=[3, [3, 0, 2]],
|
||||
expected=ragged_factory_ops.constant_value([[10, 10, 10], [],
|
||||
[10, 10]])),
|
||||
])
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
x=[[10], [20], [30]], # shape=[3, 1]
|
||||
dim_sizes=[3, 2],
|
||||
expected=[[10, 10], [20, 20], [30, 30]]),
|
||||
dict(
|
||||
x=[[10], [20], [30]], # shape=[3, 1]
|
||||
dim_sizes=[3, [3, 0, 2]],
|
||||
expected=ragged_factory_ops.constant_value(
|
||||
[[10, 10, 10], [], [30, 30]], dtype=np.int32)),
|
||||
dict(
|
||||
x=[[[1, 2, 3]], [[4, 5, 6]]], # shape = [2, 1, 3]
|
||||
dim_sizes=[2, [2, 3], 3],
|
||||
expected=ragged_factory_ops.constant_value(
|
||||
[[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6], [4, 5, 6]]],
|
||||
dtype=np.int32,
|
||||
ragged_rank=1)),
|
||||
dict(
|
||||
x=[[[1]], [[2]]], # shape = [2, 1, 1]
|
||||
dim_sizes=[2, [2, 3], [0, 2, 1, 2, 0]],
|
||||
expected=ragged_factory_ops.constant_value(
|
||||
[[[], [1, 1]], [[2], [2, 2], []]], dtype=np.int32,
|
||||
ragged_rank=2)),
|
||||
dict(
|
||||
x=10,
|
||||
dim_sizes=[3, [3, 0, 2]],
|
||||
expected=ragged_factory_ops.constant_value([[10, 10, 10], [],
|
||||
[10, 10]])),
|
||||
dict(
|
||||
x=ragged_factory_ops.constant_value([[[1], [2]], [[3]]],
|
||||
ragged_rank=1),
|
||||
dim_sizes=[2, [2, 1], 2],
|
||||
expected=ragged_factory_ops.constant_value(
|
||||
[[[1, 1], [2, 2]], [[3, 3]]], ragged_rank=1)),
|
||||
])
|
||||
def testRaggedBroadcastTo(self, x, dim_sizes, expected):
|
||||
shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
|
||||
result = ragged_tensor_shape.broadcast_to(x, shape)
|
||||
|
Loading…
Reference in New Issue
Block a user