Fix support of ragged expand_dims on other value types (experimental api).

PiperOrigin-RevId: 339066276
Change-Id: If2ec63ede25764f0cdd3edabe290ccc66cdab256
This commit is contained in:
Chenliang Xu 2020-10-26 10:26:52 -07:00 committed by TensorFlower Gardener
parent 6ca5914663
commit 57690e40eb
2 changed files with 7 additions and 1 deletions

View File

@ -447,7 +447,10 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
return ragged_tensor.RaggedTensor.from_uniform_row_length(
input, uniform_row_length=1, nrows=input.nrows(), validate=False)
else:
return input.with_values(expand_dims(input.values, axis - 1))
if ragged_tensor.is_ragged(input.values):
return input.with_values(expand_dims(input.values, axis - 1))
else:
return input.with_values(array_ops.expand_dims(input.values, axis - 1))
#===============================================================================

View File

@ -271,6 +271,9 @@ class RaggedTensorSupportedValuesTest(test_util.TensorFlowTestCase,
'x': ([[-2.0, 3.0], [-3.0]]),
'rate': 0.5,
'seed': 1},
{'op': array_ops.expand_dims_v2,
'x': ([[-2.0, 3.0], [-3.0]]),
'axis': -1},
]) # pyformat: disable
def testUnaryElementwiseOp(self,
x,