Add the same scale decompose function

PiperOrigin-RevId: 309438076
Change-Id: I41144256127bab4d5796c382239a7cfad2f9c5ad
This commit is contained in:
Feng Liu 2020-05-01 11:06:39 -07:00 committed by TensorFlower Gardener
parent 5f1964bdea
commit ae76544efc
2 changed files with 52 additions and 1 deletions

View File

@ -82,10 +82,20 @@ LogicalResult DeviceTarget::RegisterKernel(
return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
}
namespace ph = std::placeholders;
LogicalResult DeviceTarget::RegisterKernel(
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
const ScaleConstraintType constraint) {
return specs_[kernel].Add(signature, {constraint, {}});
if (failed(specs_[kernel].Add(signature, {constraint, {}}))) return failure();
switch (constraint) {
case ScaleConstraintType::OutputInputSameScale:
specs_[kernel].WithImpl(std::bind(&DeviceTarget::DecomposeSameScale,
ph::_1, ph::_2, ph::_3, ph::_4));
return success();
default:
return failure();
}
}
LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
@ -132,5 +142,40 @@ LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
return success();
}
LogicalResult DeviceTarget::DecomposeSameScale(
Operation* op, quant::QuantizedMultipliers* input_multipliers,
quant::QuantizedMultipliers* output_multipliers,
quant::QuantizedRanges* output_ranges) {
auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
if (!rop) return failure();
// input multipliers
for (int i = 0; i < op->getNumOperands(); ++i) {
input_multipliers->push_back(kUnitQuantizedMultiplier);
}
// output multipliers
for (int i = 0; i < op->getNumResults(); ++i) {
output_multipliers->push_back(kUnitQuantizedMultiplier);
}
auto o_spec = rop.output_specs()[0]
.cast<TypeAttr>()
.getValue()
.dyn_cast<quant::UniformQuantizedType>();
if (!o_spec) return failure();
// output ranges
auto min = rop.getAttrOfType<FloatAttr>("min");
auto max = rop.getAttrOfType<FloatAttr>("max");
output_ranges->push_back(quant::CalculateQuantizedRange(
o_spec.getScale(), o_spec.getZeroPoint(),
(min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
(max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
return success();
}
} // namespace quant
} // namespace mlir

View File

@ -168,6 +168,12 @@ class DeviceTarget {
quant::QuantizedMultipliers* output_multipliers,
quant::QuantizedRanges* output_ranges);
// For "reshape" type of kernels.
static LogicalResult DecomposeSameScale(
Operation* op, quant::QuantizedMultipliers* input_multipliers,
quant::QuantizedMultipliers* output_multipliers,
quant::QuantizedRanges* output_ranges);
// A set of parameters are required to build the signatures.
FloatType f32_;
IntegerType i8_, i32_;