diff --git a/tensorflow/python/ops/ragged/ragged_where_op.py b/tensorflow/python/ops/ragged/ragged_where_op.py index 28f79ec8752..542f53a176e 100644 --- a/tensorflow/python/ops/ragged/ragged_where_op.py +++ b/tensorflow/python/ops/ragged/ragged_where_op.py @@ -126,10 +126,11 @@ def _elementwise_where(condition, x, y): elif not condition_is_ragged: # Concatenate x and y, and then use `gather` to assemble the selected rows. condition.shape.assert_has_rank(1) - x_nrows = _nrows(x) x_and_y = ragged_concat_ops.concat([x, y], axis=0) + x_nrows = _nrows(x, out_type=x_and_y.row_splits.dtype) + y_nrows = _nrows(y, out_type=x_and_y.row_splits.dtype) indices = array_ops.where(condition, math_ops.range(x_nrows), - x_nrows + math_ops.range(_nrows(y))) + x_nrows + math_ops.range(y_nrows)) return ragged_gather_ops.gather(x_and_y, indices) else: @@ -159,8 +160,8 @@ def _coordinate_where(condition): axis=1) -def _nrows(rt_input): +def _nrows(rt_input, out_type): if isinstance(rt_input, ragged_tensor.RaggedTensor): - return rt_input.nrows() + return rt_input.nrows(out_type=out_type) else: - return array_ops.shape(rt_input)[0] + return array_ops.shape(rt_input, out_type=out_type)[0] diff --git a/tensorflow/python/ops/ragged/ragged_where_op_test.py b/tensorflow/python/ops/ragged/ragged_where_op_test.py index e76a04072a5..d54e2c76ef1 100644 --- a/tensorflow/python/ops/ragged/ragged_where_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_where_op_test.py @@ -155,12 +155,19 @@ class RaggedWhereOpTest(ragged_test_util.RaggedTensorTestCase, #========================================================================= # Elementwise row-selection mode #========================================================================= - dict( # shape=[D1, D2] + dict( # x.shape=[D1, D2], y.shape=[D1, D2] condition=[True, False, True], x=[['A', 'B'], ['C', 'D'], ['E', 'F']], y=[['a', 'b'], ['c', 'd'], ['e', 'f']], expected=[[b'A', b'B'], [b'c', b'd'], [b'E', b'F']]), - dict( # shape=[D1, (D2)] + dict( # x.shape=[D1, D2], y.shape=[D1, (D2)] + condition=[True, False, True], + x=[['A', 'B'], ['C', 'D'], ['E', 'F']], + y=ragged_factory_ops.constant_value( + [['a', 'b'], ['c'], ['d', 'e']]), + expected=ragged_factory_ops.constant_value( + [[b'A', b'B'], [b'c'], [b'E', b'F']])), + dict( # x.shape=[D1, (D2)], y.shape=[D1, (D2)] condition=[True, False, True], x=ragged_factory_ops.constant_value( [['A', 'B', 'C'], ['D', 'E'], ['F', 'G']]),