fix dequantize op regression issue

This commit is contained in:
leslie-fang-intel 2020-05-08 07:34:54 +08:00
parent f3b33698c5
commit 74833a04e0

View File

@ -61,7 +61,9 @@ class DequantizeOp : public OpKernel {
" is '" +
DataTypeString(ctx->output_type(0)) + "'"));
need_cast_ = true;
if (ctx->output_type(0) == DT_FLOAT) {
need_cast_ = false;
OP_REQUIRES(ctx,
(mode_string == "MIN_COMBINED" ||
mode_string == "MIN_FIRST" || mode_string == "SCALED"),
@ -98,8 +100,9 @@ class DequantizeOp : public OpKernel {
}
Tensor* output = nullptr;
Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape());
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
Tensor float_output =
need_cast_ ? tensorflow::Tensor(DT_FLOAT, input.shape()) : *output;
if (num_slices == 1) {
const float min_range = input_min_tensor.flat<float>()(0);
const float max_range = input_max_tensor.flat<float>()(0);
@ -128,10 +131,12 @@ class DequantizeOp : public OpKernel {
max_ranges(i), output_tensor.template chip<1>(i));
}
}
S* out_ptr = output->flat<S>().data();
float* in_ptr = float_output.flat<float>().data();
for (int64 i = 0; i < float_output.NumElements(); ++i) {
out_ptr[i] = static_cast<S>(in_ptr[i]);
if (need_cast_) {
S* out_ptr = output->flat<S>().data();
float* in_ptr = float_output.flat<float>().data();
for (int64 i = 0; i < float_output.NumElements(); ++i) {
out_ptr[i] = static_cast<S>(in_ptr[i]);
}
}
}
@ -219,6 +224,7 @@ class DequantizeOp : public OpKernel {
int mode_;
int axis_;
bool narrow_range_;
bool need_cast_;
};
REGISTER_KERNEL_BUILDER(Name("Dequantize")