Add the scale propagation pass to the pipeline

PiperOrigin-RevId: 309259089
Change-Id: I6f1f1a2bc767a02f1c792ad9a4da0cb885b84de4
This commit is contained in:
Feng Liu 2020-04-30 10:59:07 -07:00 committed by TensorFlower Gardener
parent 068526d1d7
commit bac40c0346
4 changed files with 49 additions and 35 deletions

View File

@ -32,6 +32,7 @@ namespace mlir {
namespace quant {
constexpr int k8Bits = 8;
constexpr int k32Bits = 32;
constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
@ -39,20 +40,20 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
i8_ = IntegerType::get(k8Bits, ctx_);
i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
i32_ = IntegerType::get(k32Bits, ctx_);
i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
any_ = AnyQuantizedType();
qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
qi32_ = AnyQuantizedType::get(kSigned, i32_, f32_, i32_min_, i32_max_);
assert(qi8n_ == qi8n_);
}
Optional<KernelSpec> DeviceTarget::GetKernelSpec(QuantizeRegionOp op) const {
auto kernel_specs_it = specs_.find(op.logical_kernel());
Optional<KernelSpec> DeviceTarget::GetKernelSpec(
llvm::StringRef kernel, const KernelSpecs::Signature& signature) const {
auto kernel_specs_it = specs_.find(kernel);
if (kernel_specs_it == specs_.end()) return llvm::None;
KernelSpecs::Signature signature;
signature.reserve(op.input_specs().size() + op.output_specs().size());
AppendToSignature(op.input_specs(), &signature);
AppendToSignature(op.output_specs(), &signature);
return kernel_specs_it->getValue().Find(signature);
}
@ -62,6 +63,19 @@ ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const {
return kernel_specs_it->second.GetDecomposeFn();
}
void DeviceTarget::AppendToSignature(Type spec,
KernelSpecs::Signature* signature) {
if (auto quant = spec.dyn_cast_or_null<UniformQuantizedType>()) {
signature->push_back(AnyQuantizedType::get(
quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
quant.getStorageTypeMin(), quant.getStorageTypeMax()));
} else if (auto any = spec.dyn_cast_or_null<AnyQuantizedType>()) {
signature->push_back(any);
} else { // float
signature->push_back(AnyQuantizedType());
}
}
LogicalResult DeviceTarget::RegisterKernel(
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
@ -74,22 +88,6 @@ LogicalResult DeviceTarget::RegisterKernel(
return specs_[kernel].Add(signature, {constraint, {}});
}
void DeviceTarget::AppendToSignature(ArrayAttr specs_attr,
KernelSpecs::Signature* signature) const {
for (auto attr : specs_attr) {
Type spec = attr.cast<TypeAttr>().getValue();
if (auto quant = spec.dyn_cast<UniformQuantizedType>()) {
signature->push_back(AnyQuantizedType::get(
quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
quant.getStorageTypeMin(), quant.getStorageTypeMax()));
} else if (auto any = spec.dyn_cast<AnyQuantizedType>()) {
signature->push_back(any);
} else { // float
signature->push_back({});
}
}
}
LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
Operation* op, quant::QuantizedMultipliers* input_multipliers,
quant::QuantizedMultipliers* output_multipliers,

View File

@ -134,11 +134,18 @@ class DeviceTarget {
explicit DeviceTarget(MLIRContext* ctx);
// Retrieves the kernel spec for the quant region op.
Optional<KernelSpec> GetKernelSpec(quant::QuantizeRegionOp op) const;
Optional<KernelSpec> GetKernelSpec(
llvm::StringRef kernel, const KernelSpecs::Signature& signature) const;
// Retrieves the scale decomposition function for the quant region op.
ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const;
// converts specification to signature:
// - UniformedQuantizedType -> AnyQuantizedType
// - AnyQuantizedType (int) -> AnyQuantizedType
// - Float -> {}
static void AppendToSignature(Type spec, KernelSpecs::Signature* signature);
protected:
// Adds the kernel spec with the custom scale function for the kernel.
LogicalResult RegisterKernel(llvm::StringRef kernel,
@ -154,13 +161,6 @@ class DeviceTarget {
// added before.
KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; }
// converts specification to signature:
// - UniformedQuantizedType -> AnyQuantizedType
// - AnyQuantizedType (int) -> AnyQuantizedType
// - Float -> {}
void AppendToSignature(ArrayAttr specs_attr,
KernelSpecs::Signature* signature) const;
// For "mulmat->add" type of kernels, convert the scales of all the ports to
// multipliers.
static LogicalResult DecomposeMultiplyAccumulateScale(
@ -170,9 +170,9 @@ class DeviceTarget {
// A set of parameters are required to build the signatures.
FloatType f32_;
IntegerType i8_;
int64_t i8_min_, i8_max_;
AnyQuantizedType any_, qi8_, qi8n_;
IntegerType i8_, i32_;
int64_t i8_min_, i8_max_, i32_min_, i32_max_;
AnyQuantizedType any_, qi8_, qi8n_, qi32_;
private:
// Maps the kernel names to all the available kernels.

View File

@ -64,10 +64,23 @@ std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
return all_ops;
}
KernelSpecs::Signature QuantizeContext::GetSignature(QuantizeRegionOp op) {
KernelSpecs::Signature signature;
signature.reserve(op.input_specs().size() + op.output_specs().size());
for (int i = 0; i < op.getNumOperands(); ++i) {
DeviceTarget::AppendToSignature(GetOperandParams(op, i), &signature);
}
for (int i = 0; i < op.getNumResults(); ++i) {
DeviceTarget::AppendToSignature(GetResultParams(op, i), &signature);
}
return signature;
}
LogicalResult QuantizeContext::Handle(
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
bool *changed) {
auto spec = target_spec_.GetKernelSpec(op);
auto signature = GetSignature(op);
auto spec = target_spec_.GetKernelSpec(op.logical_kernel(), signature);
if (!spec.hasValue()) {
op.emitWarning(
"Couldn't find kernel from the registeration for quantization.");

View File

@ -107,6 +107,9 @@ class QuantizeContext {
return states_manager_.GetOperandParams(op, index);
}
// Return the signature of the op.
KernelSpecs::Signature GetSignature(QuantizeRegionOp op);
// A heuristic to get quantization parameters satisfies the same scale
// constraints:
// - If there are immutable states,