INT16 reference_op for TRANSPOSE_CONVOLUTION

add INT16 reference ops transpose_conv
    add TEST code for INT16 transpose_conv
This commit is contained in:
Peng Sun 2019-11-26 19:03:24 +00:00
parent d5d92b241b
commit e71f56e928
5 changed files with 254 additions and 17 deletions

View File

@ -83,6 +83,8 @@ using int16 = std::int16_t;
using uint16 = std::uint16_t;
using int32 = std::int32_t;
using uint32 = std::uint32_t;
using int64 = std::int64_t;
using uint64 = std::uint64_t;
// TFLITE_DEPRECATED()
//

View File

@ -112,6 +112,98 @@ inline void TransposeConv(
}
}
// int16 input, int8 filter, int64 accumulator
inline void TransposeConv(
const ConvParams& params, const int32* output_multiplier,
const int32* output_shift, const RuntimeShape& input_shape,
const int16* input_data, const RuntimeShape& filter_shape,
const int8* filter_data, const RuntimeShape& output_shape,
int16* output_data, const RuntimeShape& im2col_shape, int8* im2col_data,
int64* scratch_buffer) {
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
const int pad_width = params.padding_values.width;
const int pad_height = params.padding_values.height;
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
(void)im2col_data; // only used in optimized code.
(void)im2col_shape; // only used in optimized code.
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int32 input_offset = params.input_offset;
const int32 output_offset = params.output_offset;
const int32 output_activation_min = std::numeric_limits<int16_t>::min();
const int32 output_activation_max = std::numeric_limits<int16_t>::max();
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int num_elements = output_shape.FlatSize();
// We need to initialize scratch_buffer to all 0s, as we apply the same
// 'scatter' based trick as in float version.
memset(scratch_buffer, 0, num_elements * sizeof(int64));
// Loop through input elements one at a time.
for (int batch = 0; batch < batches; ++batch) {
for (int in_y = 0; in_y < input_height; ++in_y) {
for (int in_x = 0; in_x < input_width; ++in_x) {
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
// Loop through the output elements it will influence.
const int out_x_origin = (in_x * stride_width) - pad_width;
const int out_y_origin = (in_y * stride_height) - pad_height;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
for (int out_channel = 0; out_channel < output_depth;
++out_channel) {
// Compute output element location.
const int out_x = out_x_origin + filter_x;
const int out_y = out_y_origin + filter_y;
// We cannot accumulate out of bounds.
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
(out_y < output_height)) {
const int32 input_value = input_data[Offset(
input_shape, batch, in_y, in_x, in_channel)];
const int32 filter_value =
filter_data[Offset(filter_shape, out_channel, filter_y,
filter_x, in_channel)];
scratch_buffer[Offset(output_shape, batch, out_y, out_x,
out_channel)] +=
(input_value + input_offset) * filter_value;
}
}
}
}
}
}
}
}
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
int64 acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x,
out_channel)];
int32 scaled_acc = MultiplyByQuantizedMultiplier(
acc, output_multiplier[out_channel], output_shift[out_channel]);
scaled_acc += output_offset;
scaled_acc = std::max(scaled_acc, output_activation_min);
scaled_acc = std::min(scaled_acc, output_activation_max);
output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
static_cast<int16_t>(scaled_acc);
}
}
}
}
}
} // namespace reference_integer_ops
} // namespace tflite

View File

