From eccb7ec454e6617738554a255d77f08e60ee0808 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac <mihaimaruseac@google.com> Date: Mon, 19 Oct 2020 17:56:36 -0700 Subject: [PATCH] Prevent segfault in `quantize_and_dequantize` Fixes #42105. If `tf.quantization.quantize_and_dequantize` is called with `axis` argument pointing to outside of the input tensor, we obtain a `CHECK` fail which then aborts the application/interpreter. This change adds a condition check and returns a `Status` instead of crashing. PiperOrigin-RevId: 337972243 Change-Id: I71ec32c00a87266e364fb017f0ad5dfd3e23542f --- .../core/kernels/quantize_and_dequantize_op.cc | 4 ++++ tensorflow/python/kernel_tests/array_ops_test.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.cc b/tensorflow/core/kernels/quantize_and_dequantize_op.cc index dec0262cf04..675bdaec225 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op.cc @@ -71,6 +71,10 @@ class QuantizeAndDequantizeV2Op : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& input = ctx->input(0); + OP_REQUIRES( + ctx, (axis_ == -1 || axis_ < input.shape().dims()), + errors::InvalidArgument("Shape must be at least rank ", axis_ + 1, + " but is rank ", input.shape().dims())); const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); Tensor input_min_tensor; Tensor input_max_tensor; diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 8eb5af399b4..4106ea9b166 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1628,6 +1628,22 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase): axis=(axis - 4))) self.assertAllClose(fake_quantized, expected) + def testBadAxis(self): + input_tensor = [2.5, 2.5] + input_min = [0, 0] + input_max = [1, 1] + error_message_pattern = "Shape must be at least rank 11 but is rank 1" + # TODO(b/171260356): Eager mode and graph mode throw different error types + error = errors.InvalidArgumentError if context.executing_eagerly( + ) else ValueError + with self.assertRaisesRegex(error, error_message_pattern): + self.evaluate( + array_ops.quantize_and_dequantize_v2( + input=input_tensor, + input_min=input_min, + input_max=input_max, + axis=10)) + def testQuantizeDequantizeGrad(self): shape = (2, 2) max_threshold = 0