fix dequantize op regression issue

This commit is contained in:
leslie-fang-intel 2020-05-08 07:34:54 +08:00
parent f3b33698c5
commit 74833a04e0

View File

@ -61,7 +61,9 @@ class DequantizeOp : public OpKernel {
" is '" + " is '" +
DataTypeString(ctx->output_type(0)) + "'")); DataTypeString(ctx->output_type(0)) + "'"));
need_cast_ = true;
if (ctx->output_type(0) == DT_FLOAT) { if (ctx->output_type(0) == DT_FLOAT) {
need_cast_ = false;
OP_REQUIRES(ctx, OP_REQUIRES(ctx,
(mode_string == "MIN_COMBINED" || (mode_string == "MIN_COMBINED" ||
mode_string == "MIN_FIRST" || mode_string == "SCALED"), mode_string == "MIN_FIRST" || mode_string == "SCALED"),
@ -98,8 +100,9 @@ class DequantizeOp : public OpKernel {
} }
Tensor* output = nullptr; Tensor* output = nullptr;
Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape());
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
Tensor float_output =
need_cast_ ? tensorflow::Tensor(DT_FLOAT, input.shape()) : *output;
if (num_slices == 1) { if (num_slices == 1) {
const float min_range = input_min_tensor.flat<float>()(0); const float min_range = input_min_tensor.flat<float>()(0);
const float max_range = input_max_tensor.flat<float>()(0); const float max_range = input_max_tensor.flat<float>()(0);
@ -128,12 +131,14 @@ class DequantizeOp : public OpKernel {
max_ranges(i), output_tensor.template chip<1>(i)); max_ranges(i), output_tensor.template chip<1>(i));
} }
} }
if (need_cast_) {
S* out_ptr = output->flat<S>().data(); S* out_ptr = output->flat<S>().data();
float* in_ptr = float_output.flat<float>().data(); float* in_ptr = float_output.flat<float>().data();
for (int64 i = 0; i < float_output.NumElements(); ++i) { for (int64 i = 0; i < float_output.NumElements(); ++i) {
out_ptr[i] = static_cast<S>(in_ptr[i]); out_ptr[i] = static_cast<S>(in_ptr[i]);
} }
} }
}
void DequantizeTensor(OpKernelContext* ctx, const Tensor& input, void DequantizeTensor(OpKernelContext* ctx, const Tensor& input,
const float min_range, const float max_range, const float min_range, const float max_range,
@ -219,6 +224,7 @@ class DequantizeOp : public OpKernel {
int mode_; int mode_;
int axis_; int axis_;
bool narrow_range_; bool narrow_range_;
bool need_cast_;
}; };
REGISTER_KERNEL_BUILDER(Name("Dequantize") REGISTER_KERNEL_BUILDER(Name("Dequantize")