diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h index 155f8dafc9c..5eea0a2a537 100644 --- a/tensorflow/core/kernels/reshape_op.h +++ b/tensorflow/core/kernels/reshape_op.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_ #include + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/overflow.h" namespace tensorflow { @@ -135,6 +137,17 @@ class ReshapeOp : public OpKernel { shape->AddDim(size); *has_zero_dim = true; } else { + if (MultiplyWithoutOverflow(shape->num_elements(), size) < 0) { + string msg; + for (int ii = 0; ii < num_dims; ++ii) { + if (ii != 0) { + strings::StrAppend(&msg, ", "); + } + strings::StrAppend(&msg, Svec(ii)); + } + return errors::InvalidArgument("Shape [", msg, + "] has too many elements"); + } shape->AddDim(size); (*product) *= size; } diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py index 80f72554aeb..f33bebb5bb5 100644 --- a/tensorflow/python/kernel_tests/reshape_op_test.py +++ b/tensorflow/python/kernel_tests/reshape_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -214,6 +215,13 @@ class ReshapeTest(test.TestCase): y = array_ops.reshape(x, [1, 50000**2]) self.assertEqual([1, 50000**2], y.get_shape().as_list()) + @test_util.run_v2_only + def testTooLargeShape(self): + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "too many elements"): + x = array_ops.reshape([1], np.array([21943, 45817, 30516, 61760, 38987])) + self.evaluate(x) + if __name__ == "__main__": test.main()