Define the fixed output range interface and fix the logistic quantization error

The logistic quantization error is introduced by accidently fused the
requantize op to the logistic op. To fix the issue, an interface needs to be
defined for this op, so this fixed output range property can be queried by the
pass.

In the followup cls, this interface will be used to replace the FixedResultScale trait.

PiperOrigin-RevId: 314221651
Change-Id: Iece591195ca0146b93b5c1a1b9f65c0d205eed11
This commit is contained in:
Feng Liu 2020-06-01 16:03:16 -07:00 committed by TensorFlower Gardener
parent 659b1fba99
commit 25af574ff3
7 changed files with 86 additions and 6 deletions

View File

@ -216,13 +216,13 @@ cc_library(
"ir/tfl_ops.h", "ir/tfl_ops.h",
"transforms/passes.h", "transforms/passes.h",
"utils/attribute_utils.h", "utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
], ],
deps = [ deps = [
":tensorflow_lite_ops_inc_gen", ":tensorflow_lite_ops_inc_gen",
":validators", ":validators",
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators", "//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",

View File

@ -1726,6 +1726,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
// scale = 1. / (max_value + 1) // scale = 1. / (max_value + 1)
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>, FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>, FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
FixedOutputRangeInterface,
TFL_GpuTargetOp]> { TFL_GpuTargetOp]> {
let summary = "Logistic operator"; let summary = "Logistic operator";
@ -1736,6 +1737,36 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x); let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y); let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y);
let extraClassDeclaration = [{
// FixedOutputRangeInterface:
quant::UniformQuantizedType GetFixedOutputRange(
bool is_signed, int bit_width) {
auto result_type = y().getType().cast<ShapedType>();
if (!result_type.getElementType().isa<FloatType>()) return {};
Builder builder(result_type.getContext());
// Only support 8-bits
if (bit_width != 8) return {};
IntegerType storage_type = builder.getIntegerType(bit_width);
double scale = 1.0 / 256;
int64_t zero_point, storage_min, storage_max;
if (is_signed) {
zero_point = -128;
storage_min = -128;
storage_max = 127;
} else {
zero_point = 0;
storage_min = 0;
storage_max = 255;
}
return quant::UniformQuantizedType::getChecked(
is_signed, storage_type, result_type.getElementType(), scale,
zero_point, storage_min, storage_max, builder.getUnknownLoc());
}
}];
} }
def TFL_LogOp: TFL_Op<"log", [ def TFL_LogOp: TFL_Op<"log", [

View File

@ -3,6 +3,10 @@ load(
"//tensorflow/core/platform:build_config.bzl", "//tensorflow/core/platform:build_config.bzl",
"tf_proto_library", "tf_proto_library",
) )
load(
"//third_party/mlir:tblgen.bzl",
"gentbl",
)
package( package(
default_visibility = [ default_visibility = [
@ -35,6 +39,25 @@ filegroup(
], ],
) )
gentbl(
name = "quantization_interfaces_inc_gen",
tbl_outs = [
(
"-gen-op-interface-decls",
"quantization_interface.h.inc",
),
(
"-gen-op-interface-defs",
"quantization_interface.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "quantization.td",
td_srcs = [
":quantization_td_files",
],
)
tf_proto_library( tf_proto_library(
name = "quantization_info_proto", name = "quantization_info_proto",
srcs = [ srcs = [
@ -72,9 +95,11 @@ cc_library(
name = "quantization_lib", name = "quantization_lib",
srcs = [ srcs = [
"quantization_driver.cc", "quantization_driver.cc",
"quantization_interface.cc.inc",
"quantization_utils.cc", "quantization_utils.cc",
], ],
hdrs = [ hdrs = [
"quantization_interface.h.inc",
"quantization_traits.h", "quantization_traits.h",
"quantization_utils.h", "quantization_utils.h",
], ],

View File

@ -53,6 +53,7 @@ cc_library(
], ],
deps = [ deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",

View File

@ -63,6 +63,22 @@ def QI32 : QuantizedType<"Uniform", [32], 1>;
// https://www.tensorflow.org/lite/performance/quantization_spec // https://www.tensorflow.org/lite/performance/quantization_spec
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TODO(b/157870442): replace all FixedResultScale trait
def FixedOutputRangeInterface : OpInterface<
"FixedOutputRangeInterface"> {
let description = [{
Interface for defining the fixed output range.
}];
let methods = [
InterfaceMethod<
[{Returns the fixed output range.}],
"UniformQuantizedType", "GetFixedOutputRange",
(ins "bool":$sign, "int":$bit_width)
>,
];
}
// Specify this trait if the op has a fixed output value range. // Specify this trait if the op has a fixed output value range.
class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat( class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
"quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>; "quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;

View File

@ -21,13 +21,18 @@ limitations under the License.
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace OpTrait {
namespace quant {
using QuantizedType = mlir::quant::QuantizedType; using QuantizedType = mlir::quant::QuantizedType;
using UniformQuantizedType = mlir::quant::UniformQuantizedType; using UniformQuantizedType = mlir::quant::UniformQuantizedType;
namespace mlir {
// This includes the interface class definition. It couldn't be in a namespace
// because the table gen doesn't emit the namespace when it is used.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc"
namespace OpTrait {
namespace quant {
// The base class that all the quantization related OpTrait implements. // The base class that all the quantization related OpTrait implements.
template <typename ConcreteType, template <typename> class TraitType> template <typename ConcreteType, template <typename> class TraitType>
struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> { struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h" #include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -385,7 +386,8 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
Operation* def = pre_quantized.getDefiningOp(); Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure(); if (!def) return failure();
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() || if (llvm::isa<FixedOutputRangeInterface>(def) ||
def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) { def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure(); return failure();
} }