Update DepthwiseConv2D tests for XNNPACK delegate
- Extract DepthwiseConv2DTester class into a separate target - Add test cases for depthwise convolution with depth multiplier, activations, valid padding, and multi-threaded inference PiperOrigin-RevId: 307335761 Change-Id: I735dc1236fb69deb2f089761155db9ba9e0f26d3
This commit is contained in:
parent
c5b775f4e9
commit
0ca4f367ab
@ -57,6 +57,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "depthwise_conv_2d_tester",
|
||||
testonly = 1,
|
||||
srcs = ["depthwise_conv_2d_tester.cc"],
|
||||
hdrs = ["depthwise_conv_2d_tester.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pool_2d_tester",
|
||||
testonly = 1,
|
||||
@ -130,14 +145,10 @@ cc_test(
|
||||
}),
|
||||
tags = ["nomsan"], # b/145129478
|
||||
deps = [
|
||||
":depthwise_conv_2d_tester",
|
||||
":test_main",
|
||||
":xnnpack_delegate",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
":xnnpack_delegate_test_mode",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,300 +15,36 @@ limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h"
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
||||
namespace {
|
||||
TEST(DepthwiseConv2D, 1x1) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
class DepthwiseConv2DTester {
|
||||
public:
|
||||
DepthwiseConv2DTester() = default;
|
||||
DepthwiseConv2DTester(const DepthwiseConv2DTester&) = delete;
|
||||
DepthwiseConv2DTester& operator=(const DepthwiseConv2DTester&) = delete;
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester& BatchSize(int32_t batch_size) {
|
||||
EXPECT_GT(batch_size, 0);
|
||||
batch_size_ = batch_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t BatchSize() const { return batch_size_; }
|
||||
|
||||
DepthwiseConv2DTester& Groups(int32_t groups) {
|
||||
EXPECT_GT(groups, 0);
|
||||
groups_ = groups;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t Groups() const { return groups_; }
|
||||
|
||||
DepthwiseConv2DTester& DepthMultiplier(int32_t depth_multiplier) {
|
||||
EXPECT_GT(depth_multiplier, 0);
|
||||
depth_multiplier_ = depth_multiplier;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t DepthMultiplier() const { return depth_multiplier_; }
|
||||
|
||||
int32_t InputChannels() const { return Groups(); }
|
||||
|
||||
int32_t OutputChannels() const { return DepthMultiplier() * Groups(); }
|
||||
|
||||
DepthwiseConv2DTester& InputHeight(int32_t input_height) {
|
||||
EXPECT_GT(input_height, 0);
|
||||
input_height_ = input_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t InputHeight() const { return input_height_; }
|
||||
|
||||
DepthwiseConv2DTester& InputWidth(int32_t input_width) {
|
||||
EXPECT_GT(input_width, 0);
|
||||
input_width_ = input_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t InputWidth() const { return input_width_; }
|
||||
|
||||
int32_t OutputWidth() const {
|
||||
const int32_t output_width = (InputWidth() - 1) / StrideWidth() + 1;
|
||||
EXPECT_GT(output_width, 0);
|
||||
return output_width;
|
||||
}
|
||||
|
||||
int32_t OutputHeight() const {
|
||||
const int32_t output_height = (InputHeight() - 1) / StrideHeight() + 1;
|
||||
EXPECT_GT(output_height, 0);
|
||||
return output_height;
|
||||
}
|
||||
|
||||
DepthwiseConv2DTester& KernelHeight(int32_t kernel_height) {
|
||||
EXPECT_GT(kernel_height, 0);
|
||||
kernel_height_ = kernel_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t KernelHeight() const { return kernel_height_; }
|
||||
|
||||
DepthwiseConv2DTester& KernelWidth(int32_t kernel_width) {
|
||||
EXPECT_GT(kernel_width, 0);
|
||||
kernel_width_ = kernel_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t KernelWidth() const { return kernel_width_; }
|
||||
|
||||
DepthwiseConv2DTester& StrideHeight(int32_t stride_height) {
|
||||
EXPECT_GT(stride_height, 0);
|
||||
stride_height_ = stride_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t StrideHeight() const { return stride_height_; }
|
||||
|
||||
DepthwiseConv2DTester& StrideWidth(int32_t stride_width) {
|
||||
EXPECT_GT(stride_width, 0);
|
||||
stride_width_ = stride_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t StrideWidth() const { return stride_width_; }
|
||||
|
||||
DepthwiseConv2DTester& DilationHeight(int32_t dilation_height) {
|
||||
EXPECT_GT(dilation_height, 0);
|
||||
dilation_height_ = dilation_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t DilationHeight() const { return dilation_height_; }
|
||||
|
||||
DepthwiseConv2DTester& DilationWidth(int32_t dilation_width) {
|
||||
EXPECT_GT(dilation_width, 0);
|
||||
dilation_width_ = dilation_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t DilationWidth() const { return dilation_width_; }
|
||||
|
||||
void Test(TfLiteDelegate* delegate) const {
|
||||
ASSERT_EQ(DepthMultiplier(), 1) << "Flow does not support depth multiplier";
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
|
||||
|
||||
std::vector<char> buffer = CreateTfLiteModel(std::ref(f32rng));
|
||||
const Model* model = GetModel(buffer.data());
|
||||
|
||||
std::unique_ptr<Interpreter> delegate_interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())(
|
||||
&delegate_interpreter),
|
||||
kTfLiteOk);
|
||||
std::unique_ptr<Interpreter> default_interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())(
|
||||
&default_interpreter),
|
||||
kTfLiteOk);
|
||||
|
||||
ASSERT_TRUE(delegate_interpreter);
|
||||
ASSERT_TRUE(default_interpreter);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->inputs().size(), 1);
|
||||
ASSERT_EQ(default_interpreter->inputs().size(), 1);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->outputs().size(), 1);
|
||||
ASSERT_EQ(default_interpreter->outputs().size(), 1);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
ASSERT_EQ(default_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate),
|
||||
kTfLiteOk);
|
||||
|
||||
float* default_input_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->inputs()[0]);
|
||||
std::generate(default_input_data,
|
||||
default_input_data + BatchSize() * InputChannels() *
|
||||
InputHeight() * InputWidth(),
|
||||
std::ref(f32rng));
|
||||
|
||||
float* xnnpack_input_data = delegate_interpreter->typed_tensor<float>(
|
||||
delegate_interpreter->inputs()[0]);
|
||||
std::copy(default_input_data,
|
||||
default_input_data +
|
||||
BatchSize() * InputChannels() * InputHeight() * InputWidth(),
|
||||
xnnpack_input_data);
|
||||
|
||||
default_interpreter->Invoke();
|
||||
delegate_interpreter->Invoke();
|
||||
|
||||
float* default_output_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->outputs()[0]);
|
||||
float* xnnpack_output_data = delegate_interpreter->typed_tensor<float>(
|
||||
delegate_interpreter->outputs()[0]);
|
||||
|
||||
for (size_t i = 0;
|
||||
i < BatchSize() * OutputChannels() * OutputHeight() * OutputWidth();
|
||||
i++) {
|
||||
ASSERT_NEAR(default_output_data[i], xnnpack_output_data[i],
|
||||
std::numeric_limits<float>::epsilon() *
|
||||
std::max(std::abs(default_output_data[i]) * 10.0f, 1.0f));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<char> CreateTfLiteModel(std::function<float()> f32rng) const {
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
flatbuffers::Offset<OperatorCode> operator_code =
|
||||
CreateOperatorCode(builder, BuiltinOperator_DEPTHWISE_CONV_2D, 0);
|
||||
|
||||
flatbuffers::Offset<DepthwiseConv2DOptions> depthwise_conv2d_options =
|
||||
CreateDepthwiseConv2DOptions(builder, Padding_SAME, StrideWidth(),
|
||||
StrideHeight(), DepthMultiplier(),
|
||||
ActivationFunctionType_NONE,
|
||||
DilationWidth(), DilationHeight());
|
||||
|
||||
std::vector<float> filter_data(KernelHeight() * KernelWidth() *
|
||||
OutputChannels());
|
||||
std::vector<float> bias_data(OutputChannels());
|
||||
|
||||
std::generate(filter_data.begin(), filter_data.end(), f32rng);
|
||||
std::generate(bias_data.begin(), bias_data.end(), f32rng);
|
||||
|
||||
flatbuffers::Offset<Buffer> buffers[3] = {
|
||||
CreateBuffer(builder, builder.CreateVector({})),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(float) * filter_data.size())),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(float) * bias_data.size())),
|
||||
};
|
||||
|
||||
const int32_t input_shape[4] = {BatchSize(), InputHeight(), InputWidth(),
|
||||
InputChannels()};
|
||||
const int32_t output_shape[4] = {BatchSize(), OutputHeight(), OutputWidth(),
|
||||
OutputChannels()};
|
||||
const int32_t filter_shape[4] = {1, KernelHeight(), KernelWidth(),
|
||||
OutputChannels()};
|
||||
const int32_t bias_shape[1] = {OutputChannels()};
|
||||
|
||||
flatbuffers::Offset<Tensor> tensors[4] = {
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(input_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/0,
|
||||
builder.CreateString("X")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(filter_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/1,
|
||||
builder.CreateString("W")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(bias_shape, 1),
|
||||
TensorType_FLOAT32, /*buffer=*/2,
|
||||
builder.CreateString("b")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(output_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/0,
|
||||
builder.CreateString("Y")),
|
||||
};
|
||||
|
||||
const int32_t op_inputs[3] = {0, 1, 2};
|
||||
const int32_t op_outputs[1] = {3};
|
||||
|
||||
flatbuffers::Offset<Operator> op = CreateOperator(
|
||||
builder, /*opcode_index=*/0,
|
||||
builder.CreateVector<int32_t>(op_inputs, 3),
|
||||
builder.CreateVector<int32_t>(op_outputs, 1),
|
||||
BuiltinOptions_DepthwiseConv2DOptions, depthwise_conv2d_options.Union(),
|
||||
/*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS);
|
||||
|
||||
int32_t subgraph_inputs[1] = {0};
|
||||
int32_t subgraph_outputs[1] = {3};
|
||||
flatbuffers::Offset<SubGraph> subgraph =
|
||||
CreateSubGraph(builder, builder.CreateVector(tensors, 4),
|
||||
builder.CreateVector<int32_t>(subgraph_inputs, 1),
|
||||
builder.CreateVector<int32_t>(subgraph_outputs, 1),
|
||||
builder.CreateVector(&op, 1), /*name=*/0);
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> description =
|
||||
builder.CreateString("DepthwiseConv2D model");
|
||||
|
||||
flatbuffers::Offset<Model> model_buffer = CreateModel(
|
||||
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
|
||||
builder.CreateVector(&subgraph, 1), description,
|
||||
builder.CreateVector(buffers, 3));
|
||||
|
||||
builder.Finish(model_buffer);
|
||||
|
||||
return std::vector<char>(builder.GetBufferPointer(),
|
||||
builder.GetBufferPointer() + builder.GetSize());
|
||||
}
|
||||
|
||||
int32_t batch_size_ = 1;
|
||||
int32_t groups_ = 1;
|
||||
int32_t depth_multiplier_ = 1;
|
||||
int32_t input_height_ = 1;
|
||||
int32_t input_width_ = 1;
|
||||
int32_t kernel_height_ = 1;
|
||||
int32_t kernel_width_ = 1;
|
||||
int32_t stride_height_ = 1;
|
||||
int32_t stride_width_ = 1;
|
||||
int32_t dilation_height_ = 1;
|
||||
int32_t dilation_width_ = 1;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(1)
|
||||
.KernelWidth(1)
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, 2x2) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
@ -319,15 +55,16 @@ TEST(DepthwiseConv2D, 2x2) {
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
|
||||
auto groups_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 32), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.Groups(groups_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(2)
|
||||
.KernelWidth(2)
|
||||
.SamePadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
@ -340,19 +77,20 @@ TEST(DepthwiseConv2D, 3x3) {
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
|
||||
auto groups_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 32), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.Groups(groups_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(3)
|
||||
.KernelWidth(3)
|
||||
.SamePadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, SmallKernel) {
|
||||
TEST(DepthwiseConv2D, 3x3Stride2) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
@ -360,41 +98,460 @@ TEST(DepthwiseConv2D, SmallKernel) {
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 7), std::ref(rng));
|
||||
auto groups_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 32), std::ref(rng));
|
||||
std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.Groups(groups_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(3)
|
||||
.KernelWidth(3)
|
||||
.StrideHeight(2)
|
||||
.StrideWidth(2)
|
||||
.SamePadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, Stride) {
|
||||
TEST(DepthwiseConv2D, 5x5) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(3)
|
||||
.KernelWidth(3)
|
||||
.SamePadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, 5x5Stride2) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(3)
|
||||
.KernelWidth(3)
|
||||
.StrideHeight(2)
|
||||
.StrideWidth(2)
|
||||
.SamePadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, SmallKernelWithSamePadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 7), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.SamePadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, SmallKernelWithValidPadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 7), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.ValidPadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, StrideWithSamePadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto groups_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 32), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.Groups(groups_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.SamePadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, StrideWithValidPadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.ValidPadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, DilationWithSamePadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto dilation_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.DilationHeight(dilation_rng())
|
||||
.DilationWidth(dilation_rng())
|
||||
.SamePadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, DilationWithValidPadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto dilation_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.DilationHeight(dilation_rng())
|
||||
.DilationWidth(dilation_rng())
|
||||
.ValidPadding()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, DepthMultiplier) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
auto multiplier_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 8), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.DepthMultiplier(multiplier_rng())
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, ReluActivation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.ReluActivation()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, Relu6Activation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.Relu6Activation()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, ReluMinus1To1Activation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.ReluMinus1To1Activation()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, DISABLED_TanhActivation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.TanhActivation()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, DISABLED_SignBitActivation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.SignBitActivation()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, MultiThreading) {
|
||||
TfLiteXNNPackDelegateOptions delegate_options =
|
||||
TfLiteXNNPackDelegateOptionsDefault();
|
||||
delegate_options.num_threads = 2;
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
@ -402,32 +559,5 @@ TEST(DepthwiseConv2D, Stride) {
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, Dilation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto dilation_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto group_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.Groups(group_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.DilationHeight(dilation_rng())
|
||||
.DilationWidth(dilation_rng())
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
||||
|
222
tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc
Normal file
222
tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc
Normal file
@ -0,0 +1,222 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
||||
void DepthwiseConv2DTester::Test(TfLiteDelegate* delegate) const {
|
||||
std::vector<char> buffer = CreateTfLiteModel();
|
||||
const Model* model = GetModel(buffer.data());
|
||||
|
||||
std::unique_ptr<Interpreter> delegate_interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())(
|
||||
&delegate_interpreter),
|
||||
kTfLiteOk);
|
||||
std::unique_ptr<Interpreter> default_interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())(
|
||||
&default_interpreter),
|
||||
kTfLiteOk);
|
||||
|
||||
ASSERT_TRUE(delegate_interpreter);
|
||||
ASSERT_TRUE(default_interpreter);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->inputs().size(), 1);
|
||||
ASSERT_EQ(default_interpreter->inputs().size(), 1);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->outputs().size(), 1);
|
||||
ASSERT_EQ(default_interpreter->outputs().size(), 1);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
ASSERT_EQ(default_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate), kTfLiteOk);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
|
||||
float* default_input_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->inputs()[0]);
|
||||
std::generate(default_input_data,
|
||||
default_input_data + BatchSize() * InputHeight() *
|
||||
InputWidth() * InputChannels(),
|
||||
input_rng);
|
||||
|
||||
float* delegate_input_data = delegate_interpreter->typed_tensor<float>(
|
||||
delegate_interpreter->inputs()[0]);
|
||||
std::copy(default_input_data,
|
||||
default_input_data +
|
||||
BatchSize() * InputHeight() * InputWidth() * InputChannels(),
|
||||
delegate_input_data);
|
||||
|
||||
ASSERT_EQ(default_interpreter->Invoke(), kTfLiteOk);
|
||||
ASSERT_EQ(delegate_interpreter->Invoke(), kTfLiteOk);
|
||||
|
||||
float* default_output_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->outputs()[0]);
|
||||
float* delegate_output_data = delegate_interpreter->typed_tensor<float>(
|
||||
delegate_interpreter->outputs()[0]);
|
||||
|
||||
for (int32_t i = 0; i < BatchSize(); i++) {
|
||||
for (int32_t y = 0; y < OutputHeight(); y++) {
|
||||
for (int32_t x = 0; x < OutputWidth(); x++) {
|
||||
for (int32_t c = 0; c < OutputChannels(); c++) {
|
||||
const int32_t index = ((i * OutputHeight() + y) * OutputWidth() + x) *
|
||||
OutputChannels() +
|
||||
c;
|
||||
ASSERT_NEAR(default_output_data[index], delegate_output_data[index],
|
||||
std::abs(default_output_data[index]) * 3.0e-6f)
|
||||
<< "batch " << i << " / " << BatchSize() << ", y position " << y
|
||||
<< " / " << OutputHeight() << ", x position " << x << " / "
|
||||
<< OutputWidth() << ", channel " << c << " / "
|
||||
<< OutputChannels();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<char> DepthwiseConv2DTester::CreateTfLiteModel() const {
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
flatbuffers::Offset<OperatorCode> operator_code =
|
||||
CreateOperatorCode(builder, BuiltinOperator_DEPTHWISE_CONV_2D);
|
||||
|
||||
flatbuffers::Offset<DepthwiseConv2DOptions> depthwise_conv2d_options =
|
||||
CreateDepthwiseConv2DOptions(
|
||||
builder, Padding(), StrideWidth(), StrideHeight(), DepthMultiplier(),
|
||||
Activation(), DilationWidth(), DilationHeight());
|
||||
|
||||
std::vector<float> filter_data(KernelHeight() * KernelWidth() *
|
||||
OutputChannels());
|
||||
std::vector<float> bias_data(OutputChannels());
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto range_rng = std::bind(
|
||||
std::uniform_real_distribution<float>(-25.0f, 25.0f), std::ref(rng));
|
||||
for (int32_t ic = 0; ic < InputChannels(); ic++) {
|
||||
// Use the same range of all-positive or all-negative values to generate
|
||||
// all pixels within the same batch index & channel, but different ranges
|
||||
// for different channels or batches. This ensures that no catastrophic
|
||||
// cancellation occur, but test covers both positive and negative inputs.
|
||||
const float range = range_rng();
|
||||
auto value_rng =
|
||||
std::bind(std::uniform_real_distribution<float>(std::min(range, 0.0f),
|
||||
std::max(range, 0.0f)),
|
||||
std::ref(rng));
|
||||
for (int32_t m = 0; m < DepthMultiplier(); m++) {
|
||||
const int32_t oc = ic * DepthMultiplier() + m;
|
||||
bias_data[oc] = value_rng();
|
||||
for (int32_t y = 0; y < KernelHeight(); y++) {
|
||||
for (int32_t x = 0; x < KernelWidth(); x++) {
|
||||
const int32_t index = (y * KernelWidth() + x) * OutputChannels() + oc;
|
||||
filter_data[index] = value_rng();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::array<flatbuffers::Offset<tflite::Buffer>, 3> buffers{{
|
||||
CreateBuffer(builder, builder.CreateVector({})),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(float) * filter_data.size())),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(float) * bias_data.size())),
|
||||
}};
|
||||
|
||||
const std::array<int32_t, 4> input_shape{
|
||||
{BatchSize(), InputHeight(), InputWidth(), InputChannels()}};
|
||||
const std::array<int32_t, 4> output_shape{
|
||||
{BatchSize(), OutputHeight(), OutputWidth(), OutputChannels()}};
|
||||
const std::array<int32_t, 4> filter_shape{
|
||||
{1, KernelHeight(), KernelWidth(), OutputChannels()}};
|
||||
const std::array<int32_t, 1> bias_shape{{OutputChannels()}};
|
||||
|
||||
const std::array<flatbuffers::Offset<tflite::Tensor>, 4> tensors{{
|
||||
CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(input_shape.data(), input_shape.size()),
|
||||
TensorType_FLOAT32),
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(filter_shape.data(),
|
||||
filter_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/1),
|
||||
CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/2),
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(output_shape.data(),
|
||||
output_shape.size()),
|
||||
TensorType_FLOAT32),
|
||||
}};
|
||||
|
||||
const std::array<int32_t, 3> op_inputs{{0, 1, 2}};
|
||||
const std::array<int32_t, 1> op_outputs{{3}};
|
||||
|
||||
flatbuffers::Offset<tflite::Operator> op = CreateOperator(
|
||||
builder, /*opcode_index=*/0,
|
||||
builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
|
||||
builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
|
||||
BuiltinOptions_DepthwiseConv2DOptions, depthwise_conv2d_options.Union());
|
||||
|
||||
const std::array<int32_t, 1> subgraph_inputs{{0}};
|
||||
const std::array<int32_t, 1> subgraph_outputs{{3}};
|
||||
flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
|
||||
builder, builder.CreateVector(tensors.data(), tensors.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_inputs.data(),
|
||||
subgraph_inputs.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_outputs.data(),
|
||||
subgraph_outputs.size()),
|
||||
builder.CreateVector(&op, 1));
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> description =
|
||||
builder.CreateString("DepthwiseConv2D model");
|
||||
|
||||
flatbuffers::Offset<Model> model_buffer = CreateModel(
|
||||
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
|
||||
builder.CreateVector(&subgraph, 1), description,
|
||||
builder.CreateVector(buffers.data(), buffers.size()));
|
||||
|
||||
builder.Finish(model_buffer);
|
||||
|
||||
return std::vector<char>(builder.GetBufferPointer(),
|
||||
builder.GetBufferPointer() + builder.GetSize());
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
226
tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h
Normal file
226
tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h
Normal file
@ -0,0 +1,226 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_DEPTHWISE_CONV_2D_TESTER_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_XNNPACK_DEPTHWISE_CONV_2D_TESTER_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
||||
class DepthwiseConv2DTester {
|
||||
public:
|
||||
DepthwiseConv2DTester() = default;
|
||||
DepthwiseConv2DTester(const DepthwiseConv2DTester&) = delete;
|
||||
DepthwiseConv2DTester& operator=(const DepthwiseConv2DTester&) = delete;
|
||||
|
||||
inline DepthwiseConv2DTester& BatchSize(int32_t batch_size) {
|
||||
EXPECT_GT(batch_size, 0);
|
||||
batch_size_ = batch_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t BatchSize() const { return batch_size_; }
|
||||
|
||||
inline DepthwiseConv2DTester& InputChannels(int32_t input_channels) {
|
||||
EXPECT_GT(input_channels, 0);
|
||||
input_channels_ = input_channels;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t InputChannels() const { return input_channels_; }
|
||||
|
||||
inline DepthwiseConv2DTester& DepthMultiplier(int32_t depth_multiplier) {
|
||||
EXPECT_GT(depth_multiplier, 0);
|
||||
depth_multiplier_ = depth_multiplier;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t DepthMultiplier() const { return depth_multiplier_; }
|
||||
|
||||
inline int32_t OutputChannels() const {
|
||||
return DepthMultiplier() * InputChannels();
|
||||
}
|
||||
|
||||
inline DepthwiseConv2DTester& InputHeight(int32_t input_height) {
|
||||
EXPECT_GT(input_height, 0);
|
||||
input_height_ = input_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t InputHeight() const { return input_height_; }
|
||||
|
||||
inline DepthwiseConv2DTester& InputWidth(int32_t input_width) {
|
||||
EXPECT_GT(input_width, 0);
|
||||
input_width_ = input_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t InputWidth() const { return input_width_; }
|
||||
|
||||
inline int32_t OutputWidth() const {
|
||||
if (Padding() == ::tflite::Padding_SAME) {
|
||||
EXPECT_GE(InputWidth(), 1);
|
||||
return (InputWidth() - 1) / StrideWidth() + 1;
|
||||
} else {
|
||||
EXPECT_GE(InputWidth(), DilatedKernelWidth());
|
||||
return 1 + (InputWidth() - DilatedKernelWidth()) / StrideWidth();
|
||||
}
|
||||
}
|
||||
|
||||
inline int32_t OutputHeight() const {
|
||||
if (Padding() == ::tflite::Padding_SAME) {
|
||||
EXPECT_GE(InputHeight(), 1);
|
||||
return (InputHeight() - 1) / StrideHeight() + 1;
|
||||
} else {
|
||||
EXPECT_GE(InputHeight(), DilatedKernelHeight());
|
||||
return 1 + (InputHeight() - DilatedKernelHeight()) / StrideHeight();
|
||||
}
|
||||
}
|
||||
|
||||
inline DepthwiseConv2DTester& KernelHeight(int32_t kernel_height) {
|
||||
EXPECT_GT(kernel_height, 0);
|
||||
kernel_height_ = kernel_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t KernelHeight() const { return kernel_height_; }
|
||||
|
||||
inline DepthwiseConv2DTester& KernelWidth(int32_t kernel_width) {
|
||||
EXPECT_GT(kernel_width, 0);
|
||||
kernel_width_ = kernel_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t KernelWidth() const { return kernel_width_; }
|
||||
|
||||
inline DepthwiseConv2DTester& StrideHeight(int32_t stride_height) {
|
||||
EXPECT_GT(stride_height, 0);
|
||||
stride_height_ = stride_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t StrideHeight() const { return stride_height_; }
|
||||
|
||||
inline DepthwiseConv2DTester& StrideWidth(int32_t stride_width) {
|
||||
EXPECT_GT(stride_width, 0);
|
||||
stride_width_ = stride_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t StrideWidth() const { return stride_width_; }
|
||||
|
||||
inline DepthwiseConv2DTester& DilationHeight(int32_t dilation_height) {
|
||||
EXPECT_GT(dilation_height, 0);
|
||||
dilation_height_ = dilation_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t DilationHeight() const { return dilation_height_; }
|
||||
|
||||
inline DepthwiseConv2DTester& DilationWidth(int32_t dilation_width) {
|
||||
EXPECT_GT(dilation_width, 0);
|
||||
dilation_width_ = dilation_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline int32_t DilationWidth() const { return dilation_width_; }
|
||||
|
||||
inline int32_t DilatedKernelHeight() const {
|
||||
return (KernelHeight() - 1) * DilationHeight() + 1;
|
||||
}
|
||||
|
||||
inline int32_t DilatedKernelWidth() const {
|
||||
return (KernelWidth() - 1) * DilationWidth() + 1;
|
||||
}
|
||||
|
||||
inline DepthwiseConv2DTester& SamePadding() {
|
||||
padding_ = ::tflite::Padding_SAME;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline DepthwiseConv2DTester& ValidPadding() {
|
||||
padding_ = ::tflite::Padding_VALID;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline DepthwiseConv2DTester& ReluActivation() {
|
||||
activation_ = ::tflite::ActivationFunctionType_RELU;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline DepthwiseConv2DTester& Relu6Activation() {
|
||||
activation_ = ::tflite::ActivationFunctionType_RELU6;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline DepthwiseConv2DTester& ReluMinus1To1Activation() {
|
||||
activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline DepthwiseConv2DTester& TanhActivation() {
|
||||
activation_ = ::tflite::ActivationFunctionType_TANH;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline DepthwiseConv2DTester& SignBitActivation() {
|
||||
activation_ = ::tflite::ActivationFunctionType_SIGN_BIT;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Test(TfLiteDelegate* delegate) const;
|
||||
|
||||
private:
|
||||
std::vector<char> CreateTfLiteModel() const;
|
||||
|
||||
inline ::tflite::Padding Padding() const { return padding_; }
|
||||
|
||||
inline ::tflite::ActivationFunctionType Activation() const {
|
||||
return activation_;
|
||||
}
|
||||
|
||||
int32_t batch_size_ = 1;
|
||||
int32_t input_channels_ = 1;
|
||||
int32_t depth_multiplier_ = 1;
|
||||
int32_t input_height_ = 1;
|
||||
int32_t input_width_ = 1;
|
||||
int32_t kernel_height_ = 1;
|
||||
int32_t kernel_width_ = 1;
|
||||
int32_t stride_height_ = 1;
|
||||
int32_t stride_width_ = 1;
|
||||
int32_t dilation_height_ = 1;
|
||||
int32_t dilation_width_ = 1;
|
||||
::tflite::Padding padding_ = ::tflite::Padding_VALID;
|
||||
::tflite::ActivationFunctionType activation_ =
|
||||
::tflite::ActivationFunctionType_NONE;
|
||||
};
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_DEPTHWISE_CONV_2D_TESTER_H_
|
@ -1307,10 +1307,18 @@ void* SubgraphInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
}
|
||||
|
||||
TfLiteStatus SubgraphPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (node->user_data == nullptr) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return static_cast<Subgraph*>(node->user_data)->Prepare(context);
|
||||
}
|
||||
|
||||
TfLiteStatus SubgraphInvoke(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (node->user_data == nullptr) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return static_cast<Subgraph*>(node->user_data)->Invoke(context);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user