Used Eigen in Dequantize op for performance.

PiperOrigin-RevId: 253618908
This commit is contained in:
A. Unique TensorFlower 2019-06-17 11:12:49 -07:00 committed by TensorFlower Gardener
parent 84cf4d436e
commit 1efa1e39c1

View File

@ -41,11 +41,6 @@ template <typename Device, typename T>
class DequantizeOp : public OpKernel {
public:
explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
half_range_ = !std::is_signed<T>::value
? 0.0f
: (static_cast<float>(std::numeric_limits<T>::max()) -
std::numeric_limits<T>::min() + 1) /
2.0f;
string mode_string;
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
OP_REQUIRES(ctx,
@ -67,6 +62,12 @@ class DequantizeOp : public OpKernel {
const Tensor& input = ctx->input(0);
const float min_range = ctx->input(1).flat<float>()(0);
const float max_range = ctx->input(2).flat<float>()(0);
const float half_range =
!std::is_signed<T>::value
? 0.0f
: (static_cast<float>(std::numeric_limits<T>::max()) -
std::numeric_limits<T>::min() + 1) /
2.0f;
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
@ -76,15 +77,11 @@ class DequantizeOp : public OpKernel {
(static_cast<float>(std::numeric_limits<T>::max()) -
std::numeric_limits<T>::min());
float* out_ptr = output->flat<float>().data();
const T* in_ptr = input.flat<T>().data();
const auto& input_tensor = input.flat<T>();
output->flat<float>().device(ctx->eigen_device<Device>()) =
((input_tensor.template cast<float>() + half_range) * scale_factor) +
min_range;
const int64 num_elements = input.NumElements();
for (int i = 0; i < num_elements; ++i) {
out_ptr[i] =
((static_cast<int>(in_ptr[i]) + half_range_) * scale_factor) +
min_range;
}
} else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
auto input_ui8_array = input.flat<quint8>();
@ -101,17 +98,14 @@ class DequantizeOp : public OpKernel {
? (max_range / std::numeric_limits<T>::max())
: std::max(min_range / std::numeric_limits<T>::min(),
max_range / std::numeric_limits<T>::max());
float* out_ptr = output->flat<float>().data();
const T* in_ptr = input.flat<T>().data();
const int64 num_elements = input.NumElements();
for (int64 i = 0; i < num_elements; ++i) {
out_ptr[i] = static_cast<int>(in_ptr[i]) * scale_factor;
}
const auto& input_tensor = input.flat<T>();
output->flat<float>().device(ctx->eigen_device<Device>()) =
input_tensor.template cast<int>().template cast<float>() *
scale_factor;
}
}
private:
float half_range_;
int mode_;
};