@ -46,8 +46,9 @@ TfLiteStatus PopulateConvolutionQuantizationParams(
TF_LITE_ENSURE(context, affine_quantization->scale);
const bool is_per_channel = affine_quantization->scale->size > 1;
if (is_per_channel) {
// Currently only Int8 is supported for per channel quantization.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt8);
// Currently only Int8/Int16 is supported for per channel quantization.
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteInt8);
TF_LITE_ENSURE_EQ(
context, affine_quantization->scale->size,

View File

@ -154,8 +154,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
++temporaries_count;
}
// Allocate scratch buffer tensor for UInt8 inputs.
if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) {
// Allocate scratch buffer tensor
if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8 ||
input_type == kTfLiteInt16) {
if (data->scratch_tensor_id == kTensorNotAllocated) {
context->AddTensors(context, 1, &data->scratch_tensor_id);
}
@ -226,13 +227,15 @@ TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context,
GetTensorShape(transposed_weights),
GetTensorData<uint8>(transposed_weights));
} else if (weights->type == kTfLiteInt8) {
// int16 transpose_conv also with int8 weights
optimized_ops::Transpose(transpose_params, input_shape,
GetTensorData<int8>(weights),
GetTensorShape(transposed_weights),
GetTensorData<int8>(transposed_weights));
} else {
context->ReportError(
context, "Transpose conv only support float & uint8 right now.");
context,
"Transpose conv only support float, uint8, int8, int16 right now.");
return kTfLiteError;
}
@ -258,10 +261,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 ||
input->type == kTfLiteUInt8 ||
input->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, weights->type, input->type);
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteUInt8 ||
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteInt8);
} else {
TF_LITE_ENSURE_EQ(context, weights->type, input->type);
}
TF_LITE_ENSURE_EQ(context, output->type, input->type);
// Ensure that weights and inputs have the same channel dimension.
// Note: TOCO will reorder weights in the following format: OHWI.
@ -305,12 +312,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
}
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
input->type == kTfLiteInt16) {
node->temporaries->data[data->scratch_tensor_index] =
data->scratch_tensor_id;
TfLiteTensor* scratch_buffer =
GetTemporary(context, node, data->scratch_tensor_index);
scratch_buffer->type = kTfLiteInt32;
if (input->type == kTfLiteInt16) {
scratch_buffer->type = kTfLiteInt64;
} else {
scratch_buffer->type = kTfLiteInt32;
}
scratch_buffer->allocation_type = kTfLiteDynamic;
if (!IsConstantTensor(output_shape)) {
SetTensorToDynamic(scratch_buffer);
@ -473,6 +486,38 @@ void EvalQuantizedPerChannel(TfLiteContext* context,
}
}
void EvalQuantizedPerChannel16x8(TfLiteContext* context,
const TfLiteTransposeConvParams* params,
OpData* data, const TfLiteTensor* input,
const TfLiteTensor* weights,
const TfLiteTensor* transposed_weights,
TfLiteTensor* col2im, TfLiteTensor* output,
TfLiteTensor* scratch_buffer) {
tflite::ConvParams op_params;
op_params.padding_type = PaddingType::kSame;
op_params.padding_values.width = data->padding.width;
op_params.padding_values.height = data->padding.height;
op_params.padding_values.width_offset = data->padding.width_offset;
op_params.padding_values.height_offset = data->padding.height_offset;
op_params.stride_width = params->stride_width;
op_params.stride_height = params->stride_height;
// Need to flip the sign of input offset to add it directly to the quantized
// buffer.
op_params.input_offset = -input->params.zero_point;
op_params.output_offset = output->params.zero_point;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
// Need to add optimized kernel
reference_integer_ops::TransposeConv(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16>(input), GetTensorShape(weights),
GetTensorData<int8>(weights), GetTensorShape(output),
GetTensorData<int16>(output), GetTensorShape(col2im),
GetTensorData<int8>(col2im), GetTensorData<int64_t>(scratch_buffer));
}
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Retrieve tensors (All should be allocated by now)
@ -513,7 +558,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
filter_height, filter_width, params->padding, &unused_output_height,
&unused_output_width);
// Currently support float32 and uint8.
// Currently support float32, uint8, int8, int16.
switch (input->type) {
case kTfLiteFloat32: {
// Only for GenericOptimized path, we use transposed weights.
@ -558,6 +603,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output, scratch_buffer);
break;
}
case kTfLiteInt16: {
TfLiteTensor* scratch_buffer =
GetTemporary(context, node, data->scratch_tensor_index);
if (IsDynamicTensor(scratch_buffer)) {
TF_LITE_ENSURE_OK(context,
ResizeTensor(context, output_shape, scratch_buffer));
}
if (data->weights_are_transposed && !IsConstantTensor(weights)) {
ResizeAndTransposeWeights(context, weights, transposed_weights);
}
EvalQuantizedPerChannel16x8(context, params, data, input, weights,
transposed_weights, col2im, output,
scratch_buffer);
break;
}
default:
context->ReportError(context, "Type '%s' is not currently supported.",
TfLiteTypeGetName(input->type));

View File

