Prevent segfault in quantize_and_dequantize
Fixes #42105. If `tf.quantization.quantize_and_dequantize` is called with `axis` argument pointing to outside of the input tensor, we obtain a `CHECK` fail which then aborts the application/interpreter. This change adds a condition check and returns a `Status` instead of crashing. PiperOrigin-RevId: 337972243 Change-Id: I71ec32c00a87266e364fb017f0ad5dfd3e23542f
This commit is contained in:
parent
0225022b72
commit
eccb7ec454
@ -71,6 +71,10 @@ class QuantizeAndDequantizeV2Op : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
const Tensor& input = ctx->input(0);
|
const Tensor& input = ctx->input(0);
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, (axis_ == -1 || axis_ < input.shape().dims()),
|
||||||
|
errors::InvalidArgument("Shape must be at least rank ", axis_ + 1,
|
||||||
|
" but is rank ", input.shape().dims()));
|
||||||
const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
|
const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
|
||||||
Tensor input_min_tensor;
|
Tensor input_min_tensor;
|
||||||
Tensor input_max_tensor;
|
Tensor input_max_tensor;
|
||||||
|
@ -1628,6 +1628,22 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase):
|
|||||||
axis=(axis - 4)))
|
axis=(axis - 4)))
|
||||||
self.assertAllClose(fake_quantized, expected)
|
self.assertAllClose(fake_quantized, expected)
|
||||||
|
|
||||||
|
def testBadAxis(self):
|
||||||
|
input_tensor = [2.5, 2.5]
|
||||||
|
input_min = [0, 0]
|
||||||
|
input_max = [1, 1]
|
||||||
|
error_message_pattern = "Shape must be at least rank 11 but is rank 1"
|
||||||
|
# TODO(b/171260356): Eager mode and graph mode throw different error types
|
||||||
|
error = errors.InvalidArgumentError if context.executing_eagerly(
|
||||||
|
) else ValueError
|
||||||
|
with self.assertRaisesRegex(error, error_message_pattern):
|
||||||
|
self.evaluate(
|
||||||
|
array_ops.quantize_and_dequantize_v2(
|
||||||
|
input=input_tensor,
|
||||||
|
input_min=input_min,
|
||||||
|
input_max=input_max,
|
||||||
|
axis=10))
|
||||||
|
|
||||||
def testQuantizeDequantizeGrad(self):
|
def testQuantizeDequantizeGrad(self):
|
||||||
shape = (2, 2)
|
shape = (2, 2)
|
||||||
max_threshold = 0
|
max_threshold = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user