Merge pull request #45675 from yongtang:45392-tf.sparse.reorder
PiperOrigin-RevId: 349589322 Change-Id: I36fa64ded7253c65c2df6ed12c74703a0b2da87b
This commit is contained in:
commit
b2a7c3d9da
@ -54,9 +54,10 @@ class SparseReorderOp : public OpKernel {
|
|||||||
"Input shape should be a vector but received shape ",
|
"Input shape should be a vector but received shape ",
|
||||||
input_shape_in.shape().DebugString()));
|
input_shape_in.shape().DebugString()));
|
||||||
|
|
||||||
const TensorShape input_shape(input_shape_in.vec<int64>());
|
gtl::ArraySlice<int64> input_shape(input_shape_in.vec<int64>().data(),
|
||||||
|
input_shape_in.NumElements());
|
||||||
|
|
||||||
gtl::InlinedVector<int64, 8> std_order(input_shape.dims());
|
gtl::InlinedVector<int64, 8> std_order(input_shape.size());
|
||||||
std::iota(std_order.begin(), std_order.end(), 0);
|
std::iota(std_order.begin(), std_order.end(), 0);
|
||||||
|
|
||||||
// Check if the sparse tensor is already ordered correctly
|
// Check if the sparse tensor is already ordered correctly
|
||||||
|
@ -124,6 +124,18 @@ class SparseReorderTest(test.TestCase):
|
|||||||
x_init_value=input_val.values)
|
x_init_value=input_val.values)
|
||||||
self.assertLess(err, 1e-11)
|
self.assertLess(err, 1e-11)
|
||||||
|
|
||||||
|
def testShapeOverflow(self):
|
||||||
|
# Test case for GitHub issue 45392
|
||||||
|
sp_input = sparse_tensor.SparseTensor(
|
||||||
|
indices=[[0, 0, 0, 0, 0, 0]],
|
||||||
|
values=[0.0],
|
||||||
|
dense_shape=[4096, 4096, 4096, 4096, 4096, 4096])
|
||||||
|
self.assertAllEqual((4096, 4096, 4096, 4096, 4096, 4096),
|
||||||
|
sp_input.get_shape())
|
||||||
|
sp_output = sparse_ops.sparse_reorder(sp_input)
|
||||||
|
self.assertAllEqual((4096, 4096, 4096, 4096, 4096, 4096),
|
||||||
|
sp_output.get_shape())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user