In RaggedTensor.to_tensor(), allow shape to be a list whose elements are ints or tensors.
PiperOrigin-RevId: 338261958 Change-Id: I352442c1f5f42e4d66910d04a5ed271ffbff5182
This commit is contained in:
parent
243008368f
commit
112e89e34a
@ -1737,6 +1737,11 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
if default_value is None:
|
||||
default_value = array_ops.zeros((), self.dtype)
|
||||
|
||||
if (isinstance(shape, (list, tuple)) and
|
||||
any(isinstance(v, ops.Tensor) for v in shape) and
|
||||
all(isinstance(v, (int, ops.Tensor)) for v in shape)):
|
||||
shape = array_ops.stack(shape)
|
||||
|
||||
shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
|
||||
tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor(
|
||||
shape=shape_tensor,
|
||||
|
@ -720,6 +720,11 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
|
||||
return array_ops.placeholder_with_default(arg, [None] * arg.shape.rank)
|
||||
raise AssertionError('Unexpected shape_info %r' % shape_info)
|
||||
|
||||
def test_shape_is_list_including_tensor_element(self):
|
||||
rt = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]])
|
||||
result = rt.to_tensor(shape=[2, constant_op.constant(2)])
|
||||
self.assertAllEqual(result, [[1, 2], [4, 0]])
|
||||
|
||||
|
||||
class RaggedToDenseBenchmark(googletest.Benchmark):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user