Merge pull request #45675 from yongtang:45392-tf.sparse.reorder
PiperOrigin-RevId: 349589322 Change-Id: I36fa64ded7253c65c2df6ed12c74703a0b2da87b
This commit is contained in:
commit
b2a7c3d9da
@ -54,9 +54,10 @@ class SparseReorderOp : public OpKernel {
|
||||
"Input shape should be a vector but received shape ",
|
||||
input_shape_in.shape().DebugString()));
|
||||
|
||||
const TensorShape input_shape(input_shape_in.vec<int64>());
|
||||
gtl::ArraySlice<int64> input_shape(input_shape_in.vec<int64>().data(),
|
||||
input_shape_in.NumElements());
|
||||
|
||||
gtl::InlinedVector<int64, 8> std_order(input_shape.dims());
|
||||
gtl::InlinedVector<int64, 8> std_order(input_shape.size());
|
||||
std::iota(std_order.begin(), std_order.end(), 0);
|
||||
|
||||
// Check if the sparse tensor is already ordered correctly
|
||||
|
@ -124,6 +124,18 @@ class SparseReorderTest(test.TestCase):
|
||||
x_init_value=input_val.values)
|
||||
self.assertLess(err, 1e-11)
|
||||
|
||||
def testShapeOverflow(self):
|
||||
# Test case for GitHub issue 45392
|
||||
sp_input = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0, 0, 0, 0, 0]],
|
||||
values=[0.0],
|
||||
dense_shape=[4096, 4096, 4096, 4096, 4096, 4096])
|
||||
self.assertAllEqual((4096, 4096, 4096, 4096, 4096, 4096),
|
||||
sp_input.get_shape())
|
||||
sp_output = sparse_ops.sparse_reorder(sp_input)
|
||||
self.assertAllEqual((4096, 4096, 4096, 4096, 4096, 4096),
|
||||
sp_output.get_shape())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user