Merge pull request #46717 from yongtang:46693-reshape-validation

PiperOrigin-RevId: 360890621
Change-Id: I53d1ff444227858a24ec5d3bae76c3d137129f5d
This commit is contained in:
TensorFlower Gardener 2021-03-04 05:16:35 -08:00
commit fab8fb9b9a
2 changed files with 21 additions and 0 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
#include <memory>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/overflow.h"
namespace tensorflow {
@ -135,6 +137,17 @@ class ReshapeOp : public OpKernel {
shape->AddDim(size);
*has_zero_dim = true;
} else {
if (MultiplyWithoutOverflow(shape->num_elements(), size) < 0) {
string msg;
for (int ii = 0; ii < num_dims; ++ii) {
if (ii != 0) {
strings::StrAppend(&msg, ", ");
}
strings::StrAppend(&msg, Svec(ii));
}
return errors::InvalidArgument("Shape [", msg,
"] has too many elements");
}
shape->AddDim(size);
(*product) *= size;
}

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
@ -214,6 +215,13 @@ class ReshapeTest(test.TestCase):
y = array_ops.reshape(x, [1, 50000**2])
self.assertEqual([1, 50000**2], y.get_shape().as_list())
@test_util.run_v2_only
def testTooLargeShape(self):
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
"too many elements"):
x = array_ops.reshape([1], np.array([21943, 45817, 30516, 61760, 38987]))
self.evaluate(x)
if __name__ == "__main__":
test.main()