Merge pull request #35997 from psunn:int16_transpose_conv
PiperOrigin-RevId: 314541621 Change-Id: I684cba53a54b29188a5ee97bbcb80889c49ea5f5
This commit is contained in:
commit
f72f233547
@ -381,6 +381,7 @@ TransposeTest/.+
|
||||
|
||||
# transpose_conv_test
|
||||
-TransposeConvOpTest/TransposeConvOpTest.SimpleTestQuantizedPerChannelSingleChannel/0
|
||||
-TransposeConvOpTest/TransposeConvOpTest.SimpleTestQuantizedPerChannel16x8/0
|
||||
-TransposeConvOpTest/TransposeConvOpTest.TestQuantizedPerChannelMultiChannel/0
|
||||
# Const tensor only
|
||||
TransposeConvOpTest/TransposeConvOpTest/.+/0,29
|
||||
|
@ -119,6 +119,102 @@ inline void TransposeConv(
|
||||
}
|
||||
}
|
||||
|
||||
// int16 input (zero_point=0), int8 filter, int64 accumulator
|
||||
inline void TransposeConv(
|
||||
const ConvParams& params, const int32* output_multiplier,
|
||||
const int32* output_shift, const RuntimeShape& input_shape,
|
||||
const int16* input_data, const RuntimeShape& filter_shape,
|
||||
const int8* filter_data, const RuntimeShape& bias_shape,
|
||||
const std::int64_t* bias_data, const RuntimeShape& output_shape,
|
||||
int16* output_data, const RuntimeShape& im2col_shape, int8* im2col_data,
|
||||
std::int64_t* scratch_buffer) {
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int pad_width = params.padding_values.width;
|
||||
const int pad_height = params.padding_values.height;
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
(void)im2col_data; // only used in optimized code.
|
||||
(void)im2col_shape; // only used in optimized code.
|
||||
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||
if (bias_data) {
|
||||
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
|
||||
}
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int32 output_activation_min = std::numeric_limits<int16_t>::min();
|
||||
const int32 output_activation_max = std::numeric_limits<int16_t>::max();
|
||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||
|
||||
const int num_elements = output_shape.FlatSize();
|
||||
// We need to initialize scratch_buffer to all 0s, as we apply the same
|
||||
// 'scatter' based trick as in float version.
|
||||
memset(scratch_buffer, 0, num_elements * sizeof(std::int64_t));
|
||||
|
||||
// Loop through input elements one at a time.
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
for (int in_y = 0; in_y < input_height; ++in_y) {
|
||||
for (int in_x = 0; in_x < input_width; ++in_x) {
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
// Loop through the output elements it will influence.
|
||||
const int out_x_origin = (in_x * stride_width) - pad_width;
|
||||
const int out_y_origin = (in_y * stride_height) - pad_height;
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||
for (int out_channel = 0; out_channel < output_depth;
|
||||
++out_channel) {
|
||||
// Compute output element location.
|
||||
const int out_x = out_x_origin + filter_x;
|
||||
const int out_y = out_y_origin + filter_y;
|
||||
// We cannot accumulate out of bounds.
|
||||
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
|
||||
(out_y < output_height)) {
|
||||
const int32 input_value = input_data[Offset(
|
||||
input_shape, batch, in_y, in_x, in_channel)];
|
||||
const int32 filter_value =
|
||||
filter_data[Offset(filter_shape, out_channel, filter_y,
|
||||
filter_x, in_channel)];
|
||||
scratch_buffer[Offset(output_shape, batch, out_y, out_x,
|
||||
out_channel)] +=
|
||||
input_value * filter_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
|
||||
std::int64_t acc = scratch_buffer[Offset(output_shape, batch, out_y,
|
||||
out_x, out_channel)];
|
||||
if (bias_data) {
|
||||
acc += bias_data[out_channel];
|
||||
}
|
||||
int32 scaled_acc = MultiplyByQuantizedMultiplier(
|
||||
acc, output_multiplier[out_channel], output_shift[out_channel]);
|
||||
scaled_acc = std::max(scaled_acc, output_activation_min);
|
||||
scaled_acc = std::min(scaled_acc, output_activation_max);
|
||||
output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
|
||||
static_cast<int16_t>(scaled_acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_integer_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -155,8 +155,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
||||
++temporaries_count;
|
||||
}
|
||||
|
||||
// Allocate scratch buffer tensor for UInt8 inputs.
|
||||
if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) {
|
||||
// Allocate scratch buffer tensor
|
||||
if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8 ||
|
||||
input_type == kTfLiteInt16) {
|
||||
if (data->scratch_tensor_id == kTensorNotAllocated) {
|
||||
context->AddTensors(context, 1, &data->scratch_tensor_id);
|
||||
}
|
||||
@ -227,13 +228,16 @@ TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context,
|
||||
GetTensorShape(transposed_weights),
|
||||
GetTensorData<uint8>(transposed_weights));
|
||||
} else if (weights->type == kTfLiteInt8) {
|
||||
// int16 transpose_conv also with int8 weights
|
||||
optimized_ops::Transpose(transpose_params, input_shape,
|
||||
GetTensorData<int8>(weights),
|
||||
GetTensorShape(transposed_weights),
|
||||
GetTensorData<int8>(transposed_weights));
|
||||
} else {
|
||||
context->ReportError(
|
||||
context, "Transpose conv only support float & uint8 & int8 right now.");
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
"Only float32, uint8, int8, int16 is supported currently, got %s.",
|
||||
TfLiteTypeGetName(weights->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@ -263,9 +267,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
|
||||
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 ||
|
||||
input->type == kTfLiteUInt8 ||
|
||||
input->type == kTfLiteInt8);
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteFloat32 || input->type == kTfLiteUInt8 ||
|
||||
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
|
||||
|
||||
if (has_bias) {
|
||||
bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
@ -275,6 +279,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
|
||||
}
|
||||
} else if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt64);
|
||||
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, input->type);
|
||||
}
|
||||
@ -283,6 +290,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteInt8);
|
||||
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, weights->type, input->type);
|
||||
}
|
||||
TF_LITE_ENSURE_EQ(context, output->type, input->type);
|
||||
// Ensure that weights and inputs have the same channel dimension.
|
||||
// Note: TOCO will reorder weights in the following format: OHWI.
|
||||
@ -326,12 +340,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
|
||||
input->type == kTfLiteInt16) {
|
||||
node->temporaries->data[data->scratch_tensor_index] =
|
||||
data->scratch_tensor_id;
|
||||
TfLiteTensor* scratch_buffer =
|
||||
GetTemporary(context, node, data->scratch_tensor_index);
|
||||
scratch_buffer->type = kTfLiteInt32;
|
||||
if (input->type == kTfLiteInt16) {
|
||||
scratch_buffer->type = kTfLiteInt64;
|
||||
} else {
|
||||
scratch_buffer->type = kTfLiteInt32;
|
||||
}
|
||||
|
||||
scratch_buffer->allocation_type = kTfLiteDynamic;
|
||||
if (!IsConstantTensor(output_shape)) {
|
||||
SetTensorToDynamic(scratch_buffer);
|
||||
@ -500,6 +520,37 @@ void EvalQuantizedPerChannel(
|
||||
}
|
||||
}
|
||||
|
||||
void EvalQuantizedPerChannel16x8(
|
||||
TfLiteContext* context, const TfLiteTransposeConvParams* params,
|
||||
OpData* data, const TfLiteTensor* input, const TfLiteTensor* weights,
|
||||
const TfLiteTensor* transposed_weights, const TfLiteTensor* bias,
|
||||
TfLiteTensor* col2im, TfLiteTensor* output, TfLiteTensor* scratch_buffer) {
|
||||
tflite::ConvParams op_params;
|
||||
op_params.padding_type = PaddingType::kSame;
|
||||
op_params.padding_values.width = data->padding.width;
|
||||
op_params.padding_values.height = data->padding.height;
|
||||
op_params.padding_values.width_offset = data->padding.width_offset;
|
||||
op_params.padding_values.height_offset = data->padding.height_offset;
|
||||
op_params.stride_width = params->stride_width;
|
||||
op_params.stride_height = params->stride_height;
|
||||
// Need to flip the sign of input offset to add it directly to the quantized
|
||||
// buffer.
|
||||
op_params.input_offset = -input->params.zero_point;
|
||||
op_params.output_offset = output->params.zero_point;
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
|
||||
// Need to add optimized kernel
|
||||
reference_integer_ops::TransposeConv(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
GetTensorData<int16>(input), GetTensorShape(weights),
|
||||
GetTensorData<int8>(weights), GetTensorShape(bias),
|
||||
GetTensorData<int64_t>(bias), GetTensorShape(output),
|
||||
GetTensorData<int16>(output), GetTensorShape(col2im),
|
||||
GetTensorData<int8>(col2im), GetTensorData<int64_t>(scratch_buffer));
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Retrieve tensors (All should be allocated by now)
|
||||
@ -544,7 +595,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
filter_height, filter_width, params->padding, &unused_output_height,
|
||||
&unused_output_width);
|
||||
|
||||
// Currently support float32 and uint8.
|
||||
// Currently support float32, uint8, int8, int16.
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
// Only for GenericOptimized path, we use transposed weights.
|
||||
@ -589,6 +640,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
col2im, output, scratch_buffer);
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt16: {
|
||||
TfLiteTensor* scratch_buffer =
|
||||
GetTemporary(context, node, data->scratch_tensor_index);
|
||||
if (IsDynamicTensor(scratch_buffer)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeTensor(context, output_shape, scratch_buffer));
|
||||
}
|
||||
if (data->weights_are_transposed && !IsConstantTensor(weights)) {
|
||||
ResizeAndTransposeWeights(context, weights, transposed_weights);
|
||||
}
|
||||
EvalQuantizedPerChannel16x8(context, params, data, input, weights,
|
||||
transposed_weights, bias, col2im, output,
|
||||
scratch_buffer);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Type '%s' is not currently supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
|
@ -76,7 +76,10 @@ class BaseTransposeConvOpModel : public SingleOpModel {
|
||||
|
||||
if (test_type == TestType::kDynamic) {
|
||||
PopulateTensor<int32_t>(output_shape_, output_shape_data);
|
||||
PopulateTensor<InputType>(filter_, filter_data);
|
||||
if (!std::is_same<InputType, int16_t>::value &&
|
||||
!std::is_same<InputType, int8_t>::value) {
|
||||
PopulateTensor<InputType>(filter_, filter_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,6 +88,8 @@ class BaseTransposeConvOpModel : public SingleOpModel {
|
||||
QuantizeAndPopulate<uint8_t>(input_, data);
|
||||
} else if (std::is_same<InputType, int8_t>::value) {
|
||||
QuantizeAndPopulate<int8_t>(input_, data);
|
||||
} else if (std::is_same<InputType, int16_t>::value) {
|
||||
QuantizeAndPopulate<int16_t>(input_, data);
|
||||
} else {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
@ -315,6 +320,56 @@ TEST_P(TransposeConvOpTest, SimpleTestQuantized) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
}
|
||||
|
||||
TEST_P(TransposeConvOpTest, TwoFiltersTestQuantized) {
|
||||
// Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
|
||||
// 18}
|
||||
std::initializer_list<uint8_t> filter_data = {129, 131, 133, 135, 137, 139,
|
||||
141, 143, 145, 147, 149, 151,
|
||||
153, 155, 157, 159, 161, 163};
|
||||
QuantizedTransposeConvOpModel model(
|
||||
GetRegistration(), {1, 4, 4, 1},
|
||||
{TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64}, filter_data,
|
||||
{TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64},
|
||||
{TensorType_UINT8, {}, -4064, 4096}, Padding_SAME, 1, 1, GetTestType());
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
||||
23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{192, 416, 576, 544, 672, 1344, 1696, 1440, 1504, 2720, 3072,
|
||||
2432, 1984, 3360, 3648, 2752},
|
||||
1e-5)));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
}
|
||||
|
||||
TEST_P(TransposeConvOpTest, PaddingValidTestQuantized) {
|
||||
// Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
|
||||
// 18}
|
||||
std::initializer_list<uint8_t> filter_data = {129, 131, 133, 135, 137, 139,
|
||||
141, 143, 145, 147, 149, 151,
|
||||
153, 155, 157, 159, 161, 163};
|
||||
QuantizedTransposeConvOpModel model(
|
||||
GetRegistration(), {1, 6, 6, 1},
|
||||
{TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64}, filter_data,
|
||||
{TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64},
|
||||
{TensorType_UINT8, {}, -4064, 4096}, Padding_VALID, 1, 1, GetTestType());
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
||||
23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{0, 32, 64, 96, 128, 96, 64, 192, 416,
|
||||
576, 544, 352, 224, 672, 1344, 1696, 1440, 864,
|
||||
608, 1504, 2720, 3072, 2432, 1440, 864, 1984, 3360,
|
||||
3648, 2752, 1536, 704, 1536, 2528, 2720, 2016, 1088},
|
||||
1e-5)));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1}));
|
||||
}
|
||||
|
||||
class PerChannelQuantizedTransposeConvOpModel
|
||||
: public BaseTransposeConvOpModel<int8_t> {
|
||||
public:
|
||||
@ -325,10 +380,6 @@ class PerChannelQuantizedTransposeConvOpModel
|
||||
GetZeroPoint(output_));
|
||||
}
|
||||
|
||||
void SetInput(const std::initializer_list<float>& data) {
|
||||
QuantizeAndPopulate<int8_t>(input_, data);
|
||||
}
|
||||
|
||||
void SetFilter(const std::initializer_list<float>& data) {
|
||||
PerChannelSymmetricQuantizeAndPopulate(filter_, data);
|
||||
}
|
||||
@ -391,54 +442,78 @@ TEST_P(TransposeConvOpTest, TestQuantizedPerChannelMultiChannel) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
|
||||
}
|
||||
|
||||
TEST_P(TransposeConvOpTest, TwoFiltersTestQuantized) {
|
||||
// Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
|
||||
// 18}
|
||||
std::initializer_list<uint8_t> filter_data = {129, 131, 133, 135, 137, 139,
|
||||
141, 143, 145, 147, 149, 151,
|
||||
153, 155, 157, 159, 161, 163};
|
||||
QuantizedTransposeConvOpModel model(
|
||||
GetRegistration(), {1, 4, 4, 1},
|
||||
{TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64}, filter_data,
|
||||
{TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64},
|
||||
{TensorType_UINT8, {}, -4064, 4096}, Padding_SAME, 1, 1, GetTestType());
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
||||
23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
class PerChannelQuantizedTransposeConvOpModel16x8
|
||||
: public BaseTransposeConvOpModel<int16_t> {
|
||||
public:
|
||||
using BaseTransposeConvOpModel::BaseTransposeConvOpModel;
|
||||
|
||||
std::vector<float> GetDequantizedOutput() {
|
||||
return Dequantize<int16_t>(ExtractVector<int16_t>(output_),
|
||||
GetScale(output_), GetZeroPoint(output_));
|
||||
}
|
||||
|
||||
void SetFilter(const std::initializer_list<float>& data) {
|
||||
PerChannelSymmetricQuantizeAndPopulate(filter_, data);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannel16x8) {
|
||||
const std::initializer_list<float> filter_data = {
|
||||
// [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
|
||||
1, 2, // out channel = 0, y = 0, x = 0
|
||||
3, 4, // out channel = 0, y = 0, x = 1
|
||||
3, 4, // out channel = 0, y = 1, x = 0
|
||||
5, 6, // out channel = 0, y = 1, x = 1
|
||||
7, 8, // out channel = 1, y = 0, x = 0
|
||||
5, 6, // out channel = 1, y = 0, x = 1
|
||||
3, 4, // out channel = 1, y = 1, x = 0
|
||||
1, 2, // out channel = 1, y = 1, x = 1
|
||||
};
|
||||
PerChannelQuantizedTransposeConvOpModel16x8 model(
|
||||
GetRegistration(),
|
||||
/*output_shape_data=*/{1, 2, 3, 2},
|
||||
/*filter=*/
|
||||
{TensorType_INT8,
|
||||
/*shape=*/{2, 2, 2, 2},
|
||||
/*min=*/-64, /*max=*/64,
|
||||
/*scale=*/0, /*zero_point=*/0,
|
||||
/*per_channel_quantization=*/true,
|
||||
/*per_channel_quantization_scales=*/{7.0 / 127, 8.0 / 127},
|
||||
/*per_channel_quantization_offsets=*/{0, 0},
|
||||
/*channel_index=*/0},
|
||||
/*filter_data=*/{},
|
||||
/*input=*/
|
||||
{TensorType_INT16,
|
||||
/*shape=*/{1, 2, 3, 2},
|
||||
/*min=*/0, /*max=*/0,
|
||||
/*scale=*/4.0 / 127,
|
||||
/*zero_point=*/0},
|
||||
/*output=*/
|
||||
{TensorType_INT16,
|
||||
/*shape=*/{},
|
||||
/*min=*/0, /*max=*/0,
|
||||
/*scale=*/1.0,
|
||||
/*zero_point=*/0},
|
||||
/*padding=*/Padding_SAME,
|
||||
/*stride_w=*/1, /*stride_h=*/1, GetTestType());
|
||||
model.SetInput({
|
||||
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
|
||||
3, 2, // batch = 0, y = 0, x = 0
|
||||
1, -1, // batch = 0, y = 0, x = 1
|
||||
-2, -3, // batch = 0, y = 0, x = 2
|
||||
4, 3, // batch = 0, y = 1, x = 0
|
||||
2, -2, // batch = 0, y = 1, x = 1
|
||||
-3, -4, // batch = 0, y = 1, x = 2
|
||||
});
|
||||
model.SetFilter(filter_data);
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{192, 416, 576, 544, 672, 1344, 1696, 1440, 1504, 2720, 3072,
|
||||
2432, 1984, 3360, 3648, 2752},
|
||||
1e-5)));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
}
|
||||
{7, 37, 16, 26, -9, -39, 27, 69, 48, 42, -32, -74}, 1e-5)));
|
||||
|
||||
TEST_P(TransposeConvOpTest, PaddingValidTestQuantized) {
|
||||
// Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
|
||||
// 18}
|
||||
std::initializer_list<uint8_t> filter_data = {129, 131, 133, 135, 137, 139,
|
||||
141, 143, 145, 147, 149, 151,
|
||||
153, 155, 157, 159, 161, 163};
|
||||
QuantizedTransposeConvOpModel model(
|
||||
GetRegistration(), {1, 6, 6, 1},
|
||||
{TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64}, filter_data,
|
||||
{TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64},
|
||||
{TensorType_UINT8, {}, -4064, 4096}, Padding_VALID, 1, 1, GetTestType());
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
||||
23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{0, 32, 64, 96, 128, 96, 64, 192, 416,
|
||||
576, 544, 352, 224, 672, 1344, 1696, 1440, 864,
|
||||
608, 1504, 2720, 3072, 2432, 1440, 864, 1984, 3360,
|
||||
3648, 2752, 1536, 704, 1536, 2528, 2720, 2016, 1088},
|
||||
1e-5)));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1}));
|
||||
// GetOutputShape() should always be same as model.SetOutputShape(...);
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 2}));
|
||||
}
|
||||
|
||||
template <typename InputType>
|
||||
|
Loading…
Reference in New Issue
Block a user