Bug fix for RaggedTensor.from_tensor(t) when t.size()==0 and ragged_rank>1.
PiperOrigin-RevId: 303174098 Change-Id: I2ee2baf4d04a4adb9896486ec341fb35d0cdc3b4
This commit is contained in:
parent
24f11dc4fc
commit
e8590130e3
@ -29,8 +29,8 @@ from tensorflow.python.platform import googletest
|
|||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
class RaggedTensorFromTensorOpTest(test_util.TensorFlowTestCase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
|
||||||
def testDocStringExamples(self):
|
def testDocStringExamples(self):
|
||||||
# The examples from RaggedTensor.from_tensor.__doc__.
|
# The examples from RaggedTensor.from_tensor.__doc__.
|
||||||
@ -366,6 +366,8 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
|||||||
if expected_shape is not None:
|
if expected_shape is not None:
|
||||||
self.assertEqual(rt.shape.as_list(), expected_shape)
|
self.assertEqual(rt.shape.as_list(), expected_shape)
|
||||||
self.assertAllEqual(rt, expected)
|
self.assertAllEqual(rt, expected)
|
||||||
|
self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
|
||||||
|
rt.flat_values, rt.nested_row_splits, validate=True))
|
||||||
|
|
||||||
def testHighDimensions(self):
|
def testHighDimensions(self):
|
||||||
# Use distinct prime numbers for all dimension shapes in this test, so
|
# Use distinct prime numbers for all dimension shapes in this test, so
|
||||||
@ -380,6 +382,8 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
|||||||
dt.shape.is_compatible_with(rt.shape),
|
dt.shape.is_compatible_with(rt.shape),
|
||||||
'%s is incompatible with %s' % (dt.shape, rt.shape))
|
'%s is incompatible with %s' % (dt.shape, rt.shape))
|
||||||
self.assertAllEqual(rt, self.evaluate(dt).tolist())
|
self.assertAllEqual(rt, self.evaluate(dt).tolist())
|
||||||
|
self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
|
||||||
|
rt.flat_values, rt.nested_row_splits, validate=True))
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
# With no padding or lengths
|
# With no padding or lengths
|
||||||
@ -399,6 +403,10 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
|||||||
'dt_shape': [0, 2, 3],
|
'dt_shape': [0, 2, 3],
|
||||||
'expected': []
|
'expected': []
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
'dt_shape': [1, 0, 0],
|
||||||
|
'expected': [[]]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
'dt_shape': [2, 0, 3],
|
'dt_shape': [2, 0, 3],
|
||||||
'expected': [[], []]
|
'expected': [[], []]
|
||||||
@ -485,11 +493,14 @@ class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
|||||||
)
|
)
|
||||||
def testEmpty(self, dt_shape, expected, lengths=None, padding=None):
|
def testEmpty(self, dt_shape, expected, lengths=None, padding=None):
|
||||||
dt = array_ops.zeros(dt_shape)
|
dt = array_ops.zeros(dt_shape)
|
||||||
rt = RaggedTensor.from_tensor(dt, lengths, padding)
|
for ragged_rank in range(1, len(dt_shape) - 1):
|
||||||
self.assertEqual(type(rt), RaggedTensor)
|
rt = RaggedTensor.from_tensor(dt, lengths, padding, ragged_rank)
|
||||||
self.assertEqual(rt.ragged_rank, 1)
|
self.assertEqual(type(rt), RaggedTensor)
|
||||||
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
|
self.assertEqual(rt.ragged_rank, ragged_rank)
|
||||||
self.assertAllEqual(rt, expected)
|
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
|
||||||
|
self.assertAllEqual(rt, expected)
|
||||||
|
self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
|
||||||
|
rt.flat_values, rt.nested_row_splits, validate=True))
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
{
|
{
|
||||||
|
@ -1481,15 +1481,15 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
|||||||
if ragged_rank > 1:
|
if ragged_rank > 1:
|
||||||
if tensor.shape.is_fully_defined():
|
if tensor.shape.is_fully_defined():
|
||||||
input_shape = tensor.shape.as_list()
|
input_shape = tensor.shape.as_list()
|
||||||
new_shape = [-1] + input_shape[ragged_rank:]
|
|
||||||
# The total number of elements in each dimension. E.g., if
|
# The total number of elements in each dimension. E.g., if
|
||||||
# input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
|
# input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
|
||||||
dim_size = np.cumprod(input_shape)
|
dim_size = np.cumprod(input_shape)
|
||||||
|
new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:]
|
||||||
else:
|
else:
|
||||||
neg_one = constant_op.constant([-1], row_splits_dtype)
|
|
||||||
new_shape = array_ops.concat([neg_one, input_shape[ragged_rank:]],
|
|
||||||
axis=0)
|
|
||||||
dim_size = math_ops.cumprod(input_shape)
|
dim_size = math_ops.cumprod(input_shape)
|
||||||
|
new_shape = array_ops.concat([[dim_size[ragged_rank]],
|
||||||
|
input_shape[ragged_rank:]],
|
||||||
|
axis=0)
|
||||||
flattened = array_ops.reshape(tensor, new_shape)
|
flattened = array_ops.reshape(tensor, new_shape)
|
||||||
result = cls.from_tensor(
|
result = cls.from_tensor(
|
||||||
flattened, lengths, padding, row_splits_dtype=row_splits_dtype)
|
flattened, lengths, padding, row_splits_dtype=row_splits_dtype)
|
||||||
@ -1563,7 +1563,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
|||||||
# If neither padding nor lengths were specified, then create a splits
|
# If neither padding nor lengths were specified, then create a splits
|
||||||
# vector that contains no default values, and reshape the input tensor
|
# vector that contains no default values, and reshape the input tensor
|
||||||
# to form the values for the RaggedTensor.
|
# to form the values for the RaggedTensor.
|
||||||
values_shape = array_ops.concat([[-1], input_shape[2:]], axis=0)
|
values_shape = array_ops.concat([[input_shape[0] * input_shape[1]],
|
||||||
|
input_shape[2:]], axis=0)
|
||||||
values = array_ops.reshape(tensor, values_shape)
|
values = array_ops.reshape(tensor, values_shape)
|
||||||
const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value
|
const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value
|
||||||
const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value
|
const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value
|
||||||
|
Loading…
Reference in New Issue
Block a user