Merge pull request #46717 from yongtang:46693-reshape-validation
PiperOrigin-RevId: 360890621 Change-Id: I53d1ff444227858a24ec5d3bae76c3d137129f5d
This commit is contained in:
commit
fab8fb9b9a
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -24,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/overflow.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -135,6 +137,17 @@ class ReshapeOp : public OpKernel {
|
||||
shape->AddDim(size);
|
||||
*has_zero_dim = true;
|
||||
} else {
|
||||
if (MultiplyWithoutOverflow(shape->num_elements(), size) < 0) {
|
||||
string msg;
|
||||
for (int ii = 0; ii < num_dims; ++ii) {
|
||||
if (ii != 0) {
|
||||
strings::StrAppend(&msg, ", ");
|
||||
}
|
||||
strings::StrAppend(&msg, Svec(ii));
|
||||
}
|
||||
return errors::InvalidArgument("Shape [", msg,
|
||||
"] has too many elements");
|
||||
}
|
||||
shape->AddDim(size);
|
||||
(*product) *= size;
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -214,6 +215,13 @@ class ReshapeTest(test.TestCase):
|
||||
y = array_ops.reshape(x, [1, 50000**2])
|
||||
self.assertEqual([1, 50000**2], y.get_shape().as_list())
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTooLargeShape(self):
|
||||
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
||||
"too many elements"):
|
||||
x = array_ops.reshape([1], np.array([21943, 45817, 30516, 61760, 38987]))
|
||||
self.evaluate(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user