Create a pass to import quant stats from a proto definition

This is for advanced users to assign ranges to some tensors they know about.
Later on, these ranges are converted, with the sign and bitwidth, to quantize
and dequantize ops to start the quantization pass.

PiperOrigin-RevId: 273404295
This commit is contained in:
Feng Liu 2019-10-07 16:35:16 -07:00 committed by TensorFlower Gardener
parent e39d636d0d
commit 2b060e1a0d
7 changed files with 343 additions and 2 deletions

View File

@ -60,6 +60,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",

View File

@ -36,6 +36,31 @@ tf_proto_library(
cc_api_version = 2,
)
cc_library(
name = "quantization_passes",
srcs = [
"import_quant_stats_pass.cc",
],
hdrs = [
"quantization_passes.h",
],
deps = [
":quantization_info_proto_cc",
"//tensorflow/compiler/mlir/tensorflow:import_utils",
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
alwayslink = 1,
)
cc_library(
name = "quantization_lib",
srcs = [

View File

@ -0,0 +1,217 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/memory/memory.h"
#include "absl/strings/str_split.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/AffineExpr.h" // TF:local_config_mlir
#include "mlir/IR/AffineMap.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> quantize_stats(
"quant-test-stats", llvm::cl::value_desc("string"),
llvm::cl::desc("serialized quant info string. Only used in tests"),
llvm::cl::init(""));
//===----------------------------------------------------------------------===//
// The Pass to import quantization stats to the ops in a function. This requires
// a custom method to retrieve the unique name of the operation.
namespace mlir {
namespace quant {
using QuantParamsEntry = QuantizationInfo::QuantParams;
namespace {
class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
public:
explicit ImportQuantStatsPass(OperationToName op_to_name)
: op_to_name_(op_to_name) {}
void runOnFunction() override;
// Parses the serialized quant stats protobuf and initialize the internal
// data structure. This method must be called after the pass is created.
bool ParseQuantStats(const std::string &stats_str);
private:
void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
const QuantParamsEntry &info);
void InsertStatsOpAtResult(OpBuilder b, Value *res, ElementsAttr layer_stats,
ElementsAttr axis_stats, IntegerAttr axis);
// If the index is out of range, this method returns false. Otherwise it
// returns true if the value is a float tensor.
bool IsQuantizableResult(Operation *op, int index) {
if (index < 0 || index >= op->getNumResults()) return false;
Value *res = op->getResult(index);
return res->getType().isa<ShapedType>() &&
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
}
// A method to retrive the name for the given op.
OperationToName op_to_name_;
// We split the normal names and regex names, since the former can use hash
// map to lookup and the latter needs to iterate all the regex to find the
// match.
// The `int` in the following two containers are to specify the result index
// of the given op. -1 indicates all the floating-point results.
llvm::StringMap<std::pair<int, const QuantParamsEntry>> name_to_info_;
llvm::StringMap<std::pair<int, const QuantParamsEntry>> regex_to_info_;
};
} // namespace
bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) {
QuantizationInfo quant_stats;
if (!tensorflow::LoadProtoFromBuffer(stats_str, &quant_stats).ok()) {
return true;
}
for (const auto &entry : quant_stats.entries()) {
if (!entry.name().empty()) {
std::vector<std::string> name_and_port =
absl::StrSplit(entry.name(), ':');
int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
name_to_info_.insert({name_and_port[0], {port, entry}});
} else if (!entry.name_regex().empty()) {
std::vector<std::string> name_and_port =
absl::StrSplit(entry.name_regex(), ':');
int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
regex_to_info_.insert({name_and_port[0], {port, entry}});
}
}
return false;
}
void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value *res,
ElementsAttr layer_stats,
ElementsAttr axis_stats,
IntegerAttr axis) {
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
layer_stats, axis_stats, axis);
res->replaceAllUsesWith(stats_op);
stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
}
void ImportQuantStatsPass::ImportAsStatsOps(OpBuilder b, Operation *op,
int index,
const QuantParamsEntry &info) {
if (info.params_size() == 0) return;
SmallVector<APFloat, 4> min_maxs;
min_maxs.reserve(info.params_size() * 2);
for (const auto &param : info.params()) {
llvm::APFloat min(param.min_max().min());
llvm::APFloat max(param.min_max().max());
min_maxs.push_back(min);
min_maxs.push_back(max);
}
// The layer stats contain only the first min/max pairs.
ElementsAttr layer_stats = DenseFPElementsAttr::get(
b.getTensorType({2}, b.getF32Type()), {min_maxs[0], min_maxs[1]});
ElementsAttr axis_stats;
IntegerAttr axis;
if (info.params_size() > 1) {
SmallVector<int64_t, 4> axis_stats_shape{info.params_size(), 2};
axis_stats = DenseFPElementsAttr::get(
b.getTensorType(axis_stats_shape, b.getF32Type()), min_maxs);
axis = b.getI64IntegerAttr(info.meta().quantize_axis());
}
b.setInsertionPointAfter(op);
if (IsQuantizableResult(op, index)) {
InsertStatsOpAtResult(b, op->getResult(index), layer_stats, axis_stats,
axis);
} else {
for (int i = 0; i < op->getNumResults(); ++i) {
if (IsQuantizableResult(op, i)) {
InsertStatsOpAtResult(b, op->getResult(i), layer_stats, axis_stats,
axis);
}
}
}
}
void ImportQuantStatsPass::runOnFunction() {
FuncOp func = getFunction();
OpBuilder builder(func);
func.walk([&](Operation *op) {
if (op->isKnownTerminator()) return;
auto op_name = op_to_name_(op);
// Check the named info collection first.
auto it = name_to_info_.find(op_name);
if (it != name_to_info_.end()) {
ImportAsStatsOps(builder, op, it->second.first, it->second.second);
return;
}
// Iterate all the regex names and matches the first one.
for (auto &regex : regex_to_info_) {
if (llvm::Regex(regex.first()).match(op_name)) {
ImportAsStatsOps(builder, op, regex.second.first, regex.second.second);
break;
}
}
});
}
// Creates an instance of the default quant parameters pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
OperationToName op_to_name, const std::string &stats_str) {
auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
if (pass->ParseQuantStats(stats_str)) return nullptr;
return pass;
}
// Registers this pass with default values, only for test
static PassRegistration<ImportQuantStatsPass> pass(
"quant-import-stats", "Import quantization stats to the model", [] {
return CreateImportQuantStatsPass(
[](Operation *op) {
if (auto name = op->getAttrOfType<StringAttr>("name"))
return name.getValue();
else
return StringRef();
},
quantize_stats);
});
} // namespace quant
} // namespace mlir

