Used Eigen in Dequantize op for performance.
PiperOrigin-RevId: 253618908
This commit is contained in:
parent
84cf4d436e
commit
1efa1e39c1
@ -41,11 +41,6 @@ template <typename Device, typename T>
|
|||||||
class DequantizeOp : public OpKernel {
|
class DequantizeOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
half_range_ = !std::is_signed<T>::value
|
|
||||||
? 0.0f
|
|
||||||
: (static_cast<float>(std::numeric_limits<T>::max()) -
|
|
||||||
std::numeric_limits<T>::min() + 1) /
|
|
||||||
2.0f;
|
|
||||||
string mode_string;
|
string mode_string;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
|
||||||
OP_REQUIRES(ctx,
|
OP_REQUIRES(ctx,
|
||||||
@ -67,6 +62,12 @@ class DequantizeOp : public OpKernel {
|
|||||||
const Tensor& input = ctx->input(0);
|
const Tensor& input = ctx->input(0);
|
||||||
const float min_range = ctx->input(1).flat<float>()(0);
|
const float min_range = ctx->input(1).flat<float>()(0);
|
||||||
const float max_range = ctx->input(2).flat<float>()(0);
|
const float max_range = ctx->input(2).flat<float>()(0);
|
||||||
|
const float half_range =
|
||||||
|
!std::is_signed<T>::value
|
||||||
|
? 0.0f
|
||||||
|
: (static_cast<float>(std::numeric_limits<T>::max()) -
|
||||||
|
std::numeric_limits<T>::min() + 1) /
|
||||||
|
2.0f;
|
||||||
|
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
|
||||||
@ -76,15 +77,11 @@ class DequantizeOp : public OpKernel {
|
|||||||
(static_cast<float>(std::numeric_limits<T>::max()) -
|
(static_cast<float>(std::numeric_limits<T>::max()) -
|
||||||
std::numeric_limits<T>::min());
|
std::numeric_limits<T>::min());
|
||||||
|
|
||||||
float* out_ptr = output->flat<float>().data();
|
const auto& input_tensor = input.flat<T>();
|
||||||
const T* in_ptr = input.flat<T>().data();
|
output->flat<float>().device(ctx->eigen_device<Device>()) =
|
||||||
|
((input_tensor.template cast<float>() + half_range) * scale_factor) +
|
||||||
|
min_range;
|
||||||
|
|
||||||
const int64 num_elements = input.NumElements();
|
|
||||||
for (int i = 0; i < num_elements; ++i) {
|
|
||||||
out_ptr[i] =
|
|
||||||
((static_cast<int>(in_ptr[i]) + half_range_) * scale_factor) +
|
|
||||||
min_range;
|
|
||||||
}
|
|
||||||
} else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
|
} else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
|
||||||
if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
|
if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
|
||||||
auto input_ui8_array = input.flat<quint8>();
|
auto input_ui8_array = input.flat<quint8>();
|
||||||
@ -101,17 +98,14 @@ class DequantizeOp : public OpKernel {
|
|||||||
? (max_range / std::numeric_limits<T>::max())
|
? (max_range / std::numeric_limits<T>::max())
|
||||||
: std::max(min_range / std::numeric_limits<T>::min(),
|
: std::max(min_range / std::numeric_limits<T>::min(),
|
||||||
max_range / std::numeric_limits<T>::max());
|
max_range / std::numeric_limits<T>::max());
|
||||||
float* out_ptr = output->flat<float>().data();
|
const auto& input_tensor = input.flat<T>();
|
||||||
const T* in_ptr = input.flat<T>().data();
|
output->flat<float>().device(ctx->eigen_device<Device>()) =
|
||||||
const int64 num_elements = input.NumElements();
|
input_tensor.template cast<int>().template cast<float>() *
|
||||||
for (int64 i = 0; i < num_elements; ++i) {
|
scale_factor;
|
||||||
out_ptr[i] = static_cast<int>(in_ptr[i]) * scale_factor;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float half_range_;
|
|
||||||
int mode_;
|
int mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user