Add the same scale decompose function
PiperOrigin-RevId: 309438076 Change-Id: I41144256127bab4d5796c382239a7cfad2f9c5ad
This commit is contained in:
parent
5f1964bdea
commit
ae76544efc
|
@ -82,10 +82,20 @@ LogicalResult DeviceTarget::RegisterKernel(
|
|||
return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
|
||||
}
|
||||
|
||||
namespace ph = std::placeholders;
|
||||
|
||||
LogicalResult DeviceTarget::RegisterKernel(
|
||||
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
|
||||
const ScaleConstraintType constraint) {
|
||||
return specs_[kernel].Add(signature, {constraint, {}});
|
||||
if (failed(specs_[kernel].Add(signature, {constraint, {}}))) return failure();
|
||||
switch (constraint) {
|
||||
case ScaleConstraintType::OutputInputSameScale:
|
||||
specs_[kernel].WithImpl(std::bind(&DeviceTarget::DecomposeSameScale,
|
||||
ph::_1, ph::_2, ph::_3, ph::_4));
|
||||
return success();
|
||||
default:
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
|
||||
|
@ -132,5 +142,40 @@ LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DeviceTarget::DecomposeSameScale(
|
||||
Operation* op, quant::QuantizedMultipliers* input_multipliers,
|
||||
quant::QuantizedMultipliers* output_multipliers,
|
||||
quant::QuantizedRanges* output_ranges) {
|
||||
auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
|
||||
if (!rop) return failure();
|
||||
|
||||
// input multipliers
|
||||
for (int i = 0; i < op->getNumOperands(); ++i) {
|
||||
input_multipliers->push_back(kUnitQuantizedMultiplier);
|
||||
}
|
||||
|
||||
// output multipliers
|
||||
for (int i = 0; i < op->getNumResults(); ++i) {
|
||||
output_multipliers->push_back(kUnitQuantizedMultiplier);
|
||||
}
|
||||
|
||||
auto o_spec = rop.output_specs()[0]
|
||||
.cast<TypeAttr>()
|
||||
.getValue()
|
||||
.dyn_cast<quant::UniformQuantizedType>();
|
||||
if (!o_spec) return failure();
|
||||
|
||||
// output ranges
|
||||
auto min = rop.getAttrOfType<FloatAttr>("min");
|
||||
auto max = rop.getAttrOfType<FloatAttr>("max");
|
||||
output_ranges->push_back(quant::CalculateQuantizedRange(
|
||||
o_spec.getScale(), o_spec.getZeroPoint(),
|
||||
(min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
|
||||
(max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
|
||||
o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
|
|
@ -168,6 +168,12 @@ class DeviceTarget {
|
|||
quant::QuantizedMultipliers* output_multipliers,
|
||||
quant::QuantizedRanges* output_ranges);
|
||||
|
||||
// For "reshape" type of kernels.
|
||||
static LogicalResult DecomposeSameScale(
|
||||
Operation* op, quant::QuantizedMultipliers* input_multipliers,
|
||||
quant::QuantizedMultipliers* output_multipliers,
|
||||
quant::QuantizedRanges* output_ranges);
|
||||
|
||||
// A set of parameters are required to build the signatures.
|
||||
FloatType f32_;
|
||||
IntegerType i8_, i32_;
|
||||
|
|
Loading…
Reference in New Issue