Bug fix for RaggedTensor.from_tensor(t) when t.size()==0 and ragged_rank>1.
PiperOrigin-RevId: 303174098 Change-Id: I2ee2baf4d04a4adb9896486ec341fb35d0cdc3b4
This commit is contained in:
parent
24f11dc4fc
commit
e8590130e3
@ -29,8 +29,8 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
class RaggedTensorFromTensorOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testDocStringExamples(self):
|
||||
# The examples from RaggedTensor.from_tensor.__doc__.
|
||||
@ -366,6 +366,8 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
||||
if expected_shape is not None:
|
||||
self.assertEqual(rt.shape.as_list(), expected_shape)
|
||||
self.assertAllEqual(rt, expected)
|
||||
self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
|
||||
rt.flat_values, rt.nested_row_splits, validate=True))
|
||||
|
||||
def testHighDimensions(self):
|
||||
# Use distinct prime numbers for all dimension shapes in this test, so
|
||||
@ -380,6 +382,8 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
||||
dt.shape.is_compatible_with(rt.shape),
|
||||
'%s is incompatible with %s' % (dt.shape, rt.shape))
|
||||
self.assertAllEqual(rt, self.evaluate(dt).tolist())
|
||||
self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
|
||||
rt.flat_values, rt.nested_row_splits, validate=True))
|
||||
|
||||
@parameterized.parameters(
|
||||
# With no padding or lengths
|
||||
@ -399,6 +403,10 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
||||
'dt_shape': [0, 2, 3],
|
||||
'expected': []
|
||||
},
|
||||
{
|
||||
'dt_shape': [1, 0, 0],
|
||||
'expected': [[]]
|
||||
},
|
||||
{
|
||||
'dt_shape': [2, 0, 3],
|
||||
'expected': [[], []]
|
||||
@ -485,11 +493,14 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
||||
)
|
||||
def testEmpty(self, dt_shape, expected, lengths=None, padding=None):
|
||||
dt = array_ops.zeros(dt_shape)
|
||||
rt = RaggedTensor.from_tensor(dt, lengths, padding)
|
||||
self.assertEqual(type(rt), RaggedTensor)
|
||||
self.assertEqual(rt.ragged_rank, 1)
|
||||
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
|
||||
self.assertAllEqual(rt, expected)
|
||||
for ragged_rank in range(1, len(dt_shape) - 1):
|
||||
rt = RaggedTensor.from_tensor(dt, lengths, padding, ragged_rank)
|
||||
self.assertEqual(type(rt), RaggedTensor)
|
||||
self.assertEqual(rt.ragged_rank, ragged_rank)
|
||||
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
|
||||
self.assertAllEqual(rt, expected)
|
||||
self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
|
||||
rt.flat_values, rt.nested_row_splits, validate=True))
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
|
@ -1481,15 +1481,15 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
if ragged_rank > 1:
|
||||
if tensor.shape.is_fully_defined():
|
||||
input_shape = tensor.shape.as_list()
|
||||
new_shape = [-1] + input_shape[ragged_rank:]
|
||||
# The total number of elements in each dimension. E.g., if
|
||||
# input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
|
||||
dim_size = np.cumprod(input_shape)
|
||||
new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:]
|
||||
else:
|
||||
neg_one = constant_op.constant([-1], row_splits_dtype)
|
||||
new_shape = array_ops.concat([neg_one, input_shape[ragged_rank:]],
|
||||
axis=0)
|
||||
dim_size = math_ops.cumprod(input_shape)
|
||||
new_shape = array_ops.concat([[dim_size[ragged_rank]],
|
||||
input_shape[ragged_rank:]],
|
||||
axis=0)
|
||||
flattened = array_ops.reshape(tensor, new_shape)
|
||||
result = cls.from_tensor(
|
||||
flattened, lengths, padding, row_splits_dtype=row_splits_dtype)
|
||||
@ -1563,7 +1563,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
# If neither padding nor lengths were specified, then create a splits
|
||||
# vector that contains no default values, and reshape the input tensor
|
||||
# to form the values for the RaggedTensor.
|
||||
values_shape = array_ops.concat([[-1], input_shape[2:]], axis=0)
|
||||
values_shape = array_ops.concat([[input_shape[0] * input_shape[1]],
|
||||
input_shape[2:]], axis=0)
|
||||
values = array_ops.reshape(tensor, values_shape)
|
||||
const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value
|
||||
const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value
|
||||
|
Loading…
Reference in New Issue
Block a user