Merge pull request #39285 from Intel-tensorflow:lesliefang/fix_dequantize_op_regression

PiperOrigin-RevId: 311173226
Change-Id: I9cba49711d9b511b9f7deb8fb361a5118c29cd46
This commit is contained in:
TensorFlower Gardener 2020-05-12 12:07:04 -07:00
commit ac7e71cc3d

View File

@ -61,7 +61,9 @@ class DequantizeOp : public OpKernel {
" is '" +
DataTypeString(ctx->output_type(0)) + "'"));
need_cast_ = true;
if (ctx->output_type(0) == DT_FLOAT) {
need_cast_ = false;
OP_REQUIRES(ctx,
(mode_string == "MIN_COMBINED" ||
mode_string == "MIN_FIRST" || mode_string == "SCALED"),
@ -98,8 +100,9 @@ class DequantizeOp : public OpKernel {
}
Tensor* output = nullptr;
Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape());
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
Tensor float_output =
need_cast_ ? tensorflow::Tensor(DT_FLOAT, input.shape()) : *output;
if (num_slices == 1) {
const float min_range = input_min_tensor.flat<float>()(0);
const float max_range = input_max_tensor.flat<float>()(0);
@ -128,10 +131,12 @@ class DequantizeOp : public OpKernel {
max_ranges(i), output_tensor.template chip<1>(i));
}
}
S* out_ptr = output->flat<S>().data();
float* in_ptr = float_output.flat<float>().data();
for (int64 i = 0; i < float_output.NumElements(); ++i) {
out_ptr[i] = static_cast<S>(in_ptr[i]);
if (need_cast_) {
S* out_ptr = output->flat<S>().data();
float* in_ptr = float_output.flat<float>().data();
for (int64 i = 0; i < float_output.NumElements(); ++i) {
out_ptr[i] = static_cast<S>(in_ptr[i]);
}
}
}
@ -219,6 +224,7 @@ class DequantizeOp : public OpKernel {
int mode_;
int axis_;
bool narrow_range_;
bool need_cast_;
};
REGISTER_KERNEL_BUILDER(Name("Dequantize")