[TF2XLA] Make dequantize support dynamic range.

PiperOrigin-RevId: 329444190
Change-Id: Icac703969a95093dd7982820c1706887a50d1bce
This commit is contained in:
Blake Hechtman 2020-08-31 22:22:02 -07:00 committed by TensorFlower Gardener
parent b8c0f25a4d
commit 22f5d50f9a

View File

@ -63,36 +63,27 @@ class DequantizeOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
DataType input_type = ctx->input_type(0);
double minrange, maxrange;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &minrange));
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &maxrange));
float min_range = static_cast<float>(minrange);
float max_range = static_cast<float>(maxrange);
float full_range, half_range;
xla::XlaOp input = ctx->Input(0);
xla::XlaOp output = xla::ConvertElementType(input, xla::F32);
xla::XlaOp min_range = xla::ConvertElementType(ctx->Input(1), xla::F32);
xla::XlaOp max_range = xla::ConvertElementType(ctx->Input(2), xla::F32);
xla::XlaOp full_range;
xla::XlaOp half_range;
if (input_type == DT_QINT8) {
full_range = get_fullrange<qint8>();
half_range = (full_range + 1.0f) / 2.0f;
full_range = ScalarLike(output, get_fullrange<qint8>());
half_range =
(full_range + ScalarLike(output, 1.0f)) / ScalarLike(output, 2.0f);
} else {
OP_REQUIRES(ctx, input_type == DT_QUINT8,
errors::InvalidArgument(
"Only support DT_QINT8 or DT_QUINT8, got ", input_type));
full_range = get_fullrange<quint8>();
half_range = 0.0f;
full_range = ScalarLike(output, get_fullrange<quint8>());
half_range = ScalarLike(output, 0.0f);
}
float scale_factor = (max_range - min_range) / full_range;
xla::XlaOp scale = (max_range - min_range) / full_range;
xla::XlaOp input = ctx->Input(0);
xla::XlaOp output;
output = xla::ConvertElementType(input, xla::F32);
auto scale = ScalarLike(output, scale_factor);
auto halfrange = ScalarLike(output, half_range);
output = xla::Add(xla::Mul(xla::Add(output, halfrange), scale),
ScalarLike(output, min_range));
output = xla::Add(xla::Mul(xla::Add(output, half_range), scale), min_range);
if (dtype_ == DT_BFLOAT16) {
output = xla::ConvertElementType(output, xla::BF16);