View File

@ -1,6 +1,6 @@
syntax = "proto3";
package third_party.tensorflow.compiler.mlir.lite.quantization;
package mlir.quant;
option cc_enable_arenas = true;
@ -56,7 +56,7 @@ message QuantizationInfo {
string name = 1;
// An regex can be used to match multiple tensors.
string name_regexp = 2;
string name_regex = 2;
}
// The quantization parameters for the tensor. If it is for per-axis, the

View File

@ -0,0 +1,36 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
namespace mlir {
namespace quant {
using OperationToName = std::function<llvm::StringRef(Operation* op)>;
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
OperationToName op_to_name, const std::string& stats_str);
} // namespace quant
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_

View File

@ -0,0 +1,19 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm//:FileCheck",
],
)

View File

@ -0,0 +1,43 @@
// RUN: tf-opt %s -quant-import-stats --quant-test-stats='entries { name: "op" params { min_max { min: -1 max: 1 } } } entries { name: "op_0:0" params { min_max { min: -2 max: 2 } } } entries { name_regex: "op_*" params { min_max { min: -3 max: 3 } } }' | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: import_stats_skip
func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "skip"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: "tfl.split"
// CHECK-NEXT: return
}
// CHECK-LABEL: import_stats_name
func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
// CHECK-NEXT: %[[stats1:.*]] = "quant.stats"(%[[split]]#0) {layerStats = dense<[-1.000000e+00, 1.000000e+00]>
// CHECK-NEXT: %[[stats2:.*]] = "quant.stats"(%[[split]]#1) {layerStats = dense<[-1.000000e+00, 1.000000e+00]>
// CHECK-NEXT: return %[[stats1]], %[[stats2]] : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: import_stats_name_port
func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_0"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
// CHECK-NEXT: %[[stats1:.*]] = "quant.stats"(%[[split]]#0) {layerStats = dense<[-2.000000e+00, 2.000000e+00]>
// CHECK-NEXT: return %[[stats1]], %[[split]]#1 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: import_stats_name_regex
func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
// CHECK-NEXT: %[[stats1:.*]] = "quant.stats"(%[[split]]#0) {layerStats = dense<[-3.000000e+00, 3.000000e+00]>
// CHECK-NEXT: %[[stats2:.*]] = "quant.stats"(%[[split]]#1) {layerStats = dense<[-3.000000e+00, 3.000000e+00]>
// CHECK-NEXT: return %[[stats1]], %[[stats2]] : tensor<2xf32>, tensor<2xf32>
}