Prevent segfault in quantize_and_dequantize

Fixes #42105.

If `tf.quantization.quantize_and_dequantize` is called with `axis` argument pointing to outside of the input tensor, we obtain a `CHECK` fail which then aborts the application/interpreter. This change adds a condition check and returns a `Status` instead of crashing.

PiperOrigin-RevId: 337972243
Change-Id: I71ec32c00a87266e364fb017f0ad5dfd3e23542f
This commit is contained in:
Mihai Maruseac 2020-10-19 17:56:36 -07:00 committed by TensorFlower Gardener
parent 0225022b72
commit eccb7ec454
2 changed files with 20 additions and 0 deletions

View File

@ -71,6 +71,10 @@ class QuantizeAndDequantizeV2Op : public OpKernel {
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0); const Tensor& input = ctx->input(0);
OP_REQUIRES(
ctx, (axis_ == -1 || axis_ < input.shape().dims()),
errors::InvalidArgument("Shape must be at least rank ", axis_ + 1,
" but is rank ", input.shape().dims()));
const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
Tensor input_min_tensor; Tensor input_min_tensor;
Tensor input_max_tensor; Tensor input_max_tensor;

View File

@ -1628,6 +1628,22 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase):
axis=(axis - 4))) axis=(axis - 4)))
self.assertAllClose(fake_quantized, expected) self.assertAllClose(fake_quantized, expected)
def testBadAxis(self):
input_tensor = [2.5, 2.5]
input_min = [0, 0]
input_max = [1, 1]
error_message_pattern = "Shape must be at least rank 11 but is rank 1"
# TODO(b/171260356): Eager mode and graph mode throw different error types
error = errors.InvalidArgumentError if context.executing_eagerly(
) else ValueError
with self.assertRaisesRegex(error, error_message_pattern):
self.evaluate(
array_ops.quantize_and_dequantize_v2(
input=input_tensor,
input_min=input_min,
input_max=input_max,
axis=10))
def testQuantizeDequantizeGrad(self): def testQuantizeDequantizeGrad(self):
shape = (2, 2) shape = (2, 2)
max_threshold = 0 max_threshold = 0