Merge pull request #46717 from yongtang:46693-reshape-validation
PiperOrigin-RevId: 360890621 Change-Id: I53d1ff444227858a24ec5d3bae76c3d137129f5d
This commit is contained in:
commit
fab8fb9b9a
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
@ -24,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/util/overflow.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -135,6 +137,17 @@ class ReshapeOp : public OpKernel {
|
|||||||
shape->AddDim(size);
|
shape->AddDim(size);
|
||||||
*has_zero_dim = true;
|
*has_zero_dim = true;
|
||||||
} else {
|
} else {
|
||||||
|
if (MultiplyWithoutOverflow(shape->num_elements(), size) < 0) {
|
||||||
|
string msg;
|
||||||
|
for (int ii = 0; ii < num_dims; ++ii) {
|
||||||
|
if (ii != 0) {
|
||||||
|
strings::StrAppend(&msg, ", ");
|
||||||
|
}
|
||||||
|
strings::StrAppend(&msg, Svec(ii));
|
||||||
|
}
|
||||||
|
return errors::InvalidArgument("Shape [", msg,
|
||||||
|
"] has too many elements");
|
||||||
|
}
|
||||||
shape->AddDim(size);
|
shape->AddDim(size);
|
||||||
(*product) *= size;
|
(*product) *= size;
|
||||||
}
|
}
|
||||||
|
@ -22,6 +22,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -214,6 +215,13 @@ class ReshapeTest(test.TestCase):
|
|||||||
y = array_ops.reshape(x, [1, 50000**2])
|
y = array_ops.reshape(x, [1, 50000**2])
|
||||||
self.assertEqual([1, 50000**2], y.get_shape().as_list())
|
self.assertEqual([1, 50000**2], y.get_shape().as_list())
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testTooLargeShape(self):
|
||||||
|
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
||||||
|
"too many elements"):
|
||||||
|
x = array_ops.reshape([1], np.array([21943, 45817, 30516, 61760, 38987]))
|
||||||
|
self.evaluate(x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user