[TF2XLA] Make dequantize support dynamic range.
PiperOrigin-RevId: 329444190 Change-Id: Icac703969a95093dd7982820c1706887a50d1bce
This commit is contained in:
parent
b8c0f25a4d
commit
22f5d50f9a
@ -63,36 +63,27 @@ class DequantizeOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType input_type = ctx->input_type(0);
|
||||
|
||||
double minrange, maxrange;
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &minrange));
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &maxrange));
|
||||
|
||||
float min_range = static_cast<float>(minrange);
|
||||
float max_range = static_cast<float>(maxrange);
|
||||
float full_range, half_range;
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
xla::XlaOp output = xla::ConvertElementType(input, xla::F32);
|
||||
xla::XlaOp min_range = xla::ConvertElementType(ctx->Input(1), xla::F32);
|
||||
xla::XlaOp max_range = xla::ConvertElementType(ctx->Input(2), xla::F32);
|
||||
xla::XlaOp full_range;
|
||||
xla::XlaOp half_range;
|
||||
if (input_type == DT_QINT8) {
|
||||
full_range = get_fullrange<qint8>();
|
||||
half_range = (full_range + 1.0f) / 2.0f;
|
||||
full_range = ScalarLike(output, get_fullrange<qint8>());
|
||||
half_range =
|
||||
(full_range + ScalarLike(output, 1.0f)) / ScalarLike(output, 2.0f);
|
||||
} else {
|
||||
OP_REQUIRES(ctx, input_type == DT_QUINT8,
|
||||
errors::InvalidArgument(
|
||||
"Only support DT_QINT8 or DT_QUINT8, got ", input_type));
|
||||
full_range = get_fullrange<quint8>();
|
||||
half_range = 0.0f;
|
||||
full_range = ScalarLike(output, get_fullrange<quint8>());
|
||||
half_range = ScalarLike(output, 0.0f);
|
||||
}
|
||||
|
||||
float scale_factor = (max_range - min_range) / full_range;
|
||||
xla::XlaOp scale = (max_range - min_range) / full_range;
|
||||
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
xla::XlaOp output;
|
||||
|
||||
output = xla::ConvertElementType(input, xla::F32);
|
||||
|
||||
auto scale = ScalarLike(output, scale_factor);
|
||||
auto halfrange = ScalarLike(output, half_range);
|
||||
output = xla::Add(xla::Mul(xla::Add(output, halfrange), scale),
|
||||
ScalarLike(output, min_range));
|
||||
output = xla::Add(xla::Mul(xla::Add(output, half_range), scale), min_range);
|
||||
|
||||
if (dtype_ == DT_BFLOAT16) {
|
||||
output = xla::ConvertElementType(output, xla::BF16);
|
||||
|
Loading…
Reference in New Issue
Block a user