Adds support for bias input to TransposeConv in Hexagon delegate
PiperOrigin-RevId: 349330322 Change-Id: I8af4d3394409af7bebadc0f3a7b89eeebf7bd8de
This commit is contained in:
parent
59f5abfbc8
commit
727a286fa3
tensorflow/lite/delegates/hexagon
@ -176,7 +176,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
static int dummy = 0;
|
||||
stride_shape_ = {1, stride_height, stride_width, 1};
|
||||
auto* stride_node = graph_builder_->AddConstNodeWithData(
|
||||
stride_shape_.data(), (char*)&dummy, sizeof(dummy));
|
||||
stride_shape_.data(), reinterpret_cast<char*>(&dummy), sizeof(dummy));
|
||||
|
||||
// Output dimensions.
|
||||
int output_batch_size, output_height_size, output_width_size,
|
||||
@ -237,13 +237,16 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
dilation_factors_h_w_, padding_type, &space_to_batch_paddings_,
|
||||
&batch_to_space_crops_);
|
||||
auto* dilation_factors_const = graph_builder_->AddConstNodeWithData(
|
||||
dilation_factors_shape.data(), (char*)dilation_factors_h_w_.data(),
|
||||
dilation_factors_shape.data(),
|
||||
reinterpret_cast<char*>(dilation_factors_h_w_.data()),
|
||||
dilation_factors_h_w_.size() * sizeof(stride_height));
|
||||
auto* paddings_const = graph_builder_->AddConstNodeWithData(
|
||||
paddings_shape.data(), (char*)space_to_batch_paddings_.data(),
|
||||
paddings_shape.data(),
|
||||
reinterpret_cast<char*>(space_to_batch_paddings_.data()),
|
||||
space_to_batch_paddings_.size() * sizeof(stride_height));
|
||||
auto* crops_const = graph_builder_->AddConstNodeWithData(
|
||||
paddings_shape.data(), (char*)batch_to_space_crops_.data(),
|
||||
paddings_shape.data(),
|
||||
reinterpret_cast<char*>(batch_to_space_crops_.data()),
|
||||
batch_to_space_crops_.size() * sizeof(stride_height));
|
||||
|
||||
// 1. SpaceToBatch.
|
||||
@ -278,8 +281,9 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
conv_op->AddInput(TensorID(bias_max_node_->GetID(), 0));
|
||||
conv_op->AddInput(TensorID(conv_output_min_const->GetID(), 0));
|
||||
conv_op->AddInput(TensorID(conv_output_max_const->GetID(), 0));
|
||||
if (channel_scales_node_ != nullptr) {
|
||||
conv_op->AddInput(TensorID(channel_scales_node_->GetID(), 0));
|
||||
if (per_channel_quant_.channel_scales_node != nullptr) {
|
||||
conv_op->AddInput(
|
||||
TensorID(per_channel_quant_.channel_scales_node->GetID(), 0));
|
||||
}
|
||||
// The padding is handled by the SpaceToBatch/BatchToSpace ops surrounding
|
||||
// this node. Hence, this op's padding remains VALID only.
|
||||
@ -341,8 +345,8 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
AddInput(TensorID(bias_max_node_->GetID(), 0));
|
||||
AddInput(TensorID(conv_output_min_const->GetID(), 0));
|
||||
AddInput(TensorID(conv_output_max_const->GetID(), 0));
|
||||
if (channel_scales_node_ != nullptr) {
|
||||
AddInput(TensorID(channel_scales_node_->GetID(), 0));
|
||||
if (per_channel_quant_.channel_scales_node != nullptr) {
|
||||
AddInput(TensorID(per_channel_quant_.channel_scales_node->GetID(), 0));
|
||||
}
|
||||
// Outputs
|
||||
output_tensor = AddOutput(sizeof(uint8_t), 4,
|
||||
|
@ -23,6 +23,19 @@ namespace tflite {
|
||||
namespace delegates {
|
||||
namespace hexagon {
|
||||
|
||||
// Stores quantization data for Conv/TransposeConv nodes.
|
||||
// This information is used to handle the per-channel quantized weights & biases
|
||||
// correctly in the Hexagon delegate.
|
||||
struct PerChannelQuantData {
|
||||
// This is initialized while processing quantized weights, and acts as an
|
||||
// input to Hexagon Conv nodes.
|
||||
OpBuilder* channel_scales_node = nullptr;
|
||||
// Scale information is obtained from TfLiteAffineQuantization in the weights
|
||||
// tensor.
|
||||
float* scales_data = nullptr;
|
||||
int num_scale_values = 1;
|
||||
};
|
||||
|
||||
class Conv2dOpBuilder : public OpBuilder {
|
||||
public:
|
||||
explicit Conv2dOpBuilder(GraphBuilder* graph_builder, int op_type)
|
||||
@ -37,23 +50,11 @@ class Conv2dOpBuilder : public OpBuilder {
|
||||
~Conv2dOpBuilder() override;
|
||||
|
||||
private:
|
||||
// TODO(b/142009955): Combine into common util for all types of Conv.
|
||||
TfLiteStatus ProcessPerChannelQuantizedWeights(const TfLiteIntArray* inputs,
|
||||
const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context,
|
||||
float* weights_min,
|
||||
float* weights_max);
|
||||
|
||||
TfLiteStatus InitializeWeightsNodes(const TfLiteIntArray* inputs,
|
||||
const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context,
|
||||
const int input_depth);
|
||||
|
||||
TfLiteStatus ProcessPerChannelQuantizedBias(const TfLiteIntArray* inputs,
|
||||
const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context,
|
||||
float* bias_min, float* bias_max);
|
||||
|
||||
TfLiteStatus InitializeBiasNodes(const TfLiteIntArray* inputs,
|
||||
const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context);
|
||||
@ -67,10 +68,8 @@ class Conv2dOpBuilder : public OpBuilder {
|
||||
OpBuilder* bias_min_node_ = nullptr;
|
||||
OpBuilder* bias_max_node_ = nullptr;
|
||||
|
||||
// Non-null only if node has per-channel quantized weights/biases.
|
||||
OpBuilder* channel_scales_node_ = nullptr;
|
||||
float* scales_data_ = nullptr;
|
||||
int num_scale_values_ = 1;
|
||||
// Modified only if node has per-channel quantized weights/biases.
|
||||
PerChannelQuantData per_channel_quant_;
|
||||
|
||||
// Only used for dilated Depthwise Conv.
|
||||
std::vector<int> dilation_factors_h_w_;
|
||||
@ -78,6 +77,23 @@ class Conv2dOpBuilder : public OpBuilder {
|
||||
std::vector<int> batch_to_space_crops_;
|
||||
};
|
||||
|
||||
// ProcessPerChannelQuantizedWeights & ProcessPerChannelQuantizedBias can be
|
||||
// used to pre-process per-channel quantized weights & biases for Hexagon.
|
||||
// NOTE: ProcessPerChannelQuantizedWeights should be run before
|
||||
// ProcessPerChannelQuantizedBias. This is becase we set PerChannelQuantData
|
||||
// based on the weights tensor, which is utilized while preprocessing bias.
|
||||
|
||||
TfLiteStatus ProcessPerChannelQuantizedWeights(
|
||||
const TfLiteIntArray* inputs, const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context, float* weights_min, float* weights_max,
|
||||
GraphBuilder* graph_builder, PerChannelQuantData* per_channel_quant);
|
||||
|
||||
TfLiteStatus ProcessPerChannelQuantizedBias(
|
||||
const TfLiteIntArray* inputs, const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context, float* bias_min, float* bias_max,
|
||||
GraphBuilder* graph_builder, PerChannelQuantData* per_channel_quant,
|
||||
OpBuilder** bias_const_node = nullptr);
|
||||
|
||||
} // namespace hexagon
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
||||
|
@ -38,25 +38,27 @@ constexpr float kHexagonMinRelativeScale = 0.0009766f;
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus Conv2dOpBuilder::ProcessPerChannelQuantizedWeights(
|
||||
TfLiteStatus ProcessPerChannelQuantizedWeights(
|
||||
const TfLiteIntArray* inputs, const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context, float* weights_min, float* weights_max) {
|
||||
TfLiteContext* context, float* weights_min, float* weights_max,
|
||||
GraphBuilder* graph_builder, PerChannelQuantData* per_channel_quant) {
|
||||
if (!per_channel_quant) return kTfLiteError;
|
||||
const auto& weights_tensor = context->tensors[inputs->data[1]];
|
||||
TfLiteAffineQuantization* weights_quant_params =
|
||||
reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
weights_tensor.quantization.params);
|
||||
|
||||
// Retrieve channel scales.
|
||||
num_scale_values_ = weights_quant_params->scale->size;
|
||||
per_channel_quant->num_scale_values = weights_quant_params->scale->size;
|
||||
// Normalize the scales as expected by Hexagon.
|
||||
scales_data_ = weights_quant_params->scale->data;
|
||||
per_channel_quant->scales_data = weights_quant_params->scale->data;
|
||||
std::vector<float> normalized_scales;
|
||||
normalized_scales.reserve(num_scale_values_);
|
||||
normalized_scales.reserve(per_channel_quant->num_scale_values);
|
||||
float scale_max = 0.0;
|
||||
for (int i = 0; i < num_scale_values_; ++i) {
|
||||
normalized_scales.push_back(scales_data_[i]);
|
||||
if (scales_data_[i] > scale_max) {
|
||||
scale_max = scales_data_[i];
|
||||
for (int i = 0; i < per_channel_quant->num_scale_values; ++i) {
|
||||
normalized_scales.push_back(per_channel_quant->scales_data[i]);
|
||||
if (per_channel_quant->scales_data[i] > scale_max) {
|
||||
scale_max = per_channel_quant->scales_data[i];
|
||||
}
|
||||
}
|
||||
if (scale_max == 0.0) {
|
||||
@ -64,13 +66,14 @@ TfLiteStatus Conv2dOpBuilder::ProcessPerChannelQuantizedWeights(
|
||||
weights_tensor.name);
|
||||
return kTfLiteError;
|
||||
}
|
||||
for (int i = 0; i < num_scale_values_; ++i) {
|
||||
for (int i = 0; i < per_channel_quant->num_scale_values; ++i) {
|
||||
normalized_scales[i] =
|
||||
std::max(normalized_scales[i] / scale_max, kHexagonMinRelativeScale);
|
||||
}
|
||||
// Add node for channel scales data.
|
||||
const std::vector<int> scales_shape = {1, 1, 1, num_scale_values_};
|
||||
channel_scales_node_ = graph_builder_->AddConstNodeWithData(
|
||||
const std::vector<int> scales_shape = {1, 1, 1,
|
||||
per_channel_quant->num_scale_values};
|
||||
per_channel_quant->channel_scales_node = graph_builder->AddConstNodeWithData(
|
||||
scales_shape.data(), reinterpret_cast<char*>(normalized_scales.data()),
|
||||
normalized_scales.size() * sizeof(normalized_scales[0]));
|
||||
*weights_min = -128 * scale_max;
|
||||
@ -78,6 +81,60 @@ TfLiteStatus Conv2dOpBuilder::ProcessPerChannelQuantizedWeights(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ProcessPerChannelQuantizedBias(
|
||||
const TfLiteIntArray* inputs, const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context, float* bias_min, float* bias_max,
|
||||
GraphBuilder* graph_builder, PerChannelQuantData* per_channel_quant,
|
||||
OpBuilder** bias_const_node) {
|
||||
const auto& bias_tensor = context->tensors[inputs->data[2]];
|
||||
|
||||
const TfLiteAffineQuantization* input_quant_params =
|
||||
static_cast<const TfLiteAffineQuantization*>(
|
||||
context->tensors[inputs->data[0]].quantization.params);
|
||||
const float input_scale = input_quant_params->scale->data[0];
|
||||
// Now dequantize bias values to float first, to adjust for the
|
||||
// normalization of channel scales.
|
||||
auto* bias_data = bias_tensor.data.i32;
|
||||
const int bias_size = NumElements(&bias_tensor);
|
||||
if (bias_size != per_channel_quant->num_scale_values) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Bias/channel scales number mismatch for bias tensor: %s",
|
||||
bias_tensor.name);
|
||||
return kTfLiteError;
|
||||
}
|
||||
std::vector<float> dequantized_bias;
|
||||
dequantized_bias.reserve(bias_size);
|
||||
for (int i = 0; i < bias_size; ++i) {
|
||||
const float dequantized_value =
|
||||
bias_data[i] * input_scale * per_channel_quant->scales_data[i];
|
||||
const float abs_dequantized_value = std::abs(dequantized_value);
|
||||
if (abs_dequantized_value > *bias_max) {
|
||||
*bias_max = abs_dequantized_value;
|
||||
}
|
||||
dequantized_bias.push_back(dequantized_value);
|
||||
}
|
||||
*bias_max = *bias_max * 8;
|
||||
*bias_min = -1 * *bias_max;
|
||||
// Now requantize the bias values to the new min/max values.
|
||||
std::vector<int> preprocessed_bias_data;
|
||||
preprocessed_bias_data.reserve(per_channel_quant->num_scale_values);
|
||||
for (int i = 0; i < bias_size; ++i) {
|
||||
preprocessed_bias_data.push_back(static_cast<int>(
|
||||
std::round(std::pow(2, 31) * (dequantized_bias[i] / *bias_max))));
|
||||
}
|
||||
// Add nodes for bias.
|
||||
const std::vector<int> bias_shape = {1, 1, 1, bias_size};
|
||||
auto* bias_data_node = graph_builder->AddConstNodeWithData(
|
||||
bias_shape.data(), reinterpret_cast<char*>(preprocessed_bias_data.data()),
|
||||
preprocessed_bias_data.size() * sizeof(preprocessed_bias_data[0]));
|
||||
if (bias_const_node) {
|
||||
*bias_const_node = bias_data_node;
|
||||
}
|
||||
graph_builder->AddTensorWithID(inputs->data[2], bias_data_node->GetID(), 0,
|
||||
/*overwrite=*/true);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes(
|
||||
const TfLiteIntArray* inputs, const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context, const int input_depth) {
|
||||
@ -174,7 +231,8 @@ TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes(
|
||||
float weights_max = 0;
|
||||
if (is_per_channel_quant) {
|
||||
ProcessPerChannelQuantizedWeights(inputs, outputs, context, &weights_min,
|
||||
&weights_max);
|
||||
&weights_max, graph_builder_,
|
||||
&per_channel_quant_);
|
||||
} else {
|
||||
TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
|
||||
weights_tensor, &weights_min, &weights_max));
|
||||
@ -189,55 +247,6 @@ TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Conv2dOpBuilder::ProcessPerChannelQuantizedBias(
|
||||
const TfLiteIntArray* inputs, const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context, float* bias_min, float* bias_max) {
|
||||
const auto& bias_tensor = context->tensors[inputs->data[2]];
|
||||
|
||||
const TfLiteAffineQuantization* input_quant_params =
|
||||
static_cast<const TfLiteAffineQuantization*>(
|
||||
context->tensors[inputs->data[0]].quantization.params);
|
||||
const float input_scale = input_quant_params->scale->data[0];
|
||||
// Now dequantize bias values to float first, to adjust for the
|
||||
// normalization of channel scales.
|
||||
auto* bias_data = bias_tensor.data.i32;
|
||||
const int bias_size = NumElements(&bias_tensor);
|
||||
if (bias_size != num_scale_values_) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "Bias/channel scales number mismatch for bias tensor: %s",
|
||||
bias_tensor.name);
|
||||
return kTfLiteError;
|
||||
}
|
||||
std::vector<float> dequantized_bias;
|
||||
dequantized_bias.reserve(bias_size);
|
||||
for (int i = 0; i < bias_size; ++i) {
|
||||
const float dequantized_value =
|
||||
bias_data[i] * input_scale * scales_data_[i];
|
||||
const float abs_dequantized_value = std::abs(dequantized_value);
|
||||
if (abs_dequantized_value > *bias_max) {
|
||||
*bias_max = abs_dequantized_value;
|
||||
}
|
||||
dequantized_bias.push_back(dequantized_value);
|
||||
}
|
||||
*bias_max = *bias_max * 8;
|
||||
*bias_min = -1 * *bias_max;
|
||||
// Now requantize the bias values to the new min/max values.
|
||||
std::vector<int> preprocessed_bias_data;
|
||||
preprocessed_bias_data.reserve(num_scale_values_);
|
||||
for (int i = 0; i < bias_size; ++i) {
|
||||
preprocessed_bias_data.push_back(static_cast<int>(
|
||||
std::round(std::pow(2, 31) * (dequantized_bias[i] / *bias_max))));
|
||||
}
|
||||
// Add nodes for bias.
|
||||
const std::vector<int> bias_shape = {1, 1, 1, bias_size};
|
||||
auto* bias_data_node = graph_builder_->AddConstNodeWithData(
|
||||
bias_shape.data(), reinterpret_cast<char*>(preprocessed_bias_data.data()),
|
||||
preprocessed_bias_data.size() * sizeof(preprocessed_bias_data[0]));
|
||||
graph_builder_->AddTensorWithID(inputs->data[2], bias_data_node->GetID(), 0,
|
||||
/*overwrite=*/true);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Conv2dOpBuilder::InitializeBiasNodes(const TfLiteIntArray* inputs,
|
||||
const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context) {
|
||||
@ -247,9 +256,10 @@ TfLiteStatus Conv2dOpBuilder::InitializeBiasNodes(const TfLiteIntArray* inputs,
|
||||
|
||||
float bias_min = 0;
|
||||
float bias_max = 0;
|
||||
if (channel_scales_node_ != nullptr) {
|
||||
if (per_channel_quant_.channel_scales_node != nullptr) {
|
||||
ProcessPerChannelQuantizedBias(inputs, outputs, context, &bias_min,
|
||||
&bias_max);
|
||||
&bias_max, graph_builder_,
|
||||
&per_channel_quant_);
|
||||
} else {
|
||||
auto* bias_data_node =
|
||||
graph_builder_->AddConstNodeWithData(inputs->data[2], bias_tensor);
|
||||
|
@ -27,7 +27,8 @@ class QuantizedTransposeConvOpModel : public SingleOpModelWithHexagon {
|
||||
std::initializer_list<InputType> filter_data,
|
||||
const TensorData& input,
|
||||
const TensorData& output, Padding padding,
|
||||
int stride_w, int stride_h) {
|
||||
int stride_w, int stride_h,
|
||||
bool add_bias = false) {
|
||||
// Just to be confusing, transpose_conv has an _input_ named "output_shape"
|
||||
// that sets the shape of the output tensor of the op :). It must always be
|
||||
// an int32 1D four element tensor.
|
||||
@ -35,6 +36,39 @@ class QuantizedTransposeConvOpModel : public SingleOpModelWithHexagon {
|
||||
filter_ = AddConstInput(filter, filter_data);
|
||||
input_ = AddInput(input);
|
||||
|
||||
if (add_bias) {
|
||||
int bias_size = GetShape(filter_)[0];
|
||||
if (input.type == TensorType_INT8) {
|
||||
// per channel quantization.
|
||||
std::vector<float> bias_scale(
|
||||
filter.per_channel_quantization_scales.size());
|
||||
std::vector<int64_t> bias_zero_points(
|
||||
filter.per_channel_quantization_scales.size());
|
||||
for (size_t i = 0; i < filter.per_channel_quantization_scales.size();
|
||||
++i) {
|
||||
bias_scale[i] =
|
||||
input.scale * filter.per_channel_quantization_scales[i];
|
||||
bias_zero_points[i] = 0;
|
||||
}
|
||||
TensorData bias{TensorType_INT32,
|
||||
{bias_size},
|
||||
/*min=*/0,
|
||||
/*max=*/0,
|
||||
/*scale=*/0,
|
||||
/*zero_point=*/0,
|
||||
true,
|
||||
/*per_channel_quantization_scales=*/bias_scale,
|
||||
/*per_channel_quantization_offsets=*/bias_zero_points,
|
||||
/*channel_index==*/0};
|
||||
bias_ = AddInput(bias);
|
||||
} else {
|
||||
// per tensor quantization.
|
||||
auto bias_scale = GetScale(input_) * GetScale(filter_);
|
||||
TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
|
||||
bias_ = AddInput(bias);
|
||||
}
|
||||
}
|
||||
|
||||
output_ = AddOutput(output);
|
||||
|
||||
SetBuiltinOp(
|
||||
@ -49,6 +83,17 @@ class QuantizedTransposeConvOpModel : public SingleOpModelWithHexagon {
|
||||
QuantizeAndPopulate<InputType>(input_, data);
|
||||
}
|
||||
|
||||
void SetBias(std::initializer_list<float> bias) {
|
||||
if (std::is_same<InputType, uint8_t>::value) {
|
||||
QuantizeAndPopulate<int32_t>(bias_, bias);
|
||||
} else if (std::is_same<InputType, int8_t>::value) {
|
||||
PerChannelQuantizeBias(bias_, bias);
|
||||
}
|
||||
// Set allocation type to MmapRo to simulate a 'constant' tensor.
|
||||
auto* bias_tensor = interpreter_->tensor(bias_);
|
||||
bias_tensor->allocation_type = kTfLiteMmapRo;
|
||||
}
|
||||
|
||||
std::vector<float> GetDequantizedOutput() {
|
||||
return Dequantize<InputType>(ExtractVector<InputType>(output_),
|
||||
GetScale(output_), GetZeroPoint(output_));
|
||||
@ -60,6 +105,7 @@ class QuantizedTransposeConvOpModel : public SingleOpModelWithHexagon {
|
||||
int output_shape_;
|
||||
int filter_;
|
||||
int input_;
|
||||
int bias_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
@ -183,4 +229,45 @@ TEST(QuantizedTransposeConvOpModel, TestQuantizedPerChannelMultiChannel) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
|
||||
}
|
||||
|
||||
TEST(QuantizedTransposeConvOpModel, SimpleBiasQuantized) {
|
||||
const std::initializer_list<uint8_t> filter_data = {129, 131, 133, 135, 137,
|
||||
139, 141, 143, 145};
|
||||
auto model = QuantizedTransposeConvOpModel<uint8_t>(
|
||||
{1, 4, 4, 1}, {TensorType_UINT8, {1, 3, 3, 1}, -63.5, 64}, filter_data,
|
||||
{TensorType_UINT8, {1, 4, 4, 1}, -63.5, 64},
|
||||
{TensorType_UINT8, {}, -508, 512}, Padding_SAME, 1, 1,
|
||||
/*add_bias=*/true);
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
model.SetBias({1});
|
||||
model.ApplyDelegateAndInvoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
model.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear({32, 64, 84, 76, 100, 192, 240, 200, 208,
|
||||
372, 420, 332, 264, 448, 488, 368},
|
||||
1e-5)));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
}
|
||||
|
||||
TEST(QuantizedTransposeConvOpModel, PerChannelQuantizedBias) {
|
||||
const std::initializer_list<int8_t> filter_data = {14, 28, 42, 56, 70,
|
||||
84, 98, 112, 126};
|
||||
auto model = QuantizedTransposeConvOpModel<int8_t>(
|
||||
{1, 4, 4, 1},
|
||||
{TensorType_INT8, {1, 3, 3, 1}, 0, 0, 0, 0, true, {9.0 / 127}, {0}, 0},
|
||||
filter_data, {TensorType_INT8, {1, 4, 4, 1}, 0, 0, 16.0 / 255, -128},
|
||||
{TensorType_INT8, {}, 0, 0, 2, -128}, Padding_SAME, 1, 1,
|
||||
/*add_bias=*/true);
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
model.SetBias({1});
|
||||
model.ApplyDelegateAndInvoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
model.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear({30, 62, 84, 76, 100, 192, 236, 198, 206,
|
||||
370, 414, 328, 262, 442, 482, 362},
|
||||
1e-5)));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -36,46 +36,6 @@ constexpr float kHexagonMinRelativeScale = 0.0009766f;
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus TransposeConv2dOpBuilder::ProcessPerChannelQuantizedWeights(
|
||||
const TfLiteIntArray* inputs, const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context, float* weights_min, float* weights_max) {
|
||||
const auto& weights_tensor = context->tensors[inputs->data[1]];
|
||||
TfLiteAffineQuantization* weights_quant_params =
|
||||
reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
weights_tensor.quantization.params);
|
||||
|
||||
// Retrieve channel scales.
|
||||
num_scale_values_ = weights_quant_params->scale->size;
|
||||
// Normalize the scales as expected by Hexagon.
|
||||
scales_data_ = weights_quant_params->scale->data;
|
||||
std::vector<float> normalized_scales;
|
||||
normalized_scales.reserve(num_scale_values_);
|
||||
float scale_max = 0.0;
|
||||
for (int i = 0; i < num_scale_values_; ++i) {
|
||||
normalized_scales.push_back(scales_data_[i]);
|
||||
if (scales_data_[i] > scale_max) {
|
||||
scale_max = scales_data_[i];
|
||||
}
|
||||
}
|
||||
if (scale_max == 0.0) {
|
||||
TF_LITE_KERNEL_LOG(context, "Scale max is zero for: %s",
|
||||
weights_tensor.name);
|
||||
return kTfLiteError;
|
||||
}
|
||||
for (int i = 0; i < num_scale_values_; ++i) {
|
||||
normalized_scales[i] =
|
||||
std::max(normalized_scales[i] / scale_max, kHexagonMinRelativeScale);
|
||||
}
|
||||
// Add node for channel scales data.
|
||||
const std::vector<int> scales_shape = {1, 1, 1, num_scale_values_};
|
||||
channel_scales_node_ = graph_builder_->AddConstNodeWithData(
|
||||
scales_shape.data(), reinterpret_cast<char*>(normalized_scales.data()),
|
||||
normalized_scales.size() * sizeof(normalized_scales[0]));
|
||||
*weights_min = -128 * scale_max;
|
||||
*weights_max = 127 * scale_max;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus TransposeConv2dOpBuilder::PopulateSubGraph(
|
||||
const TfLiteIntArray* inputs, const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context) {
|
||||
@ -111,7 +71,8 @@ TfLiteStatus TransposeConv2dOpBuilder::PopulateSubGraph(
|
||||
float weights_max = 0;
|
||||
if (is_per_channel_quant) {
|
||||
ProcessPerChannelQuantizedWeights(inputs, outputs, context, &weights_min,
|
||||
&weights_max);
|
||||
&weights_max, graph_builder_,
|
||||
&per_channel_quant_);
|
||||
} else {
|
||||
TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
|
||||
weights_tensor, &weights_min, &weights_max));
|
||||
@ -160,18 +121,48 @@ TfLiteStatus TransposeConv2dOpBuilder::PopulateSubGraph(
|
||||
AddInput(TensorID(stride_node->GetID(), 0));
|
||||
|
||||
// BIAS.
|
||||
// TFLite's TransposeConv doesn't have a bias input, so we just feed in 0s.
|
||||
std::vector<int> bias_data(output_depth_size, 0);
|
||||
// Hexagon's conv ops require bias as a [1, 1, 1, dout] tensor.
|
||||
bias_shape_ = {1, 1, 1, output_depth_size};
|
||||
auto* bias_const = graph_builder_->AddConstNodeWithData(
|
||||
bias_shape_.data(), reinterpret_cast<char*>(bias_data.data()),
|
||||
sizeof(bias_data[0]) * bias_data.size());
|
||||
float zero_bound = 0;
|
||||
auto* bias_min_const = graph_builder_->AddConstNodeWithData(
|
||||
kScalarShape, reinterpret_cast<char*>(&zero_bound), sizeof(zero_bound));
|
||||
auto* bias_max_const = graph_builder_->AddConstNodeWithData(
|
||||
kScalarShape, reinterpret_cast<char*>(&zero_bound), sizeof(zero_bound));
|
||||
const bool has_bias = inputs->size == 4;
|
||||
OpBuilder* bias_const = nullptr;
|
||||
OpBuilder* bias_min_const = nullptr;
|
||||
OpBuilder* bias_max_const = nullptr;
|
||||
if (!has_bias) {
|
||||
// If the TFLite node does not have a bias, we simply feed in 0s.
|
||||
std::vector<int> bias_data(output_depth_size, 0);
|
||||
bias_shape_ = {1, 1, 1, output_depth_size};
|
||||
bias_const = graph_builder_->AddConstNodeWithData(
|
||||
bias_shape_.data(), reinterpret_cast<char*>(bias_data.data()),
|
||||
sizeof(bias_data[0]) * bias_data.size());
|
||||
float zero_bound = 0;
|
||||
bias_min_const = graph_builder_->AddConstNodeWithData(
|
||||
kScalarShape, reinterpret_cast<char*>(&zero_bound), sizeof(zero_bound));
|
||||
bias_max_const = graph_builder_->AddConstNodeWithData(
|
||||
kScalarShape, reinterpret_cast<char*>(&zero_bound), sizeof(zero_bound));
|
||||
} else {
|
||||
const auto& bias_tensor = context->tensors[inputs->data[3]];
|
||||
if (bias_tensor.allocation_type != kTfLiteMmapRo) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Bias tensor doesn't have correct allocation type: %s",
|
||||
bias_tensor.name);
|
||||
return kTfLiteError;
|
||||
}
|
||||
float bias_min = 0;
|
||||
float bias_max = 0;
|
||||
if (per_channel_quant_.channel_scales_node != nullptr) {
|
||||
ProcessPerChannelQuantizedBias(inputs, outputs, context, &bias_min,
|
||||
&bias_max, graph_builder_,
|
||||
&per_channel_quant_, &bias_const);
|
||||
} else {
|
||||
bias_const =
|
||||
graph_builder_->AddConstNodeWithData(inputs->data[3], bias_tensor);
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ComputeMinAndMaxQuantValues(bias_tensor, &bias_min, &bias_max));
|
||||
}
|
||||
|
||||
bias_min_const = graph_builder_->AddConstNodeWithData(
|
||||
kScalarShape, reinterpret_cast<char*>(&bias_min), sizeof(bias_min));
|
||||
bias_max_const = graph_builder_->AddConstNodeWithData(
|
||||
kScalarShape, reinterpret_cast<char*>(&bias_max), sizeof(bias_max));
|
||||
}
|
||||
AddInput(TensorID(bias_const->GetID(), 0));
|
||||
AddInput(TensorID(bias_min_const->GetID(), 0));
|
||||
AddInput(TensorID(bias_max_const->GetID(), 0));
|
||||
@ -181,8 +172,8 @@ TfLiteStatus TransposeConv2dOpBuilder::PopulateSubGraph(
|
||||
ComputeAndAddMinAndMax(context, context->tensors[outputs->data[0]]));
|
||||
|
||||
// Channel scales, if this op is per-channel quantized.
|
||||
if (channel_scales_node_ != nullptr) {
|
||||
AddInput(TensorID(channel_scales_node_->GetID(), 0));
|
||||
if (per_channel_quant_.channel_scales_node != nullptr) {
|
||||
AddInput(TensorID(per_channel_quant_.channel_scales_node->GetID(), 0));
|
||||
}
|
||||
|
||||
// Hexagon outputs for this node.
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.h"
|
||||
#include "tensorflow/lite/delegates/hexagon/builders/op_builder.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -34,26 +35,16 @@ class TransposeConv2dOpBuilder : public OpBuilder {
|
||||
TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context) override;
|
||||
|
||||
~TransposeConv2dOpBuilder();
|
||||
~TransposeConv2dOpBuilder() override;
|
||||
|
||||
private:
|
||||
// TODO(b/142009955): Combine into common util for all types of Conv.
|
||||
TfLiteStatus ProcessPerChannelQuantizedWeights(const TfLiteIntArray* inputs,
|
||||
const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context,
|
||||
float* weights_min,
|
||||
float* weights_max);
|
||||
|
||||
TensorID node_output_;
|
||||
std::vector<float> transposed_weights_;
|
||||
std::vector<int> stride_shape_;
|
||||
std::vector<int> bias_shape_;
|
||||
std::vector<int> bias_data_;
|
||||
|
||||
// Non-null only if node has per-channel quantized weights/biases.
|
||||
OpBuilder* channel_scales_node_ = nullptr;
|
||||
float* scales_data_ = nullptr;
|
||||
int num_scale_values_ = 1;
|
||||
// Modified only if node has per-channel quantized weights/biases.
|
||||
PerChannelQuantData per_channel_quant_;
|
||||
};
|
||||
|
||||
} // namespace hexagon
|
||||
|
@ -247,11 +247,22 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
|
||||
pool_params->activation == kTfLiteActNone);
|
||||
}
|
||||
case kTfLiteBuiltinTransposeConv: {
|
||||
if (!InputsWithCorrectTypes(node, context,
|
||||
{{kTfLiteInt32},
|
||||
{kTfLiteUInt8, kTfLiteInt8},
|
||||
{kTfLiteUInt8, kTfLiteInt8}}))
|
||||
if (NumInputs(node) == 3) {
|
||||
if (!InputsWithCorrectTypes(node, context,
|
||||
{{kTfLiteInt32},
|
||||
{kTfLiteUInt8, kTfLiteInt8},
|
||||
{kTfLiteUInt8, kTfLiteInt8}}))
|
||||
return false;
|
||||
} else if (NumInputs(node) == 4) {
|
||||
if (!InputsWithCorrectTypes(node, context,
|
||||
{{kTfLiteInt32},
|
||||
{kTfLiteUInt8, kTfLiteInt8},
|
||||
{kTfLiteUInt8, kTfLiteInt8},
|
||||
{kTfLiteInt32}}))
|
||||
return false;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
const TfLiteTransposeConvParams* params =
|
||||
reinterpret_cast<const TfLiteTransposeConvParams*>(
|
||||
node->builtin_data);
|
||||
|
Loading…
Reference in New Issue
Block a user