Fix bug where the ragged version of tf.where could fail if the condition is a vector, and only one of (x, y) is ragged.
PiperOrigin-RevId: 246023503
This commit is contained in:
parent
1c6d02d8b8
commit
52c7b0d597
@ -126,10 +126,11 @@ def _elementwise_where(condition, x, y):
|
||||
elif not condition_is_ragged:
|
||||
# Concatenate x and y, and then use `gather` to assemble the selected rows.
|
||||
condition.shape.assert_has_rank(1)
|
||||
x_nrows = _nrows(x)
|
||||
x_and_y = ragged_concat_ops.concat([x, y], axis=0)
|
||||
x_nrows = _nrows(x, out_type=x_and_y.row_splits.dtype)
|
||||
y_nrows = _nrows(y, out_type=x_and_y.row_splits.dtype)
|
||||
indices = array_ops.where(condition, math_ops.range(x_nrows),
|
||||
x_nrows + math_ops.range(_nrows(y)))
|
||||
x_nrows + math_ops.range(y_nrows))
|
||||
return ragged_gather_ops.gather(x_and_y, indices)
|
||||
|
||||
else:
|
||||
@ -159,8 +160,8 @@ def _coordinate_where(condition):
|
||||
axis=1)
|
||||
|
||||
|
||||
def _nrows(rt_input):
|
||||
def _nrows(rt_input, out_type):
|
||||
if isinstance(rt_input, ragged_tensor.RaggedTensor):
|
||||
return rt_input.nrows()
|
||||
return rt_input.nrows(out_type=out_type)
|
||||
else:
|
||||
return array_ops.shape(rt_input)[0]
|
||||
return array_ops.shape(rt_input, out_type=out_type)[0]
|
||||
|
@ -155,12 +155,19 @@ class RaggedWhereOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
#=========================================================================
|
||||
# Elementwise row-selection mode
|
||||
#=========================================================================
|
||||
dict( # shape=[D1, D2]
|
||||
dict( # x.shape=[D1, D2], y.shape=[D1, D2]
|
||||
condition=[True, False, True],
|
||||
x=[['A', 'B'], ['C', 'D'], ['E', 'F']],
|
||||
y=[['a', 'b'], ['c', 'd'], ['e', 'f']],
|
||||
expected=[[b'A', b'B'], [b'c', b'd'], [b'E', b'F']]),
|
||||
dict( # shape=[D1, (D2)]
|
||||
dict( # x.shape=[D1, D2], y.shape=[D1, (D2)]
|
||||
condition=[True, False, True],
|
||||
x=[['A', 'B'], ['C', 'D'], ['E', 'F']],
|
||||
y=ragged_factory_ops.constant_value(
|
||||
[['a', 'b'], ['c'], ['d', 'e']]),
|
||||
expected=ragged_factory_ops.constant_value(
|
||||
[[b'A', b'B'], [b'c'], [b'E', b'F']])),
|
||||
dict( # x.shape=[D1, (D2)], y.shape=[D1, (D2)]
|
||||
condition=[True, False, True],
|
||||
x=ragged_factory_ops.constant_value(
|
||||
[['A', 'B', 'C'], ['D', 'E'], ['F', 'G']]),
|
||||
|
Loading…
Reference in New Issue
Block a user