diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc index 28940e08494..2bf7b9c1834 100644 --- a/tensorflow/core/kernels/dequantize_op.cc +++ b/tensorflow/core/kernels/dequantize_op.cc @@ -41,11 +41,6 @@ template class DequantizeOp : public OpKernel { public: explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - half_range_ = !std::is_signed::value - ? 0.0f - : (static_cast(std::numeric_limits::max()) - - std::numeric_limits::min() + 1) / - 2.0f; string mode_string; OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string)); OP_REQUIRES(ctx, @@ -67,6 +62,12 @@ class DequantizeOp : public OpKernel { const Tensor& input = ctx->input(0); const float min_range = ctx->input(1).flat()(0); const float max_range = ctx->input(2).flat()(0); + const float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max()) - + std::numeric_limits::min() + 1) / + 2.0f; Tensor* output = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); @@ -76,15 +77,11 @@ class DequantizeOp : public OpKernel { (static_cast(std::numeric_limits::max()) - std::numeric_limits::min()); - float* out_ptr = output->flat().data(); - const T* in_ptr = input.flat().data(); + const auto& input_tensor = input.flat(); + output->flat().device(ctx->eigen_device()) = + ((input_tensor.template cast() + half_range) * scale_factor) + + min_range; - const int64 num_elements = input.NumElements(); - for (int i = 0; i < num_elements; ++i) { - out_ptr[i] = - ((static_cast(in_ptr[i]) + half_range_) * scale_factor) + - min_range; - } } else if (mode_ == QUANTIZE_MODE_MIN_FIRST) { if (meta::IsSupportedAndEnabled() && std::is_same()) { auto input_ui8_array = input.flat(); @@ -101,17 +98,14 @@ class DequantizeOp : public OpKernel { ? (max_range / std::numeric_limits::max()) : std::max(min_range / std::numeric_limits::min(), max_range / std::numeric_limits::max()); - float* out_ptr = output->flat().data(); - const T* in_ptr = input.flat().data(); - const int64 num_elements = input.NumElements(); - for (int64 i = 0; i < num_elements; ++i) { - out_ptr[i] = static_cast(in_ptr[i]) * scale_factor; - } + const auto& input_tensor = input.flat(); + output->flat().device(ctx->eigen_device()) = + input_tensor.template cast().template cast() * + scale_factor; } } private: - float half_range_; int mode_; };