Optimized TransposeConv implementation.
PiperOrigin-RevId: 218181572
This commit is contained in:
parent
34e3129d8e
commit
2db1f885bd
@ -305,11 +305,6 @@ def generated_test_models():
|
||||
# If you have to disable a test, please add here with a link to the appropriate
|
||||
# bug or issue.
|
||||
def generated_test_models_failing(conversion_mode):
|
||||
if not conversion_mode:
|
||||
return [
|
||||
"transpose_conv", # disabled due to b/111213074
|
||||
]
|
||||
|
||||
if conversion_mode == "toco-flex":
|
||||
# TODO(b/117328698): Fix and enable the known flex failures.
|
||||
return [
|
||||
|
@ -1177,6 +1177,7 @@ tf_cc_test(
|
||||
":builtin_ops",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite/kernels:test_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/contrib/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
|
||||
@ -32,74 +33,165 @@ namespace ops {
|
||||
namespace builtin {
|
||||
namespace transpose_conv {
|
||||
|
||||
// This file has 2 implementation of TransposeConv.
|
||||
enum KernelType {
|
||||
kReference,
|
||||
kGenericOptimized, // Neon-free
|
||||
};
|
||||
|
||||
constexpr int kOutputShapeTensor = 0;
|
||||
constexpr int kWeightsTensor = 1;
|
||||
constexpr int kDataInputTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus ResizeOutputShape(TfLiteContext* context,
|
||||
const TfLiteTensor* output_shape,
|
||||
TfLiteTensor* output) {
|
||||
const int kTensorNotAllocated = -1;
|
||||
|
||||
struct OpData {
|
||||
// IDs are the arbitrary identifiers used by TF Lite to identify and access
|
||||
// memory buffers.
|
||||
int im2col_id = kTensorNotAllocated;
|
||||
|
||||
// im2col is the only temporary currently tracked, therefore always index 0.
|
||||
// If more temporaries are added, they should be properly tracked.
|
||||
int32_t im2col_index = 0;
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
// This is a builtin op, so we don't use the contents in 'buffer', if any.
|
||||
// Instead, we allocate a new object to use as scratch space for im2col, and
|
||||
// to carry information from Prepare() to Eval().
|
||||
auto* data = new OpData;
|
||||
eigen_support::IncrementUsageCounter(context);
|
||||
return data;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
eigen_support::DecrementUsageCounter(context);
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
const TfLiteTensor* shape_tensor,
|
||||
TfLiteTensor* output) {
|
||||
// Currently only support int32 for output shape.
|
||||
if (output_shape->type != kTfLiteInt32) {
|
||||
if (shape_tensor->type != kTfLiteInt32) {
|
||||
context->ReportError(context, "Output shape is %d, not int32.",
|
||||
shape_tensor->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
TfLiteIntArray* shape = TfLiteIntArrayCreate(NumElements(shape_tensor));
|
||||
for (int i = 0; i < shape->size; ++i) {
|
||||
shape->data[i] = GetTensorData<int32_t>(shape_tensor)[i];
|
||||
}
|
||||
|
||||
return context->ResizeTensor(context, output, shape);
|
||||
}
|
||||
|
||||
static TfLiteStatus AllocateIm2colTensorIfRequired(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
if (data->im2col_id == kTensorNotAllocated) {
|
||||
context->AddTensors(context, 1, &data->im2col_id);
|
||||
context->tensors[data->im2col_id].type = kTfLiteFloat32;
|
||||
}
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
node->temporaries = TfLiteIntArrayCreate(1);
|
||||
node->temporaries->data[data->im2col_index] = data->im2col_id;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context,
|
||||
const TfLiteTensor* output_shape,
|
||||
const TfLiteTensor* weights,
|
||||
const TfLiteTensor* input,
|
||||
TfLiteTensor* im2col) {
|
||||
if (output_shape->type != kTfLiteInt32) {
|
||||
context->ReportError(context, "im2col shape is %d, not int32.",
|
||||
output_shape->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
const int output_dimensions = NumElements(output_shape);
|
||||
TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions);
|
||||
for (int i = 0; i < output_dimensions; ++i) {
|
||||
output_shape_array->data[i] = GetTensorData<int32_t>(output_shape)[i];
|
||||
}
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4);
|
||||
TfLiteIntArray* im2col_shape_array = TfLiteIntArrayCreate(4);
|
||||
im2col_shape_array->data[0] = output_shape->data.i32[0];
|
||||
im2col_shape_array->data[1] = output_shape->data.i32[1];
|
||||
im2col_shape_array->data[2] = output_shape->data.i32[2];
|
||||
const int input_depth = SizeOfDimension(input, 3);
|
||||
const int filter_width = SizeOfDimension(weights, 1);
|
||||
const int filter_height = SizeOfDimension(weights, 2);
|
||||
im2col_shape_array->data[3] = input_depth * filter_height * filter_width;
|
||||
|
||||
return context->ResizeTensor(context, output, output_shape_array);
|
||||
im2col->type = input->type;
|
||||
im2col->allocation_type = kTfLiteDynamic;
|
||||
return context->ResizeTensor(context, im2col, im2col_shape_array);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Sanity checks on op
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
// Allocate Im2col Tensor
|
||||
TF_LITE_ENSURE_STATUS(AllocateIm2colTensorIfRequired(context, node));
|
||||
|
||||
// Retrieve tensors
|
||||
const TfLiteTensor* output_shape =
|
||||
GetInput(context, node, kOutputShapeTensor);
|
||||
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* im2col =
|
||||
&context->tensors[node->temporaries->data[user_data->im2col_index]];
|
||||
|
||||
// Tensor sanity checks
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
|
||||
|
||||
// Currently only supports float32.
|
||||
const TfLiteType data_type = input->type;
|
||||
TF_LITE_ENSURE(context, data_type == kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, output->type, data_type);
|
||||
TF_LITE_ENSURE_EQ(context, weights->type, data_type);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||
// Ensure that weights and inputs have the same channel dimension.
|
||||
// Note: TOCO will reorder weights in the following format: OHWI.
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
|
||||
SizeOfDimension(weights, 3));
|
||||
|
||||
if (!IsConstantTensor(output_shape)) {
|
||||
// Defer resizing until Eval().
|
||||
SetTensorToDynamic(output);
|
||||
return kTfLiteOk;
|
||||
SetTensorToDynamic(im2col);
|
||||
} else {
|
||||
TF_LITE_ENSURE_STATUS(ResizeOutputTensor(context, output_shape, output));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ResizeIm2ColTensor(context, output_shape, weights, input, im2col));
|
||||
}
|
||||
return ResizeOutputShape(context, output_shape, output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Retrieve tensors (All should be allocated by now)
|
||||
const TfLiteTensor* output_shape =
|
||||
GetInput(context, node, kOutputShapeTensor);
|
||||
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* im2col =
|
||||
&context->tensors[node->temporaries->data[user_data->im2col_index]];
|
||||
const auto* params =
|
||||
reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data);
|
||||
|
||||
// Resize any deferred dynamic tensors
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeOutputShape(context, output_shape, output));
|
||||
ResizeOutputTensor(context, output_shape, output));
|
||||
}
|
||||
if (IsDynamicTensor(im2col)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape,
|
||||
weights, input, im2col));
|
||||
}
|
||||
|
||||
// Get height and width of the output image.
|
||||
@ -124,17 +216,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
op_params.padding_values.height = padding_size.height;
|
||||
op_params.stride_width = stride_width;
|
||||
op_params.stride_height = stride_height;
|
||||
|
||||
reference_ops::TransposeConv(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(weights), GetTensorData<float>(weights),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
// Last two args specify im2col which reference_ops ignores.
|
||||
// (Note this does not lead to a performance regression, as the
|
||||
// previous optimized version was just a copy of the reference code.)
|
||||
// TODO(b/110208176): Allocate im2col tensors and switch to
|
||||
// optimized_ops.
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
switch (kernel_type) {
|
||||
case kReference: {
|
||||
reference_ops::TransposeConv(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(weights), GetTensorData<float>(weights),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
GetTensorShape(im2col), GetTensorData<float>(im2col));
|
||||
break;
|
||||
}
|
||||
case kGenericOptimized: {
|
||||
optimized_ops::TransposeConv(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(weights), GetTensorData<float>(weights),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
GetTensorShape(im2col), GetTensorData<float>(im2col));
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -147,12 +246,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
} // namespace transpose_conv
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSE_CONV() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, transpose_conv::Prepare,
|
||||
transpose_conv::Eval};
|
||||
TfLiteRegistration* Register_TRANSPOSECONV_REF() {
|
||||
static TfLiteRegistration r = {
|
||||
transpose_conv::Init, transpose_conv::Free, transpose_conv::Prepare,
|
||||
transpose_conv::Eval<transpose_conv::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSECONV_GENERIC_OPT() {
|
||||
static TfLiteRegistration r = {
|
||||
transpose_conv::Init, transpose_conv::Free, transpose_conv::Prepare,
|
||||
transpose_conv::Eval<transpose_conv::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSE_CONV() {
|
||||
return Register_TRANSPOSECONV_GENERIC_OPT();
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -14,36 +14,59 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdarg>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/contrib/lite/interpreter.h"
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "tensorflow/contrib/lite/kernels/test_util.h"
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSECONV_REF();
|
||||
TfLiteRegistration* Register_TRANSPOSECONV_GENERIC_OPT();
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class TransposeConvOpModel : public SingleOpModel {
|
||||
public:
|
||||
TransposeConvOpModel(std::initializer_list<int> input_shape,
|
||||
std::initializer_list<int> filter_shape, Padding padding,
|
||||
int stride_w, int stride_h) {
|
||||
output_shape_ = AddInput(TensorType_INT32);
|
||||
filter_ = AddInput(TensorType_FLOAT32);
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
TransposeConvOpModel(TfLiteRegistration* registration,
|
||||
const TensorData& filter, const TensorData& input,
|
||||
const TensorData& output, Padding padding, int stride_w,
|
||||
int stride_h) {
|
||||
// Just to be confusing, transpose_conv has an _input_ named "output_shape"
|
||||
// that sets the shape of the output tensor of the op :). It must always be
|
||||
// an int32 1D four element tensor.
|
||||
output_shape_ = AddInput({TensorType_INT32, {4}});
|
||||
filter_ = AddInput(filter);
|
||||
input_ = AddInput(input);
|
||||
|
||||
output_ = AddOutput(output);
|
||||
|
||||
SetBuiltinOp(
|
||||
BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
|
||||
CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
|
||||
.Union());
|
||||
BuildInterpreter({{4}, filter_shape, input_shape});
|
||||
resolver_ = absl::make_unique<SingleOpResolver>(
|
||||
BuiltinOperator_TRANSPOSE_CONV, registration);
|
||||
BuildInterpreter(
|
||||
{GetShape(output_shape_), GetShape(input_), GetShape(filter_)});
|
||||
}
|
||||
|
||||
int output_shape() { return output_shape_; }
|
||||
int filter() { return filter_; }
|
||||
int input() { return input_; }
|
||||
|
||||
void SetOutputShape(std::initializer_list<int> i) {
|
||||
PopulateTensor(output_shape_, i);
|
||||
}
|
||||
void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
@ -54,6 +77,18 @@ class TransposeConvOpModel : public SingleOpModel {
|
||||
int output_;
|
||||
};
|
||||
|
||||
const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
|
||||
{"Reference", ops::builtin::Register_TRANSPOSECONV_REF()},
|
||||
{"GenericOptimized", ops::builtin::Register_TRANSPOSECONV_GENERIC_OPT()},
|
||||
});
|
||||
|
||||
class TransposeConvOpTest : public SingleOpTest {
|
||||
protected:
|
||||
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
||||
return *kKernelMap;
|
||||
}
|
||||
};
|
||||
|
||||
// Test case:
|
||||
// output = tf.nn.conv2d_backprop_input(
|
||||
// tf.constant([ 1, 4, 4, 1 ]),
|
||||
@ -61,17 +96,19 @@ class TransposeConvOpModel : public SingleOpModel {
|
||||
// tf.constant(np.arange(1, 17), shape=[ 1, 4, 4, 1 ], dtype=tf.float32),
|
||||
// [1, 1, 1, 1 ],
|
||||
// "SAME")
|
||||
TEST(TransposeConvOpModelTest, SimpleTest) {
|
||||
TransposeConvOpModel m({1, 4, 4, 1}, {1, 3, 3, 1}, Padding_SAME, 1, 1);
|
||||
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
|
||||
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
m.PopulateTensor<float>(
|
||||
m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
TEST_P(TransposeConvOpTest, SimpleTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||
{TensorType_FLOAT32, {1, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_SAME, 1, 1);
|
||||
m.SetOutputShape({1, 4, 4, 1});
|
||||
m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray({29, 62, 83, 75, 99, 192, 237, 198, 207, 372,
|
||||
417, 330, 263, 446, 485, 365}));
|
||||
// GetOutputShape() should always be same as m.SetOutputShape(...);
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
}
|
||||
|
||||
@ -87,15 +124,14 @@ TEST(TransposeConvOpModelTest, SimpleTest) {
|
||||
// "SAME")
|
||||
// And filter value is derived by:
|
||||
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1])
|
||||
TEST(TransposeConvOpModelTest, TwoFiltersTest) {
|
||||
TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_SAME, 1, 1);
|
||||
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
|
||||
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18});
|
||||
m.PopulateTensor<float>(
|
||||
m.input(),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
TEST_P(TransposeConvOpTest, TwoFiltersTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 4, 4, 2}},
|
||||
{TensorType_FLOAT32, {1, 3, 3, 2}},
|
||||
{TensorType_FLOAT32, {}}, Padding_SAME, 1, 1);
|
||||
m.SetOutputShape({1, 4, 4, 1});
|
||||
m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
@ -116,15 +152,14 @@ TEST(TransposeConvOpModelTest, TwoFiltersTest) {
|
||||
// "VALID")
|
||||
// And filter value is derived by:
|
||||
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18])
|
||||
TEST(TransposeConvOpModelTest, PaddingValidTest) {
|
||||
TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_VALID, 1, 1);
|
||||
m.PopulateTensor<int>(m.output_shape(), {1, 6, 6, 1});
|
||||
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18});
|
||||
m.PopulateTensor<float>(
|
||||
m.input(),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
TEST_P(TransposeConvOpTest, PaddingValidTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 4, 4, 2}},
|
||||
{TensorType_FLOAT32, {1, 3, 3, 2}},
|
||||
{TensorType_FLOAT32, {}}, Padding_VALID, 1, 1);
|
||||
m.SetOutputShape({1, 6, 6, 1});
|
||||
m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
@ -146,11 +181,13 @@ TEST(TransposeConvOpModelTest, PaddingValidTest) {
|
||||
// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32),
|
||||
// [1, 2, 2, 1 ],
|
||||
// "VALID")
|
||||
TEST(TransposeConvOpModelTest, StrideValidTest) {
|
||||
TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 1}, Padding_VALID, 2, 2);
|
||||
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 1});
|
||||
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
|
||||
TEST_P(TransposeConvOpTest, StrideValidTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 2, 2, 1}},
|
||||
{TensorType_FLOAT32, {1, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_VALID, 2, 2);
|
||||
m.SetOutputShape({1, 5, 5, 1});
|
||||
m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
@ -170,12 +207,13 @@ TEST(TransposeConvOpModelTest, StrideValidTest) {
|
||||
// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32),
|
||||
// [1, 2, 2, 1 ],
|
||||
// "VALID")
|
||||
TEST(TransposeConvOpModelTest, MultiChannelTest) {
|
||||
TransposeConvOpModel m({1, 2, 2, 1}, {2, 3, 3, 1}, Padding_VALID, 2, 2);
|
||||
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 2});
|
||||
m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
|
||||
8, 10, 12, 14, 16, 18});
|
||||
m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
|
||||
TEST_P(TransposeConvOpTest, MultiChannelTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 2, 2, 1}},
|
||||
{TensorType_FLOAT32, {2, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_VALID, 2, 2);
|
||||
m.SetOutputShape({1, 5, 5, 2});
|
||||
m.SetFilter({1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18});
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
@ -199,11 +237,13 @@ TEST(TransposeConvOpModelTest, MultiChannelTest) {
|
||||
// "SAME")
|
||||
// And filter value is derived by:
|
||||
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[-1])
|
||||
TEST(TransposeConvOpModelTest, AccuracyTest) {
|
||||
TransposeConvOpModel m({1, 1, 2, 1}, {1, 3, 3, 1}, Padding_SAME, 3, 3);
|
||||
m.PopulateTensor<int>(m.output_shape(), {1, 3, 4, 1});
|
||||
m.PopulateTensor<float>(m.filter(), {9, 5, 6, 9, 8, 5, 3, 1, 4});
|
||||
m.PopulateTensor<float>(m.input(), {323, 521});
|
||||
TEST_P(TransposeConvOpTest, AccuracyTest) {
|
||||
TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 1, 2, 1}},
|
||||
{TensorType_FLOAT32, {1, 3, 3, 1}},
|
||||
{TensorType_FLOAT32, {}}, Padding_SAME, 3, 3);
|
||||
m.SetOutputShape({1, 3, 4, 1});
|
||||
m.SetFilter({9, 5, 6, 9, 8, 5, 3, 1, 4});
|
||||
m.SetInput({323, 521});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
|
||||
@ -212,6 +252,10 @@ TEST(TransposeConvOpModelTest, AccuracyTest) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 4, 1}));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
TransposeConvOpTest, TransposeConvOpTest,
|
||||
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user