Get per-channel bias quantization parameters

The bias quantization parameters are derived from the inputs and weights. When
the input is per-tensor and weight is per-channel, the scale of the inputs is
broadcasted to each channel. The quantization dimension indexes and dimension
sizes from the per-channel quantization parametes need to be matched.
Otherwise, nullptr is returned.

PiperOrigin-RevId: 270840445
This commit is contained in:
Feng Liu 2019-09-23 23:36:41 -07:00 committed by TensorFlower Gardener
parent aebcf43046
commit 4aa7dbce08
2 changed files with 106 additions and 15 deletions

View File

@ -17,11 +17,14 @@ limitations under the License.
#include <cstdint>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
@ -48,6 +51,7 @@ static Type GetQuantizedType(Builder builder, Type input_type,
if (!shape || min.size() != shape.getDimSize(shape.getRank() - 1)) {
return {};
}
// TODO(b/141508873): the quantization dim is set to the last dimension.
quantizedEleType = quant::fakeQuantAttrsToType(
builder.getUnknownLoc(), storage_type_width, shape.getRank() - 1, min,
max, narrow_range, converter.expressedType, is_signed);
@ -152,22 +156,69 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
const std::vector<quant::QuantizedType>& op_types) {
if (op_types.empty()) return {};
double scale = 1.0;
for (unsigned i = 0, e = op_types.size(); i != e; ++i) {
auto qtype = op_types[i].dyn_cast_or_null<quant::UniformQuantizedType>();
if (!qtype) return {};
scale *= qtype.getScale();
int axis_size = 1;
int32_t quant_dim = -1;
Type expressed_type;
// Requires all the op types are valid UniformQuantizedTypes or
// UniformQuantizedPerAxisTypes and also have same expressed type. For all
// the UniformQuantizedPerAxisTypes, the quantization dimension index and
// dimension sizes are same.
for (auto op_type : op_types) {
if (!op_type) return {};
if (expressed_type && expressed_type != op_type.getExpressedType()) {
return {};
}
expressed_type = op_type.getExpressedType();
if (auto type = op_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
if ((axis_size != 1 && axis_size != type.getScales().size())) return {};
if (quant_dim != -1 && quant_dim != type.getQuantizedDimension())
return {};
axis_size = type.getScales().size();
quant_dim = type.getQuantizedDimension();
} else if (!op_type.isa<quant::UniformQuantizedType>()) {
return {};
}
}
// The scale from the UniformQuantizedTypes is broadcasted if there are
// UniformQuantizedPerAxisTypes.
llvm::SmallVector<double, 4> scales(axis_size, 1.0);
for (auto op_type : op_types) {
if (auto type = op_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
for (auto index_scale : llvm::enumerate(type.getScales())) {
scales[index_scale.index()] *= index_scale.value();
}
} else if (auto type = op_type.dyn_cast<quant::UniformQuantizedType>()) {
for (int index = 0; index != axis_size; ++index) {
scales[index] *= type.getScale();
}
}
}
// Builds the result quantized type, which has signed 32 bits storage type.
Builder builder(expressed_type.getContext());
IntegerType storage_type = builder.getIntegerType(32);
int64_t storage_type_min =
quant::QuantizedType::getDefaultMininumForInteger(/*isSigned=*/true, 32);
int64_t storage_type_max =
quant::QuantizedType::getDefaultMaxinumForInteger(/*isSigned=*/true, 32);
if (axis_size == 1) {
return quant::UniformQuantizedType::getChecked(
/*flags=*/true, storage_type, expressed_type, scales[0],
/*zeroPoint=*/0, storage_type_min, storage_type_max,
builder.getUnknownLoc());
} else {
llvm::SmallVector<int64_t, 4> zero_points(axis_size, 0);
// TODO(b/141508873): Assume the bias is a 1-D tensor, and set the
// quantization dim to the last dimension, which is 0. If the bias rank is
// larger than 1, this returned quantized type couldn't be used to quantize
// the bias.
return quant::UniformQuantizedPerAxisType::getChecked(
/*flags=*/true, storage_type, expressed_type, scales, zero_points,
/*quantizedDimension=*/0, storage_type_min, storage_type_max,
builder.getUnknownLoc());
}
auto type = op_types.back().cast<quant::UniformQuantizedType>();
Builder builder(type.getContext());
// TODO(fengliuai): make the bit width configurable.
IntegerType storageType = builder.getIntegerType(32);
return quant::UniformQuantizedType::getChecked(
/*flags=*/true, storageType, type.getExpressedType(), scale,
/*zeroPoint=*/0,
quant::QuantizedType::getDefaultMininumForInteger(/*isSigned=*/true, 32),
quant::QuantizedType::getDefaultMaxinumForInteger(/*isSigned=*/true, 32),
builder.getUnknownLoc());
}
ElementsAttr Quantize(Attribute real_value, Type tensor_type) {

View File

@ -13,6 +13,46 @@ func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform<u8:f32, 7.84313725490
// CHECK: return %2
}
// CHECK-LABEL: QuantizeConv2DPerChannel
func @QuantizeConv2DPerChannel(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 1.5>>,
%arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x112x112x32xf32> {
%bias = constant dense<1.0> : tensor<32xf32>
%input = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 1.5>>) -> tensor<1x224x224x3xf32>
%weight = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32:3, {1.0,2.0,3.0}>>) -> tensor<32x3x3x3xf32>
%conv = "tfl.conv_2d"(%input, %weight, %bias) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32,
fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
: (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
return %conv : tensor<1x112x112x32xf32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform<i32:f32:0, {1.500000e+00,3.000000e+00,4.500000e+00}>>}
// CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]])
// CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0)
// CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1)
// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[in]], %[[w]], %[[bias]])
// CHECK-NEXT: return %[[conv]]
}
// CHECK-LABEL: QuantizeConv2DPerChannels
func @QuantizeConv2DPerChannels(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32:3, {1.0,2.0,3.0}>>,
%arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x112x112x32xf32> {
%bias = constant dense<1.0> : tensor<32xf32>
%input = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x224x224x3xf32>
%weight = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32:3, {1.0,2.0,3.0}>>) -> tensor<32x3x3x3xf32>
%conv = "tfl.conv_2d"(%input, %weight, %bias) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32,
fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
: (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
return %conv : tensor<1x112x112x32xf32>
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform<i32:f32:0, {1.000000e+00,4.000000e+00,9.000000e+00}>>}
// CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]])
// CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0)
// CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1)
// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[in]], %[[w]], %[[bias]])
// CHECK-NEXT: return %[[conv]]
}
// CHECK-LABEL: QuantizeConv2D
func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):