From 57db297c71026f6c388484fbcfc391e426062831 Mon Sep 17 00:00:00 2001 From: David Rim Date: Wed, 3 Feb 2021 22:28:09 -0800 Subject: [PATCH] Quantize abs uses same input and output scale for compatibility with mlir. PiperOrigin-RevId: 355556020 Change-Id: I29a2cbc5f3cfb45721907354287b086703a0314b --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 1 + tensorflow/lite/kernels/elementwise.cc | 9 ++++-- tensorflow/lite/kernels/elementwise_test.cc | 28 +++++++++++++++++++ .../lite/tools/optimize/operator_property.cc | 5 ++++ 4 files changed, 41 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index e39f45619c0..6d254d78f7a 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -531,6 +531,7 @@ def TFL_AbsOp : TFL_Op<"abs", [ NoSideEffect, SameOperandsAndResultShape, SameOperandsAndResultType, + SameOperandsAndResultsScale, NoQuantizableResult]> { let summary = "Absolute value operator"; diff --git a/tensorflow/lite/kernels/elementwise.cc b/tensorflow/lite/kernels/elementwise.cc index c6ed1f4d7fb..f59f95263aa 100644 --- a/tensorflow/lite/kernels/elementwise.cc +++ b/tensorflow/lite/kernels/elementwise.cc @@ -47,6 +47,7 @@ struct OpData { int32_t shift; int input_offset; int output_offset; + bool needs_rescale; }; bool IsNumericSupportedType(const TfLiteType type) { @@ -118,7 +119,8 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { } const float input_scale = input_params->scale->data[0]; const float output_scale = output_params->scale->data[0]; - if (op_name == kAbsName) { + op_data->needs_rescale = input_scale != output_scale; + if (op_name == kAbsName && op_data->needs_rescale) { SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier, &op_data->shift); } else if (op_name == kRsqrtName) { @@ -188,10 +190,13 @@ TfLiteStatus AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, std::function func = [&](T i) { const int32_t value = std::abs(i - op_data->input_offset); + if (!op_data->needs_rescale) { + return static_cast( + std::min(std::max(value + op_data->output_offset, kMin), kMax)); + } const int32_t output = MultiplyByQuantizedMultiplier( value, op_data->multiplier, op_data->shift) + op_data->output_offset; - return static_cast(std::min(std::max(output, kMin), kMax)); }; diff --git a/tensorflow/lite/kernels/elementwise_test.cc b/tensorflow/lite/kernels/elementwise_test.cc index a1a41df5a31..974bbaf08f5 100644 --- a/tensorflow/lite/kernels/elementwise_test.cc +++ b/tensorflow/lite/kernels/elementwise_test.cc @@ -191,6 +191,34 @@ TEST(ElementWise, AbsInt8) { ElementsAreArray(ArrayFloatNear(abs_data, kInputScale))); } +TEST(ElementWise, AbsSameScaleInt8) { + std::vector data = {15., 46., 78., -142., -1., -17., -49., 113.}; + std::vector abs_data(data.size()); + for (int i = 0; i < abs_data.size(); i++) { + abs_data[i] = std::abs(data[i]); + } + const auto minmax = std::minmax_element(data.begin(), data.end()); + const float abs_max = std::max(std::abs(*minmax.first), *minmax.second); + const float kInputScale = (*minmax.second - *minmax.first) / 255.0; + const int input_zero_point = 127 - *minmax.second; + ElementWiseOpQuantizedModel m( + BuiltinOperator_ABS, + {TensorType_INT8, + {1, 8}, + *minmax.first, + *minmax.second, + kInputScale, + input_zero_point, + true, + {kInputScale}, + {input_zero_point}}, + {TensorType_INT8, {1, 8}, 0, abs_max, kInputScale, input_zero_point}); + m.AsymmetricQuantizeAndPopulate(m.input(), data); + m.Invoke(); + EXPECT_THAT(m.ExtractDequantVector(m.output()), + ElementsAreArray(ArrayFloatNear(abs_data, kInputScale))); +} + TEST(ElementWise, AbsInt16) { const float kQuantizedTolerance = GetQuantizationStep(-150, 150); std::vector data = {15., 46., 78., -142., -1., -17., -49., 113.}; diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index 970bd9769c9..eff4514ca3f 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -66,6 +66,11 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) { OperatorProperty property; switch (op_code) { case BuiltinOperator_ABS: + property.inputs = {{0, {}}}; + property.outputs = {{0, {}}}; + property.version = 2; + property.restrict_same_input_output_scale = true; + break; case BuiltinOperator_RSQRT: property.inputs = {{0, {}}}; property.outputs = {{0, {}}};