@ -76,7 +76,10 @@ class BaseTransposeConvOpModel : public SingleOpModel {
if (test_type == TestType::DYNAMIC) {
PopulateTensor<int32_t>(output_shape_, output_shape_data);
PopulateTensor<InputType>(filter_, filter_data);
if (!std::is_same<InputType, int16_t>::value &&
!std::is_same<InputType, int8_t>::value) {
PopulateTensor<InputType>(filter_, filter_data);
}
}
}
@ -85,6 +88,8 @@ class BaseTransposeConvOpModel : public SingleOpModel {
QuantizeAndPopulate<uint8_t>(input_, data);
} else if (std::is_same<InputType, int8_t>::value) {
QuantizeAndPopulate<int8_t>(input_, data);
} else if (std::is_same<InputType, int16_t>::value) {
QuantizeAndPopulate<int16_t>(input_, data);
} else {
PopulateTensor(input_, data);
}
@ -325,10 +330,6 @@ class PerChannelQuantizedTransposeConvOpModel
GetZeroPoint(output_));
}
void SetInput(const std::initializer_list<float>& data) {
QuantizeAndPopulate<int8_t>(input_, data);
}
void SetFilter(const std::initializer_list<float>& data) {
PerChannelSymmetricQuantizeAndPopulate(filter_, data);
}
@ -451,6 +452,87 @@ TEST_P(TransposeConvOpTest, PaddingValidTestQuantized) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1}));
}
class PerChannelQuantizedTransposeConvOpModel16x8
: public BaseTransposeConvOpModel<int16_t> {
public:
using BaseTransposeConvOpModel::BaseTransposeConvOpModel;
std::vector<float> GetDequantizedOutput() {
return Dequantize<int16_t>(ExtractVector<int16_t>(output_),
GetScale(output_), GetZeroPoint(output_));
}
void SetFilter(const std::initializer_list<float>& data) {
PerChannelSymmetricQuantizeAndPopulate(filter_, data);
}
};
TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannel16x8) {
// TensorData(TensorType type = TensorType_FLOAT32, std::vector<int> shape =
// {},
// float min = 0.0f, float max = 0.0f, float scale = 0.0f,
// int32_t zero_point = 0, bool per_channel_quantization = false,
// std::vector<float> per_channel_quantization_scales = {},
// std::vector<int64_t> per_channel_quantization_offsets = {},
// int32_t channel_index = 0)
const std::initializer_list<float> filter_data = {
// [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
1, 2, // out channel = 0, y = 0, x = 0
3, 4, // out channel = 0, y = 0, x = 1
3, 4, // out channel = 0, y = 1, x = 0
5, 6, // out channel = 0, y = 1, x = 1
7, 8, // out channel = 1, y = 0, x = 0
5, 6, // out channel = 1, y = 0, x = 1
3, 4, // out channel = 1, y = 1, x = 0
1, 2, // out channel = 1, y = 1, x = 1
};
PerChannelQuantizedTransposeConvOpModel16x8 model(
GetRegistration(),
/*output_shape_data=*/{1, 2, 3, 2},
/*filter=*/
{TensorType_INT8,
/*shape=*/{2, 2, 2, 2},
/*min=*/-64, /*max=*/64,
/*scale=*/0, /*zero_point=*/0,
/*per_channel=*/true,
/*per_channel_scales=*/{7.0 / 127, 8.0 / 127},
/*per_channel_offsets=*/{0, 0},
/*channel_index=*/0},
/*filter_data=*/{},
/*input=*/
{TensorType_INT16,
/*shape=*/{1, 2, 3, 2},
/*min=*/0, /*max=*/0,
/*scale=*/4.0 / 127,
/*zero_point=*/0},
/*output=*/
{TensorType_INT16,
/*shape=*/{},
/*min=*/0, /*max=*/0,
/*scale=*/1.0,
/*zero_point=*/0},
/*padding=*/Padding_SAME,
/*stride_w=*/1, /*stride_h=*/1, GetTestType());
model.SetInput({
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
3, 2, // batch = 0, y = 0, x = 0
1, -1, // batch = 0, y = 0, x = 1
-2, -3, // batch = 0, y = 0, x = 2
4, 3, // batch = 0, y = 1, x = 0
2, -2, // batch = 0, y = 1, x = 1
-3, -4, // batch = 0, y = 1, x = 2
});
model.SetFilter(filter_data);
model.Invoke();
EXPECT_THAT(model.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(
{7, 37, 16, 26, -9, -39, 27, 69, 48, 42, -32, -74}, 1e-5)));
// GetOutputShape() should always be same as model.SetOutputShape(...);
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 2}));
}
INSTANTIATE_TEST_SUITE_P(
TransposeConvOpTest, TransposeConvOpTest,
::testing::Combine(