Add the scale propagation pass to the pipeline
PiperOrigin-RevId: 309259089 Change-Id: I6f1f1a2bc767a02f1c792ad9a4da0cb885b84de4
This commit is contained in:
parent
068526d1d7
commit
bac40c0346
@ -32,6 +32,7 @@ namespace mlir {
|
|||||||
namespace quant {
|
namespace quant {
|
||||||
|
|
||||||
constexpr int k8Bits = 8;
|
constexpr int k8Bits = 8;
|
||||||
|
constexpr int k32Bits = 32;
|
||||||
constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
|
constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
|
||||||
|
|
||||||
DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
|
DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
|
||||||
@ -39,20 +40,20 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
|
|||||||
i8_ = IntegerType::get(k8Bits, ctx_);
|
i8_ = IntegerType::get(k8Bits, ctx_);
|
||||||
i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
|
i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
|
||||||
i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
|
i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
|
||||||
|
i32_ = IntegerType::get(k32Bits, ctx_);
|
||||||
|
i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
|
||||||
|
i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
|
||||||
any_ = AnyQuantizedType();
|
any_ = AnyQuantizedType();
|
||||||
qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
|
qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
|
||||||
qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
|
qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
|
||||||
|
qi32_ = AnyQuantizedType::get(kSigned, i32_, f32_, i32_min_, i32_max_);
|
||||||
assert(qi8n_ == qi8n_);
|
assert(qi8n_ == qi8n_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<KernelSpec> DeviceTarget::GetKernelSpec(QuantizeRegionOp op) const {
|
Optional<KernelSpec> DeviceTarget::GetKernelSpec(
|
||||||
auto kernel_specs_it = specs_.find(op.logical_kernel());
|
llvm::StringRef kernel, const KernelSpecs::Signature& signature) const {
|
||||||
|
auto kernel_specs_it = specs_.find(kernel);
|
||||||
if (kernel_specs_it == specs_.end()) return llvm::None;
|
if (kernel_specs_it == specs_.end()) return llvm::None;
|
||||||
|
|
||||||
KernelSpecs::Signature signature;
|
|
||||||
signature.reserve(op.input_specs().size() + op.output_specs().size());
|
|
||||||
AppendToSignature(op.input_specs(), &signature);
|
|
||||||
AppendToSignature(op.output_specs(), &signature);
|
|
||||||
return kernel_specs_it->getValue().Find(signature);
|
return kernel_specs_it->getValue().Find(signature);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,6 +63,19 @@ ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const {
|
|||||||
return kernel_specs_it->second.GetDecomposeFn();
|
return kernel_specs_it->second.GetDecomposeFn();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DeviceTarget::AppendToSignature(Type spec,
|
||||||
|
KernelSpecs::Signature* signature) {
|
||||||
|
if (auto quant = spec.dyn_cast_or_null<UniformQuantizedType>()) {
|
||||||
|
signature->push_back(AnyQuantizedType::get(
|
||||||
|
quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
|
||||||
|
quant.getStorageTypeMin(), quant.getStorageTypeMax()));
|
||||||
|
} else if (auto any = spec.dyn_cast_or_null<AnyQuantizedType>()) {
|
||||||
|
signature->push_back(any);
|
||||||
|
} else { // float
|
||||||
|
signature->push_back(AnyQuantizedType());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult DeviceTarget::RegisterKernel(
|
LogicalResult DeviceTarget::RegisterKernel(
|
||||||
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
|
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
|
||||||
const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
|
const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
|
||||||
@ -74,22 +88,6 @@ LogicalResult DeviceTarget::RegisterKernel(
|
|||||||
return specs_[kernel].Add(signature, {constraint, {}});
|
return specs_[kernel].Add(signature, {constraint, {}});
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeviceTarget::AppendToSignature(ArrayAttr specs_attr,
|
|
||||||
KernelSpecs::Signature* signature) const {
|
|
||||||
for (auto attr : specs_attr) {
|
|
||||||
Type spec = attr.cast<TypeAttr>().getValue();
|
|
||||||
if (auto quant = spec.dyn_cast<UniformQuantizedType>()) {
|
|
||||||
signature->push_back(AnyQuantizedType::get(
|
|
||||||
quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
|
|
||||||
quant.getStorageTypeMin(), quant.getStorageTypeMax()));
|
|
||||||
} else if (auto any = spec.dyn_cast<AnyQuantizedType>()) {
|
|
||||||
signature->push_back(any);
|
|
||||||
} else { // float
|
|
||||||
signature->push_back({});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
|
LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
|
||||||
Operation* op, quant::QuantizedMultipliers* input_multipliers,
|
Operation* op, quant::QuantizedMultipliers* input_multipliers,
|
||||||
quant::QuantizedMultipliers* output_multipliers,
|
quant::QuantizedMultipliers* output_multipliers,
|
||||||
|
@ -134,11 +134,18 @@ class DeviceTarget {
|
|||||||
explicit DeviceTarget(MLIRContext* ctx);
|
explicit DeviceTarget(MLIRContext* ctx);
|
||||||
|
|
||||||
// Retrieves the kernel spec for the quant region op.
|
// Retrieves the kernel spec for the quant region op.
|
||||||
Optional<KernelSpec> GetKernelSpec(quant::QuantizeRegionOp op) const;
|
Optional<KernelSpec> GetKernelSpec(
|
||||||
|
llvm::StringRef kernel, const KernelSpecs::Signature& signature) const;
|
||||||
|
|
||||||
// Retrieves the scale decomposition function for the quant region op.
|
// Retrieves the scale decomposition function for the quant region op.
|
||||||
ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const;
|
ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const;
|
||||||
|
|
||||||
|
// converts specification to signature:
|
||||||
|
// - UniformedQuantizedType -> AnyQuantizedType
|
||||||
|
// - AnyQuantizedType (int) -> AnyQuantizedType
|
||||||
|
// - Float -> {}
|
||||||
|
static void AppendToSignature(Type spec, KernelSpecs::Signature* signature);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Adds the kernel spec with the custom scale function for the kernel.
|
// Adds the kernel spec with the custom scale function for the kernel.
|
||||||
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
||||||
@ -154,13 +161,6 @@ class DeviceTarget {
|
|||||||
// added before.
|
// added before.
|
||||||
KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; }
|
KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; }
|
||||||
|
|
||||||
// converts specification to signature:
|
|
||||||
// - UniformedQuantizedType -> AnyQuantizedType
|
|
||||||
// - AnyQuantizedType (int) -> AnyQuantizedType
|
|
||||||
// - Float -> {}
|
|
||||||
void AppendToSignature(ArrayAttr specs_attr,
|
|
||||||
KernelSpecs::Signature* signature) const;
|
|
||||||
|
|
||||||
// For "mulmat->add" type of kernels, convert the scales of all the ports to
|
// For "mulmat->add" type of kernels, convert the scales of all the ports to
|
||||||
// multipliers.
|
// multipliers.
|
||||||
static LogicalResult DecomposeMultiplyAccumulateScale(
|
static LogicalResult DecomposeMultiplyAccumulateScale(
|
||||||
@ -170,9 +170,9 @@ class DeviceTarget {
|
|||||||
|
|
||||||
// A set of parameters are required to build the signatures.
|
// A set of parameters are required to build the signatures.
|
||||||
FloatType f32_;
|
FloatType f32_;
|
||||||
IntegerType i8_;
|
IntegerType i8_, i32_;
|
||||||
int64_t i8_min_, i8_max_;
|
int64_t i8_min_, i8_max_, i32_min_, i32_max_;
|
||||||
AnyQuantizedType any_, qi8_, qi8n_;
|
AnyQuantizedType any_, qi8_, qi8n_, qi32_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Maps the kernel names to all the available kernels.
|
// Maps the kernel names to all the available kernels.
|
||||||
|
@ -64,10 +64,23 @@ std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
|
|||||||
return all_ops;
|
return all_ops;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KernelSpecs::Signature QuantizeContext::GetSignature(QuantizeRegionOp op) {
|
||||||
|
KernelSpecs::Signature signature;
|
||||||
|
signature.reserve(op.input_specs().size() + op.output_specs().size());
|
||||||
|
for (int i = 0; i < op.getNumOperands(); ++i) {
|
||||||
|
DeviceTarget::AppendToSignature(GetOperandParams(op, i), &signature);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < op.getNumResults(); ++i) {
|
||||||
|
DeviceTarget::AppendToSignature(GetResultParams(op, i), &signature);
|
||||||
|
}
|
||||||
|
return signature;
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult QuantizeContext::Handle(
|
LogicalResult QuantizeContext::Handle(
|
||||||
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
|
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
|
||||||
bool *changed) {
|
bool *changed) {
|
||||||
auto spec = target_spec_.GetKernelSpec(op);
|
auto signature = GetSignature(op);
|
||||||
|
auto spec = target_spec_.GetKernelSpec(op.logical_kernel(), signature);
|
||||||
if (!spec.hasValue()) {
|
if (!spec.hasValue()) {
|
||||||
op.emitWarning(
|
op.emitWarning(
|
||||||
"Couldn't find kernel from the registeration for quantization.");
|
"Couldn't find kernel from the registeration for quantization.");
|
||||||
|
@ -107,6 +107,9 @@ class QuantizeContext {
|
|||||||
return states_manager_.GetOperandParams(op, index);
|
return states_manager_.GetOperandParams(op, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the signature of the op.
|
||||||
|
KernelSpecs::Signature GetSignature(QuantizeRegionOp op);
|
||||||
|
|
||||||
// A heuristic to get quantization parameters satisfies the same scale
|
// A heuristic to get quantization parameters satisfies the same scale
|
||||||
// constraints:
|
// constraints:
|
||||||
// - If there are immutable states,
|
// - If there are immutable states,
|
||||||
|
Loading…
Reference in New Issue
Block a user