Add the scale propagation pass to the pipeline
PiperOrigin-RevId: 309259089 Change-Id: I6f1f1a2bc767a02f1c792ad9a4da0cb885b84de4
This commit is contained in:
parent
068526d1d7
commit
bac40c0346
@ -32,6 +32,7 @@ namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
constexpr int k8Bits = 8;
|
||||
constexpr int k32Bits = 32;
|
||||
constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
|
||||
|
||||
DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
|
||||
@ -39,20 +40,20 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
|
||||
i8_ = IntegerType::get(k8Bits, ctx_);
|
||||
i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
|
||||
i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
|
||||
i32_ = IntegerType::get(k32Bits, ctx_);
|
||||
i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
|
||||
i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
|
||||
any_ = AnyQuantizedType();
|
||||
qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
|
||||
qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
|
||||
qi32_ = AnyQuantizedType::get(kSigned, i32_, f32_, i32_min_, i32_max_);
|
||||
assert(qi8n_ == qi8n_);
|
||||
}
|
||||
|
||||
Optional<KernelSpec> DeviceTarget::GetKernelSpec(QuantizeRegionOp op) const {
|
||||
auto kernel_specs_it = specs_.find(op.logical_kernel());
|
||||
Optional<KernelSpec> DeviceTarget::GetKernelSpec(
|
||||
llvm::StringRef kernel, const KernelSpecs::Signature& signature) const {
|
||||
auto kernel_specs_it = specs_.find(kernel);
|
||||
if (kernel_specs_it == specs_.end()) return llvm::None;
|
||||
|
||||
KernelSpecs::Signature signature;
|
||||
signature.reserve(op.input_specs().size() + op.output_specs().size());
|
||||
AppendToSignature(op.input_specs(), &signature);
|
||||
AppendToSignature(op.output_specs(), &signature);
|
||||
return kernel_specs_it->getValue().Find(signature);
|
||||
}
|
||||
|
||||
@ -62,6 +63,19 @@ ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const {
|
||||
return kernel_specs_it->second.GetDecomposeFn();
|
||||
}
|
||||
|
||||
void DeviceTarget::AppendToSignature(Type spec,
|
||||
KernelSpecs::Signature* signature) {
|
||||
if (auto quant = spec.dyn_cast_or_null<UniformQuantizedType>()) {
|
||||
signature->push_back(AnyQuantizedType::get(
|
||||
quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
|
||||
quant.getStorageTypeMin(), quant.getStorageTypeMax()));
|
||||
} else if (auto any = spec.dyn_cast_or_null<AnyQuantizedType>()) {
|
||||
signature->push_back(any);
|
||||
} else { // float
|
||||
signature->push_back(AnyQuantizedType());
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult DeviceTarget::RegisterKernel(
|
||||
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
|
||||
const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
|
||||
@ -74,22 +88,6 @@ LogicalResult DeviceTarget::RegisterKernel(
|
||||
return specs_[kernel].Add(signature, {constraint, {}});
|
||||
}
|
||||
|
||||
void DeviceTarget::AppendToSignature(ArrayAttr specs_attr,
|
||||
KernelSpecs::Signature* signature) const {
|
||||
for (auto attr : specs_attr) {
|
||||
Type spec = attr.cast<TypeAttr>().getValue();
|
||||
if (auto quant = spec.dyn_cast<UniformQuantizedType>()) {
|
||||
signature->push_back(AnyQuantizedType::get(
|
||||
quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
|
||||
quant.getStorageTypeMin(), quant.getStorageTypeMax()));
|
||||
} else if (auto any = spec.dyn_cast<AnyQuantizedType>()) {
|
||||
signature->push_back(any);
|
||||
} else { // float
|
||||
signature->push_back({});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
|
||||
Operation* op, quant::QuantizedMultipliers* input_multipliers,
|
||||
quant::QuantizedMultipliers* output_multipliers,
|
||||
|
@ -134,11 +134,18 @@ class DeviceTarget {
|
||||
explicit DeviceTarget(MLIRContext* ctx);
|
||||
|
||||
// Retrieves the kernel spec for the quant region op.
|
||||
Optional<KernelSpec> GetKernelSpec(quant::QuantizeRegionOp op) const;
|
||||
Optional<KernelSpec> GetKernelSpec(
|
||||
llvm::StringRef kernel, const KernelSpecs::Signature& signature) const;
|
||||
|
||||
// Retrieves the scale decomposition function for the quant region op.
|
||||
ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const;
|
||||
|
||||
// converts specification to signature:
|
||||
// - UniformedQuantizedType -> AnyQuantizedType
|
||||
// - AnyQuantizedType (int) -> AnyQuantizedType
|
||||
// - Float -> {}
|
||||
static void AppendToSignature(Type spec, KernelSpecs::Signature* signature);
|
||||
|
||||
protected:
|
||||
// Adds the kernel spec with the custom scale function for the kernel.
|
||||
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
||||
@ -154,13 +161,6 @@ class DeviceTarget {
|
||||
// added before.
|
||||
KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; }
|
||||
|
||||
// converts specification to signature:
|
||||
// - UniformedQuantizedType -> AnyQuantizedType
|
||||
// - AnyQuantizedType (int) -> AnyQuantizedType
|
||||
// - Float -> {}
|
||||
void AppendToSignature(ArrayAttr specs_attr,
|
||||
KernelSpecs::Signature* signature) const;
|
||||
|
||||
// For "mulmat->add" type of kernels, convert the scales of all the ports to
|
||||
// multipliers.
|
||||
static LogicalResult DecomposeMultiplyAccumulateScale(
|
||||
@ -170,9 +170,9 @@ class DeviceTarget {
|
||||
|
||||
// A set of parameters are required to build the signatures.
|
||||
FloatType f32_;
|
||||
IntegerType i8_;
|
||||
int64_t i8_min_, i8_max_;
|
||||
AnyQuantizedType any_, qi8_, qi8n_;
|
||||
IntegerType i8_, i32_;
|
||||
int64_t i8_min_, i8_max_, i32_min_, i32_max_;
|
||||
AnyQuantizedType any_, qi8_, qi8n_, qi32_;
|
||||
|
||||
private:
|
||||
// Maps the kernel names to all the available kernels.
|
||||
|
@ -64,10 +64,23 @@ std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
|
||||
return all_ops;
|
||||
}
|
||||
|
||||
KernelSpecs::Signature QuantizeContext::GetSignature(QuantizeRegionOp op) {
|
||||
KernelSpecs::Signature signature;
|
||||
signature.reserve(op.input_specs().size() + op.output_specs().size());
|
||||
for (int i = 0; i < op.getNumOperands(); ++i) {
|
||||
DeviceTarget::AppendToSignature(GetOperandParams(op, i), &signature);
|
||||
}
|
||||
for (int i = 0; i < op.getNumResults(); ++i) {
|
||||
DeviceTarget::AppendToSignature(GetResultParams(op, i), &signature);
|
||||
}
|
||||
return signature;
|
||||
}
|
||||
|
||||
LogicalResult QuantizeContext::Handle(
|
||||
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
|
||||
bool *changed) {
|
||||
auto spec = target_spec_.GetKernelSpec(op);
|
||||
auto signature = GetSignature(op);
|
||||
auto spec = target_spec_.GetKernelSpec(op.logical_kernel(), signature);
|
||||
if (!spec.hasValue()) {
|
||||
op.emitWarning(
|
||||
"Couldn't find kernel from the registeration for quantization.");
|
||||
|
@ -107,6 +107,9 @@ class QuantizeContext {
|
||||
return states_manager_.GetOperandParams(op, index);
|
||||
}
|
||||
|
||||
// Return the signature of the op.
|
||||
KernelSpecs::Signature GetSignature(QuantizeRegionOp op);
|
||||
|
||||
// A heuristic to get quantization parameters satisfies the same scale
|
||||
// constraints:
|
||||
// - If there are immutable states,
|
||||
|
Loading…
Reference in New Issue
Block a user