Prevent segfault in quantize_and_dequantize
Fixes #42105. If `tf.quantization.quantize_and_dequantize` is called with `axis` argument pointing to outside of the input tensor, we obtain a `CHECK` fail which then aborts the application/interpreter. This change adds a condition check and returns a `Status` instead of crashing. PiperOrigin-RevId: 337972243 Change-Id: I71ec32c00a87266e364fb017f0ad5dfd3e23542f
This commit is contained in:
parent
0225022b72
commit
eccb7ec454
tensorflow
@ -71,6 +71,10 @@ class QuantizeAndDequantizeV2Op : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& input = ctx->input(0);
|
||||
OP_REQUIRES(
|
||||
ctx, (axis_ == -1 || axis_ < input.shape().dims()),
|
||||
errors::InvalidArgument("Shape must be at least rank ", axis_ + 1,
|
||||
" but is rank ", input.shape().dims()));
|
||||
const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
|
||||
Tensor input_min_tensor;
|
||||
Tensor input_max_tensor;
|
||||
|
@ -1628,6 +1628,22 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase):
|
||||
axis=(axis - 4)))
|
||||
self.assertAllClose(fake_quantized, expected)
|
||||
|
||||
def testBadAxis(self):
|
||||
input_tensor = [2.5, 2.5]
|
||||
input_min = [0, 0]
|
||||
input_max = [1, 1]
|
||||
error_message_pattern = "Shape must be at least rank 11 but is rank 1"
|
||||
# TODO(b/171260356): Eager mode and graph mode throw different error types
|
||||
error = errors.InvalidArgumentError if context.executing_eagerly(
|
||||
) else ValueError
|
||||
with self.assertRaisesRegex(error, error_message_pattern):
|
||||
self.evaluate(
|
||||
array_ops.quantize_and_dequantize_v2(
|
||||
input=input_tensor,
|
||||
input_min=input_min,
|
||||
input_max=input_max,
|
||||
axis=10))
|
||||
|
||||
def testQuantizeDequantizeGrad(self):
|
||||
shape = (2, 2)
|
||||
max_threshold = 0
|
||||
|
Loading…
Reference in New Issue
Block a user