Bug fix for RaggedTensor.from_tensor(t) when t.size()==0 and ragged_rank>1.

PiperOrigin-RevId: 303174098
Change-Id: I2ee2baf4d04a4adb9896486ec341fb35d0cdc3b4
This commit is contained in:
Edward Loper 2020-03-26 12:45:59 -07:00 committed by TensorFlower Gardener
parent 24f11dc4fc
commit e8590130e3
2 changed files with 24 additions and 12 deletions

View File

@ -29,7 +29,7 @@ from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
class RaggedTensorFromTensorOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def testDocStringExamples(self):
@ -366,6 +366,8 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
if expected_shape is not None:
self.assertEqual(rt.shape.as_list(), expected_shape)
self.assertAllEqual(rt, expected)
self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
rt.flat_values, rt.nested_row_splits, validate=True))
def testHighDimensions(self):
# Use distinct prime numbers for all dimension shapes in this test, so
@ -380,6 +382,8 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
dt.shape.is_compatible_with(rt.shape),
'%s is incompatible with %s' % (dt.shape, rt.shape))
self.assertAllEqual(rt, self.evaluate(dt).tolist())
self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
rt.flat_values, rt.nested_row_splits, validate=True))
@parameterized.parameters(
# With no padding or lengths
@ -399,6 +403,10 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
'dt_shape': [0, 2, 3],
'expected': []
},
{
'dt_shape': [1, 0, 0],
'expected': [[]]
},
{
'dt_shape': [2, 0, 3],
'expected': [[], []]
@ -485,11 +493,14 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
)
def testEmpty(self, dt_shape, expected, lengths=None, padding=None):
dt = array_ops.zeros(dt_shape)
rt = RaggedTensor.from_tensor(dt, lengths, padding)
for ragged_rank in range(1, len(dt_shape) - 1):
rt = RaggedTensor.from_tensor(dt, lengths, padding, ragged_rank)
self.assertEqual(type(rt), RaggedTensor)
self.assertEqual(rt.ragged_rank, 1)
self.assertEqual(rt.ragged_rank, ragged_rank)
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
self.assertAllEqual(rt, expected)
self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
rt.flat_values, rt.nested_row_splits, validate=True))
@parameterized.parameters(
{

View File

@ -1481,15 +1481,15 @@ class RaggedTensor(composite_tensor.CompositeTensor):
if ragged_rank > 1:
if tensor.shape.is_fully_defined():
input_shape = tensor.shape.as_list()
new_shape = [-1] + input_shape[ragged_rank:]
# The total number of elements in each dimension. E.g., if
# input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
dim_size = np.cumprod(input_shape)
new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:]
else:
neg_one = constant_op.constant([-1], row_splits_dtype)
new_shape = array_ops.concat([neg_one, input_shape[ragged_rank:]],
axis=0)
dim_size = math_ops.cumprod(input_shape)
new_shape = array_ops.concat([[dim_size[ragged_rank]],
input_shape[ragged_rank:]],
axis=0)
flattened = array_ops.reshape(tensor, new_shape)
result = cls.from_tensor(
flattened, lengths, padding, row_splits_dtype=row_splits_dtype)
@ -1563,7 +1563,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
# If neither padding nor lengths were specified, then create a splits
# vector that contains no default values, and reshape the input tensor
# to form the values for the RaggedTensor.
values_shape = array_ops.concat([[-1], input_shape[2:]], axis=0)
values_shape = array_ops.concat([[input_shape[0] * input_shape[1]],
input_shape[2:]], axis=0)
values = array_ops.reshape(tensor, values_shape)
const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value
const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value