From eccb7ec454e6617738554a255d77f08e60ee0808 Mon Sep 17 00:00:00 2001
From: Mihai Maruseac <mihaimaruseac@google.com>
Date: Mon, 19 Oct 2020 17:56:36 -0700
Subject: [PATCH] Prevent segfault in `quantize_and_dequantize`

Fixes #42105.

If `tf.quantization.quantize_and_dequantize` is called with `axis` argument pointing to outside of the input tensor, we obtain a `CHECK` fail which then aborts the application/interpreter. This change adds a condition check and returns a `Status` instead of crashing.

PiperOrigin-RevId: 337972243
Change-Id: I71ec32c00a87266e364fb017f0ad5dfd3e23542f
---
 .../core/kernels/quantize_and_dequantize_op.cc   |  4 ++++
 tensorflow/python/kernel_tests/array_ops_test.py | 16 ++++++++++++++++
 2 files changed, 20 insertions(+)

diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.cc b/tensorflow/core/kernels/quantize_and_dequantize_op.cc
index dec0262cf04..675bdaec225 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op.cc
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op.cc
@@ -71,6 +71,10 @@ class QuantizeAndDequantizeV2Op : public OpKernel {
 
   void Compute(OpKernelContext* ctx) override {
     const Tensor& input = ctx->input(0);
+    OP_REQUIRES(
+        ctx, (axis_ == -1 || axis_ < input.shape().dims()),
+        errors::InvalidArgument("Shape must be at least rank ", axis_ + 1,
+                                " but is rank ", input.shape().dims()));
     const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
     Tensor input_min_tensor;
     Tensor input_max_tensor;
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 8eb5af399b4..4106ea9b166 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1628,6 +1628,22 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase):
                   axis=(axis - 4)))
           self.assertAllClose(fake_quantized, expected)
 
+  def testBadAxis(self):
+    input_tensor = [2.5, 2.5]
+    input_min = [0, 0]
+    input_max = [1, 1]
+    error_message_pattern = "Shape must be at least rank 11 but is rank 1"
+    # TODO(b/171260356): Eager mode and graph mode throw different error types
+    error = errors.InvalidArgumentError if context.executing_eagerly(
+    ) else ValueError
+    with self.assertRaisesRegex(error, error_message_pattern):
+      self.evaluate(
+          array_ops.quantize_and_dequantize_v2(
+              input=input_tensor,
+              input_min=input_min,
+              input_max=input_max,
+              axis=10))
+
   def testQuantizeDequantizeGrad(self):
     shape = (2, 2)
     max_threshold = 0