Fix support of ragged expand_dims on other value types (experimental api).
PiperOrigin-RevId: 339066276 Change-Id: If2ec63ede25764f0cdd3edabe290ccc66cdab256
This commit is contained in:
parent
6ca5914663
commit
57690e40eb
@ -447,7 +447,10 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
||||
return ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
input, uniform_row_length=1, nrows=input.nrows(), validate=False)
|
||||
else:
|
||||
return input.with_values(expand_dims(input.values, axis - 1))
|
||||
if ragged_tensor.is_ragged(input.values):
|
||||
return input.with_values(expand_dims(input.values, axis - 1))
|
||||
else:
|
||||
return input.with_values(array_ops.expand_dims(input.values, axis - 1))
|
||||
|
||||
|
||||
#===============================================================================
|
||||
|
@ -271,6 +271,9 @@ class RaggedTensorSupportedValuesTest(test_util.TensorFlowTestCase,
|
||||
'x': ([[-2.0, 3.0], [-3.0]]),
|
||||
'rate': 0.5,
|
||||
'seed': 1},
|
||||
{'op': array_ops.expand_dims_v2,
|
||||
'x': ([[-2.0, 3.0], [-3.0]]),
|
||||
'axis': -1},
|
||||
]) # pyformat: disable
|
||||
def testUnaryElementwiseOp(self,
|
||||
x,
|
||||
|
Loading…
x
Reference in New Issue
Block a user