diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 1a508bdb190..c60c0c0edbf 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -216,13 +216,13 @@ cc_library( "ir/tfl_ops.h", "transforms/passes.h", "utils/attribute_utils.h", - "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h", "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", "//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/lite/schema:schema_fbs", "@llvm-project//llvm:support", diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index e645f98e922..76c342bd10a 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1726,6 +1726,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ // scale = 1. / (max_value + 1) FixedResultScale>, FixedResultScale>, + FixedOutputRangeInterface, TFL_GpuTargetOp]> { let summary = "Logistic operator"; @@ -1736,6 +1737,36 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x); let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y); + + let extraClassDeclaration = [{ + // FixedOutputRangeInterface: + quant::UniformQuantizedType GetFixedOutputRange( + bool is_signed, int bit_width) { + auto result_type = y().getType().cast(); + if (!result_type.getElementType().isa()) return {}; + Builder builder(result_type.getContext()); + + // Only support 8-bits + if (bit_width != 8) return {}; + IntegerType storage_type = builder.getIntegerType(bit_width); + + double scale = 1.0 / 256; + int64_t zero_point, storage_min, storage_max; + if (is_signed) { + zero_point = -128; + storage_min = -128; + storage_max = 127; + } else { + zero_point = 0; + storage_min = 0; + storage_max = 255; + } + + return quant::UniformQuantizedType::getChecked( + is_signed, storage_type, result_type.getElementType(), scale, + zero_point, storage_min, storage_max, builder.getUnknownLoc()); + } + }]; } def TFL_LogOp: TFL_Op<"log", [ diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index 91590bfbc13..57417e95ec6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -3,6 +3,10 @@ load( "//tensorflow/core/platform:build_config.bzl", "tf_proto_library", ) +load( + "//third_party/mlir:tblgen.bzl", + "gentbl", +) package( default_visibility = [ @@ -35,6 +39,25 @@ filegroup( ], ) +gentbl( + name = "quantization_interfaces_inc_gen", + tbl_outs = [ + ( + "-gen-op-interface-decls", + "quantization_interface.h.inc", + ), + ( + "-gen-op-interface-defs", + "quantization_interface.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "quantization.td", + td_srcs = [ + ":quantization_td_files", + ], +) + tf_proto_library( name = "quantization_info_proto", srcs = [ @@ -72,9 +95,11 @@ cc_library( name = "quantization_lib", srcs = [ "quantization_driver.cc", + "quantization_interface.cc.inc", "quantization_utils.cc", ], hdrs = [ + "quantization_interface.h.inc", "quantization_traits.h", "quantization_utils.h", ], diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index d9e478950e6..2783297814b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -53,6 +53,7 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization.td b/tensorflow/compiler/mlir/lite/quantization/quantization.td index 7bfcdb65686..c1e392bd3ad 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/quantization.td @@ -63,6 +63,22 @@ def QI32 : QuantizedType<"Uniform", [32], 1>; // https://www.tensorflow.org/lite/performance/quantization_spec //===----------------------------------------------------------------------===// +// TODO(b/157870442): replace all FixedResultScale trait +def FixedOutputRangeInterface : OpInterface< + "FixedOutputRangeInterface"> { + let description = [{ + Interface for defining the fixed output range. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the fixed output range.}], + "UniformQuantizedType", "GetFixedOutputRange", + (ins "bool":$sign, "int":$bit_width) + >, + ]; +} + // Specify this trait if the op has a fixed output value range. class FixedResultScale : NativeOpTrait::Impl")>; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h index b59164b72e6..693f692c61a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h @@ -21,13 +21,18 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -namespace mlir { -namespace OpTrait { -namespace quant { - using QuantizedType = mlir::quant::QuantizedType; using UniformQuantizedType = mlir::quant::UniformQuantizedType; +namespace mlir { + +// This includes the interface class definition. It couldn't be in a namespace +// because the table gen doesn't emit the namespace when it is used. +#include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc" + +namespace OpTrait { +namespace quant { + // The base class that all the quantization related OpTrait implements. template class TraitType> struct QuantizationSpecTraitBase : public TraitBase { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index b9ff9869232..f17e44cd756 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project @@ -385,7 +386,8 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern { Operation* def = pre_quantized.getDefiningOp(); if (!def) return failure(); - if (def->hasTrait() || + if (llvm::isa(def) || + def->hasTrait() || def->hasTrait()) { return failure(); }