Replace tfl quantize ops by quant quatizecast ops

PiperOrigin-RevId: 291814824
Change-Id: I06e484a77a1ac066ba8ab30254103fba3da0e084
This commit is contained in:
Feng Liu 2020-01-27 15:34:41 -08:00 committed by TensorFlower Gardener
parent be369f57e9
commit aa8a6fa82c
9 changed files with 150 additions and 34 deletions

View File

@ -357,6 +357,7 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",

View File

@ -71,18 +71,17 @@ cc_library(
"quantization_utils.cc",
],
hdrs = [
"quantization_traits.h",
"quantization_utils.h",
],
deps = [
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(fengliuai): remove this dependence.
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:lib_proto_parsing",
],
)

View File

@ -12,6 +12,7 @@ package_group(
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
"//tensorflow/lite/...",
],
)
@ -41,6 +42,24 @@ cc_library(
],
)
cc_library(
name = "tfl_to_std",
srcs = [
"tfl_to_std.cc",
],
hdrs = [
"tfl_to_std.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
],
)
# Binary to apply quantization on the annotated files.
tf_cc_binary(
name = "tfl_quantizer",

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func) {
OpBuilder b(func);
func.walk([&](Operation* op) {
b.setInsertionPoint(op);
if (auto dq = llvm::dyn_cast<DequantizeOp>(op)) {
auto dcast = b.create<quant::DequantizeCastOp>(
dq.getLoc(), dq.output().getType(), dq.input());
dq.output().replaceAllUsesWith(dcast);
dq.erase();
} else if (auto q = llvm::dyn_cast<QuantizeOp>(op)) {
auto qcast = b.create<quant::QuantizeCastOp>(
q.getLoc(), q.output().getType(), q.input());
q.output().replaceAllUsesWith(qcast);
q.erase();
}
});
}
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
OpBuilder b(func);
func.walk([&](Operation* op) {
b.setInsertionPoint(op);
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(op)) {
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
dq.arg());
dq.getResult().replaceAllUsesWith(dcast);
dq.erase();
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
auto out_type = q.getResult().getType();
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
TypeAttr::get(out_type));
q.getResult().replaceAllUsesWith(qcast);
q.erase();
}
});
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
#include "mlir/IR/Function.h" // TF:llvm-project
namespace mlir {
namespace TFL {
// Converts all the tfl.quantize/tfl.dequantize ops to the ops in the mlir.quant
// dialect ones in the function.
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func);
// Converts all the mlir.quant dialect ops to the tfl.quantize/tfl.dequantize
// ops in the function.
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func);
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
@ -34,7 +35,6 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/core/platform/logging.h"
@ -457,11 +457,9 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
// This value isn't an expressed type (float), skip.
if (!new_type) return;
TypeAttr type_attr = TypeAttr::get(new_type);
auto quantize =
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
auto dequantize = builder_.create<TFL::DequantizeOp>(loc, expressed_type,
quantize.output());
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
auto dequantize = builder_.create<quant::DequantizeCastOp>(
loc, expressed_type, quantize.getResult());
// `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards
value.replaceAllUsesWith(dequantize);
@ -475,7 +473,7 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
Value value = op->getResult(index);
if (state->pos == RequantizeState::ON_OUTPUT) {
Operation *user = value.getUses().begin().getUser();
if (llvm::isa<TFL::QuantizeOp>(user)) {
if (llvm::isa<quant::QuantizeCastOp>(user)) {
// The requantize op is inserted between `quantize` and `dequantize` ops.
value = user->getResult(0);
builder_.setInsertionPointAfter(user);
@ -490,8 +488,8 @@ void QuantizationDriver::RequantizeArg(BlockArgument arg,
builder_.setInsertionPointToStart(arg.getOwner());
if (value.hasOneUse()) {
auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output();
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
value = q.getResult();
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
}
}
@ -518,9 +516,8 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
// This value isn't an expressed type (float), skip.
if (!new_type) return;
TypeAttr type_attr = TypeAttr::get(new_type);
auto requantize_op =
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
value.replaceAllUsesWith(requantize_op);
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
}
@ -650,8 +647,8 @@ void QuantizationDriver::SetupAllStates() {
// If the argument is quantized, it should only has one user.
if (arg.hasOneUse()) {
auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output();
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
value = q.getResult();
}
}
InitializeArgState(arg, value, &value_to_state);
@ -659,7 +656,9 @@ void QuantizationDriver::SetupAllStates() {
fn_.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::DequantizeCastOp>(op) ||
llvm::isa<quant::QuantizeCastOp>(op))
return;
work_list_.push_back(op);
@ -668,8 +667,8 @@ void QuantizationDriver::SetupAllStates() {
if (auto *inst = operand.getDefiningOp()) {
// If the operand comes from a tfl.dequantize op, we use the quantized
// input of this tfl.dequantize op to set the state.
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
operand = dq.input();
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(inst)) {
operand = dq.arg();
}
}
InitializeOperandState(op, i, operand, &value_to_state);
@ -682,8 +681,8 @@ void QuantizationDriver::SetupAllStates() {
// create the state and mark it immutable.
if (result.hasOneUse()) {
auto user = result.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
result = q.output();
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
result = q.getResult();
}
}
InitializeResultState(op, res, result, &value_to_state);

View File

@ -30,7 +30,6 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
namespace mlir {
namespace TFL {

View File

@ -113,8 +113,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
rewriter.setInsertionPointAfter(op);
Type result_type = quant_type.castFromExpressedType(op.getType());
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
TypeAttr::get(result_type));
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg());
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
op.getResult().replaceAllUsesWith(dq);
q.getOperation()->replaceUsesOfWith(dq, op.arg());
@ -316,7 +315,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
PatternMatchResult matchAndRewrite(Q op,
PatternRewriter& rewriter) const override {
Type output_type = op.output().getType();
Type output_type = op.getResult().getType();
auto qtype = QType::getQuantizedElementType(output_type);
if (!qtype || qtype.isSigned()) return this->matchFailure();
@ -355,8 +354,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
if (!new_qtype) return this->matchFailure();
Type new_output_type = new_qtype.castFromExpressedType(
QType::castToExpressedType(output_type));
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.input(),
TypeAttr::get(new_output_type));
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
return this->matchSuccess();
}
};

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -149,11 +150,11 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
builder.getF64FloatAttr(min_max.second), /*quant_dim=*/-1, num_bits,
narrow_range, is_signed);
builder.setInsertionPoint(block, insertion_point);
auto q_op = builder.create<TFL::QuantizeOp>(loc, params.getValue(), arg,
params);
auto dq_op =
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
arg.replaceAllUsesWith(dq_op.output());
auto q_op =
builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);
auto dq_op = builder.create<quant::DequantizeCastOp>(loc, input_type,
q_op.getResult());
arg.replaceAllUsesWith(dq_op.getResult());
q_op.setOperand(arg);
}
}
@ -176,12 +177,14 @@ bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) {
}
using PrepareQuantStats =
TFL::ConvertStatsToQDQs<TFL::QuantizeOp, TFL::DequantizeOp>;
TFL::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
void PrepareQuantizePass::runOnFunction() {
FuncOp func = getFunction();
MLIRContext* ctx = func.getContext();
ConvertTFLQuantOpsToMlirQuantOps(func);
if (quant_specs_.post_training_quantization) {
RemoveRedundantStats(func);
} else {
@ -198,7 +201,7 @@ void PrepareQuantizePass::runOnFunction() {
OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType();
if (is_signed) {
patterns.insert<ConvertUnsignedToSigned<TFL::QuantizeOp>>(ctx);
patterns.insert<ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters.
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, true, ctx);
@ -213,6 +216,8 @@ void PrepareQuantizePass::runOnFunction() {
// values (tensors).
ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel,
GetOpQuantSpec);
ConvertMlirQuantOpsToTFLQuantOps(func);
}
} // namespace