Add reference implementation for quantized TransposeConv.
PiperOrigin-RevId: 239298532
This commit is contained in:
parent
6ec89893ce
commit
91a8831b4a
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <sys/types.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
@ -4049,6 +4050,93 @@ inline void TransposeConv(
|
||||
}
|
||||
}
|
||||
|
||||
inline void TransposeConv(const ConvParams& params,
|
||||
const RuntimeShape& input_shape,
|
||||
const uint8* input_data,
|
||||
const RuntimeShape& filter_shape,
|
||||
const uint8* filter_data,
|
||||
const RuntimeShape& output_shape, uint8* output_data,
|
||||
const RuntimeShape& im2col_shape, uint8* im2col_data,
|
||||
int32* scratch_buffer) {
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int pad_width = params.padding_values.width;
|
||||
const int pad_height = params.padding_values.height;
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
(void)im2col_data; // only used in optimized code.
|
||||
(void)im2col_shape; // only used in optimized code.
|
||||
|
||||
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
|
||||
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int32 input_offset = params.input_offset;
|
||||
const int32 filter_offset = params.weights_offset;
|
||||
const int32 output_offset = params.output_offset;
|
||||
const int32 output_multiplier = params.output_multiplier;
|
||||
const int output_shift = params.output_shift;
|
||||
const int32 output_activation_min = params.quantized_activation_min;
|
||||
const int32 output_activation_max = params.quantized_activation_max;
|
||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||
|
||||
const int num_elements = output_shape.FlatSize();
|
||||
// We need to initialize scratch_buffer to all 0s, as we apply the same
|
||||
// 'scatter' based trick as in float version.
|
||||
memset(scratch_buffer, 0, num_elements * sizeof(int32));
|
||||
|
||||
// Loop through input elements one at a time.
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
for (int in_y = 0; in_y < input_height; ++in_y) {
|
||||
for (int in_x = 0; in_x < input_width; ++in_x) {
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
// Loop through the output elements it will influence.
|
||||
const int out_x_origin = (in_x * stride_width) - pad_width;
|
||||
const int out_y_origin = (in_y * stride_height) - pad_height;
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||
for (int out_channel = 0; out_channel < output_depth;
|
||||
++out_channel) {
|
||||
// Compute output element location.
|
||||
const int out_x = out_x_origin + filter_x;
|
||||
const int out_y = out_y_origin + filter_y;
|
||||
// We cannot accumulate out of bounds.
|
||||
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
|
||||
(out_y < output_height)) {
|
||||
uint8 input_value = input_data[Offset(
|
||||
input_shape, batch, in_y, in_x, in_channel)];
|
||||
uint8 filter_value =
|
||||
filter_data[Offset(filter_shape, out_channel, filter_y,
|
||||
filter_x, in_channel)];
|
||||
scratch_buffer[Offset(output_shape, batch, out_y, out_x,
|
||||
out_channel)] +=
|
||||
(input_value + input_offset) *
|
||||
(filter_value + filter_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
int32 acc = scratch_buffer[i];
|
||||
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
|
||||
acc += output_offset;
|
||||
// Clamp the output before converting back to uint8.
|
||||
acc = std::max(acc, output_activation_min);
|
||||
acc = std::min(acc, output_activation_max);
|
||||
output_data[i] = static_cast<uint8>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool EqualFn(T lhs, T rhs) {
|
||||
return lhs == rhs;
|
||||
|
@ -111,15 +111,22 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
|
||||
double* multiplier) {
|
||||
const double input_product_scale = input->params.scale * filter->params.scale;
|
||||
const double bias_scale = bias->params.scale;
|
||||
const double output_scale = output->params.scale;
|
||||
|
||||
// TODO(ahentz): The following conditions must be guaranteed by the training
|
||||
// pipeline.
|
||||
TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <=
|
||||
1e-6 * std::min(input_product_scale, bias_scale));
|
||||
TF_LITE_ENSURE(context, input_product_scale >= 0);
|
||||
return GetQuantizedConvolutionMultipler(context, input, filter, output,
|
||||
multiplier);
|
||||
}
|
||||
|
||||
*multiplier = input_product_scale / output_scale;
|
||||
TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
|
||||
const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
TfLiteTensor* output,
|
||||
double* multiplier) {
|
||||
const double input_product_scale = input->params.scale * filter->params.scale;
|
||||
TF_LITE_ENSURE(context, input_product_scale >= 0);
|
||||
*multiplier = input_product_scale / output->params.scale;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
@ -113,6 +113,12 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
|
||||
TfLiteTensor* output,
|
||||
double* multiplier);
|
||||
|
||||
TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
|
||||
const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
TfLiteTensor* output,
|
||||
double* multiplier);
|
||||
|
||||
// Calculates the useful quantized range of an activation layer given its
|
||||
// activation tensor.
|
||||
TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context,
|
||||
|
@ -54,6 +54,19 @@ struct OpData {
|
||||
// im2col is the only temporary currently tracked, therefore always index 0.
|
||||
// If more temporaries are added, they should be properly tracked.
|
||||
int32_t im2col_index = 0;
|
||||
|
||||
TfLitePaddingValues padding;
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
// uint8_t these would be 0 and 255.
|
||||
int32_t output_activation_min;
|
||||
int32_t output_activation_max;
|
||||
|
||||
int scratch_tensor_index;
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
@ -61,6 +74,9 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
// Instead, we allocate a new object to use as scratch space for im2col, and
|
||||
// to carry information from Prepare() to Eval().
|
||||
auto* data = new OpData;
|
||||
// Populate scratch_tensor_index.
|
||||
context->AddTensors(context, /*tensors_to_add=*/1,
|
||||
&data->scratch_tensor_index);
|
||||
eigen_support::IncrementUsageCounter(context);
|
||||
return data;
|
||||
}
|
||||
@ -70,9 +86,9 @@ void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
const TfLiteTensor* shape_tensor,
|
||||
TfLiteTensor* output) {
|
||||
TfLiteStatus ResizeTensor(TfLiteContext* context,
|
||||
const TfLiteTensor* shape_tensor,
|
||||
TfLiteTensor* tensor_to_resize) {
|
||||
// Currently only support int32 for output shape.
|
||||
if (shape_tensor->type != kTfLiteInt32) {
|
||||
context->ReportError(context, "Output shape is %d, not int32.",
|
||||
@ -85,7 +101,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
shape->data[i] = GetTensorData<int32_t>(shape_tensor)[i];
|
||||
}
|
||||
|
||||
return context->ResizeTensor(context, output, shape);
|
||||
return context->ResizeTensor(context, tensor_to_resize, shape);
|
||||
}
|
||||
|
||||
static TfLiteStatus AllocateIm2colTensorIfRequired(TfLiteContext* context,
|
||||
@ -129,6 +145,8 @@ TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context,
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
// Sanity checks on op
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
@ -150,9 +168,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
|
||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteFloat32 || input->type == kTfLiteUInt8);
|
||||
TF_LITE_ENSURE_EQ(context, weights->type, input->type);
|
||||
TF_LITE_ENSURE_EQ(context, output->type, input->type);
|
||||
// Ensure that weights and inputs have the same channel dimension.
|
||||
// Note: TOCO will reorder weights in the following format: OHWI.
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
|
||||
@ -163,13 +182,103 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
SetTensorToDynamic(output);
|
||||
SetTensorToDynamic(im2col);
|
||||
} else {
|
||||
TF_LITE_ENSURE_STATUS(ResizeOutputTensor(context, output_shape, output));
|
||||
TF_LITE_ENSURE_STATUS(ResizeTensor(context, output_shape, output));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ResizeIm2ColTensor(context, output_shape, weights, input, im2col));
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
// Set up a scratch buffer tensor.
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
node->temporaries = TfLiteIntArrayCreate(1);
|
||||
node->temporaries->data[0] = data->scratch_tensor_index;
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
|
||||
scratch_buffer->type = kTfLiteInt32;
|
||||
scratch_buffer->allocation_type = kTfLiteArenaRw;
|
||||
if (!IsConstantTensor(output_shape)) {
|
||||
SetTensorToDynamic(scratch_buffer);
|
||||
} else {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ResizeTensor(context, output_shape, scratch_buffer));
|
||||
}
|
||||
|
||||
// Calcuate output multiplier for quantization.
|
||||
double real_multiplier = 0.0;
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||
context, input, weights, output, &real_multiplier));
|
||||
int exponent;
|
||||
// Populate quantization parameteters with multiplier and shift.
|
||||
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
|
||||
data->output_shift = -exponent;
|
||||
// Populate max and min activation range.
|
||||
CalculateActivationRangeUint8(kTfLiteActNone, output,
|
||||
&data->output_activation_min,
|
||||
&data->output_activation_max);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
|
||||
const TfLiteTensor* input, const TfLiteTensor* weights,
|
||||
TfLiteTensor* im2col, TfLiteTensor* output) {
|
||||
tflite::ConvParams op_params;
|
||||
op_params.padding_type = PaddingType::kSame;
|
||||
op_params.padding_values.width = data->padding.width;
|
||||
op_params.padding_values.height = data->padding.height;
|
||||
op_params.stride_width = params->stride_width;
|
||||
op_params.stride_height = params->stride_height;
|
||||
switch (kernel_type) {
|
||||
case kReference: {
|
||||
reference_ops::TransposeConv(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(weights), GetTensorData<float>(weights),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
GetTensorShape(im2col), GetTensorData<float>(im2col));
|
||||
break;
|
||||
}
|
||||
case kGenericOptimized: {
|
||||
optimized_ops::TransposeConv(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(weights), GetTensorData<float>(weights),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
GetTensorShape(im2col), GetTensorData<float>(im2col));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void EvalQuantized(const TfLiteTransposeConvParams* params, OpData* data,
|
||||
const TfLiteTensor* input, const TfLiteTensor* weights,
|
||||
TfLiteTensor* im2col, TfLiteTensor* output,
|
||||
TfLiteTensor* scratch_buffer) {
|
||||
int32_t input_offset = -input->params.zero_point;
|
||||
int32_t filter_offset = -weights->params.zero_point;
|
||||
int32_t output_offset = output->params.zero_point;
|
||||
|
||||
tflite::ConvParams op_params;
|
||||
op_params.padding_type = PaddingType::kSame;
|
||||
op_params.padding_values.width = data->padding.width;
|
||||
op_params.padding_values.height = data->padding.height;
|
||||
op_params.stride_width = params->stride_width;
|
||||
op_params.stride_height = params->stride_height;
|
||||
op_params.input_offset = input_offset;
|
||||
op_params.output_offset = output_offset;
|
||||
op_params.weights_offset = filter_offset;
|
||||
op_params.output_multiplier = data->output_multiplier;
|
||||
op_params.output_shift = -data->output_shift;
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
|
||||
// TODO(haoliang): Add optimized implementation later.
|
||||
reference_ops::TransposeConv(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8>(input),
|
||||
GetTensorShape(weights), GetTensorData<uint8>(weights),
|
||||
GetTensorShape(output), GetTensorData<uint8>(output),
|
||||
GetTensorShape(im2col), GetTensorData<uint8>(im2col),
|
||||
GetTensorData<int32_t>(scratch_buffer));
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Retrieve tensors (All should be allocated by now)
|
||||
@ -178,16 +287,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* im2col =
|
||||
&context->tensors[node->temporaries->data[user_data->im2col_index]];
|
||||
&context->tensors[node->temporaries->data[data->im2col_index]];
|
||||
const auto* params =
|
||||
reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data);
|
||||
|
||||
// Resize any deferred dynamic tensors
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeOutputTensor(context, output_shape, output));
|
||||
TF_LITE_ENSURE_OK(context, ResizeTensor(context, output_shape, output));
|
||||
}
|
||||
if (IsDynamicTensor(im2col)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape,
|
||||
@ -200,45 +308,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int filter_width = SizeOfDimension(weights, 2);
|
||||
const int filter_height = SizeOfDimension(weights, 1);
|
||||
|
||||
const int stride_width = params->stride_width;
|
||||
const int stride_height = params->stride_height;
|
||||
data->padding = ComputePaddingHeightWidth(
|
||||
params->stride_height, params->stride_width, 1, height, width,
|
||||
filter_height, filter_width, params->padding);
|
||||
|
||||
const TfLitePaddingValues& padding_size =
|
||||
ComputePaddingHeightWidth(stride_height, stride_width, 1, height, width,
|
||||
filter_height, filter_width, params->padding);
|
||||
|
||||
// Currently only support float32.
|
||||
// Currently support float32 and uint8.
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
tflite::ConvParams op_params;
|
||||
op_params.padding_type = PaddingType::kSame;
|
||||
op_params.padding_values.width = padding_size.width;
|
||||
op_params.padding_values.height = padding_size.height;
|
||||
op_params.stride_width = stride_width;
|
||||
op_params.stride_height = stride_height;
|
||||
switch (kernel_type) {
|
||||
case kReference: {
|
||||
reference_ops::TransposeConv(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(weights), GetTensorData<float>(weights),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
GetTensorShape(im2col), GetTensorData<float>(im2col));
|
||||
break;
|
||||
}
|
||||
case kGenericOptimized: {
|
||||
optimized_ops::TransposeConv(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(weights), GetTensorData<float>(weights),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
GetTensorShape(im2col), GetTensorData<float>(im2col));
|
||||
break;
|
||||
}
|
||||
EvalFloat<kernel_type>(params, data, input, weights, im2col, output);
|
||||
break;
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
// TODO(haoliang): support optimized implementation for quantized
|
||||
// TransposeConv.
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index*/ 0);
|
||||
if (IsDynamicTensor(scratch_buffer)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeTensor(context, output_shape, scratch_buffer));
|
||||
}
|
||||
EvalQuantized(params, data, input, weights, im2col, output,
|
||||
scratch_buffer);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Type %d, not currently supported.",
|
||||
input->type);
|
||||
context->ReportError(context, "Type '%s' is not currently supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
@ -35,12 +35,12 @@ namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class TransposeConvOpModel : public SingleOpModel {
|
||||
class BaseTransposeConvOpModel : public SingleOpModel {
|
||||
public:
|
||||
TransposeConvOpModel(TfLiteRegistration* registration,
|
||||
const TensorData& filter, const TensorData& input,
|
||||
const TensorData& output, Padding padding, int stride_w,
|
||||
int stride_h) {
|
||||
BaseTransposeConvOpModel(TfLiteRegistration* registration,
|
||||
const TensorData& filter, const TensorData& input,
|
||||
const TensorData& output, Padding padding,
|
||||
int stride_w, int stride_h) {
|
||||
// Just to be confusing, transpose_conv has an _input_ named "output_shape"
|
||||
// that sets the shape of the output tensor of the op :). It must always be
|
||||
// an int32 1D four element tensor.
|
||||
@ -63,18 +63,25 @@ class TransposeConvOpModel : public SingleOpModel {
|
||||
void SetOutputShape(std::initializer_list<int> i) {
|
||||
PopulateTensor(output_shape_, i);
|
||||
}
|
||||
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
protected:
|
||||
int output_shape_;
|
||||
int filter_;
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
class TransposeConvOpModel : public BaseTransposeConvOpModel {
|
||||
public:
|
||||
using BaseTransposeConvOpModel::BaseTransposeConvOpModel;
|
||||
|
||||
void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int output_shape_;
|
||||
int filter_;
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
|
||||
@ -97,19 +104,20 @@ class TransposeConvOpTest : public SingleOpTest {
|
||||
// [1, 1, 1, 1 ],
|
||||
// "SAME")
|
||||
TEST_P(TransposeConvOpTest, SimpleTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_SAME, 1, 1);
|
||||
m.SetOutputShape({1, 4, 4, 1});
|
||||
m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.Invoke();
|
||||
TransposeConvOpModel model(GetRegistration(),
|
||||
{TensorType_FLOAT32, {1, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_SAME, 1, 1);
|
||||
model.SetOutputShape({1, 4, 4, 1});
|
||||
model.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({29, 62, 83, 75, 99, 192, 237, 198, 207, 372,
|
||||
417, 330, 263, 446, 485, 365}));
|
||||
// GetOutputShape() should always be same as m.SetOutputShape(...);
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
// GetOutputShape() should always be same as model.SetOutputShape(...);
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
}
|
||||
|
||||
// Test case:
|
||||
@ -125,19 +133,22 @@ TEST_P(TransposeConvOpTest, SimpleTest) {
|
||||
// And filter value is derived by:
|
||||
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1])
|
||||
TEST_P(TransposeConvOpTest, TwoFiltersTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 2}},
|
||||
{TensorType_FLOAT32, {1, 4, 4, 2}},
|
||||
{TensorType_FLOAT32, {}}, Padding_SAME, 1, 1);
|
||||
m.SetOutputShape({1, 4, 4, 1});
|
||||
m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
m.Invoke();
|
||||
TransposeConvOpModel model(GetRegistration(),
|
||||
{TensorType_FLOAT32, {1, 3, 3, 2}},
|
||||
{TensorType_FLOAT32, {1, 4, 4, 2}},
|
||||
{TensorType_FLOAT32, {}}, Padding_SAME, 1, 1);
|
||||
model.SetOutputShape({1, 4, 4, 1});
|
||||
model.SetFilter(
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
||||
23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({184, 412, 568, 528, 678, 1347, 1689, 1434, 1494,
|
||||
2715, 3057, 2442, 1968, 3352, 3652, 2760}));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
}
|
||||
|
||||
// Test case:
|
||||
@ -153,22 +164,25 @@ TEST_P(TransposeConvOpTest, TwoFiltersTest) {
|
||||
// And filter value is derived by:
|
||||
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18])
|
||||
TEST_P(TransposeConvOpTest, PaddingValidTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 2}},
|
||||
{TensorType_FLOAT32, {1, 4, 4, 2}},
|
||||
{TensorType_FLOAT32, {}}, Padding_VALID, 1, 1);
|
||||
m.SetOutputShape({1, 6, 6, 1});
|
||||
m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
m.Invoke();
|
||||
TransposeConvOpModel model(GetRegistration(),
|
||||
{TensorType_FLOAT32, {1, 3, 3, 2}},
|
||||
{TensorType_FLOAT32, {1, 4, 4, 2}},
|
||||
{TensorType_FLOAT32, {}}, Padding_VALID, 1, 1);
|
||||
model.SetOutputShape({1, 6, 6, 1});
|
||||
model.SetFilter(
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
||||
23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
m.GetOutput(),
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({5, 22, 59, 101, 114, 83, 52, 184, 412,
|
||||
568, 528, 344, 237, 678, 1347, 1689, 1434, 879,
|
||||
597, 1494, 2715, 3057, 2442, 1431, 856, 1968, 3352,
|
||||
3652, 2760, 1548, 689, 1534, 2543, 2729, 2010, 1103}));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6, 6, 1}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1}));
|
||||
}
|
||||
|
||||
// Test case:
|
||||
@ -182,19 +196,20 @@ TEST_P(TransposeConvOpTest, PaddingValidTest) {
|
||||
// [1, 2, 2, 1 ],
|
||||
// "VALID")
|
||||
TEST_P(TransposeConvOpTest, StrideValidTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {1, 2, 2, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_VALID, 2, 2);
|
||||
m.SetOutputShape({1, 5, 5, 1});
|
||||
m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.Invoke();
|
||||
TransposeConvOpModel model(GetRegistration(),
|
||||
{TensorType_FLOAT32, {1, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {1, 2, 2, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_VALID, 2, 2);
|
||||
model.SetOutputShape({1, 5, 5, 1});
|
||||
model.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
model.SetInput({1, 2, 3, 4});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
m.GetOutput(),
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({1, 2, 5, 4, 6, 4, 5, 14, 10, 12, 10, 14, 36,
|
||||
24, 30, 12, 15, 34, 20, 24, 21, 24, 55, 32, 36}));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 1}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 1}));
|
||||
}
|
||||
|
||||
// Test case:
|
||||
@ -208,21 +223,23 @@ TEST_P(TransposeConvOpTest, StrideValidTest) {
|
||||
// [1, 2, 2, 1 ],
|
||||
// "VALID")
|
||||
TEST_P(TransposeConvOpTest, MultiChannelTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {1, 2, 2, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_VALID, 2, 2);
|
||||
m.SetOutputShape({1, 5, 5, 2});
|
||||
m.SetFilter({1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18});
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.Invoke();
|
||||
TransposeConvOpModel model(GetRegistration(),
|
||||
{TensorType_FLOAT32, {2, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {1, 2, 2, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_VALID, 2, 2);
|
||||
model.SetOutputShape({1, 5, 5, 2});
|
||||
model.SetFilter(
|
||||
{1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18});
|
||||
model.SetInput({1, 2, 3, 4});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
m.GetOutput(),
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9,
|
||||
10, 25, 28, 18, 20, 22, 24, 16, 20, 24, 28, 62, 72,
|
||||
42, 48, 54, 60, 21, 24, 27, 30, 61, 68, 36, 40, 44,
|
||||
48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72}));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
|
||||
}
|
||||
|
||||
// Test case:
|
||||
@ -238,18 +255,100 @@ TEST_P(TransposeConvOpTest, MultiChannelTest) {
|
||||
// And filter value is derived by:
|
||||
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[-1])
|
||||
TEST_P(TransposeConvOpTest, AccuracyTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {1, 1, 2, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_SAME, 3, 3);
|
||||
m.SetOutputShape({1, 3, 4, 1});
|
||||
m.SetFilter({9, 5, 6, 9, 8, 5, 3, 1, 4});
|
||||
m.SetInput({323, 521});
|
||||
m.Invoke();
|
||||
TransposeConvOpModel model(GetRegistration(),
|
||||
{TensorType_FLOAT32, {1, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {1, 1, 2, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_SAME, 3, 3);
|
||||
model.SetOutputShape({1, 3, 4, 1});
|
||||
model.SetFilter({9, 5, 6, 9, 8, 5, 3, 1, 4});
|
||||
model.SetInput({323, 521});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
|
||||
{1615., 1938., 4689., 2605., 2584., 1615.,
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray(
|
||||
ArrayFloatNear({1615., 1938., 4689., 2605., 2584., 1615.,
|
||||
4689., 4168., 323., 1292., 1563., 521.})));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 4, 1}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 3, 4, 1}));
|
||||
}
|
||||
|
||||
class QuantizedTransposeConvOpModel : public BaseTransposeConvOpModel {
|
||||
public:
|
||||
using BaseTransposeConvOpModel::BaseTransposeConvOpModel;
|
||||
|
||||
void SetFilter(std::initializer_list<float> f) {
|
||||
QuantizeAndPopulate<uint8_t>(filter_, f);
|
||||
}
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
QuantizeAndPopulate<uint8_t>(input_, data);
|
||||
}
|
||||
std::vector<float> GetDequantizedOutput() {
|
||||
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
|
||||
GetScale(output_), GetZeroPoint(output_));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(TransposeConvOpTest, SimpleTestQuantized) {
|
||||
QuantizedTransposeConvOpModel model(
|
||||
GetRegistration(), {TensorType_UINT8, {1, 3, 3, 1}, -63.5, 64},
|
||||
{TensorType_UINT8, {1, 4, 4, 1}, -63.5, 64},
|
||||
{TensorType_UINT8, {}, -508, 512}, Padding_SAME, 1, 1);
|
||||
model.SetOutputShape({1, 4, 4, 1});
|
||||
model.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
model.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear({28, 64, 84, 76, 100, 192, 236, 200, 208,
|
||||
372, 416, 332, 264, 448, 484, 364},
|
||||
1e-5)));
|
||||
|
||||
// GetOutputShape() should always be same as model.SetOutputShape(...);
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
}
|
||||
|
||||
TEST_P(TransposeConvOpTest, TwoFiltersTestQuantized) {
|
||||
QuantizedTransposeConvOpModel model(
|
||||
GetRegistration(), {TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64},
|
||||
{TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64},
|
||||
{TensorType_UINT8, {}, -4064, 4096}, Padding_SAME, 1, 1);
|
||||
model.SetOutputShape({1, 4, 4, 1});
|
||||
model.SetFilter(
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
||||
23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{192, 416, 576, 544, 672, 1344, 1696, 1440, 1504, 2720, 3072,
|
||||
2432, 1984, 3360, 3648, 2752},
|
||||
1e-5)));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
}
|
||||
|
||||
TEST_P(TransposeConvOpTest, PaddingValidTestQuantized) {
|
||||
QuantizedTransposeConvOpModel model(
|
||||
GetRegistration(), {TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64},
|
||||
{TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64},
|
||||
{TensorType_UINT8, {}, -4064, 4096}, Padding_VALID, 1, 1);
|
||||
model.SetOutputShape({1, 6, 6, 1});
|
||||
model.SetFilter(
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
||||
23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{0, 32, 64, 96, 128, 96, 64, 192, 416,
|
||||
576, 544, 352, 224, 672, 1344, 1696, 1440, 864,
|
||||
608, 1504, 2720, 3072, 2432, 1440, 864, 1984, 3360,
|
||||
3648, 2752, 1536, 704, 1536, 2528, 2720, 2016, 1088},
|
||||
1e-5)));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1}));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
|
Loading…
Reference in New Issue
Block a user