Define the fixed output range interface and fix the logistic quantization error
The logistic quantization error is introduced by accidently fused the requantize op to the logistic op. To fix the issue, an interface needs to be defined for this op, so this fixed output range property can be queried by the pass. In the followup cls, this interface will be used to replace the FixedResultScale trait. PiperOrigin-RevId: 314221651 Change-Id: Iece591195ca0146b93b5c1a1b9f65c0d205eed11
This commit is contained in:
parent
659b1fba99
commit
25af574ff3
|
@ -216,13 +216,13 @@ cc_library(
|
||||||
"ir/tfl_ops.h",
|
"ir/tfl_ops.h",
|
||||||
"transforms/passes.h",
|
"transforms/passes.h",
|
||||||
"utils/attribute_utils.h",
|
"utils/attribute_utils.h",
|
||||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
|
||||||
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
|
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":tensorflow_lite_ops_inc_gen",
|
":tensorflow_lite_ops_inc_gen",
|
||||||
":validators",
|
":validators",
|
||||||
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
|
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
|
||||||
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
|
|
|
@ -1726,6 +1726,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||||
// scale = 1. / (max_value + 1)
|
// scale = 1. / (max_value + 1)
|
||||||
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
|
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
|
||||||
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
|
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
|
||||||
|
FixedOutputRangeInterface,
|
||||||
TFL_GpuTargetOp]> {
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Logistic operator";
|
let summary = "Logistic operator";
|
||||||
|
|
||||||
|
@ -1736,6 +1737,36 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||||
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x);
|
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y);
|
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// FixedOutputRangeInterface:
|
||||||
|
quant::UniformQuantizedType GetFixedOutputRange(
|
||||||
|
bool is_signed, int bit_width) {
|
||||||
|
auto result_type = y().getType().cast<ShapedType>();
|
||||||
|
if (!result_type.getElementType().isa<FloatType>()) return {};
|
||||||
|
Builder builder(result_type.getContext());
|
||||||
|
|
||||||
|
// Only support 8-bits
|
||||||
|
if (bit_width != 8) return {};
|
||||||
|
IntegerType storage_type = builder.getIntegerType(bit_width);
|
||||||
|
|
||||||
|
double scale = 1.0 / 256;
|
||||||
|
int64_t zero_point, storage_min, storage_max;
|
||||||
|
if (is_signed) {
|
||||||
|
zero_point = -128;
|
||||||
|
storage_min = -128;
|
||||||
|
storage_max = 127;
|
||||||
|
} else {
|
||||||
|
zero_point = 0;
|
||||||
|
storage_min = 0;
|
||||||
|
storage_max = 255;
|
||||||
|
}
|
||||||
|
|
||||||
|
return quant::UniformQuantizedType::getChecked(
|
||||||
|
is_signed, storage_type, result_type.getElementType(), scale,
|
||||||
|
zero_point, storage_min, storage_max, builder.getUnknownLoc());
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_LogOp: TFL_Op<"log", [
|
def TFL_LogOp: TFL_Op<"log", [
|
||||||
|
|
|
@ -3,6 +3,10 @@ load(
|
||||||
"//tensorflow/core/platform:build_config.bzl",
|
"//tensorflow/core/platform:build_config.bzl",
|
||||||
"tf_proto_library",
|
"tf_proto_library",
|
||||||
)
|
)
|
||||||
|
load(
|
||||||
|
"//third_party/mlir:tblgen.bzl",
|
||||||
|
"gentbl",
|
||||||
|
)
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
|
@ -35,6 +39,25 @@ filegroup(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gentbl(
|
||||||
|
name = "quantization_interfaces_inc_gen",
|
||||||
|
tbl_outs = [
|
||||||
|
(
|
||||||
|
"-gen-op-interface-decls",
|
||||||
|
"quantization_interface.h.inc",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"-gen-op-interface-defs",
|
||||||
|
"quantization_interface.cc.inc",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||||
|
td_file = "quantization.td",
|
||||||
|
td_srcs = [
|
||||||
|
":quantization_td_files",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_proto_library(
|
tf_proto_library(
|
||||||
name = "quantization_info_proto",
|
name = "quantization_info_proto",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -72,9 +95,11 @@ cc_library(
|
||||||
name = "quantization_lib",
|
name = "quantization_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
"quantization_driver.cc",
|
"quantization_driver.cc",
|
||||||
|
"quantization_interface.cc.inc",
|
||||||
"quantization_utils.cc",
|
"quantization_utils.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
"quantization_interface.h.inc",
|
||||||
"quantization_traits.h",
|
"quantization_traits.h",
|
||||||
"quantization_utils.h",
|
"quantization_utils.h",
|
||||||
],
|
],
|
||||||
|
|
|
@ -53,6 +53,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||||
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:QuantOps",
|
"@llvm-project//mlir:QuantOps",
|
||||||
|
|
|
@ -63,6 +63,22 @@ def QI32 : QuantizedType<"Uniform", [32], 1>;
|
||||||
// https://www.tensorflow.org/lite/performance/quantization_spec
|
// https://www.tensorflow.org/lite/performance/quantization_spec
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// TODO(b/157870442): replace all FixedResultScale trait
|
||||||
|
def FixedOutputRangeInterface : OpInterface<
|
||||||
|
"FixedOutputRangeInterface"> {
|
||||||
|
let description = [{
|
||||||
|
Interface for defining the fixed output range.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<
|
||||||
|
[{Returns the fixed output range.}],
|
||||||
|
"UniformQuantizedType", "GetFixedOutputRange",
|
||||||
|
(ins "bool":$sign, "int":$bit_width)
|
||||||
|
>,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
// Specify this trait if the op has a fixed output value range.
|
// Specify this trait if the op has a fixed output value range.
|
||||||
class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
|
class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
|
||||||
"quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;
|
"quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;
|
||||||
|
|
|
@ -21,13 +21,18 @@ limitations under the License.
|
||||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace OpTrait {
|
|
||||||
namespace quant {
|
|
||||||
|
|
||||||
using QuantizedType = mlir::quant::QuantizedType;
|
using QuantizedType = mlir::quant::QuantizedType;
|
||||||
using UniformQuantizedType = mlir::quant::UniformQuantizedType;
|
using UniformQuantizedType = mlir::quant::UniformQuantizedType;
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
// This includes the interface class definition. It couldn't be in a namespace
|
||||||
|
// because the table gen doesn't emit the namespace when it is used.
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc"
|
||||||
|
|
||||||
|
namespace OpTrait {
|
||||||
|
namespace quant {
|
||||||
|
|
||||||
// The base class that all the quantization related OpTrait implements.
|
// The base class that all the quantization related OpTrait implements.
|
||||||
template <typename ConcreteType, template <typename> class TraitType>
|
template <typename ConcreteType, template <typename> class TraitType>
|
||||||
struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {
|
struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {
|
||||||
|
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/Twine.h"
|
#include "llvm/ADT/Twine.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||||
|
@ -385,7 +386,8 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
|
||||||
|
|
||||||
Operation* def = pre_quantized.getDefiningOp();
|
Operation* def = pre_quantized.getDefiningOp();
|
||||||
if (!def) return failure();
|
if (!def) return failure();
|
||||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
|
if (llvm::isa<FixedOutputRangeInterface>(def) ||
|
||||||
|
def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
|
||||||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue