Add the propagation algorithm based on the quant region and target spec
One simple test is added to demonstrate how the custom scale function is used. PiperOrigin-RevId: 303865677 Change-Id: Ib9eeff4da4dba090fe7ceaa8c1bac97f1c92894f
This commit is contained in:
parent
f3dcd9dc11
commit
877d642a1a
@ -123,3 +123,19 @@ cc_library(
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantization_context",
|
||||
srcs = ["quantization_context.cc"],
|
||||
hdrs = ["quantization_context.h"],
|
||||
deps = [
|
||||
":device_target",
|
||||
":quantization_lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,239 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
#define DEBUG_TYPE "quantization-context"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec)
|
||||
: func_(func), target_spec_(spec) {
|
||||
llvm::DenseMap<Value, int> value_to_state;
|
||||
func.walk([&](quant::QuantizeRegionOp op) {
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
states_manager_.InitializeOperandState(op, i, &value_to_state);
|
||||
}
|
||||
|
||||
for (int res = 0, e = op.getNumResults(); res != e; ++res) {
|
||||
states_manager_.InitializeResultState(op, res, &value_to_state);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
llvm::ArrayRef<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
|
||||
llvm::SmallVector<quant::QuantizeRegionOp, 64> all_ops;
|
||||
func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); });
|
||||
return all_ops;
|
||||
}
|
||||
|
||||
LogicalResult QuantizeContext::Handle(
|
||||
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
|
||||
bool *changed) {
|
||||
auto spec = target_spec_.Get(op);
|
||||
if (!spec.hasValue()) {
|
||||
op.emitWarning(
|
||||
"Couldn't find kernel from the registeration for quantization.");
|
||||
return success();
|
||||
}
|
||||
switch (spec->type) {
|
||||
case ScaleConstraintType::OutputInputFreeScale: {
|
||||
// no propagation.
|
||||
*changed = false;
|
||||
break;
|
||||
}
|
||||
case ScaleConstraintType::CustomScale: {
|
||||
if (failed(spec->scale_fn(this, op, new_items, changed))) {
|
||||
return failure();
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
llvm_unreachable("no implementation.");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult QuantizeContext::Finalize() {
|
||||
MLIRContext *context = func_.getContext();
|
||||
func_.walk([&](quant::QuantizeRegionOp op) {
|
||||
llvm::SmallVector<Attribute, 4> input_specs;
|
||||
auto original_input_specs = op.input_specs().getValue();
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
auto &state = states_manager_.GetOperandQuantState(op, i);
|
||||
auto &requantize = states_manager_.GetOperandRequantizeState(op, i);
|
||||
if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
|
||||
input_specs.push_back(original_input_specs[i]);
|
||||
} else if (requantize.pos == RequantizeState::ON_OUTPUT) {
|
||||
input_specs.push_back(TypeAttr::get(requantize.params));
|
||||
} else {
|
||||
input_specs.push_back(TypeAttr::get(state.params));
|
||||
}
|
||||
}
|
||||
op.setAttr("input_specs", ArrayAttr::get(input_specs, context));
|
||||
|
||||
llvm::SmallVector<Attribute, 4> output_specs;
|
||||
auto original_output_specs = op.output_specs().getValue();
|
||||
for (int res = 0, e = op.getNumResults(); res != e; ++res) {
|
||||
auto &state = states_manager_.GetResultQuantState(op, res);
|
||||
auto &requantize = states_manager_.GetResultRequantizeState(op, res);
|
||||
if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
|
||||
output_specs.push_back(original_output_specs[res]);
|
||||
} else if (requantize.pos == RequantizeState::ON_INPUT) {
|
||||
output_specs.push_back(TypeAttr::get(requantize.params));
|
||||
} else {
|
||||
output_specs.push_back(TypeAttr::get(state.params));
|
||||
}
|
||||
}
|
||||
op.setAttr("output_specs", ArrayAttr::get(output_specs, context));
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
void QuantizeContext::DumpStates(QuantizeRegionOp current_op) {
|
||||
if (current_op) {
|
||||
llvm::errs() << "\n\n\n" << current_op.logical_kernel() << "\n";
|
||||
}
|
||||
func_.walk([&](QuantizeRegionOp op) {
|
||||
if (current_op == op) llvm::errs() << "===>>>";
|
||||
llvm::errs() << op.logical_kernel() << " : (";
|
||||
for (auto i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto params = GetOperandParams(op, i))
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
llvm::errs() << "_";
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ") -> (";
|
||||
for (auto i = 0; i < op.getNumResults(); ++i) {
|
||||
if (auto params = GetResultParams(op, i))
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
llvm::errs() << "_";
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ")\n";
|
||||
});
|
||||
}
|
||||
|
||||
int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op,
|
||||
int index, bool as_result) {
|
||||
Attribute params_attr;
|
||||
if (as_result) {
|
||||
params_attr = op.output_specs()[index];
|
||||
} else {
|
||||
params_attr = op.input_specs()[index];
|
||||
}
|
||||
QuantParams params =
|
||||
params_attr.cast<TypeAttr>().getValue().dyn_cast<QuantParams>();
|
||||
bool immutable = !EmptyParams(params);
|
||||
int next_state_index = states_.size();
|
||||
states_.push_back({params, immutable});
|
||||
if (as_result) {
|
||||
result_states_.insert({{op, index}, next_state_index});
|
||||
} else {
|
||||
operand_states_.insert({{op, index}, next_state_index});
|
||||
}
|
||||
return next_state_index;
|
||||
}
|
||||
|
||||
void QuantizeContext::StatesManager::InitializeOperandState(
|
||||
quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
|
||||
Value in = op.getOperand(index);
|
||||
auto cached = cache->insert({in, 0});
|
||||
if (!cached.second) {
|
||||
operand_states_.insert({{op, index}, cached.first->second});
|
||||
return;
|
||||
}
|
||||
cached.first->second = InitializeState(op, index, /*as_result=*/false);
|
||||
}
|
||||
|
||||
void QuantizeContext::StatesManager::InitializeResultState(
|
||||
quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
|
||||
auto res = op.getResult(index);
|
||||
auto cached = cache->insert({res, 0});
|
||||
if (!cached.second) {
|
||||
result_states_.insert({{op, index}, cached.first->second});
|
||||
return;
|
||||
}
|
||||
cached.first->second = InitializeState(op, index, /*as_result=*/true);
|
||||
}
|
||||
|
||||
bool QuantizeContext::StatesManager::SetConstantResultParams(Operation *op) {
|
||||
llvm_unreachable("no implementation.");
|
||||
return false;
|
||||
}
|
||||
|
||||
bool QuantizeContext::StatesManager::SetResultParams(Operation *op,
|
||||
int res_index,
|
||||
QuantParams params) {
|
||||
auto &state = GetResultQuantState(op, res_index);
|
||||
if (state.params == params) {
|
||||
return false;
|
||||
}
|
||||
if (!state.IsEmpty()) {
|
||||
auto &rescale = GetResultRequantizeState(op, res_index);
|
||||
rescale.params = params;
|
||||
rescale.pos = RequantizeState::ON_INPUT;
|
||||
return false;
|
||||
}
|
||||
state.params = params;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool QuantizeContext::StatesManager::SetOperandParams(Operation *op, int index,
|
||||
QuantParams params) {
|
||||
auto &state = GetOperandQuantState(op, index);
|
||||
if (state.params == params) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!state.IsEmpty()) {
|
||||
auto &rescale = GetOperandRequantizeState(op, index);
|
||||
rescale.params = params;
|
||||
rescale.pos = RequantizeState::ON_OUTPUT;
|
||||
return false;
|
||||
}
|
||||
state.params = params;
|
||||
return true;
|
||||
}
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
@ -0,0 +1,217 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
|
||||
|
||||
// The state for each op result during the quantization parameters propagation.
|
||||
struct QuantState {
|
||||
// Quantization parameters propagated to an op result.
|
||||
QuantParams params;
|
||||
// A flag indicates this state (the params) shouldn't be changed after it is
|
||||
// initialized. This flag will be set to true if the quantization parameters
|
||||
// are from the quantization-aware training.
|
||||
const bool immutable;
|
||||
|
||||
bool IsEmpty() { return EmptyParams(params); }
|
||||
};
|
||||
|
||||
// The state for rescaling the propagated quantization parameters. This can be
|
||||
// on the input side to satisfy the constraint of previous operation, or on the
|
||||
// output side to satisfy the constraint of the next operation.
|
||||
struct RequantizeState {
|
||||
// Sometimes, we have to "requantize" the quantization result to satisfy all
|
||||
// the constraints. The "requantize" can happen either on the input or output
|
||||
// of the quantization result.
|
||||
enum RequantizePosition {
|
||||
NO_REQUANTIZE,
|
||||
ON_INPUT,
|
||||
ON_OUTPUT
|
||||
} pos = NO_REQUANTIZE;
|
||||
|
||||
// Quantization parameters will be used to add the requantize ops.
|
||||
QuantParams params;
|
||||
};
|
||||
|
||||
// This class manages all the intermedaite quantization states.
|
||||
class QuantizeContext {
|
||||
public:
|
||||
QuantizeContext(FuncOp func, const DeviceTarget &spec);
|
||||
|
||||
// Returns all the quant region ops.
|
||||
ArrayRef<quant::QuantizeRegionOp> GetAllOps();
|
||||
|
||||
// For each quant region op, propagates its quantization parameters according
|
||||
// to the kernel specification and also returns the adjcent quant region ops
|
||||
// which get the new quantization parameters propagated.
|
||||
LogicalResult Handle(quant::QuantizeRegionOp op,
|
||||
llvm::SmallVectorImpl<Operation *> *new_items,
|
||||
bool *changed);
|
||||
|
||||
// Updates the port quantization specifications of all the quant region ops
|
||||
// with the propagation results.
|
||||
LogicalResult Finalize();
|
||||
|
||||
// Dumps the states stores in the state manager.
|
||||
void DumpStates(QuantizeRegionOp current_op = {});
|
||||
|
||||
// Update the quantization parameter for certain result of the op. By this
|
||||
// method, the quantization parameter is propagated to all the users of the
|
||||
// result as well.
|
||||
bool SetResultParams(Operation *op, int index, QuantParams params) {
|
||||
return states_manager_.SetResultParams(op, index, params);
|
||||
}
|
||||
|
||||
// Update the quantization parameter for certain operand of the op. By this
|
||||
// method, the quantization parameter is propagated to the defining op of
|
||||
// operand as well.
|
||||
bool SetOperandParams(Operation *op, int index, QuantParams params) {
|
||||
return states_manager_.SetOperandParams(op, index, params);
|
||||
}
|
||||
|
||||
// Return the quantization parameter of certain result of the op.
|
||||
QuantParams GetResultParams(Operation *op, int index) {
|
||||
return states_manager_.GetResultParams(op, index);
|
||||
}
|
||||
|
||||
// Return the quantization parameter of certain operand of the op.
|
||||
QuantParams GetOperandParams(Operation *op, int index) {
|
||||
return states_manager_.GetOperandParams(op, index);
|
||||
}
|
||||
|
||||
private:
|
||||
class StatesManager {
|
||||
public:
|
||||
// Sets the quantization parameters of the constant result according to its
|
||||
// content.
|
||||
//
|
||||
// Always returns true.
|
||||
bool SetConstantResultParams(Operation *op);
|
||||
|
||||
// Sets the quantization parameters of the result to a fixed value. If any
|
||||
// quantization parameters have been propagated, a `requantize` will happen
|
||||
// on the input of propagated quantization.
|
||||
//
|
||||
// Returns true, if the users of the result needs to be added to the
|
||||
// worklist.
|
||||
bool SetResultParams(Operation *op, int index, QuantParams params);
|
||||
|
||||
// Sets the quantization parameters of the operand to a fixed value. If any
|
||||
// quantization parameters have been propagated, a `requantize` will happen
|
||||
// on the output of propagated quantization.
|
||||
//
|
||||
// Returns true, if the defining op of the operand needs to be added to the
|
||||
// worklist.
|
||||
bool SetOperandParams(Operation *op, int index, QuantParams params);
|
||||
|
||||
// Returns the quantization parameters of the index-th result of the op.
|
||||
QuantParams GetResultParams(Operation *op, int index) {
|
||||
return states_[result_states_[{op, index}]].params;
|
||||
}
|
||||
|
||||
// Returns the quantization parameters of the index-th operand of the op.
|
||||
QuantParams GetOperandParams(Operation *op, int index) {
|
||||
return states_[operand_states_[{op, index}]].params;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class QuantizeContext;
|
||||
|
||||
// Uses the type of `val` to set the initial state of the index-th result if
|
||||
// `as_result` is true or index-th operand if `as_result` is false. The
|
||||
// state is immutable if the type is a quantized type. Returns the index of
|
||||
// this new state in the state vector.
|
||||
int InitializeState(quant::QuantizeRegionOp op, int index, bool as_result);
|
||||
|
||||
// Sets the state of the index-th operand of the op. If this operand is
|
||||
// cached, uses the cached result without creating new entry in the state
|
||||
// vector. Otherwise, allocate a new entry in the state vector.
|
||||
void InitializeOperandState(quant::QuantizeRegionOp op, int index,
|
||||
llvm::DenseMap<Value, int> *cache);
|
||||
|
||||
// Sets the state of the index-th result of the op. If this result is
|
||||
// cached, uses the cached result without creating new entry in the state
|
||||
// vector. Otherwise, allocate a new entry in the state vector.
|
||||
void InitializeResultState(quant::QuantizeRegionOp op, int index,
|
||||
llvm::DenseMap<Value, int> *cache);
|
||||
|
||||
// Returns the state of the index-th operand of the op.
|
||||
QuantState &GetOperandQuantState(Operation *op, int index) {
|
||||
return states_[operand_states_[{op, index}]];
|
||||
}
|
||||
|
||||
// Returns the state of the index-th result of the op.
|
||||
QuantState &GetResultQuantState(Operation *op, int index) {
|
||||
return states_[result_states_[{op, index}]];
|
||||
}
|
||||
|
||||
// Returns the state of the index-th operand of the op.
|
||||
RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
|
||||
return rescale_states_[operand_states_[{op, index}]];
|
||||
}
|
||||
|
||||
// Returns the state of the index-th result of the op.
|
||||
RequantizeState &GetResultRequantizeState(Operation *op, int index) {
|
||||
return rescale_states_[result_states_[{op, index}]];
|
||||
}
|
||||
|
||||
private:
|
||||
// This is used to identify an operand or result of an op. The second
|
||||
// element of this pair is the index of the operand or result.
|
||||
using OpValue = std::pair<mlir::Operation *, int>;
|
||||
|
||||
// The vector contains all the quantization parameters propagated from the
|
||||
// defining operations of the value, or from the quantization aware
|
||||
// training.
|
||||
std::vector<QuantState> states_;
|
||||
|
||||
// The map contains all the quantization parameters which are required to
|
||||
// satisfy the same operands and results constraint. The keys of this map
|
||||
// are the values from `operand_states_` and `result_state_`.
|
||||
std::unordered_map<int, RequantizeState> rescale_states_;
|
||||
|
||||
// Maps of indexes to the propagation state vector from the ops operands,
|
||||
// results and arguments.
|
||||
llvm::DenseMap<OpValue, int> operand_states_;
|
||||
llvm::DenseMap<OpValue, int> result_states_;
|
||||
};
|
||||
|
||||
FuncOp func_;
|
||||
|
||||
DeviceTarget target_spec_;
|
||||
|
||||
StatesManager states_manager_;
|
||||
};
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
|
@ -35,6 +35,7 @@ cc_library(
|
||||
deps = [
|
||||
":cpu_device_target",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||
@ -60,6 +61,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:device_target",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
@ -36,5 +37,23 @@ CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) {
|
||||
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
|
||||
this, ph::_1, ph::_2, ph::_3, ph::_4));
|
||||
}
|
||||
|
||||
LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale(
|
||||
quant::QuantizeContext* ctx, Operation* op,
|
||||
quant::AdjacentOperations* new_items, bool* changed) {
|
||||
auto bias_params = ctx->GetOperandParams(op, 2);
|
||||
if (!EmptyParams(bias_params)) {
|
||||
return success();
|
||||
}
|
||||
std::vector<quant::QuantParams> op_types{ctx->GetOperandParams(op, 0),
|
||||
ctx->GetOperandParams(op, 1)};
|
||||
auto bias_scale = GetUniformQuantizedTypeForBias(op_types);
|
||||
if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) {
|
||||
*changed = true;
|
||||
new_items->push_back(op->getOperand(2).getDefiningOp());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
@ -26,7 +26,9 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> disable_per_channel(
|
||||
@ -59,9 +61,36 @@ struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
|
||||
|
||||
void PropagateQuantPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
// TODO(fengliuai): deprecate this old code generation path.
|
||||
// XLA only support uint8/uint16 quantization for now.
|
||||
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
|
||||
disable_per_channel, GetOpQuantSpec);
|
||||
|
||||
CpuDeviceTarget spec(&getContext());
|
||||
quant::QuantizeContext ctx(func, spec);
|
||||
|
||||
std::vector<quant::QuantizeRegionOp> work_list(ctx.GetAllOps());
|
||||
bool changed = false;
|
||||
while (!work_list.empty()) {
|
||||
quant::QuantizeRegionOp op = work_list.back();
|
||||
work_list.pop_back();
|
||||
|
||||
llvm::SmallVector<Operation *, 4> new_items;
|
||||
if (failed(ctx.Handle(op, &new_items, &changed))) {
|
||||
// The IR is still valid, thus we shouldn't fail.
|
||||
signalPassFailure();
|
||||
}
|
||||
for (auto item : new_items) {
|
||||
if (auto reg = llvm::dyn_cast_or_null<quant::QuantizeRegionOp>(item))
|
||||
work_list.push_back(reg);
|
||||
}
|
||||
}
|
||||
|
||||
if (!changed) return;
|
||||
|
||||
if (failed(ctx.Finalize())) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -0,0 +1,54 @@
|
||||
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_source_no_params
|
||||
func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [f32, f32, f32]
|
||||
// CHECK-SAME: output_specs = [f32]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated_no_narrow_range
|
||||
func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8:f32, 1.0:-128>, f32],
|
||||
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated
|
||||
func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8<-127:127>:f32, 1.0:-128>, f32],
|
||||
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
|
||||
}
|
Loading…
Reference in New Issue
Block a user