Replace usage of math_ops.maximum with math_ops.reduce_max when getting max length from SparseTensors.
PiperOrigin-RevId: 170748309
This commit is contained in:
parent
f08c961c97
commit
75cac0a5d5
@ -30,6 +30,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -527,6 +528,50 @@ class PaddingTest(test.TestCase):
|
||||
self.assertTrue(
|
||||
math_ops.reduce_all(math_ops.equal(val, padded_seq[key])).eval())
|
||||
|
||||
def testPaddingOnlySparse(self):
|
||||
ind1 = np.array([[0], [2]])
|
||||
val1 = np.array([3, 4])
|
||||
shape1 = np.array([4])
|
||||
|
||||
ind2 = np.array([[1], [2]])
|
||||
val2 = np.array([9, 12])
|
||||
shape2 = np.array([5])
|
||||
|
||||
with ops.Graph().as_default() as g, self.test_session(graph=g):
|
||||
sp_tensor1 = sparse_tensor.SparseTensor(
|
||||
indices=array_ops.constant(ind1, dtypes.int64),
|
||||
values=array_ops.constant(val1, dtypes.int64),
|
||||
dense_shape=array_ops.constant(shape1, dtypes.int64))
|
||||
sp_tensor2 = sparse_tensor.SparseTensor(
|
||||
indices=array_ops.constant(ind2, dtypes.int64),
|
||||
values=array_ops.constant(val2, dtypes.int64),
|
||||
dense_shape=array_ops.constant(shape2, dtypes.int64))
|
||||
|
||||
sp_tensor1_expected = sparse_tensor.SparseTensor(
|
||||
indices=sp_tensor1.indices,
|
||||
values=sp_tensor1.values,
|
||||
dense_shape=[8])
|
||||
sp_tensor2_expected = sparse_tensor.SparseTensor(
|
||||
indices=sp_tensor2.indices,
|
||||
values=sp_tensor2.values,
|
||||
dense_shape=[8])
|
||||
|
||||
sequences = {
|
||||
"key_1": sp_tensor1,
|
||||
"key_2": sp_tensor2,
|
||||
}
|
||||
_, padded_seq = sqss._padding(sequences, 4)
|
||||
|
||||
expected_padded_seq = {
|
||||
"key_1": sp_tensor1_expected,
|
||||
"key_2": sp_tensor2_expected,
|
||||
}
|
||||
|
||||
for key, val in expected_padded_seq.items():
|
||||
self.assertAllEqual(
|
||||
sparse_ops.sparse_tensor_to_dense(val).eval(),
|
||||
sparse_ops.sparse_tensor_to_dense(padded_seq[key]).eval())
|
||||
|
||||
|
||||
class SparseTensorReConstructionTest(test.TestCase):
|
||||
|
||||
|
@ -1596,7 +1596,7 @@ def _padding(sequences, num_unroll):
|
||||
else: # Only have SparseTensors
|
||||
sparse_lengths = [value.dense_shape[0] for value in sequences_dict.values()
|
||||
if isinstance(value, sparse_tensor.SparseTensor)]
|
||||
length = math_ops.maximum(sparse_lengths)
|
||||
length = math_ops.reduce_max(math_ops.to_int32(sparse_lengths))
|
||||
|
||||
unroll = array_ops.constant(num_unroll)
|
||||
padded_length = length + ((unroll - (length % unroll)) % unroll)
|
||||
|
Loading…
x
Reference in New Issue
Block a user