From 74833a04e032766a27890ff882d669d9e484a497 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 8 May 2020 07:34:54 +0800 Subject: [PATCH] fix dequantize op regression issue --- tensorflow/core/kernels/dequantize_op.cc | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc index 0f5a7019b1f..3b38daf0067 100644 --- a/tensorflow/core/kernels/dequantize_op.cc +++ b/tensorflow/core/kernels/dequantize_op.cc @@ -61,7 +61,9 @@ class DequantizeOp : public OpKernel { " is '" + DataTypeString(ctx->output_type(0)) + "'")); + need_cast_ = true; if (ctx->output_type(0) == DT_FLOAT) { + need_cast_ = false; OP_REQUIRES(ctx, (mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" || mode_string == "SCALED"), @@ -98,8 +100,9 @@ class DequantizeOp : public OpKernel { } Tensor* output = nullptr; - Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape()); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); + Tensor float_output = + need_cast_ ? tensorflow::Tensor(DT_FLOAT, input.shape()) : *output; if (num_slices == 1) { const float min_range = input_min_tensor.flat()(0); const float max_range = input_max_tensor.flat()(0); @@ -128,10 +131,12 @@ class DequantizeOp : public OpKernel { max_ranges(i), output_tensor.template chip<1>(i)); } } - S* out_ptr = output->flat().data(); - float* in_ptr = float_output.flat().data(); - for (int64 i = 0; i < float_output.NumElements(); ++i) { - out_ptr[i] = static_cast(in_ptr[i]); + if (need_cast_) { + S* out_ptr = output->flat().data(); + float* in_ptr = float_output.flat().data(); + for (int64 i = 0; i < float_output.NumElements(); ++i) { + out_ptr[i] = static_cast(in_ptr[i]); + } } } @@ -219,6 +224,7 @@ class DequantizeOp : public OpKernel { int mode_; int axis_; bool narrow_range_; + bool need_cast_; }; REGISTER_KERNEL_BUILDER(Name("Dequantize")