Add int8 transpose conv.
PiperOrigin-RevId: 281565221 Change-Id: I984c4e7e4dbb30a872c63778e52eac0bd91fd999
This commit is contained in:
parent
98c6b54f1c
commit
a7de1a78fe
@ -421,6 +421,7 @@ cc_library(
|
|||||||
"reference/integer_ops/pooling.h",
|
"reference/integer_ops/pooling.h",
|
||||||
"reference/integer_ops/softmax.h",
|
"reference/integer_ops/softmax.h",
|
||||||
"reference/integer_ops/tanh.h",
|
"reference/integer_ops/tanh.h",
|
||||||
|
"reference/integer_ops/transpose_conv.h",
|
||||||
"reference/logistic.h",
|
"reference/logistic.h",
|
||||||
"reference/maximum_minimum.h",
|
"reference/maximum_minimum.h",
|
||||||
"reference/mul.h",
|
"reference/mul.h",
|
||||||
|
@ -0,0 +1,118 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_integer_ops {
|
||||||
|
|
||||||
|
// Fixed-point per-channel-quantization transpose convolution reference kernel.
|
||||||
|
inline void TransposeConv(
|
||||||
|
const ConvParams& params, const int32* output_multiplier,
|
||||||
|
const int32* output_shift, const RuntimeShape& input_shape,
|
||||||
|
const int8* input_data, const RuntimeShape& filter_shape,
|
||||||
|
const int8* filter_data, const RuntimeShape& output_shape,
|
||||||
|
int8* output_data, const RuntimeShape& im2col_shape, int8* im2col_data,
|
||||||
|
int32* scratch_buffer) {
|
||||||
|
const int stride_width = params.stride_width;
|
||||||
|
const int stride_height = params.stride_height;
|
||||||
|
const int pad_width = params.padding_values.width;
|
||||||
|
const int pad_height = params.padding_values.height;
|
||||||
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
(void)im2col_data; // only used in optimized code.
|
||||||
|
(void)im2col_shape; // only used in optimized code.
|
||||||
|
|
||||||
|
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||||
|
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||||
|
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||||
|
const int input_height = input_shape.Dims(1);
|
||||||
|
const int input_width = input_shape.Dims(2);
|
||||||
|
const int filter_height = filter_shape.Dims(1);
|
||||||
|
const int filter_width = filter_shape.Dims(2);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int32 input_offset = params.input_offset;
|
||||||
|
const int32 output_offset = params.output_offset;
|
||||||
|
const int32 output_activation_min = std::numeric_limits<int8_t>::min();
|
||||||
|
const int32 output_activation_max = std::numeric_limits<int8_t>::max();
|
||||||
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
|
|
||||||
|
const int num_elements = output_shape.FlatSize();
|
||||||
|
// We need to initialize scratch_buffer to all 0s, as we apply the same
|
||||||
|
// 'scatter' based trick as in float version.
|
||||||
|
memset(scratch_buffer, 0, num_elements * sizeof(int32));
|
||||||
|
|
||||||
|
// Loop through input elements one at a time.
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int in_y = 0; in_y < input_height; ++in_y) {
|
||||||
|
for (int in_x = 0; in_x < input_width; ++in_x) {
|
||||||
|
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||||
|
// Loop through the output elements it will influence.
|
||||||
|
const int out_x_origin = (in_x * stride_width) - pad_width;
|
||||||
|
const int out_y_origin = (in_y * stride_height) - pad_height;
|
||||||
|
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||||
|
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||||
|
for (int out_channel = 0; out_channel < output_depth;
|
||||||
|
++out_channel) {
|
||||||
|
// Compute output element location.
|
||||||
|
const int out_x = out_x_origin + filter_x;
|
||||||
|
const int out_y = out_y_origin + filter_y;
|
||||||
|
// We cannot accumulate out of bounds.
|
||||||
|
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
|
||||||
|
(out_y < output_height)) {
|
||||||
|
const int8 input_value = input_data[Offset(
|
||||||
|
input_shape, batch, in_y, in_x, in_channel)];
|
||||||
|
const int8 filter_value =
|
||||||
|
filter_data[Offset(filter_shape, out_channel, filter_y,
|
||||||
|
filter_x, in_channel)];
|
||||||
|
scratch_buffer[Offset(output_shape, batch, out_y, out_x,
|
||||||
|
out_channel)] +=
|
||||||
|
(input_value + input_offset) * filter_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||||
|
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||||
|
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||||
|
int32 acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x,
|
||||||
|
out_channel)];
|
||||||
|
acc = MultiplyByQuantizedMultiplier(
|
||||||
|
acc, output_multiplier[out_channel], output_shift[out_channel]);
|
||||||
|
acc += output_offset;
|
||||||
|
acc = std::max(acc, output_activation_min);
|
||||||
|
acc = std::min(acc, output_activation_max);
|
||||||
|
output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
|
||||||
|
static_cast<int8_t>(acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_integer_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_
|
@ -205,7 +205,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
/* max_version */ 3);
|
/* max_version */ 3);
|
||||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||||
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
||||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
|
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(),
|
||||||
|
/* min_version */ 1,
|
||||||
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
|
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
|
||||||
AddBuiltin(BuiltinOperator_SUM, Register_SUM(),
|
AddBuiltin(BuiltinOperator_SUM, Register_SUM(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
|
@ -168,7 +168,12 @@ class SingleOpModel {
|
|||||||
// Templated version of AddConstInput().
|
// Templated version of AddConstInput().
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int AddConstInput(const TensorData& t, std::initializer_list<T> data) {
|
int AddConstInput(const TensorData& t, std::initializer_list<T> data) {
|
||||||
int id = AddTensor(t, data);
|
int id = 0;
|
||||||
|
if (t.per_channel_quantization) {
|
||||||
|
id = AddTensorPerChannelQuant(t);
|
||||||
|
} else {
|
||||||
|
id = AddTensor(t, data);
|
||||||
|
}
|
||||||
inputs_.push_back(id);
|
inputs_.push_back(id);
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
#include "tensorflow/lite/kernels/eigen_support.h"
|
#include "tensorflow/lite/kernels/eigen_support.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
|
// NOLINTNEXTLINE - This header file should't go to the top.
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
@ -75,6 +77,12 @@ struct OpData {
|
|||||||
int32_t output_multiplier;
|
int32_t output_multiplier;
|
||||||
int output_shift;
|
int output_shift;
|
||||||
|
|
||||||
|
// Per channel output multiplier and shift.
|
||||||
|
// TODO(b/144846950): Add channel dimension index for the kernel to be more
|
||||||
|
// flexible.
|
||||||
|
std::vector<int32_t> per_channel_output_multiplier;
|
||||||
|
std::vector<int32_t> per_channel_output_shift;
|
||||||
|
|
||||||
// The range of the fused activation layer. For example for kNone and
|
// The range of the fused activation layer. For example for kNone and
|
||||||
// uint8_t these would be 0 and 255.
|
// uint8_t these would be 0 and 255.
|
||||||
int32_t output_activation_min;
|
int32_t output_activation_min;
|
||||||
@ -144,7 +152,7 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Allocate scratch buffer tensor for UInt8 inputs.
|
// Allocate scratch buffer tensor for UInt8 inputs.
|
||||||
if (input_type == kTfLiteUInt8) {
|
if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) {
|
||||||
if (data->scratch_tensor_id == kTensorNotAllocated) {
|
if (data->scratch_tensor_id == kTensorNotAllocated) {
|
||||||
context->AddTensors(context, 1, &data->scratch_tensor_id);
|
context->AddTensors(context, 1, &data->scratch_tensor_id);
|
||||||
}
|
}
|
||||||
@ -214,6 +222,11 @@ TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context,
|
|||||||
GetTensorData<uint8>(weights),
|
GetTensorData<uint8>(weights),
|
||||||
GetTensorShape(transposed_weights),
|
GetTensorShape(transposed_weights),
|
||||||
GetTensorData<uint8>(transposed_weights));
|
GetTensorData<uint8>(transposed_weights));
|
||||||
|
} else if (weights->type == kTfLiteInt8) {
|
||||||
|
optimized_ops::Transpose(transpose_params, input_shape,
|
||||||
|
GetTensorData<int8>(weights),
|
||||||
|
GetTensorShape(transposed_weights),
|
||||||
|
GetTensorData<int8>(transposed_weights));
|
||||||
} else {
|
} else {
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
context, "Transpose conv only support float & uint8 right now.");
|
context, "Transpose conv only support float & uint8 right now.");
|
||||||
@ -242,8 +255,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
|
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
|
||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
|
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
|
||||||
TF_LITE_ENSURE(context,
|
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 ||
|
||||||
input->type == kTfLiteFloat32 || input->type == kTfLiteUInt8);
|
input->type == kTfLiteUInt8 ||
|
||||||
|
input->type == kTfLiteInt8);
|
||||||
TF_LITE_ENSURE_EQ(context, weights->type, input->type);
|
TF_LITE_ENSURE_EQ(context, weights->type, input->type);
|
||||||
TF_LITE_ENSURE_EQ(context, output->type, input->type);
|
TF_LITE_ENSURE_EQ(context, output->type, input->type);
|
||||||
// Ensure that weights and inputs have the same channel dimension.
|
// Ensure that weights and inputs have the same channel dimension.
|
||||||
@ -288,7 +302,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (input->type == kTfLiteUInt8) {
|
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||||
node->temporaries->data[data->scratch_tensor_index] =
|
node->temporaries->data[data->scratch_tensor_index] =
|
||||||
data->scratch_tensor_id;
|
data->scratch_tensor_id;
|
||||||
TfLiteTensor* scratch_buffer =
|
TfLiteTensor* scratch_buffer =
|
||||||
@ -302,19 +316,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
ResizeTensor(context, output_shape, scratch_buffer));
|
ResizeTensor(context, output_shape, scratch_buffer));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calcuate output multiplier for quantization.
|
TF_LITE_ENSURE_EQ(context, weights->quantization.type,
|
||||||
double real_multiplier = 0.0;
|
kTfLiteAffineQuantization);
|
||||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
const auto* affine_quantization =
|
||||||
context, input, weights, output, &real_multiplier));
|
reinterpret_cast<TfLiteAffineQuantization*>(
|
||||||
int exponent;
|
weights->quantization.params);
|
||||||
// Populate quantization parameteters with multiplier and shift.
|
TF_LITE_ENSURE(context, affine_quantization);
|
||||||
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
|
TF_LITE_ENSURE(context, affine_quantization->scale);
|
||||||
data->output_shift = -exponent;
|
const int number_channel = affine_quantization->scale->size;
|
||||||
// Populate max and min activation range.
|
data->per_channel_output_multiplier.resize(number_channel);
|
||||||
CalculateActivationRangeUint8(kTfLiteActNone, output,
|
data->per_channel_output_shift.resize(number_channel);
|
||||||
&data->output_activation_min,
|
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
|
||||||
&data->output_activation_max);
|
context, input, weights, nullptr, output, kTfLiteActNone,
|
||||||
|
&data->output_multiplier, &data->output_shift,
|
||||||
|
&data->output_activation_min, &data->output_activation_max,
|
||||||
|
data->per_channel_output_multiplier.data(),
|
||||||
|
data->per_channel_output_shift.data()));
|
||||||
}
|
}
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -403,6 +422,39 @@ void EvalQuantized(TfLiteContext* context,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EvalQuantizedPerChannel(TfLiteContext* context,
|
||||||
|
const TfLiteTransposeConvParams* params,
|
||||||
|
OpData* data, const TfLiteTensor* input,
|
||||||
|
const TfLiteTensor* weights,
|
||||||
|
const TfLiteTensor* transposed_weights,
|
||||||
|
TfLiteTensor* col2im, TfLiteTensor* output,
|
||||||
|
TfLiteTensor* scratch_buffer) {
|
||||||
|
tflite::ConvParams op_params;
|
||||||
|
op_params.padding_type = PaddingType::kSame;
|
||||||
|
op_params.padding_values.width = data->padding.width;
|
||||||
|
op_params.padding_values.height = data->padding.height;
|
||||||
|
op_params.padding_values.width_offset = data->padding.width_offset;
|
||||||
|
op_params.padding_values.height_offset = data->padding.height_offset;
|
||||||
|
op_params.stride_width = params->stride_width;
|
||||||
|
op_params.stride_height = params->stride_height;
|
||||||
|
// Need to flip the sign of input offset to add it directly to the quantized
|
||||||
|
// buffer.
|
||||||
|
op_params.input_offset = -input->params.zero_point;
|
||||||
|
op_params.output_offset = output->params.zero_point;
|
||||||
|
op_params.quantized_activation_min = data->output_activation_min;
|
||||||
|
op_params.quantized_activation_max = data->output_activation_max;
|
||||||
|
|
||||||
|
// TODO(b/143380105): Need to add optimized kernel for int8 quantized
|
||||||
|
// transpose conv.
|
||||||
|
reference_integer_ops::TransposeConv(
|
||||||
|
op_params, data->per_channel_output_multiplier.data(),
|
||||||
|
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||||
|
GetTensorData<int8>(input), GetTensorShape(weights),
|
||||||
|
GetTensorData<int8>(weights), GetTensorShape(output),
|
||||||
|
GetTensorData<int8>(output), GetTensorShape(col2im),
|
||||||
|
GetTensorData<int8>(col2im), GetTensorData<int32_t>(scratch_buffer));
|
||||||
|
}
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
template <KernelType kernel_type>
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
// Retrieve tensors (All should be allocated by now)
|
// Retrieve tensors (All should be allocated by now)
|
||||||
@ -473,6 +525,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
scratch_buffer);
|
scratch_buffer);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case kTfLiteInt8: {
|
||||||
|
TfLiteTensor* scratch_buffer =
|
||||||
|
GetTemporary(context, node, data->scratch_tensor_index);
|
||||||
|
if (IsDynamicTensor(scratch_buffer)) {
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
ResizeTensor(context, output_shape, scratch_buffer));
|
||||||
|
}
|
||||||
|
if (data->weights_are_transposed && !IsConstantTensor(weights)) {
|
||||||
|
ResizeAndTransposeWeights(context, weights, transposed_weights);
|
||||||
|
}
|
||||||
|
EvalQuantizedPerChannel(context, params, data, input, weights,
|
||||||
|
transposed_weights, col2im, output,
|
||||||
|
scratch_buffer);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type '%s' is not currently supported.",
|
context->ReportError(context, "Type '%s' is not currently supported.",
|
||||||
TfLiteTypeGetName(input->type));
|
TfLiteTypeGetName(input->type));
|
||||||
|
@ -50,7 +50,7 @@ class BaseTransposeConvOpModel : public SingleOpModel {
|
|||||||
std::initializer_list<InputType> filter_data,
|
std::initializer_list<InputType> filter_data,
|
||||||
const TensorData& input, const TensorData& output,
|
const TensorData& input, const TensorData& output,
|
||||||
Padding padding, int stride_w, int stride_h,
|
Padding padding, int stride_w, int stride_h,
|
||||||
TestType test_type) {
|
TestType test_type, int version = 1) {
|
||||||
// Just to be confusing, transpose_conv has an _input_ named "output_shape"
|
// Just to be confusing, transpose_conv has an _input_ named "output_shape"
|
||||||
// that sets the shape of the output tensor of the op :). It must always be
|
// that sets the shape of the output tensor of the op :). It must always be
|
||||||
// an int32 1D four element tensor.
|
// an int32 1D four element tensor.
|
||||||
@ -70,7 +70,7 @@ class BaseTransposeConvOpModel : public SingleOpModel {
|
|||||||
CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
|
CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
|
||||||
.Union());
|
.Union());
|
||||||
resolver_ = absl::make_unique<SingleOpResolver>(
|
resolver_ = absl::make_unique<SingleOpResolver>(
|
||||||
BuiltinOperator_TRANSPOSE_CONV, registration);
|
BuiltinOperator_TRANSPOSE_CONV, registration, version);
|
||||||
BuildInterpreter(
|
BuildInterpreter(
|
||||||
{GetShape(output_shape_), GetShape(filter_), GetShape(input_)});
|
{GetShape(output_shape_), GetShape(filter_), GetShape(input_)});
|
||||||
|
|
||||||
@ -83,6 +83,8 @@ class BaseTransposeConvOpModel : public SingleOpModel {
|
|||||||
void SetInput(std::initializer_list<float> data) {
|
void SetInput(std::initializer_list<float> data) {
|
||||||
if (std::is_same<InputType, uint8_t>::value) {
|
if (std::is_same<InputType, uint8_t>::value) {
|
||||||
QuantizeAndPopulate<uint8_t>(input_, data);
|
QuantizeAndPopulate<uint8_t>(input_, data);
|
||||||
|
} else if (std::is_same<InputType, int8_t>::value) {
|
||||||
|
QuantizeAndPopulate<int8_t>(input_, data);
|
||||||
} else {
|
} else {
|
||||||
PopulateTensor(input_, data);
|
PopulateTensor(input_, data);
|
||||||
}
|
}
|
||||||
@ -313,6 +315,92 @@ TEST_P(TransposeConvOpTest, SimpleTestQuantized) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class PerChannelQuantizedTransposeConvOpModel
|
||||||
|
: public BaseTransposeConvOpModel<int8_t> {
|
||||||
|
public:
|
||||||
|
using BaseTransposeConvOpModel::BaseTransposeConvOpModel;
|
||||||
|
|
||||||
|
std::vector<float> GetDequantizedOutput() {
|
||||||
|
return Dequantize<int8_t>(ExtractVector<int8_t>(output_), GetScale(output_),
|
||||||
|
GetZeroPoint(output_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetInput(const std::initializer_list<float>& data) {
|
||||||
|
QuantizeAndPopulate<int8_t>(input_, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetFilter(const std::initializer_list<float>& data) {
|
||||||
|
PerChannelSymmetricQuantizeAndPopulate(filter_, data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannelSingleChannel) {
|
||||||
|
// TODO(b/138722124): Enable these tests on NNAPI.
|
||||||
|
if (SingleOpModel::GetForceUseNnapi()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::initializer_list<float> filter_data = {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||||
|
PerChannelQuantizedTransposeConvOpModel model(
|
||||||
|
GetRegistration(), {1, 4, 4, 1},
|
||||||
|
{TensorType_INT8, {1, 3, 3, 1}, 0, 0, 0, 0, true, {9.0 / 127}, {0}, 0},
|
||||||
|
{}, {TensorType_INT8, {1, 4, 4, 1}, 0, 0, 16.0 / 255, -128},
|
||||||
|
{TensorType_INT8, {}, 0, 0, 2, -128}, Padding_SAME, 1, 1, GetTestType(),
|
||||||
|
/* version */ 2);
|
||||||
|
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||||
|
model.SetFilter(filter_data);
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
model.GetDequantizedOutput(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({28, 62, 82, 76, 98, 192, 238, 198, 206,
|
||||||
|
372, 416, 330, 262, 446, 486, 366},
|
||||||
|
1e-5)));
|
||||||
|
|
||||||
|
// GetOutputShape() should always be same as model.SetOutputShape(...);
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test data copied from the float multi-channel test above.
|
||||||
|
TEST_P(TransposeConvOpTest, TestQuantizedPerChannelMultiChannel) {
|
||||||
|
// TODO(b/138722124): Enable these tests on NNAPI.
|
||||||
|
if (SingleOpModel::GetForceUseNnapi()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::initializer_list<float> filter_data = {
|
||||||
|
1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18};
|
||||||
|
PerChannelQuantizedTransposeConvOpModel model(
|
||||||
|
GetRegistration(), {1, 5, 5, 2},
|
||||||
|
{TensorType_INT8,
|
||||||
|
{2, 3, 3, 1},
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
true,
|
||||||
|
{17.0 / 127, 18.0 / 127},
|
||||||
|
{0, 0},
|
||||||
|
0},
|
||||||
|
{}, {TensorType_INT8, {1, 2, 2, 1}, 0, 0, 4.0 / 255, -128},
|
||||||
|
{TensorType_INT8, {}, 0, 0, 1, -128}, Padding_VALID, 2, 2, GetTestType(),
|
||||||
|
/* version */ 2);
|
||||||
|
model.SetInput({1, 2, 3, 4});
|
||||||
|
model.SetFilter(filter_data);
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
model.GetDequantizedOutput(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9, 10, 25, 28, 18,
|
||||||
|
20, 22, 24, 16, 20, 24, 28, 62, 72, 42, 48, 54, 60, 21, 24, 27, 30,
|
||||||
|
61, 68, 36, 40, 44, 48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72},
|
||||||
|
1e-5)));
|
||||||
|
|
||||||
|
// GetOutputShape() should always be same as model.SetOutputShape(...);
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(TransposeConvOpTest, TwoFiltersTestQuantized) {
|
TEST_P(TransposeConvOpTest, TwoFiltersTestQuantized) {
|
||||||
// Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
|
// Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
|
||||||
// 18}
|
// 18}
|
||||||
|
@ -93,6 +93,16 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
property.version = 3;
|
property.version = 3;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOperator_TRANSPOSE_CONV: {
|
||||||
|
TensorProperty tensor_property;
|
||||||
|
tensor_property.per_axis = true;
|
||||||
|
tensor_property.per_axis_index = 0;
|
||||||
|
tensor_property.symmetric = true;
|
||||||
|
property.inputs = {{1, tensor_property}, {2, {}}};
|
||||||
|
property.outputs = {{0, {}}};
|
||||||
|
property.version = 2;
|
||||||
|
break;
|
||||||
|
}
|
||||||
case BuiltinOperator_DEPTHWISE_CONV_2D: {
|
case BuiltinOperator_DEPTHWISE_CONV_2D: {
|
||||||
TensorProperty tensor_property;
|
TensorProperty tensor_property;
|
||||||
tensor_property.per_axis = true;
|
tensor_property.per_axis = true;
|
||||||
|
@ -149,6 +149,13 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
}
|
}
|
||||||
return 1;
|
return 1;
|
||||||
|
|
||||||
|
case BuiltinOperator_TRANSPOSE_CONV:
|
||||||
|
// If the op takes int8 input, it is version 2.
|
||||||
|
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
|
||||||
case BuiltinOperator_LSTM:
|
case BuiltinOperator_LSTM:
|
||||||
// If the input tensor is float and a weight is int8, this is a version
|
// If the input tensor is float and a weight is int8, this is a version
|
||||||
// 3 hybrid operation.
|
// 3 hybrid operation.
|
||||||
|
@ -351,4 +351,19 @@ TEST(OpVersionTest, VersioningFloorDivOperatorTest) {
|
|||||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(OpVersionTest, VersioningTransposeConvOperatorTest) {
|
||||||
|
OpSignature fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_TRANSPOSE_CONV,
|
||||||
|
.input_types =
|
||||||
|
std::vector<TensorType>{TensorType_FLOAT32, TensorType_UINT8},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||||
|
|
||||||
|
fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_TRANSPOSE_CONV,
|
||||||
|
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
Loading…
Reference in New Issue
Block a user