Replace tfl quantize ops by quant quatizecast ops
PiperOrigin-RevId: 291814824 Change-Id: I06e484a77a1ac066ba8ab30254103fba3da0e084
This commit is contained in:
parent
be369f57e9
commit
aa8a6fa82c
@ -357,6 +357,7 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
|
@ -71,18 +71,17 @@ cc_library(
|
||||
"quantization_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"quantization_traits.h",
|
||||
"quantization_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(fengliuai): remove this dependence.
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -12,6 +12,7 @@ package_group(
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
"//tensorflow/lite/...",
|
||||
],
|
||||
)
|
||||
@ -41,6 +42,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfl_to_std",
|
||||
srcs = [
|
||||
"tfl_to_std.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tfl_to_std.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
)
|
||||
|
||||
# Binary to apply quantization on the annotated files.
|
||||
tf_cc_binary(
|
||||
name = "tfl_quantizer",
|
||||
|
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func) {
|
||||
OpBuilder b(func);
|
||||
func.walk([&](Operation* op) {
|
||||
b.setInsertionPoint(op);
|
||||
if (auto dq = llvm::dyn_cast<DequantizeOp>(op)) {
|
||||
auto dcast = b.create<quant::DequantizeCastOp>(
|
||||
dq.getLoc(), dq.output().getType(), dq.input());
|
||||
dq.output().replaceAllUsesWith(dcast);
|
||||
dq.erase();
|
||||
} else if (auto q = llvm::dyn_cast<QuantizeOp>(op)) {
|
||||
auto qcast = b.create<quant::QuantizeCastOp>(
|
||||
q.getLoc(), q.output().getType(), q.input());
|
||||
q.output().replaceAllUsesWith(qcast);
|
||||
q.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
|
||||
OpBuilder b(func);
|
||||
func.walk([&](Operation* op) {
|
||||
b.setInsertionPoint(op);
|
||||
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(op)) {
|
||||
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
|
||||
dq.arg());
|
||||
dq.getResult().replaceAllUsesWith(dcast);
|
||||
dq.erase();
|
||||
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
|
||||
auto out_type = q.getResult().getType();
|
||||
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
|
||||
TypeAttr::get(out_type));
|
||||
q.getResult().replaceAllUsesWith(qcast);
|
||||
q.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
34
tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h
Normal file
34
tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Converts all the tfl.quantize/tfl.dequantize ops to the ops in the mlir.quant
|
||||
// dialect ones in the function.
|
||||
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func);
|
||||
|
||||
// Converts all the mlir.quant dialect ops to the tfl.quantize/tfl.dequantize
|
||||
// ops in the function.
|
||||
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func);
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
@ -34,7 +35,6 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -457,11 +457,9 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
||||
// This value isn't an expressed type (float), skip.
|
||||
if (!new_type) return;
|
||||
|
||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||
auto quantize =
|
||||
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
|
||||
auto dequantize = builder_.create<TFL::DequantizeOp>(loc, expressed_type,
|
||||
quantize.output());
|
||||
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
|
||||
auto dequantize = builder_.create<quant::DequantizeCastOp>(
|
||||
loc, expressed_type, quantize.getResult());
|
||||
// `original_result` has a use to `quantize`, so this will replace that use
|
||||
// by the result of `dequantize`. Remember to reset that use afterwards
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
@ -475,7 +473,7 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
|
||||
Value value = op->getResult(index);
|
||||
if (state->pos == RequantizeState::ON_OUTPUT) {
|
||||
Operation *user = value.getUses().begin().getUser();
|
||||
if (llvm::isa<TFL::QuantizeOp>(user)) {
|
||||
if (llvm::isa<quant::QuantizeCastOp>(user)) {
|
||||
// The requantize op is inserted between `quantize` and `dequantize` ops.
|
||||
value = user->getResult(0);
|
||||
builder_.setInsertionPointAfter(user);
|
||||
@ -490,8 +488,8 @@ void QuantizationDriver::RequantizeArg(BlockArgument arg,
|
||||
builder_.setInsertionPointToStart(arg.getOwner());
|
||||
if (value.hasOneUse()) {
|
||||
auto user = value.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
value = q.output();
|
||||
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||
value = q.getResult();
|
||||
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
|
||||
}
|
||||
}
|
||||
@ -518,9 +516,8 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
|
||||
// This value isn't an expressed type (float), skip.
|
||||
if (!new_type) return;
|
||||
|
||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||
auto requantize_op =
|
||||
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
|
||||
builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
|
||||
value.replaceAllUsesWith(requantize_op);
|
||||
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
|
||||
}
|
||||
@ -650,8 +647,8 @@ void QuantizationDriver::SetupAllStates() {
|
||||
// If the argument is quantized, it should only has one user.
|
||||
if (arg.hasOneUse()) {
|
||||
auto user = value.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
value = q.output();
|
||||
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||
value = q.getResult();
|
||||
}
|
||||
}
|
||||
InitializeArgState(arg, value, &value_to_state);
|
||||
@ -659,7 +656,9 @@ void QuantizationDriver::SetupAllStates() {
|
||||
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator() ||
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op) ||
|
||||
llvm::isa<quant::QuantizeCastOp>(op))
|
||||
return;
|
||||
work_list_.push_back(op);
|
||||
|
||||
@ -668,8 +667,8 @@ void QuantizationDriver::SetupAllStates() {
|
||||
if (auto *inst = operand.getDefiningOp()) {
|
||||
// If the operand comes from a tfl.dequantize op, we use the quantized
|
||||
// input of this tfl.dequantize op to set the state.
|
||||
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
|
||||
operand = dq.input();
|
||||
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(inst)) {
|
||||
operand = dq.arg();
|
||||
}
|
||||
}
|
||||
InitializeOperandState(op, i, operand, &value_to_state);
|
||||
@ -682,8 +681,8 @@ void QuantizationDriver::SetupAllStates() {
|
||||
// create the state and mark it immutable.
|
||||
if (result.hasOneUse()) {
|
||||
auto user = result.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
result = q.output();
|
||||
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||
result = q.getResult();
|
||||
}
|
||||
}
|
||||
InitializeResultState(op, res, result, &value_to_state);
|
||||
|
@ -30,7 +30,6 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
@ -113,8 +113,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
Type result_type = quant_type.castFromExpressedType(op.getType());
|
||||
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
|
||||
TypeAttr::get(result_type));
|
||||
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg());
|
||||
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
|
||||
op.getResult().replaceAllUsesWith(dq);
|
||||
q.getOperation()->replaceUsesOfWith(dq, op.arg());
|
||||
@ -316,7 +315,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
|
||||
PatternMatchResult matchAndRewrite(Q op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Type output_type = op.output().getType();
|
||||
Type output_type = op.getResult().getType();
|
||||
auto qtype = QType::getQuantizedElementType(output_type);
|
||||
if (!qtype || qtype.isSigned()) return this->matchFailure();
|
||||
|
||||
@ -355,8 +354,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
if (!new_qtype) return this->matchFailure();
|
||||
Type new_output_type = new_qtype.castFromExpressedType(
|
||||
QType::castToExpressedType(output_type));
|
||||
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.input(),
|
||||
TypeAttr::get(new_output_type));
|
||||
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -149,11 +150,11 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
builder.getF64FloatAttr(min_max.second), /*quant_dim=*/-1, num_bits,
|
||||
narrow_range, is_signed);
|
||||
builder.setInsertionPoint(block, insertion_point);
|
||||
auto q_op = builder.create<TFL::QuantizeOp>(loc, params.getValue(), arg,
|
||||
params);
|
||||
auto dq_op =
|
||||
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
|
||||
arg.replaceAllUsesWith(dq_op.output());
|
||||
auto q_op =
|
||||
builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);
|
||||
auto dq_op = builder.create<quant::DequantizeCastOp>(loc, input_type,
|
||||
q_op.getResult());
|
||||
arg.replaceAllUsesWith(dq_op.getResult());
|
||||
q_op.setOperand(arg);
|
||||
}
|
||||
}
|
||||
@ -176,12 +177,14 @@ bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) {
|
||||
}
|
||||
|
||||
using PrepareQuantStats =
|
||||
TFL::ConvertStatsToQDQs<TFL::QuantizeOp, TFL::DequantizeOp>;
|
||||
TFL::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
|
||||
void PrepareQuantizePass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
MLIRContext* ctx = func.getContext();
|
||||
|
||||
ConvertTFLQuantOpsToMlirQuantOps(func);
|
||||
|
||||
if (quant_specs_.post_training_quantization) {
|
||||
RemoveRedundantStats(func);
|
||||
} else {
|
||||
@ -198,7 +201,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||
if (is_signed) {
|
||||
patterns.insert<ConvertUnsignedToSigned<TFL::QuantizeOp>>(ctx);
|
||||
patterns.insert<ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||
// Convert quant stats to int8 quantization parameters.
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(8, false, true, ctx);
|
||||
@ -213,6 +216,8 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// values (tensors).
|
||||
ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel,
|
||||
GetOpQuantSpec);
|
||||
|
||||
ConvertMlirQuantOpsToTFLQuantOps(func);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user