Create a pass to import quant stats from a proto definition
This is for advanced users to assign ranges to some tensors they know about. Later on, these ranges are converted, with the sign and bitwidth, to quantize and dequantize ops to start the quantization pass. PiperOrigin-RevId: 273404295
This commit is contained in:
parent
e39d636d0d
commit
2b060e1a0d
@ -60,6 +60,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
|
@ -36,6 +36,31 @@ tf_proto_library(
|
||||
cc_api_version = 2,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantization_passes",
|
||||
srcs = [
|
||||
"import_quant_stats_pass.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"quantization_passes.h",
|
||||
],
|
||||
deps = [
|
||||
":quantization_info_proto_cc",
|
||||
"//tensorflow/compiler/mlir/tensorflow:import_utils",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantization_lib",
|
||||
srcs = [
|
||||
|
@ -0,0 +1,217 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/AffineExpr.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/AffineMap.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<std::string> quantize_stats(
|
||||
"quant-test-stats", llvm::cl::value_desc("string"),
|
||||
llvm::cl::desc("serialized quant info string. Only used in tests"),
|
||||
llvm::cl::init(""));
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The Pass to import quantization stats to the ops in a function. This requires
|
||||
// a custom method to retrieve the unique name of the operation.
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
using QuantParamsEntry = QuantizationInfo::QuantParams;
|
||||
|
||||
namespace {
|
||||
class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
|
||||
public:
|
||||
explicit ImportQuantStatsPass(OperationToName op_to_name)
|
||||
: op_to_name_(op_to_name) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
// Parses the serialized quant stats protobuf and initialize the internal
|
||||
// data structure. This method must be called after the pass is created.
|
||||
bool ParseQuantStats(const std::string &stats_str);
|
||||
|
||||
private:
|
||||
void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
|
||||
const QuantParamsEntry &info);
|
||||
|
||||
void InsertStatsOpAtResult(OpBuilder b, Value *res, ElementsAttr layer_stats,
|
||||
ElementsAttr axis_stats, IntegerAttr axis);
|
||||
|
||||
// If the index is out of range, this method returns false. Otherwise it
|
||||
// returns true if the value is a float tensor.
|
||||
bool IsQuantizableResult(Operation *op, int index) {
|
||||
if (index < 0 || index >= op->getNumResults()) return false;
|
||||
Value *res = op->getResult(index);
|
||||
return res->getType().isa<ShapedType>() &&
|
||||
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
||||
}
|
||||
|
||||
// A method to retrive the name for the given op.
|
||||
OperationToName op_to_name_;
|
||||
|
||||
// We split the normal names and regex names, since the former can use hash
|
||||
// map to lookup and the latter needs to iterate all the regex to find the
|
||||
// match.
|
||||
// The `int` in the following two containers are to specify the result index
|
||||
// of the given op. -1 indicates all the floating-point results.
|
||||
llvm::StringMap<std::pair<int, const QuantParamsEntry>> name_to_info_;
|
||||
llvm::StringMap<std::pair<int, const QuantParamsEntry>> regex_to_info_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) {
|
||||
QuantizationInfo quant_stats;
|
||||
if (!tensorflow::LoadProtoFromBuffer(stats_str, &quant_stats).ok()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (const auto &entry : quant_stats.entries()) {
|
||||
if (!entry.name().empty()) {
|
||||
std::vector<std::string> name_and_port =
|
||||
absl::StrSplit(entry.name(), ':');
|
||||
int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
|
||||
name_to_info_.insert({name_and_port[0], {port, entry}});
|
||||
} else if (!entry.name_regex().empty()) {
|
||||
std::vector<std::string> name_and_port =
|
||||
absl::StrSplit(entry.name_regex(), ':');
|
||||
int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
|
||||
regex_to_info_.insert({name_and_port[0], {port, entry}});
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value *res,
|
||||
ElementsAttr layer_stats,
|
||||
ElementsAttr axis_stats,
|
||||
IntegerAttr axis) {
|
||||
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
|
||||
layer_stats, axis_stats, axis);
|
||||
res->replaceAllUsesWith(stats_op);
|
||||
stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
|
||||
}
|
||||
|
||||
void ImportQuantStatsPass::ImportAsStatsOps(OpBuilder b, Operation *op,
|
||||
int index,
|
||||
const QuantParamsEntry &info) {
|
||||
if (info.params_size() == 0) return;
|
||||
|
||||
SmallVector<APFloat, 4> min_maxs;
|
||||
min_maxs.reserve(info.params_size() * 2);
|
||||
for (const auto ¶m : info.params()) {
|
||||
llvm::APFloat min(param.min_max().min());
|
||||
llvm::APFloat max(param.min_max().max());
|
||||
min_maxs.push_back(min);
|
||||
min_maxs.push_back(max);
|
||||
}
|
||||
// The layer stats contain only the first min/max pairs.
|
||||
ElementsAttr layer_stats = DenseFPElementsAttr::get(
|
||||
b.getTensorType({2}, b.getF32Type()), {min_maxs[0], min_maxs[1]});
|
||||
ElementsAttr axis_stats;
|
||||
IntegerAttr axis;
|
||||
|
||||
if (info.params_size() > 1) {
|
||||
SmallVector<int64_t, 4> axis_stats_shape{info.params_size(), 2};
|
||||
axis_stats = DenseFPElementsAttr::get(
|
||||
b.getTensorType(axis_stats_shape, b.getF32Type()), min_maxs);
|
||||
axis = b.getI64IntegerAttr(info.meta().quantize_axis());
|
||||
}
|
||||
|
||||
b.setInsertionPointAfter(op);
|
||||
if (IsQuantizableResult(op, index)) {
|
||||
InsertStatsOpAtResult(b, op->getResult(index), layer_stats, axis_stats,
|
||||
axis);
|
||||
} else {
|
||||
for (int i = 0; i < op->getNumResults(); ++i) {
|
||||
if (IsQuantizableResult(op, i)) {
|
||||
InsertStatsOpAtResult(b, op->getResult(i), layer_stats, axis_stats,
|
||||
axis);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ImportQuantStatsPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
OpBuilder builder(func);
|
||||
|
||||
func.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator()) return;
|
||||
auto op_name = op_to_name_(op);
|
||||
|
||||
// Check the named info collection first.
|
||||
auto it = name_to_info_.find(op_name);
|
||||
if (it != name_to_info_.end()) {
|
||||
ImportAsStatsOps(builder, op, it->second.first, it->second.second);
|
||||
return;
|
||||
}
|
||||
|
||||
// Iterate all the regex names and matches the first one.
|
||||
for (auto ®ex : regex_to_info_) {
|
||||
if (llvm::Regex(regex.first()).match(op_name)) {
|
||||
ImportAsStatsOps(builder, op, regex.second.first, regex.second.second);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Creates an instance of the default quant parameters pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
OperationToName op_to_name, const std::string &stats_str) {
|
||||
auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
|
||||
if (pass->ParseQuantStats(stats_str)) return nullptr;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// Registers this pass with default values, only for test
|
||||
static PassRegistration<ImportQuantStatsPass> pass(
|
||||
"quant-import-stats", "Import quantization stats to the model", [] {
|
||||
return CreateImportQuantStatsPass(
|
||||
[](Operation *op) {
|
||||
if (auto name = op->getAttrOfType<StringAttr>("name"))
|
||||
return name.getValue();
|
||||
else
|
||||
return StringRef();
|
||||
},
|
||||
quantize_stats);
|
||||
});
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
@ -1,6 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package third_party.tensorflow.compiler.mlir.lite.quantization;
|
||||
package mlir.quant;
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
@ -56,7 +56,7 @@ message QuantizationInfo {
|
||||
string name = 1;
|
||||
|
||||
// An regex can be used to match multiple tensors.
|
||||
string name_regexp = 2;
|
||||
string name_regex = 2;
|
||||
}
|
||||
|
||||
// The quantization parameters for the tensor. If it is for per-axis, the
|
||||
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
using OperationToName = std::function<llvm::StringRef(Operation* op)>;
|
||||
|
||||
// Creates an instance pass to import quantization stats to the operations in
|
||||
// the function. A custom method to get the name from the op is used because
|
||||
// different dialect ops might have different ways to assign the name.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
OperationToName op_to_name, const std::string& stats_str);
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
|
19
tensorflow/compiler/mlir/lite/quantization/tests/BUILD
Normal file
19
tensorflow/compiler/mlir/lite/quantization/tests/BUILD
Normal file
@ -0,0 +1,19 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm//:FileCheck",
|
||||
],
|
||||
)
|
@ -0,0 +1,43 @@
|
||||
// RUN: tf-opt %s -quant-import-stats --quant-test-stats='entries { name: "op" params { min_max { min: -1 max: 1 } } } entries { name: "op_0:0" params { min_max { min: -2 max: 2 } } } entries { name_regex: "op_*" params { min_max { min: -3 max: 3 } } }' | FileCheck %s --dump-input-on-failure
|
||||
|
||||
|
||||
// CHECK-LABEL: import_stats_skip
|
||||
func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "skip"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: "tfl.split"
|
||||
// CHECK-NEXT: return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: import_stats_name
|
||||
func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
// CHECK-NEXT: %[[stats1:.*]] = "quant.stats"(%[[split]]#0) {layerStats = dense<[-1.000000e+00, 1.000000e+00]>
|
||||
// CHECK-NEXT: %[[stats2:.*]] = "quant.stats"(%[[split]]#1) {layerStats = dense<[-1.000000e+00, 1.000000e+00]>
|
||||
// CHECK-NEXT: return %[[stats1]], %[[stats2]] : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: import_stats_name_port
|
||||
func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_0"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
// CHECK-NEXT: %[[stats1:.*]] = "quant.stats"(%[[split]]#0) {layerStats = dense<[-2.000000e+00, 2.000000e+00]>
|
||||
// CHECK-NEXT: return %[[stats1]], %[[split]]#1 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: import_stats_name_regex
|
||||
func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
// CHECK-NEXT: %[[stats1:.*]] = "quant.stats"(%[[split]]#0) {layerStats = dense<[-3.000000e+00, 3.000000e+00]>
|
||||
// CHECK-NEXT: %[[stats2:.*]] = "quant.stats"(%[[split]]#1) {layerStats = dense<[-3.000000e+00, 3.000000e+00]>
|
||||
// CHECK-NEXT: return %[[stats1]], %[[stats2]] : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
Loading…
Reference in New Issue
Block a user