Open-source XNNPACK delegate
PiperOrigin-RevId: 293253085 Change-Id: I361098e9b9cc064e2514134f6c6cd692e416320b
This commit is contained in:
parent
45d5427d1c
commit
34bec1ebd4
83
tensorflow/lite/delegates/xnnpack/BUILD
Normal file
83
tensorflow/lite/delegates/xnnpack/BUILD
Normal file
@ -0,0 +1,83 @@
|
||||
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite_combined")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
EMSCRIPTEN_LINKOPTS = [
|
||||
"-s ASSERTIONS=2",
|
||||
"-s ERROR_ON_UNDEFINED_SYMBOLS=1",
|
||||
"-s DEMANGLE_SUPPORT=1",
|
||||
"-s EXIT_RUNTIME=1",
|
||||
"-s ALLOW_MEMORY_GROWTH=1",
|
||||
"-s TOTAL_MEMORY=134217728",
|
||||
]
|
||||
|
||||
cc_library(
|
||||
name = "xnnpack_delegate",
|
||||
srcs = ["xnnpack_delegate.cc"],
|
||||
hdrs = ["xnnpack_delegate.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@XNNPACK",
|
||||
],
|
||||
)
|
||||
|
||||
############################## Integration tests ###############################
|
||||
|
||||
cc_library(
|
||||
name = "test_main",
|
||||
testonly = 1,
|
||||
linkopts = select({
|
||||
"//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS,
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "conv_2d_test",
|
||||
srcs = ["conv_2d_test.cc"],
|
||||
linkopts = select({
|
||||
"//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS,
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":test_main",
|
||||
":xnnpack_delegate",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "depthwise_conv_2d_test",
|
||||
srcs = ["depthwise_conv_2d_test.cc"],
|
||||
linkopts = select({
|
||||
"//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS,
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = ["nomsan"], # b/145129478
|
||||
deps = [
|
||||
":test_main",
|
||||
":xnnpack_delegate",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})
|
81
tensorflow/lite/delegates/xnnpack/README.md
Normal file
81
tensorflow/lite/delegates/xnnpack/README.md
Normal file
@ -0,0 +1,81 @@
|
||||
# XNNPACK backend for TensorFlow Lite
|
||||
|
||||
XNNPACK is a highly optimized library of floating-point neural network
|
||||
inference operators for ARM, WebAssembly, and x86 platforms. This document
|
||||
describes how to use the XNNPACK library as a backend for TensorFlow Lite.
|
||||
|
||||
## Enabling XNNPACK backend in TensorFlow Lite models
|
||||
|
||||
XNNPACK integrates with TensorFlow Lite interpreter through the delegation
|
||||
mechanism. To leverage XNNPACK library for acceleration, the users need to
|
||||
create an XNNPACK delegate with the `TfLiteXNNPackDelegateCreate` function,
|
||||
and call `Interpreter::ModifyGraphWithDelegate` to delegate supported parts of
|
||||
the model to the XNNPACK delegate. The users must destroy the delegate with
|
||||
`TfLiteXNNPackDelegateDelete` **after** releasing the TensorFlow Lite
|
||||
interpreter. The snippet below illustrates the typical usage:
|
||||
|
||||
|
||||
```c++
|
||||
// Build the interpreter
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
...
|
||||
|
||||
// IMPORTANT: initialize options with TfLiteXNNPackDelegateOptionsDefault() for
|
||||
// API-compatibility with future extensions of the TfLiteXNNPackDelegateOptions
|
||||
// structure.
|
||||
TfLiteXNNPackDelegateOptions xnnpack_options =
|
||||
TfLiteXNNPackDelegateOptionsDefault();
|
||||
xnnpack_options.num_threads = num_threads;
|
||||
|
||||
TfLiteDelegate* xnnpack_delegate =
|
||||
TfLiteXNNPackDelegateCreate(&xnnpack_options);
|
||||
if (interpreter->ModifyGraphWithDelegate(xnnpack_delegate) != kTfLiteOk) {
|
||||
// Report error and fall back to another delegate, or the default backend
|
||||
}
|
||||
|
||||
...
|
||||
|
||||
// Run inference using XNNPACK
|
||||
interpreter->Invoke()
|
||||
|
||||
...
|
||||
|
||||
// IMPORTANT: release the interpreter before destroing the delegate
|
||||
interpreter.reset();
|
||||
TfLiteXNNPackDelegateDelete(xnnpack_delegate);
|
||||
```
|
||||
|
||||
## Limitations and supported operators
|
||||
|
||||
XNNPACK delegate is a work-in-progress, and currently supports a limited set of
|
||||
operators. Unsupported operators will fall back to the default implementations,
|
||||
so models using a combination of supported and unsupported operators can still
|
||||
benefit from XNNPACK delegate.
|
||||
|
||||
Below is the list of current operators and limitations:
|
||||
|
||||
### `CONV_2D`
|
||||
|
||||
* Inputs and outputs must be in 32-bit floating-point format.
|
||||
* Bias is mandatory.
|
||||
* Both filter and bias must be static (use `kTfLiteMmapRo` allocation type).
|
||||
* Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported,
|
||||
but fused `TANH` and `SIGN_BIT` activations are not.
|
||||
* Dynamically allocated (with `kTfLiteDynamic` allocation type) input and output
|
||||
are not supported.
|
||||
|
||||
### `DEPTHWISE_CONV_2D`
|
||||
|
||||
* Inputs and outputs must be in 32-bit floating-point format.
|
||||
* Bias is mandatory.
|
||||
* Both filter and bias must be static (use `kTfLiteMmapRo` allocation type).
|
||||
* Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported,
|
||||
but fused `TANH` and `SIGN_BIT` activations are not.
|
||||
* Dynamically allocated (with `kTfLiteDynamic` allocation type) input and output
|
||||
are not supported.
|
||||
|
||||
### Other limitations
|
||||
|
||||
* Resizing model inputs (via `Interpreter::ResizeInputTensor`) is supported, but
|
||||
cause a complete reinitialization of the delegate instance, which has
|
||||
considerable overhead.
|
510
tensorflow/lite/delegates/xnnpack/conv_2d_test.cc
Normal file
510
tensorflow/lite/delegates/xnnpack/conv_2d_test.cc
Normal file
@ -0,0 +1,510 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
||||
namespace {
|
||||
|
||||
class Conv2DTester {
|
||||
public:
|
||||
Conv2DTester() = default;
|
||||
Conv2DTester(const Conv2DTester&) = delete;
|
||||
Conv2DTester& operator=(const Conv2DTester&) = delete;
|
||||
|
||||
Conv2DTester& BatchSize(int32_t batch_size) {
|
||||
EXPECT_GT(batch_size, 0);
|
||||
batch_size_ = batch_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t BatchSize() const { return batch_size_; }
|
||||
|
||||
Conv2DTester& InputChannels(int32_t input_channels) {
|
||||
EXPECT_GT(input_channels, 0);
|
||||
input_channels_ = input_channels;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t InputChannels() const { return input_channels_; }
|
||||
|
||||
Conv2DTester& OutputChannels(int32_t output_channels) {
|
||||
EXPECT_GT(output_channels, 0);
|
||||
output_channels_ = output_channels;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t OutputChannels() const { return output_channels_; }
|
||||
|
||||
Conv2DTester& InputHeight(int32_t input_height) {
|
||||
EXPECT_GT(input_height, 0);
|
||||
input_height_ = input_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t InputHeight() const { return input_height_; }
|
||||
|
||||
Conv2DTester& InputWidth(int32_t input_width) {
|
||||
EXPECT_GT(input_width, 0);
|
||||
input_width_ = input_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t InputWidth() const { return input_width_; }
|
||||
|
||||
int32_t OutputWidth() const {
|
||||
if (SamePadding()) {
|
||||
return (InputWidth() - 1) / StrideWidth() + 1;
|
||||
} else {
|
||||
return (InputWidth() - (KernelWidth() - 1) * DilationWidth() - 1) /
|
||||
StrideWidth() +
|
||||
1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t OutputHeight() const {
|
||||
if (SamePadding()) {
|
||||
return (InputHeight() - 1) / StrideHeight() + 1;
|
||||
} else {
|
||||
return (InputHeight() - (KernelHeight() - 1) * DilationHeight() - 1) /
|
||||
StrideHeight() +
|
||||
1;
|
||||
}
|
||||
}
|
||||
|
||||
Conv2DTester& KernelHeight(int32_t kernel_height) {
|
||||
EXPECT_GT(kernel_height, 0);
|
||||
kernel_height_ = kernel_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t KernelHeight() const { return kernel_height_; }
|
||||
|
||||
Conv2DTester& KernelWidth(int32_t kernel_width) {
|
||||
EXPECT_GT(kernel_width, 0);
|
||||
kernel_width_ = kernel_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t KernelWidth() const { return kernel_width_; }
|
||||
|
||||
Conv2DTester& StrideHeight(int32_t stride_height) {
|
||||
EXPECT_GT(stride_height, 0);
|
||||
stride_height_ = stride_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t StrideHeight() const { return stride_height_; }
|
||||
|
||||
Conv2DTester& StrideWidth(int32_t stride_width) {
|
||||
EXPECT_GT(stride_width, 0);
|
||||
stride_width_ = stride_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t StrideWidth() const { return stride_width_; }
|
||||
|
||||
Conv2DTester& DilationHeight(int32_t dilation_height) {
|
||||
EXPECT_GT(dilation_height, 0);
|
||||
dilation_height_ = dilation_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t DilationHeight() const { return dilation_height_; }
|
||||
|
||||
Conv2DTester& DilationWidth(int32_t dilation_width) {
|
||||
EXPECT_GT(dilation_width, 0);
|
||||
dilation_width_ = dilation_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t DilationWidth() const { return dilation_width_; }
|
||||
|
||||
Conv2DTester& SamePadding(bool same_padding) {
|
||||
same_padding_ = same_padding;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool SamePadding() const { return same_padding_; }
|
||||
|
||||
void Test(TfLiteDelegate* delegate) const {
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
|
||||
|
||||
std::vector<char> buffer = CreateTfLiteModel(std::ref(f32rng));
|
||||
const Model* model = GetModel(buffer.data());
|
||||
|
||||
std::unique_ptr<Interpreter> delegate_interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())(
|
||||
&delegate_interpreter),
|
||||
kTfLiteOk);
|
||||
std::unique_ptr<Interpreter> default_interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())(
|
||||
&default_interpreter),
|
||||
kTfLiteOk);
|
||||
|
||||
ASSERT_TRUE(delegate_interpreter);
|
||||
ASSERT_TRUE(default_interpreter);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->inputs().size(), 1);
|
||||
ASSERT_EQ(default_interpreter->inputs().size(), 1);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->outputs().size(), 1);
|
||||
ASSERT_EQ(default_interpreter->outputs().size(), 1);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
ASSERT_EQ(default_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate),
|
||||
kTfLiteOk);
|
||||
|
||||
float* default_input_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->inputs()[0]);
|
||||
std::generate(default_input_data,
|
||||
default_input_data + BatchSize() * InputHeight() *
|
||||
InputWidth() * InputChannels(),
|
||||
std::ref(f32rng));
|
||||
|
||||
float* xnnpack_input_data = delegate_interpreter->typed_tensor<float>(
|
||||
delegate_interpreter->inputs()[0]);
|
||||
std::copy(default_input_data,
|
||||
default_input_data +
|
||||
BatchSize() * InputHeight() * InputWidth() * InputChannels(),
|
||||
xnnpack_input_data);
|
||||
|
||||
default_interpreter->Invoke();
|
||||
delegate_interpreter->Invoke();
|
||||
|
||||
float* default_output_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->outputs()[0]);
|
||||
float* xnnpack_output_data = delegate_interpreter->typed_tensor<float>(
|
||||
delegate_interpreter->outputs()[0]);
|
||||
|
||||
for (size_t i = 0;
|
||||
i < BatchSize() * OutputHeight() * OutputWidth() * OutputChannels();
|
||||
i++) {
|
||||
ASSERT_NEAR(default_output_data[i], xnnpack_output_data[i],
|
||||
std::numeric_limits<float>::epsilon() *
|
||||
std::max(std::abs(default_output_data[i]) * 25.0f, 1.0f));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<char> CreateTfLiteModel(std::function<float()> f32rng) const {
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
flatbuffers::Offset<OperatorCode> operator_code =
|
||||
CreateOperatorCode(builder, BuiltinOperator_CONV_2D, 0);
|
||||
|
||||
flatbuffers::Offset<Conv2DOptions> conv2d_options = CreateConv2DOptions(
|
||||
builder, SamePadding() ? tflite::Padding_SAME : tflite::Padding_VALID,
|
||||
StrideWidth(), StrideHeight(), ActivationFunctionType_NONE,
|
||||
DilationWidth(), DilationHeight());
|
||||
|
||||
std::vector<float> filter_data(OutputChannels() * KernelHeight() *
|
||||
KernelWidth() * InputChannels());
|
||||
std::vector<float> bias_data(OutputChannels());
|
||||
|
||||
std::generate(filter_data.begin(), filter_data.end(), f32rng);
|
||||
std::generate(bias_data.begin(), bias_data.end(), f32rng);
|
||||
|
||||
flatbuffers::Offset<Buffer> buffers[3] = {
|
||||
CreateBuffer(builder, builder.CreateVector({})),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(float) * filter_data.size())),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(float) * bias_data.size())),
|
||||
};
|
||||
|
||||
const int32_t input_shape[4] = {BatchSize(), InputHeight(), InputWidth(),
|
||||
InputChannels()};
|
||||
const int32_t output_shape[4] = {BatchSize(), OutputHeight(), OutputWidth(),
|
||||
OutputChannels()};
|
||||
const int32_t filter_shape[4] = {OutputChannels(), KernelHeight(),
|
||||
KernelWidth(), InputChannels()};
|
||||
const int32_t bias_shape[1] = {OutputChannels()};
|
||||
|
||||
flatbuffers::Offset<Tensor> tensors[4] = {
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(input_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/0,
|
||||
builder.CreateString("X")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(filter_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/1,
|
||||
builder.CreateString("W")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(bias_shape, 1),
|
||||
TensorType_FLOAT32, /*buffer=*/2,
|
||||
builder.CreateString("b")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(output_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/0,
|
||||
builder.CreateString("Y")),
|
||||
};
|
||||
|
||||
const int32_t op_inputs[3] = {0, 1, 2};
|
||||
const int32_t op_outputs[1] = {3};
|
||||
|
||||
flatbuffers::Offset<Operator> op =
|
||||
CreateOperator(builder, /*opcode_index=*/0,
|
||||
builder.CreateVector<int32_t>(op_inputs, 3),
|
||||
builder.CreateVector<int32_t>(op_outputs, 1),
|
||||
BuiltinOptions_Conv2DOptions, conv2d_options.Union());
|
||||
|
||||
int32_t subgraph_inputs[1] = {0};
|
||||
int32_t subgraph_outputs[1] = {3};
|
||||
flatbuffers::Offset<SubGraph> subgraph =
|
||||
CreateSubGraph(builder, builder.CreateVector(tensors, 4),
|
||||
builder.CreateVector<int32_t>(subgraph_inputs, 1),
|
||||
builder.CreateVector<int32_t>(subgraph_outputs, 1),
|
||||
builder.CreateVector(&op, 1), /*name=*/0);
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> description =
|
||||
builder.CreateString("Conv2D model");
|
||||
|
||||
flatbuffers::Offset<Model> model_buffer = CreateModel(
|
||||
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
|
||||
builder.CreateVector(&subgraph, 1), description,
|
||||
builder.CreateVector(buffers, 3));
|
||||
|
||||
builder.Finish(model_buffer);
|
||||
|
||||
return std::vector<char>(builder.GetBufferPointer(),
|
||||
builder.GetBufferPointer() + builder.GetSize());
|
||||
}
|
||||
|
||||
int32_t batch_size_ = 1;
|
||||
int32_t input_channels_ = 1;
|
||||
int32_t output_channels_ = 1;
|
||||
int32_t input_height_ = 1;
|
||||
int32_t input_width_ = 1;
|
||||
int32_t kernel_height_ = 1;
|
||||
int32_t kernel_width_ = 1;
|
||||
int32_t stride_height_ = 1;
|
||||
int32_t stride_width_ = 1;
|
||||
int32_t dilation_height_ = 1;
|
||||
int32_t dilation_width_ = 1;
|
||||
bool same_padding_ = true;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(Conv2D, Pointwise) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
|
||||
|
||||
Conv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.OutputChannels(channel_rng())
|
||||
.KernelHeight(1)
|
||||
.KernelWidth(1)
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Conv2D, SmallKernelWithSamePadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 7), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
|
||||
|
||||
Conv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.OutputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.SamePadding(true)
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Conv2D, SmallKernelWithValidPadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 7), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
|
||||
|
||||
Conv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.OutputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.SamePadding(false)
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Conv2D, StrideWithSamePadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
|
||||
|
||||
Conv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.OutputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.SamePadding(true)
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Conv2D, StrideWithValidPadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
|
||||
|
||||
Conv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.OutputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.SamePadding(false)
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Conv2D, DilationWithSamePadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto dilation_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
|
||||
|
||||
Conv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.OutputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.DilationHeight(dilation_rng())
|
||||
.DilationWidth(dilation_rng())
|
||||
.SamePadding(true)
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Conv2D, DilationWithValidPadding) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto dilation_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
|
||||
|
||||
Conv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.OutputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.DilationHeight(dilation_rng())
|
||||
.DilationWidth(dilation_rng())
|
||||
.SamePadding(false)
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
433
tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc
Normal file
433
tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc
Normal file
@ -0,0 +1,433 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
||||
namespace {
|
||||
|
||||
class DepthwiseConv2DTester {
|
||||
public:
|
||||
DepthwiseConv2DTester() = default;
|
||||
DepthwiseConv2DTester(const DepthwiseConv2DTester&) = delete;
|
||||
DepthwiseConv2DTester& operator=(const DepthwiseConv2DTester&) = delete;
|
||||
|
||||
DepthwiseConv2DTester& BatchSize(int32_t batch_size) {
|
||||
EXPECT_GT(batch_size, 0);
|
||||
batch_size_ = batch_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t BatchSize() const { return batch_size_; }
|
||||
|
||||
DepthwiseConv2DTester& Groups(int32_t groups) {
|
||||
EXPECT_GT(groups, 0);
|
||||
groups_ = groups;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t Groups() const { return groups_; }
|
||||
|
||||
DepthwiseConv2DTester& DepthMultiplier(int32_t depth_multiplier) {
|
||||
EXPECT_GT(depth_multiplier, 0);
|
||||
depth_multiplier_ = depth_multiplier;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t DepthMultiplier() const { return depth_multiplier_; }
|
||||
|
||||
int32_t InputChannels() const { return Groups(); }
|
||||
|
||||
int32_t OutputChannels() const { return DepthMultiplier() * Groups(); }
|
||||
|
||||
DepthwiseConv2DTester& InputHeight(int32_t input_height) {
|
||||
EXPECT_GT(input_height, 0);
|
||||
input_height_ = input_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t InputHeight() const { return input_height_; }
|
||||
|
||||
DepthwiseConv2DTester& InputWidth(int32_t input_width) {
|
||||
EXPECT_GT(input_width, 0);
|
||||
input_width_ = input_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t InputWidth() const { return input_width_; }
|
||||
|
||||
int32_t OutputWidth() const {
|
||||
const int32_t output_width = (InputWidth() - 1) / StrideWidth() + 1;
|
||||
EXPECT_GT(output_width, 0);
|
||||
return output_width;
|
||||
}
|
||||
|
||||
int32_t OutputHeight() const {
|
||||
const int32_t output_height = (InputHeight() - 1) / StrideHeight() + 1;
|
||||
EXPECT_GT(output_height, 0);
|
||||
return output_height;
|
||||
}
|
||||
|
||||
DepthwiseConv2DTester& KernelHeight(int32_t kernel_height) {
|
||||
EXPECT_GT(kernel_height, 0);
|
||||
kernel_height_ = kernel_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t KernelHeight() const { return kernel_height_; }
|
||||
|
||||
DepthwiseConv2DTester& KernelWidth(int32_t kernel_width) {
|
||||
EXPECT_GT(kernel_width, 0);
|
||||
kernel_width_ = kernel_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t KernelWidth() const { return kernel_width_; }
|
||||
|
||||
DepthwiseConv2DTester& StrideHeight(int32_t stride_height) {
|
||||
EXPECT_GT(stride_height, 0);
|
||||
stride_height_ = stride_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t StrideHeight() const { return stride_height_; }
|
||||
|
||||
DepthwiseConv2DTester& StrideWidth(int32_t stride_width) {
|
||||
EXPECT_GT(stride_width, 0);
|
||||
stride_width_ = stride_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t StrideWidth() const { return stride_width_; }
|
||||
|
||||
DepthwiseConv2DTester& DilationHeight(int32_t dilation_height) {
|
||||
EXPECT_GT(dilation_height, 0);
|
||||
dilation_height_ = dilation_height;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t DilationHeight() const { return dilation_height_; }
|
||||
|
||||
DepthwiseConv2DTester& DilationWidth(int32_t dilation_width) {
|
||||
EXPECT_GT(dilation_width, 0);
|
||||
dilation_width_ = dilation_width;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int32_t DilationWidth() const { return dilation_width_; }
|
||||
|
||||
void Test(TfLiteDelegate* delegate) const {
|
||||
ASSERT_EQ(DepthMultiplier(), 1) << "Flow does not support depth multiplier";
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
|
||||
|
||||
std::vector<char> buffer = CreateTfLiteModel(std::ref(f32rng));
|
||||
const Model* model = GetModel(buffer.data());
|
||||
|
||||
std::unique_ptr<Interpreter> delegate_interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())(
|
||||
&delegate_interpreter),
|
||||
kTfLiteOk);
|
||||
std::unique_ptr<Interpreter> default_interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())(
|
||||
&default_interpreter),
|
||||
kTfLiteOk);
|
||||
|
||||
ASSERT_TRUE(delegate_interpreter);
|
||||
ASSERT_TRUE(default_interpreter);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->inputs().size(), 1);
|
||||
ASSERT_EQ(default_interpreter->inputs().size(), 1);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->outputs().size(), 1);
|
||||
ASSERT_EQ(default_interpreter->outputs().size(), 1);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
ASSERT_EQ(default_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate),
|
||||
kTfLiteOk);
|
||||
|
||||
float* default_input_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->inputs()[0]);
|
||||
std::generate(default_input_data,
|
||||
default_input_data + BatchSize() * InputChannels() *
|
||||
InputHeight() * InputWidth(),
|
||||
std::ref(f32rng));
|
||||
|
||||
float* xnnpack_input_data = delegate_interpreter->typed_tensor<float>(
|
||||
delegate_interpreter->inputs()[0]);
|
||||
std::copy(default_input_data,
|
||||
default_input_data +
|
||||
BatchSize() * InputChannels() * InputHeight() * InputWidth(),
|
||||
xnnpack_input_data);
|
||||
|
||||
default_interpreter->Invoke();
|
||||
delegate_interpreter->Invoke();
|
||||
|
||||
float* default_output_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->outputs()[0]);
|
||||
float* xnnpack_output_data = delegate_interpreter->typed_tensor<float>(
|
||||
delegate_interpreter->outputs()[0]);
|
||||
|
||||
for (size_t i = 0;
|
||||
i < BatchSize() * OutputChannels() * OutputHeight() * OutputWidth();
|
||||
i++) {
|
||||
ASSERT_NEAR(default_output_data[i], xnnpack_output_data[i],
|
||||
std::numeric_limits<float>::epsilon() *
|
||||
std::max(std::abs(default_output_data[i]) * 10.0f, 1.0f));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<char> CreateTfLiteModel(std::function<float()> f32rng) const {
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
flatbuffers::Offset<OperatorCode> operator_code =
|
||||
CreateOperatorCode(builder, BuiltinOperator_DEPTHWISE_CONV_2D, 0);
|
||||
|
||||
flatbuffers::Offset<DepthwiseConv2DOptions> depthwise_conv2d_options =
|
||||
CreateDepthwiseConv2DOptions(builder, Padding_SAME, StrideWidth(),
|
||||
StrideHeight(), DepthMultiplier(),
|
||||
ActivationFunctionType_NONE,
|
||||
DilationWidth(), DilationHeight());
|
||||
|
||||
std::vector<float> filter_data(KernelHeight() * KernelWidth() *
|
||||
OutputChannels());
|
||||
std::vector<float> bias_data(OutputChannels());
|
||||
|
||||
std::generate(filter_data.begin(), filter_data.end(), f32rng);
|
||||
std::generate(bias_data.begin(), bias_data.end(), f32rng);
|
||||
|
||||
flatbuffers::Offset<Buffer> buffers[3] = {
|
||||
CreateBuffer(builder, builder.CreateVector({})),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(float) * filter_data.size())),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(float) * bias_data.size())),
|
||||
};
|
||||
|
||||
const int32_t input_shape[4] = {BatchSize(), InputHeight(), InputWidth(),
|
||||
InputChannels()};
|
||||
const int32_t output_shape[4] = {BatchSize(), OutputHeight(), OutputWidth(),
|
||||
OutputChannels()};
|
||||
const int32_t filter_shape[4] = {1, KernelHeight(), KernelWidth(),
|
||||
OutputChannels()};
|
||||
const int32_t bias_shape[1] = {OutputChannels()};
|
||||
|
||||
flatbuffers::Offset<Tensor> tensors[4] = {
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(input_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/0,
|
||||
builder.CreateString("X")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(filter_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/1,
|
||||
builder.CreateString("W")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(bias_shape, 1),
|
||||
TensorType_FLOAT32, /*buffer=*/2,
|
||||
builder.CreateString("b")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(output_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/0,
|
||||
builder.CreateString("Y")),
|
||||
};
|
||||
|
||||
const int32_t op_inputs[3] = {0, 1, 2};
|
||||
const int32_t op_outputs[1] = {3};
|
||||
|
||||
flatbuffers::Offset<Operator> op = CreateOperator(
|
||||
builder, /*opcode_index=*/0,
|
||||
builder.CreateVector<int32_t>(op_inputs, 3),
|
||||
builder.CreateVector<int32_t>(op_outputs, 1),
|
||||
BuiltinOptions_DepthwiseConv2DOptions, depthwise_conv2d_options.Union(),
|
||||
/*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS);
|
||||
|
||||
int32_t subgraph_inputs[1] = {0};
|
||||
int32_t subgraph_outputs[1] = {3};
|
||||
flatbuffers::Offset<SubGraph> subgraph =
|
||||
CreateSubGraph(builder, builder.CreateVector(tensors, 4),
|
||||
builder.CreateVector<int32_t>(subgraph_inputs, 1),
|
||||
builder.CreateVector<int32_t>(subgraph_outputs, 1),
|
||||
builder.CreateVector(&op, 1), /*name=*/0);
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> description =
|
||||
builder.CreateString("DepthwiseConv2D model");
|
||||
|
||||
flatbuffers::Offset<Model> model_buffer = CreateModel(
|
||||
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
|
||||
builder.CreateVector(&subgraph, 1), description,
|
||||
builder.CreateVector(buffers, 3));
|
||||
|
||||
builder.Finish(model_buffer);
|
||||
|
||||
return std::vector<char>(builder.GetBufferPointer(),
|
||||
builder.GetBufferPointer() + builder.GetSize());
|
||||
}
|
||||
|
||||
int32_t batch_size_ = 1;
|
||||
int32_t groups_ = 1;
|
||||
int32_t depth_multiplier_ = 1;
|
||||
int32_t input_height_ = 1;
|
||||
int32_t input_width_ = 1;
|
||||
int32_t kernel_height_ = 1;
|
||||
int32_t kernel_width_ = 1;
|
||||
int32_t stride_height_ = 1;
|
||||
int32_t stride_width_ = 1;
|
||||
int32_t dilation_height_ = 1;
|
||||
int32_t dilation_width_ = 1;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(DepthwiseConv2D, 2x2) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
|
||||
auto groups_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.Groups(groups_rng())
|
||||
.KernelHeight(2)
|
||||
.KernelWidth(2)
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, 3x3) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
|
||||
auto groups_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.Groups(groups_rng())
|
||||
.KernelHeight(3)
|
||||
.KernelWidth(3)
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, SmallKernel) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 7), std::ref(rng));
|
||||
auto groups_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.Groups(groups_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, Stride) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto groups_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.Groups(groups_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, Dilation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto dilation_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto group_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.Groups(group_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.DilationHeight(dilation_rng())
|
||||
.DilationWidth(dilation_rng())
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
797
tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc
Normal file
797
tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc
Normal file
@ -0,0 +1,797 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <xnnpack.h>
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
namespace {
|
||||
|
||||
// Forward declaration.
|
||||
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
|
||||
|
||||
class Delegate {
|
||||
public:
|
||||
explicit Delegate(const TfLiteXNNPackDelegateOptions* options) {
|
||||
if (options) {
|
||||
options_ = *options;
|
||||
} else {
|
||||
// default: don't use thread pool.
|
||||
options_.num_threads = 0;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteDelegate* tflite_delegate() { return &delegate_; }
|
||||
|
||||
private:
|
||||
TfLiteDelegate delegate_ = {
|
||||
reinterpret_cast<void*>(this), // .data_
|
||||
DelegatePrepare, // .Prepare
|
||||
nullptr, // .CopyFromBufferHandle
|
||||
nullptr, // .CopyToBufferHandle
|
||||
nullptr, // .FreeBufferHandle
|
||||
kTfLiteDelegateFlagsNone, // .flags
|
||||
};
|
||||
|
||||
TfLiteXNNPackDelegateOptions options_;
|
||||
};
|
||||
|
||||
class Subgraph {
|
||||
public:
|
||||
static Subgraph* Create(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* params) {
|
||||
// Convert subgraph inputs and outputs to hash sets for faster lookup.
|
||||
const std::unordered_set<int> inputs(
|
||||
¶ms->input_tensors->data[0],
|
||||
¶ms->input_tensors->data[params->input_tensors->size]);
|
||||
const std::unordered_set<int> outputs(
|
||||
¶ms->output_tensors->data[0],
|
||||
¶ms->output_tensors->data[params->output_tensors->size]);
|
||||
std::unordered_set<int> externals(outputs);
|
||||
|
||||
TfLiteIntArray* execution_plan;
|
||||
if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
xnn_subgraph_t subgraph_ptr = nullptr;
|
||||
xnn_status status = xnn_create_subgraph(
|
||||
/*external_value_ids=*/context->tensors_size, /*flags=*/0,
|
||||
&subgraph_ptr);
|
||||
if (status != xnn_status_success) {
|
||||
context->ReportError(context, "failed to create XNNPACK subgraph");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Smart pointer to automatically release subgraph on exit.
|
||||
std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph(
|
||||
subgraph_ptr, &xnn_delete_subgraph);
|
||||
|
||||
// Detect which tensors are used as inputs or outputs of any subgraph nodes.
|
||||
// -1 denotes tensor not used in the subgraph. These indexes will be
|
||||
// filtered out and removed later.
|
||||
std::vector<int> tensors(context->tensors_size, -1);
|
||||
for (int i = 0; i < params->nodes_to_replace->size; i++) {
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
if (context->GetNodeAndRegistration(context,
|
||||
params->nodes_to_replace->data[i],
|
||||
&node, ®istration) != kTfLiteOk) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (int k = 0; k < node->inputs->size; k++) {
|
||||
const int t = node->inputs->data[k];
|
||||
tensors[t] = t;
|
||||
}
|
||||
for (int k = 0; k < node->outputs->size; k++) {
|
||||
const int t = node->outputs->data[k];
|
||||
tensors[t] = t;
|
||||
}
|
||||
}
|
||||
// Filter out and remove -1 (unused) indexes.
|
||||
tensors.erase(std::remove_if(tensors.begin(), tensors.end(),
|
||||
[](int i) { return i < 0; }),
|
||||
tensors.end());
|
||||
std::sort(tensors.begin(), tensors.end());
|
||||
|
||||
// XNNPACK Value IDs for TFLite tensors
|
||||
std::vector<uint32_t> xnnpack_tensors(tensors.back() + 1);
|
||||
for (int t : tensors) {
|
||||
if (context->tensors[t].type != kTfLiteFloat32) {
|
||||
context->ReportError(
|
||||
context,
|
||||
"unsupported datatype (%s) of tensor %d in XNNPACK delegate",
|
||||
TfLiteTypeGetName(context->tensors[t].type), t);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
uint32_t flags = 0;
|
||||
const void* data = nullptr;
|
||||
if (context->tensors[t].allocation_type == kTfLiteMmapRo) {
|
||||
data = context->tensors[t].data.raw_const;
|
||||
}
|
||||
if (inputs.count(t) != 0) {
|
||||
flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT;
|
||||
if (data == nullptr) {
|
||||
externals.insert(t);
|
||||
}
|
||||
}
|
||||
if (outputs.count(t) != 0) {
|
||||
flags |= XNN_VALUE_FLAG_EXTERNAL_OUTPUT;
|
||||
}
|
||||
|
||||
std::vector<size_t> dims(
|
||||
&context->tensors[t].dims->data[0],
|
||||
&context->tensors[t].dims->data[context->tensors[t].dims->size]);
|
||||
|
||||
const xnn_status status = xnn_define_tensor_value(
|
||||
subgraph.get(), xnn_datatype_fp32, dims.size(), dims.data(), data,
|
||||
static_cast<uint32_t>(t), flags, &xnnpack_tensors[t]);
|
||||
if (status != xnn_status_success) {
|
||||
context->ReportError(context,
|
||||
"failed to create XNNPACK Value for tensor %d", t);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Create XNNPACK nodes for TFLite delegate nodes
|
||||
for (int i = 0; i < params->nodes_to_replace->size; i++) {
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
if (context->GetNodeAndRegistration(context,
|
||||
params->nodes_to_replace->data[i],
|
||||
&node, ®istration) != kTfLiteOk) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (VisitNode(subgraph.get(), context, registration, node, i,
|
||||
xnnpack_tensors) != kTfLiteOk) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
xnn_runtime_t runtime_ptr = nullptr;
|
||||
status = xnn_create_runtime(subgraph.get(), &runtime_ptr);
|
||||
if (status != xnn_status_success) {
|
||||
context->ReportError(context, "failed to create XNNPACK runtime");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new Subgraph(runtime_ptr, std::move(externals));
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context) { return kTfLiteOk; }
|
||||
|
||||
TfLiteStatus Invoke(TfLiteContext* context) {
|
||||
if (first_run_) {
|
||||
std::vector<xnn_external_value> external_values;
|
||||
for (int t : externals_) {
|
||||
xnn_external_value value = {0};
|
||||
value.id = static_cast<uint32_t>(t);
|
||||
value.data = context->tensors[t].data.raw;
|
||||
external_values.push_back(value);
|
||||
}
|
||||
|
||||
const xnn_status status = xnn_setup_runtime(
|
||||
runtime_.get(), external_values.size(), external_values.data());
|
||||
if (status != xnn_status_success) {
|
||||
context->ReportError(context, "failed to setup XNNPACK runtime");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
first_run_ = false;
|
||||
}
|
||||
|
||||
const xnn_status status = xnn_invoke_runtime(runtime_.get());
|
||||
if (status != xnn_status_success) {
|
||||
context->ReportError(context, "failed to invoke XNNPACK runtime");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus CalculatePadding(TfLiteContext* context,
|
||||
TfLitePadding padding, uint32_t* flags,
|
||||
int node_index) {
|
||||
switch (padding) {
|
||||
case kTfLitePaddingSame: {
|
||||
*flags = XNN_FLAG_TENSORFLOW_SAME_PADDING;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case kTfLitePaddingValid:
|
||||
*flags = 0;
|
||||
return kTfLiteOk;
|
||||
default:
|
||||
if (context) {
|
||||
context->ReportError(context, "invalid padding mode (%d) in node #%d",
|
||||
static_cast<int>(padding), node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
static TfLiteStatus ConvertActivationToOutputRange(
|
||||
TfLiteContext* context, int node_index, TfLiteFusedActivation activation,
|
||||
float* output_min, float* output_max) {
|
||||
switch (activation) {
|
||||
case kTfLiteActNone:
|
||||
*output_min = -std::numeric_limits<float>::infinity();
|
||||
*output_max = +std::numeric_limits<float>::infinity();
|
||||
return kTfLiteOk;
|
||||
case kTfLiteActRelu:
|
||||
*output_min = 0.0f;
|
||||
*output_max = +std::numeric_limits<float>::infinity();
|
||||
return kTfLiteOk;
|
||||
case kTfLiteActRelu1:
|
||||
*output_min = -1.0f;
|
||||
*output_max = +1.0f;
|
||||
return kTfLiteOk;
|
||||
case kTfLiteActRelu6:
|
||||
*output_min = 0.0f;
|
||||
*output_max = 6.0f;
|
||||
return kTfLiteOk;
|
||||
case kTfLiteActTanh:
|
||||
if (context) {
|
||||
context->ReportError(
|
||||
context, "unsupported fused activation (Tanh) in node #%d",
|
||||
node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
case kTfLiteActSignBit:
|
||||
if (context) {
|
||||
context->ReportError(
|
||||
context, "unsupported fused activation (Sign) in node #%d",
|
||||
node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
case kTfLiteActSigmoid:
|
||||
if (context) {
|
||||
context->ReportError(
|
||||
context, "unsupported fused activation (Sigmoid) in node #%d",
|
||||
node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
default:
|
||||
if (context) {
|
||||
context->ReportError(context,
|
||||
"invalid fused activation (%d) in node #%d",
|
||||
static_cast<int>(activation), node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
static TfLiteStatus CheckConvolutionParams(TfLiteContext* context,
|
||||
const TfLiteConvParams* params,
|
||||
int node_index) {
|
||||
if (params->stride_width <= 0) {
|
||||
if (context) {
|
||||
context->ReportError(context, "invalid stride width %d in node #%d",
|
||||
params->stride_width, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (params->stride_height <= 0) {
|
||||
if (context) {
|
||||
context->ReportError(context, "invalid stride height %d in node #%d",
|
||||
params->stride_height, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (params->dilation_width_factor <= 0) {
|
||||
if (context) {
|
||||
context->ReportError(context,
|
||||
"invalid dilation width factor %d in node #%d",
|
||||
params->dilation_width_factor, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (params->dilation_height_factor <= 0) {
|
||||
if (context) {
|
||||
context->ReportError(context,
|
||||
"invalid dilation height factor %d in node #%d",
|
||||
params->dilation_height_factor, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus CheckDepthwiseConvolutionParams(
|
||||
TfLiteContext* context, const TfLiteDepthwiseConvParams* params,
|
||||
int output_channels, int node_index) {
|
||||
if (params->stride_width <= 0) {
|
||||
if (context) {
|
||||
context->ReportError(context, "invalid stride width %d in node #%d",
|
||||
params->stride_width, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (params->stride_height <= 0) {
|
||||
if (context) {
|
||||
context->ReportError(context, "invalid stride height %d in node #%d",
|
||||
params->stride_height, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (params->depth_multiplier <= 0) {
|
||||
if (context) {
|
||||
context->ReportError(context, "invalid depth multiplier %d in node #%d",
|
||||
params->depth_multiplier, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (output_channels % params->depth_multiplier != 0) {
|
||||
if (context) {
|
||||
context->ReportError(context,
|
||||
"depth multiplier %d is incompatible with "
|
||||
"number of output channels %d in node #%d",
|
||||
params->depth_multiplier, output_channels,
|
||||
node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (params->dilation_width_factor <= 0) {
|
||||
if (context) {
|
||||
context->ReportError(context,
|
||||
"invalid dilation width factor %d in node #%d",
|
||||
params->dilation_width_factor, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (params->dilation_height_factor <= 0) {
|
||||
if (context) {
|
||||
context->ReportError(context,
|
||||
"invalid dilation height factor %d in node #%d",
|
||||
params->dilation_height_factor, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus CheckNumInputsAndOutputs(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
int expected_num_inputs,
|
||||
int expected_num_outputs,
|
||||
int node_index) {
|
||||
if (node->inputs->size != expected_num_inputs) {
|
||||
if (context) {
|
||||
context->ReportError(
|
||||
context, "unexpected number of inputs (%d != %d) in node #%d",
|
||||
node->inputs->size, expected_num_inputs, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (node->outputs->size != expected_num_outputs) {
|
||||
if (context) {
|
||||
context->ReportError(
|
||||
context, "unexpected number of output (%d != %d) in node #%d",
|
||||
node->outputs->size, expected_num_outputs, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus CheckTensorFloatType(TfLiteContext* context,
|
||||
const TfLiteTensor& tensor,
|
||||
int tensor_index, int node_index) {
|
||||
if (tensor.type != kTfLiteFloat32) {
|
||||
if (context) {
|
||||
context->ReportError(
|
||||
context, "unsupported type %s in tensor #%d in node #%d",
|
||||
TfLiteTypeGetName(tensor.type), tensor_index, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus CheckTensorShape(TfLiteContext* context,
|
||||
const TfLiteTensor& tensor,
|
||||
int expected_num_dims,
|
||||
int tensor_index) {
|
||||
if (tensor.dims->size != expected_num_dims) {
|
||||
if (context) {
|
||||
context->ReportError(
|
||||
context,
|
||||
"unexpected number of shape dimensions (%d != %d) in tensor #%d",
|
||||
tensor.dims->size, expected_num_dims, tensor_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
for (int i = 0; i < tensor.dims->size; i++) {
|
||||
if (tensor.dims->data[i] <= 0) {
|
||||
context->ReportError(context,
|
||||
"invalid dimension #%d (%d) in tensor #%d", i,
|
||||
tensor.dims->data[i], tensor_index);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus CheckTensorStaticAllocation(TfLiteContext* context,
|
||||
const TfLiteTensor& tensor,
|
||||
int tensor_index,
|
||||
int node_index) {
|
||||
if (tensor.allocation_type != kTfLiteMmapRo ||
|
||||
tensor.data.raw_const == nullptr) {
|
||||
if (context) {
|
||||
context->ReportError(
|
||||
context,
|
||||
"invalid allocation type in tensor #%d in node #%d: "
|
||||
"expected static read-only tensor",
|
||||
tensor_index, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus VisitNode(xnn_subgraph_t subgraph, TfLiteContext* context,
|
||||
TfLiteRegistration* registration,
|
||||
TfLiteNode* node, int node_index,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
// TFLite context used for logging purposes. When we create a new node
|
||||
// (subgraph is non-null), logging context is the same as context, and error
|
||||
// messages are passed to TFLite. When we detect supported operations
|
||||
// (subgraph is null), logging context is null, and error messages are
|
||||
// supressed.
|
||||
TfLiteContext* logging_context = subgraph == nullptr ? nullptr : context;
|
||||
switch (registration->builtin_code) {
|
||||
case kTfLiteBuiltinConv2d: {
|
||||
const TfLiteConvParams* conv_params =
|
||||
static_cast<const TfLiteConvParams*>(node->builtin_data);
|
||||
|
||||
return VisitConv2DNode(subgraph, logging_context, node_index, node,
|
||||
context->tensors, conv_params, xnnpack_tensors);
|
||||
}
|
||||
case kTfLiteBuiltinDepthwiseConv2d: {
|
||||
const TfLiteDepthwiseConvParams* dwconv_params =
|
||||
static_cast<const TfLiteDepthwiseConvParams*>(node->builtin_data);
|
||||
|
||||
return VisitDepthwiseConv2DNode(subgraph, logging_context, node_index,
|
||||
node, context->tensors, dwconv_params,
|
||||
xnnpack_tensors);
|
||||
}
|
||||
default:
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
static TfLiteStatus VisitConv2DNode(
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
const TfLiteConvParams* conv_params,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckConvolutionParams(logging_context, conv_params, node_index));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
|
||||
|
||||
const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
|
||||
node->inputs->data[0]));
|
||||
|
||||
const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
|
||||
node->inputs->data[1]));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
|
||||
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, filter_tensor, node->inputs->data[2], node_index));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
|
||||
node->inputs->data[2]));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
|
||||
node->outputs->data[0]));
|
||||
|
||||
const int output_channels = filter_tensor.dims->data[0];
|
||||
const int kernel_height = filter_tensor.dims->data[1];
|
||||
const int kernel_width = filter_tensor.dims->data[2];
|
||||
const int input_channels = filter_tensor.dims->data[3];
|
||||
|
||||
uint32_t flags;
|
||||
TF_LITE_ENSURE_STATUS(CalculatePadding(
|
||||
logging_context, conv_params->padding, &flags, node_index));
|
||||
|
||||
float output_min = -std::numeric_limits<float>::infinity();
|
||||
float output_max = +std::numeric_limits<float>::infinity();
|
||||
TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
|
||||
logging_context, node_index, conv_params->activation, &output_min,
|
||||
&output_max));
|
||||
|
||||
if (subgraph) {
|
||||
const xnn_status status = xnn_define_convolution_2d(
|
||||
subgraph,
|
||||
/*input_padding_top=*/0,
|
||||
/*input_padding_right=*/0,
|
||||
/*input_padding_bottom=*/0,
|
||||
/*input_padding_left=*/0, static_cast<uint32_t>(kernel_height),
|
||||
static_cast<uint32_t>(kernel_width),
|
||||
static_cast<uint32_t>(conv_params->stride_height),
|
||||
static_cast<uint32_t>(conv_params->stride_width),
|
||||
static_cast<uint32_t>(conv_params->dilation_height_factor),
|
||||
static_cast<uint32_t>(conv_params->dilation_width_factor),
|
||||
/*groups=*/1, static_cast<size_t>(input_channels),
|
||||
static_cast<size_t>(output_channels), output_min, output_max,
|
||||
/*input_id=*/xnnpack_tensors[node->inputs->data[0]],
|
||||
/*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
|
||||
/*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
|
||||
/*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
|
||||
if (status != xnn_status_success) {
|
||||
logging_context->ReportError(
|
||||
logging_context, "failed to delegate Convolution 2D node #%d",
|
||||
node_index);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus VisitDepthwiseConv2DNode(
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
const TfLiteDepthwiseConvParams* dwconv_params,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
|
||||
|
||||
const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
|
||||
node->inputs->data[0]));
|
||||
|
||||
const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
|
||||
node->inputs->data[1]));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
|
||||
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, filter_tensor, node->inputs->data[2], node_index));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
|
||||
node->inputs->data[2]));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
|
||||
node->outputs->data[0]));
|
||||
|
||||
const int kernel_height = filter_tensor.dims->data[1];
|
||||
const int kernel_width = filter_tensor.dims->data[2];
|
||||
const int output_channels = filter_tensor.dims->data[3];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckDepthwiseConvolutionParams(
|
||||
logging_context, dwconv_params, output_channels, node_index));
|
||||
|
||||
uint32_t flags = 0;
|
||||
TF_LITE_ENSURE_STATUS(CalculatePadding(
|
||||
logging_context, dwconv_params->padding, &flags, node_index));
|
||||
|
||||
float output_min = -std::numeric_limits<float>::infinity();
|
||||
float output_max = +std::numeric_limits<float>::infinity();
|
||||
TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
|
||||
logging_context, node_index, dwconv_params->activation, &output_min,
|
||||
&output_max));
|
||||
|
||||
if (subgraph) {
|
||||
const xnn_status status = xnn_define_depthwise_convolution_2d(
|
||||
subgraph,
|
||||
/*input_padding_top=*/0,
|
||||
/*input_padding_right=*/0,
|
||||
/*input_padding_bottom=*/0,
|
||||
/*input_padding_left=*/0, static_cast<uint32_t>(kernel_height),
|
||||
static_cast<uint32_t>(kernel_width),
|
||||
static_cast<uint32_t>(dwconv_params->stride_height),
|
||||
static_cast<uint32_t>(dwconv_params->stride_width),
|
||||
static_cast<uint32_t>(dwconv_params->dilation_height_factor),
|
||||
static_cast<uint32_t>(dwconv_params->dilation_width_factor),
|
||||
static_cast<uint32_t>(dwconv_params->depth_multiplier),
|
||||
/*input_channels=*/
|
||||
static_cast<uint32_t>(output_channels /
|
||||
dwconv_params->depth_multiplier),
|
||||
output_min, output_max,
|
||||
/*input_id=*/xnnpack_tensors[node->inputs->data[0]],
|
||||
/*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
|
||||
/*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
|
||||
/*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
|
||||
if (status != xnn_status_success) {
|
||||
logging_context->ReportError(
|
||||
logging_context,
|
||||
"failed to delegate Depthwise Convolution 2D node #%d", node_index);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
private:
|
||||
Subgraph(xnn_runtime_t runtime, std::unordered_set<int>&& externals)
|
||||
: runtime_(runtime, &xnn_delete_runtime), externals_(externals) {}
|
||||
|
||||
// XNNPACK Runtime (subgraph + workspace) with smart-pointer for lifetime
|
||||
// management.
|
||||
std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> runtime_{
|
||||
nullptr, &xnn_delete_runtime};
|
||||
// TFLite Tensor IDs == XNNPACK Value IDs of input/output tensors for the
|
||||
// delegated subgraph.
|
||||
std::unordered_set<int> externals_;
|
||||
bool first_run_{true};
|
||||
};
|
||||
|
||||
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
|
||||
TfLiteIntArray* execution_plan = nullptr;
|
||||
if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
|
||||
context->ReportError(context, "Unable to get graph execution plan.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TfLiteIntArray* nodes_to_replace = TfLiteIntArrayCreate(execution_plan->size);
|
||||
nodes_to_replace->size = 0;
|
||||
for (int i = 0; i < execution_plan->size; ++i) {
|
||||
const int node_index = execution_plan->data[i];
|
||||
|
||||
// Check if TFLite nodes can be delegated to XNNPACK
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
if (context->GetNodeAndRegistration(context, node_index, &node,
|
||||
®istration) != kTfLiteOk) {
|
||||
context->ReportError(context,
|
||||
"Unable to get node and registration for node %d.",
|
||||
node_index);
|
||||
continue; // Soft error (skip this node).
|
||||
}
|
||||
|
||||
if (Subgraph::VisitNode(/*subgraph=*/nullptr, context, registration, node,
|
||||
node_index, std::vector<uint32_t>()) != kTfLiteOk) {
|
||||
// Non-delegatable node is not an error.
|
||||
continue;
|
||||
}
|
||||
|
||||
nodes_to_replace->data[nodes_to_replace->size++] = node_index;
|
||||
}
|
||||
return nodes_to_replace;
|
||||
}
|
||||
|
||||
void* SubgraphInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
const TfLiteDelegateParams* params =
|
||||
reinterpret_cast<const TfLiteDelegateParams*>(buffer);
|
||||
|
||||
return static_cast<void*>(Subgraph::Create(context, params));
|
||||
}
|
||||
|
||||
TfLiteStatus SubgraphPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return static_cast<Subgraph*>(node->user_data)->Prepare(context);
|
||||
}
|
||||
|
||||
TfLiteStatus SubgraphInvoke(TfLiteContext* context, TfLiteNode* node) {
|
||||
return static_cast<Subgraph*>(node->user_data)->Invoke(context);
|
||||
}
|
||||
|
||||
void SubgraphFree(TfLiteContext* context, void* buffer) {
|
||||
if (buffer != nullptr) {
|
||||
delete static_cast<Subgraph*>(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
const TfLiteRegistration kSubgraphRegistration = {
|
||||
/*.init=*/SubgraphInit,
|
||||
/*.free=*/SubgraphFree,
|
||||
/*.prepare=*/SubgraphPrepare,
|
||||
/*.invoke=*/SubgraphInvoke,
|
||||
/*.profiling_string=*/nullptr,
|
||||
/*.builtin_code=*/0,
|
||||
/*.custom_name=*/"TfLiteXNNPackDelegate",
|
||||
/*.version=*/2,
|
||||
};
|
||||
|
||||
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||
const TfLiteStatus status = context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||
context, kSubgraphRegistration, ops_to_replace, delegate);
|
||||
TfLiteIntArrayFree(ops_to_replace);
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
||||
|
||||
TfLiteXNNPackDelegateOptions TfLiteXNNPackDelegateOptionsDefault() {
|
||||
TfLiteXNNPackDelegateOptions options = {0};
|
||||
return options;
|
||||
}
|
||||
|
||||
TfLiteDelegate* TfLiteXNNPackDelegateCreate(
|
||||
const TfLiteXNNPackDelegateOptions* options) {
|
||||
xnn_status status = xnn_initialize(/*allocator=*/nullptr);
|
||||
if (status != xnn_status_success) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* xnnpack_delegate = new ::tflite::xnnpack::Delegate(options);
|
||||
return xnnpack_delegate ? xnnpack_delegate->tflite_delegate() : nullptr;
|
||||
}
|
||||
|
||||
void TfLiteXNNPackDelegateDelete(TfLiteDelegate* delegate) {
|
||||
if (delegate != nullptr) {
|
||||
delete reinterpret_cast<::tflite::xnnpack::Delegate*>(delegate);
|
||||
}
|
||||
}
|
47
tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h
Normal file
47
tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_XNNPACK_DELEGATE_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_XNNPACK_XNNPACK_DELEGATE_H_
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
typedef struct {
|
||||
// Number of threads to use in the thread pool.
|
||||
// 0 or negative value means no thread pool used.
|
||||
int32_t num_threads;
|
||||
} TfLiteXNNPackDelegateOptions;
|
||||
|
||||
// Returns a structure with the default XNNPack delegate options.
|
||||
TfLiteXNNPackDelegateOptions TfLiteXNNPackDelegateOptionsDefault();
|
||||
|
||||
// Creates a new delegate instance that need to be destroyed with
|
||||
// `TfLiteXNNPackDelegateDelete` when delegate is no longer used by TFLite.
|
||||
// When `options` is set to `nullptr`, the following default values are used:
|
||||
TfLiteDelegate* TfLiteXNNPackDelegateCreate(
|
||||
const TfLiteXNNPackDelegateOptions* options);
|
||||
|
||||
// Destroys a delegate created with `TfLiteXNNPackDelegateCreate` call.
|
||||
void TfLiteXNNPackDelegateDelete(TfLiteDelegate* delegate);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_XNNPACK_DELEGATE_H_
|
@ -144,11 +144,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
|
||||
tf_http_archive(
|
||||
name = "XNNPACK",
|
||||
sha256 = "24b6285c679dece8805d2a7d63cc567413b7670279bc0c66a99e555123fe4700",
|
||||
strip_prefix = "XNNPACK-9a88efe2d84fef93eb2b8acb6f0ac8f3cacee8b5",
|
||||
sha256 = "8afdbfd2e71c14dc3251b55e0cd0c799079bb747ac0fe08462d8d009c912cb42",
|
||||
strip_prefix = "XNNPACK-7278a95e3cfae6eac73f363c4fda5db53e1b2a87",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/9a88efe2d84fef93eb2b8acb6f0ac8f3cacee8b5.zip",
|
||||
"https://github.com/google/XNNPACK/archive/9a88efe2d84fef93eb2b8acb6f0ac8f3cacee8b5.zip",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/7278a95e3cfae6eac73f363c4fda5db53e1b2a87.zip",
|
||||
"https://github.com/google/XNNPACK/archive/7278a95e3cfae6eac73f363c4fda5db53e1b2a87.zip",
|
||||
],
|
||||
)
|
||||
|
||||
|
1
third_party/clog/BUILD.bazel
vendored
1
third_party/clog/BUILD.bazel
vendored
@ -25,6 +25,7 @@ cc_library(
|
||||
"//conditions:default": [
|
||||
],
|
||||
}),
|
||||
linkstatic = True,
|
||||
strip_include_prefix = "deps/clog/include",
|
||||
)
|
||||
|
||||
|
1
third_party/cpuinfo/BUILD.bazel
vendored
1
third_party/cpuinfo/BUILD.bazel
vendored
@ -16,6 +16,7 @@ C99OPTS = [
|
||||
# Source code common to all platforms.
|
||||
COMMON_SRCS = [
|
||||
"src/api.c",
|
||||
"src/cache.c",
|
||||
"src/init.c",
|
||||
]
|
||||
|
||||
|
8
third_party/cpuinfo/workspace.bzl
vendored
8
third_party/cpuinfo/workspace.bzl
vendored
@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
|
||||
def repo():
|
||||
third_party_http_archive(
|
||||
name = "cpuinfo",
|
||||
strip_prefix = "cpuinfo-d5e37adf1406cf899d7d9ec1d317c47506ccb970",
|
||||
sha256 = "3f2dc1970f397a0e59db72f9fca6ff144b216895c1d606f6c94a507c1e53a025",
|
||||
strip_prefix = "cpuinfo-e39a5790059b6b8274ed91f7b5b5b13641dff267",
|
||||
sha256 = "e5caa8b7c58f1623eed88f4d5147e3753ff19cde821526bc9aa551b004f751fe",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/pytorch/cpuinfo/archive/d5e37adf1406cf899d7d9ec1d317c47506ccb970.tar.gz",
|
||||
"https://github.com/pytorch/cpuinfo/archive/d5e37adf1406cf899d7d9ec1d317c47506ccb970.tar.gz",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/pytorch/cpuinfo/archive/e39a5790059b6b8274ed91f7b5b5b13641dff267.tar.gz",
|
||||
"https://github.com/pytorch/cpuinfo/archive/e39a5790059b6b8274ed91f7b5b5b13641dff267.tar.gz",
|
||||
],
|
||||
build_file = "//third_party/cpuinfo:BUILD.bazel",
|
||||
)
|
||||
|
8
third_party/psimd/workspace.bzl
vendored
8
third_party/psimd/workspace.bzl
vendored
@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
|
||||
def repo():
|
||||
third_party_http_archive(
|
||||
name = "psimd",
|
||||
strip_prefix = "psimd-8fd2884b88848180904a40c452a362d1ee429ad5",
|
||||
sha256 = "9d4f05bc5a93a0ab8bcef12027ebe54cfddd0050d4862442449c8de11b4e8c17",
|
||||
strip_prefix = "psimd-10b4ffc6ea9e2e11668f86969586f88bc82aaefa",
|
||||
sha256 = "1fefd66702cb2eb3462b962f33d4fb23d59a55d5889ee6372469d286c4512df4",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/psimd/archive/8fd2884b88848180904a40c452a362d1ee429ad5.tar.gz",
|
||||
"https://github.com/Maratyszcza/psimd/archive/8fd2884b88848180904a40c452a362d1ee429ad5.tar.gz",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/psimd/archive/10b4ffc6ea9e2e11668f86969586f88bc82aaefa.tar.gz",
|
||||
"https://github.com/Maratyszcza/psimd/archive/10b4ffc6ea9e2e11668f86969586f88bc82aaefa.tar.gz",
|
||||
],
|
||||
build_file = "//third_party/psimd:BUILD.bazel",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user