Used Eigen in Dequantize op for performance.

PiperOrigin-RevId: 253618908
This commit is contained in:
A. Unique TensorFlower 2019-06-17 11:12:49 -07:00 committed by TensorFlower Gardener
parent 84cf4d436e
commit 1efa1e39c1

View File

@ -41,11 +41,6 @@ template <typename Device, typename T>
class DequantizeOp : public OpKernel { class DequantizeOp : public OpKernel {
public: public:
explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
half_range_ = !std::is_signed<T>::value
? 0.0f
: (static_cast<float>(std::numeric_limits<T>::max()) -
std::numeric_limits<T>::min() + 1) /
2.0f;
string mode_string; string mode_string;
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string)); OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
OP_REQUIRES(ctx, OP_REQUIRES(ctx,
@ -67,6 +62,12 @@ class DequantizeOp : public OpKernel {
const Tensor& input = ctx->input(0); const Tensor& input = ctx->input(0);
const float min_range = ctx->input(1).flat<float>()(0); const float min_range = ctx->input(1).flat<float>()(0);
const float max_range = ctx->input(2).flat<float>()(0); const float max_range = ctx->input(2).flat<float>()(0);
const float half_range =
!std::is_signed<T>::value
? 0.0f
: (static_cast<float>(std::numeric_limits<T>::max()) -
std::numeric_limits<T>::min() + 1) /
2.0f;
Tensor* output = nullptr; Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
@ -76,15 +77,11 @@ class DequantizeOp : public OpKernel {
(static_cast<float>(std::numeric_limits<T>::max()) - (static_cast<float>(std::numeric_limits<T>::max()) -
std::numeric_limits<T>::min()); std::numeric_limits<T>::min());
float* out_ptr = output->flat<float>().data(); const auto& input_tensor = input.flat<T>();
const T* in_ptr = input.flat<T>().data(); output->flat<float>().device(ctx->eigen_device<Device>()) =
((input_tensor.template cast<float>() + half_range) * scale_factor) +
min_range;
const int64 num_elements = input.NumElements();
for (int i = 0; i < num_elements; ++i) {
out_ptr[i] =
((static_cast<int>(in_ptr[i]) + half_range_) * scale_factor) +
min_range;
}
} else if (mode_ == QUANTIZE_MODE_MIN_FIRST) { } else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) { if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
auto input_ui8_array = input.flat<quint8>(); auto input_ui8_array = input.flat<quint8>();
@ -101,17 +98,14 @@ class DequantizeOp : public OpKernel {
? (max_range / std::numeric_limits<T>::max()) ? (max_range / std::numeric_limits<T>::max())
: std::max(min_range / std::numeric_limits<T>::min(), : std::max(min_range / std::numeric_limits<T>::min(),
max_range / std::numeric_limits<T>::max()); max_range / std::numeric_limits<T>::max());
float* out_ptr = output->flat<float>().data(); const auto& input_tensor = input.flat<T>();
const T* in_ptr = input.flat<T>().data(); output->flat<float>().device(ctx->eigen_device<Device>()) =
const int64 num_elements = input.NumElements(); input_tensor.template cast<int>().template cast<float>() *
for (int64 i = 0; i < num_elements; ++i) { scale_factor;
out_ptr[i] = static_cast<int>(in_ptr[i]) * scale_factor;
}
} }
} }
private: private:
float half_range_;
int mode_; int mode_;
}; };