Fix bug where the ragged version of tf.where could fail if the condition is a vector, and only one of (x, y) is ragged.

PiperOrigin-RevId: 246023503
This commit is contained in:
Edward Loper 2019-04-30 14:37:05 -07:00 committed by TensorFlower Gardener
parent 1c6d02d8b8
commit 52c7b0d597
2 changed files with 15 additions and 7 deletions

View File

@ -126,10 +126,11 @@ def _elementwise_where(condition, x, y):
elif not condition_is_ragged:
# Concatenate x and y, and then use `gather` to assemble the selected rows.
condition.shape.assert_has_rank(1)
x_nrows = _nrows(x)
x_and_y = ragged_concat_ops.concat([x, y], axis=0)
x_nrows = _nrows(x, out_type=x_and_y.row_splits.dtype)
y_nrows = _nrows(y, out_type=x_and_y.row_splits.dtype)
indices = array_ops.where(condition, math_ops.range(x_nrows),
x_nrows + math_ops.range(_nrows(y)))
x_nrows + math_ops.range(y_nrows))
return ragged_gather_ops.gather(x_and_y, indices)
else:
@ -159,8 +160,8 @@ def _coordinate_where(condition):
axis=1)
def _nrows(rt_input):
def _nrows(rt_input, out_type):
if isinstance(rt_input, ragged_tensor.RaggedTensor):
return rt_input.nrows()
return rt_input.nrows(out_type=out_type)
else:
return array_ops.shape(rt_input)[0]
return array_ops.shape(rt_input, out_type=out_type)[0]

View File

@ -155,12 +155,19 @@ class RaggedWhereOpTest(ragged_test_util.RaggedTensorTestCase,
#=========================================================================
# Elementwise row-selection mode
#=========================================================================
dict( # shape=[D1, D2]
dict( # x.shape=[D1, D2], y.shape=[D1, D2]
condition=[True, False, True],
x=[['A', 'B'], ['C', 'D'], ['E', 'F']],
y=[['a', 'b'], ['c', 'd'], ['e', 'f']],
expected=[[b'A', b'B'], [b'c', b'd'], [b'E', b'F']]),
dict( # shape=[D1, (D2)]
dict( # x.shape=[D1, D2], y.shape=[D1, (D2)]
condition=[True, False, True],
x=[['A', 'B'], ['C', 'D'], ['E', 'F']],
y=ragged_factory_ops.constant_value(
[['a', 'b'], ['c'], ['d', 'e']]),
expected=ragged_factory_ops.constant_value(
[[b'A', b'B'], [b'c'], [b'E', b'F']])),
dict( # x.shape=[D1, (D2)], y.shape=[D1, (D2)]
condition=[True, False, True],
x=ragged_factory_ops.constant_value(
[['A', 'B', 'C'], ['D', 'E'], ['F', 'G']]),