fix dequantize op regression issue
This commit is contained in:
parent
f3b33698c5
commit
74833a04e0
@ -61,7 +61,9 @@ class DequantizeOp : public OpKernel {
|
||||
" is '" +
|
||||
DataTypeString(ctx->output_type(0)) + "'"));
|
||||
|
||||
need_cast_ = true;
|
||||
if (ctx->output_type(0) == DT_FLOAT) {
|
||||
need_cast_ = false;
|
||||
OP_REQUIRES(ctx,
|
||||
(mode_string == "MIN_COMBINED" ||
|
||||
mode_string == "MIN_FIRST" || mode_string == "SCALED"),
|
||||
@ -98,8 +100,9 @@ class DequantizeOp : public OpKernel {
|
||||
}
|
||||
|
||||
Tensor* output = nullptr;
|
||||
Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape());
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
|
||||
Tensor float_output =
|
||||
need_cast_ ? tensorflow::Tensor(DT_FLOAT, input.shape()) : *output;
|
||||
if (num_slices == 1) {
|
||||
const float min_range = input_min_tensor.flat<float>()(0);
|
||||
const float max_range = input_max_tensor.flat<float>()(0);
|
||||
@ -128,10 +131,12 @@ class DequantizeOp : public OpKernel {
|
||||
max_ranges(i), output_tensor.template chip<1>(i));
|
||||
}
|
||||
}
|
||||
S* out_ptr = output->flat<S>().data();
|
||||
float* in_ptr = float_output.flat<float>().data();
|
||||
for (int64 i = 0; i < float_output.NumElements(); ++i) {
|
||||
out_ptr[i] = static_cast<S>(in_ptr[i]);
|
||||
if (need_cast_) {
|
||||
S* out_ptr = output->flat<S>().data();
|
||||
float* in_ptr = float_output.flat<float>().data();
|
||||
for (int64 i = 0; i < float_output.NumElements(); ++i) {
|
||||
out_ptr[i] = static_cast<S>(in_ptr[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -219,6 +224,7 @@ class DequantizeOp : public OpKernel {
|
||||
int mode_;
|
||||
int axis_;
|
||||
bool narrow_range_;
|
||||
bool need_cast_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
|
Loading…
Reference in New Issue
Block a user