Merge pull request #45675 from yongtang:45392-tf.sparse.reorder

PiperOrigin-RevId: 349589322
Change-Id: I36fa64ded7253c65c2df6ed12c74703a0b2da87b
This commit is contained in:
TensorFlower Gardener 2020-12-30 12:42:27 -08:00
commit b2a7c3d9da
2 changed files with 15 additions and 2 deletions

View File

@ -54,9 +54,10 @@ class SparseReorderOp : public OpKernel {
"Input shape should be a vector but received shape ",
input_shape_in.shape().DebugString()));
const TensorShape input_shape(input_shape_in.vec<int64>());
gtl::ArraySlice<int64> input_shape(input_shape_in.vec<int64>().data(),
input_shape_in.NumElements());
gtl::InlinedVector<int64, 8> std_order(input_shape.dims());
gtl::InlinedVector<int64, 8> std_order(input_shape.size());
std::iota(std_order.begin(), std_order.end(), 0);
// Check if the sparse tensor is already ordered correctly

View File

@ -124,6 +124,18 @@ class SparseReorderTest(test.TestCase):
x_init_value=input_val.values)
self.assertLess(err, 1e-11)
def testShapeOverflow(self):
# Test case for GitHub issue 45392
sp_input = sparse_tensor.SparseTensor(
indices=[[0, 0, 0, 0, 0, 0]],
values=[0.0],
dense_shape=[4096, 4096, 4096, 4096, 4096, 4096])
self.assertAllEqual((4096, 4096, 4096, 4096, 4096, 4096),
sp_input.get_shape())
sp_output = sparse_ops.sparse_reorder(sp_input)
self.assertAllEqual((4096, 4096, 4096, 4096, 4096, 4096),
sp_output.get_shape())
if __name__ == "__main__":
test.main()