From 25af574ff39d3562eb613c9f05effffee6e59493 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 1 Jun 2020 16:03:16 -0700 Subject: [PATCH] Define the fixed output range interface and fix the logistic quantization error The logistic quantization error is introduced by accidently fused the requantize op to the logistic op. To fix the issue, an interface needs to be defined for this op, so this fixed output range property can be queried by the pass. In the followup cls, this interface will be used to replace the FixedResultScale trait. PiperOrigin-RevId: 314221651 Change-Id: Iece591195ca0146b93b5c1a1b9f65c0d205eed11 --- tensorflow/compiler/mlir/lite/BUILD | 2 +- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 31 +++++++++++++++++++ .../compiler/mlir/lite/quantization/BUILD | 25 +++++++++++++++ .../mlir/lite/quantization/lite/BUILD | 1 + .../mlir/lite/quantization/quantization.td | 16 ++++++++++ .../lite/quantization/quantization_traits.h | 13 +++++--- .../lite/quantization/quantization_utils.h | 4 ++- 7 files changed, 86 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 1a508bdb190..c60c0c0edbf 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -216,13 +216,13 @@ cc_library( "ir/tfl_ops.h", "transforms/passes.h", "utils/attribute_utils.h", - "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h", "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", "//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/lite/schema:schema_fbs", "@llvm-project//llvm:support", diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index e645f98e922..76c342bd10a 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1726,6 +1726,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ // scale = 1. / (max_value + 1) FixedResultScale>, FixedResultScale>, + FixedOutputRangeInterface, TFL_GpuTargetOp]> { let summary = "Logistic operator"; @@ -1736,6 +1737,36 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x); let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y); + + let extraClassDeclaration = [{ + // FixedOutputRangeInterface: + quant::UniformQuantizedType GetFixedOutputRange( + bool is_signed, int bit_width) { + auto result_type = y().getType().cast(); + if (!result_type.getElementType().isa()) return {}; + Builder builder(result_type.getContext()); + + // Only support 8-bits + if (bit_width != 8) return {}; + IntegerType storage_type = builder.getIntegerType(bit_width); + + double scale = 1.0 / 256; + int64_t zero_point, storage_min, storage_max; + if (is_signed) { + zero_point = -128; + storage_min = -128; + storage_max = 127; + } else { + zero_point = 0; + storage_min = 0; + storage_max = 255; + } + + return quant::UniformQuantizedType::getChecked( + is_signed, storage_type, result_type.getElementType(), scale, + zero_point, storage_min, storage_max, builder.getUnknownLoc()); + } + }]; } def TFL_LogOp: TFL_Op<"log", [ diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index 91590bfbc13..57417e95ec6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -3,6 +3,10 @@ load( "//tensorflow/core/platform:build_config.bzl", "tf_proto_library", ) +load( + "//third_party/mlir:tblgen.bzl", + "gentbl", +) package( default_visibility = [ @@ -35,6 +39,25 @@ filegroup( ], ) +gentbl( + name = "quantization_interfaces_inc_gen", + tbl_outs = [ + ( + "-gen-op-interface-decls", + "quantization_interface.h.inc", + ), + ( + "-gen-op-interface-defs", + "quantization_interface.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "quantization.td", + td_srcs = [ + ":quantization_td_files", + ], +) + tf_proto_library( name = "quantization_info_proto", srcs = [ @@ -72,9 +95,11 @@ cc_library( name = "quantization_lib", srcs = [ "quantization_driver.cc", + "quantization_interface.cc.inc", "quantization_utils.cc", ], hdrs = [ + "quantization_interface.h.inc", "quantization_traits.h", "quantization_utils.h", ], diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index d9e478950e6..2783297814b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -53,6 +53,7 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization.td b/tensorflow/compiler/mlir/lite/quantization/quantization.td index 7bfcdb65686..c1e392bd3ad 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/quantization.td @@ -63,6 +63,22 @@ def QI32 : QuantizedType<"Uniform", [32], 1>; // https://www.tensorflow.org/lite/performance/quantization_spec //===----------------------------------------------------------------------===// +// TODO(b/157870442): replace all FixedResultScale trait +def FixedOutputRangeInterface : OpInterface< + "FixedOutputRangeInterface"> { + let description = [{ + Interface for defining the fixed output range. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the fixed output range.}], + "UniformQuantizedType", "GetFixedOutputRange", + (ins "bool":$sign, "int":$bit_width) + >, + ]; +} + // Specify this trait if the op has a fixed output value range. class FixedResultScale : NativeOpTrait::Impl")>; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h index b59164b72e6..693f692c61a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h @@ -21,13 +21,18 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -namespace mlir { -namespace OpTrait { -namespace quant { - using QuantizedType = mlir::quant::QuantizedType; using UniformQuantizedType = mlir::quant::UniformQuantizedType; +namespace mlir { + +// This includes the interface class definition. It couldn't be in a namespace +// because the table gen doesn't emit the namespace when it is used. +#include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc" + +namespace OpTrait { +namespace quant { + // The base class that all the quantization related OpTrait implements. template class TraitType> struct QuantizationSpecTraitBase : public TraitBase { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index b9ff9869232..f17e44cd756 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project @@ -385,7 +386,8 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern { Operation* def = pre_quantized.getDefiningOp(); if (!def) return failure(); - if (def->hasTrait() || + if (llvm::isa(def) || + def->hasTrait() || def->hasTrait()) { return failure(); }