From 63690f568c87237a76d91af21177456fcd8fdfd1 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 17 Dec 2020 00:00:10 +0000 Subject: [PATCH] Allow sparse tensor with reorder work on potentially large dims, Update: review comments addressed. Signed-off-by: Yong Tang --- tensorflow/core/kernels/sparse_reorder_op.cc | 18 +++--------------- .../kernel_tests/sparse_reorder_op_test.py | 7 +++---- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/tensorflow/core/kernels/sparse_reorder_op.cc b/tensorflow/core/kernels/sparse_reorder_op.cc index 465bff8d3cf..b7fb56185c3 100644 --- a/tensorflow/core/kernels/sparse_reorder_op.cc +++ b/tensorflow/core/kernels/sparse_reorder_op.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/util/sparse/sparse_tensor.h" -#include "tensorflow/core/util/overflow.h" namespace tensorflow { @@ -55,21 +54,10 @@ class SparseReorderOp : public OpKernel { "Input shape should be a vector but received shape ", input_shape_in.shape().DebugString())); - // Check if the sparse tensor input shape is valid - int64 total = 1; - for (int64 i = 0; i < input_shape_in.NumElements(); ++i) { - int dim = input_shape_in.vec()(i); - OP_REQUIRES(context, (dim >= 0), - errors::InvalidArgument("Dimension ", dim, " must be >= 0")); - total = MultiplyWithoutOverflow(total, dim); - OP_REQUIRES(context, (total > 0), - errors::InvalidArgument( - "Shape would have more than 2**63 - 1 elements")); - } + gtl::ArraySlice input_shape( + input_shape_in.vec().data(), input_shape_in.NumElements()); - const TensorShape input_shape(input_shape_in.vec()); - - gtl::InlinedVector std_order(input_shape.dims()); + gtl::InlinedVector std_order(input_shape.size()); std::iota(std_order.begin(), std_order.end(), 0); // Check if the sparse tensor is already ordered correctly diff --git a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py index d05fd4e9985..c889fb14fbd 100644 --- a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors_impl from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -132,9 +131,9 @@ class SparseReorderTest(test.TestCase): dense_shape=[4096, 4096, 4096, 4096, 4096, 4096]) self.assertAllEqual( (4096, 4096, 4096, 4096, 4096, 4096), sp_input.get_shape()) - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "Shape would have more than"): - sp_output = sparse_ops.sparse_reorder(sp_input) + sp_output = sparse_ops.sparse_reorder(sp_input) + self.assertAllEqual( + (4096, 4096, 4096, 4096, 4096, 4096), sp_output.get_shape()) if __name__ == "__main__":