In RaggedTensor.to_tensor(), allow shape to be a list whose elements are ints or tensors.

PiperOrigin-RevId: 338261958
Change-Id: I352442c1f5f42e4d66910d04a5ed271ffbff5182
This commit is contained in:
Edward Loper 2020-10-21 07:50:27 -07:00 committed by TensorFlower Gardener
parent 243008368f
commit 112e89e34a
2 changed files with 10 additions and 0 deletions

View File

@ -1737,6 +1737,11 @@ class RaggedTensor(composite_tensor.CompositeTensor,
if default_value is None:
default_value = array_ops.zeros((), self.dtype)
if (isinstance(shape, (list, tuple)) and
any(isinstance(v, ops.Tensor) for v in shape) and
all(isinstance(v, (int, ops.Tensor)) for v in shape)):
shape = array_ops.stack(shape)
shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor(
shape=shape_tensor,

View File

@ -720,6 +720,11 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
return array_ops.placeholder_with_default(arg, [None] * arg.shape.rank)
raise AssertionError('Unexpected shape_info %r' % shape_info)
def test_shape_is_list_including_tensor_element(self):
rt = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]])
result = rt.to_tensor(shape=[2, constant_op.constant(2)])
self.assertAllEqual(result, [[1, 2], [4, 0]])
class RaggedToDenseBenchmark(googletest.Benchmark):