Support models with FP16 weights in XNNPACK delegate
PiperOrigin-RevId: 313505742 Change-Id: Id21f7528741073e93a7132d529c3cd79957a73fb
This commit is contained in:
parent
fb86acf839
commit
a7048d89a1
@ -24,6 +24,7 @@ cc_library(
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@FP16",
|
||||
"@XNNPACK",
|
||||
],
|
||||
)
|
||||
@ -39,6 +40,7 @@ cc_library(
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@FP16",
|
||||
"@XNNPACK",
|
||||
],
|
||||
)
|
||||
@ -56,6 +58,7 @@ cc_library(
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@FP16",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
@ -72,6 +75,7 @@ cc_library(
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@FP16",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
@ -88,6 +92,7 @@ cc_library(
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@FP16",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
@ -215,6 +220,7 @@ cc_test(
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@FP16",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
|
@ -679,6 +679,35 @@ TEST(Add, 2DByStatic0D) {
|
||||
.Test(BuiltinOperator_ADD, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Add, FP16Weights) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
BinaryElementwiseTester()
|
||||
.Input1Shape({batch, height, width, channels})
|
||||
.Input2Shape({batch, height, width, channels})
|
||||
.Input1Static(true)
|
||||
.FP16Weights()
|
||||
.Test(BuiltinOperator_ADD, xnnpack_delegate.get());
|
||||
|
||||
BinaryElementwiseTester()
|
||||
.Input1Shape({batch, height, width, channels})
|
||||
.Input2Shape({batch, height, width, channels})
|
||||
.Input2Static(true)
|
||||
.FP16Weights()
|
||||
.Test(BuiltinOperator_ADD, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Add, ReluActivation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <fp16.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
@ -62,6 +63,9 @@ void BinaryElementwiseTester::Test(tflite::BuiltinOperator binary_op,
|
||||
if (Input1Static()) {
|
||||
ASSERT_FALSE(Input2Static());
|
||||
}
|
||||
if (FP16Weights()) {
|
||||
ASSERT_TRUE(Input1Static() || Input2Static());
|
||||
}
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
@ -180,8 +184,12 @@ std::vector<char> BinaryElementwiseTester::CreateTfLiteModel(
|
||||
auto input2_rng = std::bind(input2_distribution, std::ref(rng));
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
flatbuffers::Offset<OperatorCode> operator_code =
|
||||
CreateOperatorCode(builder, binary_op);
|
||||
std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
|
||||
{CreateOperatorCode(builder, binary_op)}};
|
||||
if (FP16Weights()) {
|
||||
operator_codes.emplace_back(
|
||||
CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
|
||||
}
|
||||
|
||||
std::vector<flatbuffers::Offset<Buffer>> buffers{{
|
||||
CreateBuffer(builder, builder.CreateVector({})),
|
||||
@ -189,43 +197,89 @@ std::vector<char> BinaryElementwiseTester::CreateTfLiteModel(
|
||||
|
||||
int32_t input1_buffer = 0;
|
||||
if (Input1Static()) {
|
||||
std::vector<float> input1_data(ComputeSize(Input1Shape()));
|
||||
std::generate(input1_data.begin(), input1_data.end(), input1_rng);
|
||||
if (FP16Weights()) {
|
||||
std::vector<uint16_t> input1_data(ComputeSize(Input1Shape()));
|
||||
std::generate(input1_data.begin(), input1_data.end(),
|
||||
std::bind(fp16_ieee_from_fp32_value, input1_rng));
|
||||
|
||||
input1_buffer = buffers.size();
|
||||
buffers.push_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(input1_data.data()),
|
||||
sizeof(float) * input1_data.size())));
|
||||
buffers.push_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(input1_data.data()),
|
||||
sizeof(uint16_t) * input1_data.size())));
|
||||
} else {
|
||||
std::vector<float> input1_data(ComputeSize(Input1Shape()));
|
||||
std::generate(input1_data.begin(), input1_data.end(), input1_rng);
|
||||
|
||||
input1_buffer = buffers.size();
|
||||
buffers.push_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(input1_data.data()),
|
||||
sizeof(float) * input1_data.size())));
|
||||
}
|
||||
}
|
||||
|
||||
int32_t input2_buffer = 0;
|
||||
if (Input2Static()) {
|
||||
std::vector<float> input2_data(ComputeSize(Input2Shape()));
|
||||
std::generate(input2_data.begin(), input2_data.end(), input2_rng);
|
||||
if (FP16Weights()) {
|
||||
std::vector<uint16_t> input2_data(ComputeSize(Input2Shape()));
|
||||
std::generate(input2_data.begin(), input2_data.end(),
|
||||
std::bind(fp16_ieee_from_fp32_value, input1_rng));
|
||||
|
||||
input2_buffer = buffers.size();
|
||||
buffers.push_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(input2_data.data()),
|
||||
sizeof(float) * input2_data.size())));
|
||||
buffers.push_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(input2_data.data()),
|
||||
sizeof(uint16_t) * input2_data.size())));
|
||||
} else {
|
||||
std::vector<float> input2_data(ComputeSize(Input2Shape()));
|
||||
std::generate(input2_data.begin(), input2_data.end(), input2_rng);
|
||||
|
||||
input2_buffer = buffers.size();
|
||||
buffers.push_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(input2_data.data()),
|
||||
sizeof(float) * input2_data.size())));
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<int32_t> output_shape = OutputShape();
|
||||
const std::array<flatbuffers::Offset<Tensor>, 3> tensors{{
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(Input1Shape().data(),
|
||||
Input1Shape().size()),
|
||||
TensorType_FLOAT32, input1_buffer),
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(Input2Shape().data(),
|
||||
Input2Shape().size()),
|
||||
TensorType_FLOAT32, input2_buffer),
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(output_shape.data(),
|
||||
output_shape.size()),
|
||||
TensorType_FLOAT32),
|
||||
}};
|
||||
std::vector<flatbuffers::Offset<Tensor>> tensors;
|
||||
std::vector<flatbuffers::Offset<Operator>> operators;
|
||||
if (FP16Weights() && Input1Static()) {
|
||||
tensors.emplace_back(
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(Input1Shape().data(),
|
||||
Input1Shape().size()),
|
||||
TensorType_FLOAT16, 1));
|
||||
}
|
||||
if (FP16Weights() && Input2Static()) {
|
||||
tensors.emplace_back(
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(Input2Shape().data(),
|
||||
Input2Shape().size()),
|
||||
TensorType_FLOAT16, 1));
|
||||
}
|
||||
if (FP16Weights()) {
|
||||
const std::array<int32_t, 1> dequantize_inputs{{0}};
|
||||
const std::array<int32_t, 1> dequantize_outputs{{Input1Static() ? 1 : 2}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/1,
|
||||
builder.CreateVector<int32_t>(dequantize_inputs.data(),
|
||||
dequantize_inputs.size()),
|
||||
builder.CreateVector<int32_t>(dequantize_outputs.data(),
|
||||
dequantize_outputs.size())));
|
||||
}
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(Input1Shape().data(), Input1Shape().size()),
|
||||
TensorType_FLOAT32, input1_buffer));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(Input2Shape().data(), Input2Shape().size()),
|
||||
TensorType_FLOAT32, input2_buffer));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
|
||||
TensorType_FLOAT32));
|
||||
|
||||
tflite::BuiltinOptions builtin_options_type = tflite::BuiltinOptions_NONE;
|
||||
flatbuffers::Offset<void> builtin_options = 0;
|
||||
@ -250,35 +304,40 @@ std::vector<char> BinaryElementwiseTester::CreateTfLiteModel(
|
||||
EXPECT_EQ(Activation(), ActivationFunctionType_NONE);
|
||||
}
|
||||
|
||||
const std::array<int32_t, 2> op_inputs{{0, 1}};
|
||||
const std::array<int32_t, 1> op_outputs{{2}};
|
||||
flatbuffers::Offset<Operator> op = CreateOperator(
|
||||
const std::array<int32_t, 2> op_inputs{
|
||||
{static_cast<int>(tensors.size()) - 3,
|
||||
static_cast<int>(tensors.size()) - 2}};
|
||||
const std::array<int32_t, 1> op_outputs{
|
||||
{static_cast<int>(tensors.size()) - 1}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/0,
|
||||
builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
|
||||
builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
|
||||
builtin_options_type, builtin_options);
|
||||
builtin_options_type, builtin_options));
|
||||
|
||||
std::vector<int32_t> subgraph_inputs;
|
||||
if (!Input1Static()) {
|
||||
subgraph_inputs.push_back(0);
|
||||
subgraph_inputs.push_back(tensors.size() - 3);
|
||||
}
|
||||
if (!Input2Static()) {
|
||||
subgraph_inputs.push_back(1);
|
||||
subgraph_inputs.push_back(tensors.size() - 2);
|
||||
}
|
||||
const std::array<int32_t, 1> subgraph_outputs{{2}};
|
||||
const std::array<int32_t, 1> subgraph_outputs{
|
||||
{static_cast<int>(tensors.size()) - 1}};
|
||||
flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
|
||||
builder, builder.CreateVector(tensors.data(), tensors.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_inputs.data(),
|
||||
subgraph_inputs.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_outputs.data(),
|
||||
subgraph_outputs.size()),
|
||||
builder.CreateVector(&op, 1));
|
||||
builder.CreateVector(operators.data(), operators.size()));
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> description =
|
||||
builder.CreateString("Binary operator model");
|
||||
|
||||
flatbuffers::Offset<Model> model_buffer = CreateModel(
|
||||
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
|
||||
builder, TFLITE_SCHEMA_VERSION,
|
||||
builder.CreateVector(operator_codes.data(), operator_codes.size()),
|
||||
builder.CreateVector(&subgraph, 1), description,
|
||||
builder.CreateVector(buffers.data(), buffers.size()));
|
||||
|
||||
|
@ -74,6 +74,13 @@ class BinaryElementwiseTester {
|
||||
|
||||
inline bool Input2Static() const { return input2_static_; }
|
||||
|
||||
inline BinaryElementwiseTester& FP16Weights() {
|
||||
fp16_weights_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline bool FP16Weights() const { return fp16_weights_; }
|
||||
|
||||
inline BinaryElementwiseTester& ReluActivation() {
|
||||
activation_ = ::tflite::ActivationFunctionType_RELU;
|
||||
return *this;
|
||||
@ -114,6 +121,7 @@ class BinaryElementwiseTester {
|
||||
std::vector<int32_t> input2_shape_;
|
||||
bool input1_static_ = false;
|
||||
bool input2_static_ = false;
|
||||
bool fp16_weights_ = false;
|
||||
::tflite::ActivationFunctionType activation_ =
|
||||
::tflite::ActivationFunctionType_NONE;
|
||||
};
|
||||
|
@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <fp16.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
@ -146,6 +148,13 @@ class Conv2DTester {
|
||||
|
||||
int32_t DilationWidth() const { return dilation_width_; }
|
||||
|
||||
inline Conv2DTester& FP16Weights() {
|
||||
fp16_weights_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline bool FP16Weights() const { return fp16_weights_; }
|
||||
|
||||
Conv2DTester& SamePadding(bool same_padding) {
|
||||
same_padding_ = same_padding;
|
||||
return *this;
|
||||
@ -154,11 +163,7 @@ class Conv2DTester {
|
||||
bool SamePadding() const { return same_padding_; }
|
||||
|
||||
void Test(TfLiteDelegate* delegate) const {
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
|
||||
|
||||
std::vector<char> buffer = CreateTfLiteModel(std::ref(f32rng));
|
||||
std::vector<char> buffer = CreateTfLiteModel();
|
||||
const Model* model = GetModel(buffer.data());
|
||||
|
||||
std::unique_ptr<Interpreter> delegate_interpreter;
|
||||
@ -187,6 +192,10 @@ class Conv2DTester {
|
||||
ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate),
|
||||
kTfLiteOk);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
|
||||
|
||||
float* default_input_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->inputs()[0]);
|
||||
std::generate(default_input_data,
|
||||
@ -219,82 +228,149 @@ class Conv2DTester {
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<char> CreateTfLiteModel(std::function<float()> f32rng) const {
|
||||
std::vector<char> CreateTfLiteModel() const {
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
flatbuffers::Offset<OperatorCode> operator_code =
|
||||
CreateOperatorCode(builder, BuiltinOperator_CONV_2D, 0);
|
||||
std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
|
||||
{CreateOperatorCode(builder, BuiltinOperator_CONV_2D, 0)}};
|
||||
std::vector<flatbuffers::Offset<tflite::Operator>> operators;
|
||||
std::vector<flatbuffers::Offset<tflite::Buffer>> buffers{
|
||||
{CreateBuffer(builder, builder.CreateVector({}))}};
|
||||
|
||||
if (FP16Weights()) {
|
||||
operator_codes.emplace_back(
|
||||
CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
|
||||
|
||||
auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
|
||||
|
||||
std::vector<uint16_t> filter_data(OutputChannels() * KernelHeight() *
|
||||
KernelWidth() * InputChannels());
|
||||
std::vector<uint16_t> bias_data(OutputChannels());
|
||||
|
||||
std::generate(filter_data.begin(), filter_data.end(), f16rng);
|
||||
std::generate(bias_data.begin(), bias_data.end(), f16rng);
|
||||
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(uint16_t) * filter_data.size())));
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(uint16_t) * bias_data.size())));
|
||||
|
||||
const std::array<int32_t, 1> dequantize_filter_inputs{{0}};
|
||||
const std::array<int32_t, 1> dequantize_filter_outputs{{3}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/1,
|
||||
builder.CreateVector<int32_t>(dequantize_filter_inputs.data(),
|
||||
dequantize_filter_inputs.size()),
|
||||
builder.CreateVector<int32_t>(dequantize_filter_outputs.data(),
|
||||
dequantize_filter_outputs.size())));
|
||||
const std::array<int32_t, 1> dequantize_bias_inputs{{1}};
|
||||
const std::array<int32_t, 1> dequantize_bias_outputs{{4}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/1,
|
||||
builder.CreateVector<int32_t>(dequantize_bias_inputs.data(),
|
||||
dequantize_bias_inputs.size()),
|
||||
builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
|
||||
dequantize_bias_outputs.size())));
|
||||
} else {
|
||||
std::vector<float> filter_data(OutputChannels() * KernelHeight() *
|
||||
KernelWidth() * InputChannels());
|
||||
std::vector<float> bias_data(OutputChannels());
|
||||
|
||||
std::generate(filter_data.begin(), filter_data.end(), f32rng);
|
||||
std::generate(bias_data.begin(), bias_data.end(), f32rng);
|
||||
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(float) * filter_data.size())));
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(float) * bias_data.size())));
|
||||
}
|
||||
|
||||
const std::array<int32_t, 4> input_shape{
|
||||
{BatchSize(), InputHeight(), InputWidth(), InputChannels()}};
|
||||
const std::array<int32_t, 4> output_shape{
|
||||
{BatchSize(), OutputHeight(), OutputWidth(), OutputChannels()}};
|
||||
const std::array<int32_t, 4> filter_shape{
|
||||
{OutputChannels(), KernelHeight(), KernelWidth(), InputChannels()}};
|
||||
const std::array<int32_t, 1> bias_shape{{OutputChannels()}};
|
||||
|
||||
std::vector<flatbuffers::Offset<tflite::Tensor>> tensors;
|
||||
if (FP16Weights()) {
|
||||
tensors.emplace_back(
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(filter_shape.data(),
|
||||
filter_shape.size()),
|
||||
TensorType_FLOAT16, /*buffer=*/1));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
|
||||
TensorType_FLOAT16, /*buffer=*/2));
|
||||
}
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(input_shape.data(), input_shape.size()),
|
||||
TensorType_FLOAT32));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 1));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
|
||||
TensorType_FLOAT32));
|
||||
|
||||
const std::array<int32_t, 3> op_inputs{
|
||||
{static_cast<int>(tensors.size()) - 4,
|
||||
static_cast<int>(tensors.size()) - 3,
|
||||
static_cast<int>(tensors.size()) - 2}};
|
||||
const std::array<int32_t, 1> op_outputs{
|
||||
{static_cast<int>(tensors.size()) - 1}};
|
||||
|
||||
flatbuffers::Offset<Conv2DOptions> conv2d_options = CreateConv2DOptions(
|
||||
builder, SamePadding() ? tflite::Padding_SAME : tflite::Padding_VALID,
|
||||
StrideWidth(), StrideHeight(), ActivationFunctionType_NONE,
|
||||
DilationWidth(), DilationHeight());
|
||||
|
||||
std::vector<float> filter_data(OutputChannels() * KernelHeight() *
|
||||
KernelWidth() * InputChannels());
|
||||
std::vector<float> bias_data(OutputChannels());
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/0,
|
||||
builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
|
||||
builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
|
||||
BuiltinOptions_Conv2DOptions, conv2d_options.Union()));
|
||||
|
||||
std::generate(filter_data.begin(), filter_data.end(), f32rng);
|
||||
std::generate(bias_data.begin(), bias_data.end(), f32rng);
|
||||
|
||||
flatbuffers::Offset<Buffer> buffers[3] = {
|
||||
CreateBuffer(builder, builder.CreateVector({})),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(float) * filter_data.size())),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(float) * bias_data.size())),
|
||||
};
|
||||
|
||||
const int32_t input_shape[4] = {BatchSize(), InputHeight(), InputWidth(),
|
||||
InputChannels()};
|
||||
const int32_t output_shape[4] = {BatchSize(), OutputHeight(), OutputWidth(),
|
||||
OutputChannels()};
|
||||
const int32_t filter_shape[4] = {OutputChannels(), KernelHeight(),
|
||||
KernelWidth(), InputChannels()};
|
||||
const int32_t bias_shape[1] = {OutputChannels()};
|
||||
|
||||
flatbuffers::Offset<Tensor> tensors[4] = {
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(input_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/0,
|
||||
builder.CreateString("X")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(filter_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/1,
|
||||
builder.CreateString("W")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(bias_shape, 1),
|
||||
TensorType_FLOAT32, /*buffer=*/2,
|
||||
builder.CreateString("b")),
|
||||
CreateTensor(builder, builder.CreateVector<int32_t>(output_shape, 4),
|
||||
TensorType_FLOAT32, /*buffer=*/0,
|
||||
builder.CreateString("Y")),
|
||||
};
|
||||
|
||||
const int32_t op_inputs[3] = {0, 1, 2};
|
||||
const int32_t op_outputs[1] = {3};
|
||||
|
||||
flatbuffers::Offset<Operator> op =
|
||||
CreateOperator(builder, /*opcode_index=*/0,
|
||||
builder.CreateVector<int32_t>(op_inputs, 3),
|
||||
builder.CreateVector<int32_t>(op_outputs, 1),
|
||||
BuiltinOptions_Conv2DOptions, conv2d_options.Union());
|
||||
|
||||
int32_t subgraph_inputs[1] = {0};
|
||||
int32_t subgraph_outputs[1] = {3};
|
||||
flatbuffers::Offset<SubGraph> subgraph =
|
||||
CreateSubGraph(builder, builder.CreateVector(tensors, 4),
|
||||
builder.CreateVector<int32_t>(subgraph_inputs, 1),
|
||||
builder.CreateVector<int32_t>(subgraph_outputs, 1),
|
||||
builder.CreateVector(&op, 1), /*name=*/0);
|
||||
const std::array<int32_t, 1> subgraph_inputs{
|
||||
{static_cast<int>(tensors.size()) - 4}};
|
||||
const std::array<int32_t, 1> subgraph_outputs{
|
||||
{static_cast<int>(tensors.size()) - 1}};
|
||||
flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
|
||||
builder, builder.CreateVector(tensors.data(), tensors.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_inputs.data(),
|
||||
subgraph_inputs.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_outputs.data(),
|
||||
subgraph_outputs.size()),
|
||||
builder.CreateVector(operators.data(), operators.size()));
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> description =
|
||||
builder.CreateString("Conv2D model");
|
||||
|
||||
flatbuffers::Offset<Model> model_buffer = CreateModel(
|
||||
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
|
||||
builder, TFLITE_SCHEMA_VERSION,
|
||||
builder.CreateVector(operator_codes.data(), operator_codes.size()),
|
||||
builder.CreateVector(&subgraph, 1), description,
|
||||
builder.CreateVector(buffers, 3));
|
||||
builder.CreateVector(buffers.data(), buffers.size()));
|
||||
|
||||
builder.Finish(model_buffer);
|
||||
|
||||
@ -313,6 +389,7 @@ class Conv2DTester {
|
||||
int32_t stride_width_ = 1;
|
||||
int32_t dilation_height_ = 1;
|
||||
int32_t dilation_width_ = 1;
|
||||
bool fp16_weights_ = false;
|
||||
bool same_padding_ = true;
|
||||
};
|
||||
|
||||
@ -506,5 +583,35 @@ TEST(Conv2D, DilationWithValidPadding) {
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Conv2D, FP16Weights) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
|
||||
|
||||
Conv2DTester()
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.OutputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.SamePadding(true)
|
||||
.FP16Weights()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
||||
|
@ -371,6 +371,37 @@ TEST(DepthwiseConv2D, DepthMultiplier) {
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, FP16Weights) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
|
||||
auto input_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
|
||||
auto kernel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
|
||||
auto stride_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
|
||||
auto channel_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
|
||||
|
||||
DepthwiseConv2DTester()
|
||||
.BatchSize(batch_rng())
|
||||
.InputHeight(input_rng())
|
||||
.InputWidth(input_rng())
|
||||
.InputChannels(channel_rng())
|
||||
.KernelHeight(kernel_rng())
|
||||
.KernelWidth(kernel_rng())
|
||||
.StrideHeight(stride_rng())
|
||||
.StrideWidth(stride_rng())
|
||||
.FP16Weights()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(DepthwiseConv2D, ReluActivation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <fp16.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
@ -107,56 +108,110 @@ void DepthwiseConv2DTester::Test(TfLiteDelegate* delegate) const {
|
||||
}
|
||||
|
||||
std::vector<char> DepthwiseConv2DTester::CreateTfLiteModel() const {
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
flatbuffers::Offset<OperatorCode> operator_code =
|
||||
CreateOperatorCode(builder, BuiltinOperator_DEPTHWISE_CONV_2D);
|
||||
|
||||
flatbuffers::Offset<DepthwiseConv2DOptions> depthwise_conv2d_options =
|
||||
CreateDepthwiseConv2DOptions(
|
||||
builder, Padding(), StrideWidth(), StrideHeight(), DepthMultiplier(),
|
||||
Activation(), DilationWidth(), DilationHeight());
|
||||
|
||||
std::vector<float> filter_data(KernelHeight() * KernelWidth() *
|
||||
OutputChannels());
|
||||
std::vector<float> bias_data(OutputChannels());
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto range_rng = std::bind(
|
||||
std::uniform_real_distribution<float>(-25.0f, 25.0f), std::ref(rng));
|
||||
for (int32_t ic = 0; ic < InputChannels(); ic++) {
|
||||
// Use the same range of all-positive or all-negative values to generate
|
||||
// all pixels within the same batch index & channel, but different ranges
|
||||
// for different channels or batches. This ensures that no catastrophic
|
||||
// cancellation occur, but test covers both positive and negative inputs.
|
||||
const float range = range_rng();
|
||||
auto value_rng =
|
||||
std::bind(std::uniform_real_distribution<float>(std::min(range, 0.0f),
|
||||
std::max(range, 0.0f)),
|
||||
std::ref(rng));
|
||||
for (int32_t m = 0; m < DepthMultiplier(); m++) {
|
||||
const int32_t oc = ic * DepthMultiplier() + m;
|
||||
bias_data[oc] = value_rng();
|
||||
for (int32_t y = 0; y < KernelHeight(); y++) {
|
||||
for (int32_t x = 0; x < KernelWidth(); x++) {
|
||||
const int32_t index = (y * KernelWidth() + x) * OutputChannels() + oc;
|
||||
filter_data[index] = value_rng();
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
|
||||
{CreateOperatorCode(builder, BuiltinOperator_DEPTHWISE_CONV_2D)}};
|
||||
std::vector<flatbuffers::Offset<tflite::Operator>> operators;
|
||||
std::vector<flatbuffers::Offset<tflite::Buffer>> buffers{
|
||||
{CreateBuffer(builder, builder.CreateVector({}))}};
|
||||
|
||||
if (FP16Weights()) {
|
||||
operator_codes.emplace_back(
|
||||
CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
|
||||
|
||||
std::vector<uint16_t> filter_data(KernelHeight() * KernelWidth() *
|
||||
OutputChannels());
|
||||
std::vector<uint16_t> bias_data(OutputChannels());
|
||||
for (int32_t ic = 0; ic < InputChannels(); ic++) {
|
||||
// Use the same range of all-positive or all-negative values to generate
|
||||
// all pixels within the same batch index & channel, but different ranges
|
||||
// for different channels or batches. This ensures that no catastrophic
|
||||
// cancellation occur, but test covers both positive and negative inputs.
|
||||
const float range = range_rng();
|
||||
auto value_rng =
|
||||
std::bind(fp16_ieee_from_fp32_value,
|
||||
std::bind(std::uniform_real_distribution<float>(
|
||||
std::min(range, 0.0f), std::max(range, 0.0f)),
|
||||
std::ref(rng)));
|
||||
for (int32_t m = 0; m < DepthMultiplier(); m++) {
|
||||
const int32_t oc = ic * DepthMultiplier() + m;
|
||||
bias_data[oc] = value_rng();
|
||||
for (int32_t y = 0; y < KernelHeight(); y++) {
|
||||
for (int32_t x = 0; x < KernelWidth(); x++) {
|
||||
const int32_t index =
|
||||
(y * KernelWidth() + x) * OutputChannels() + oc;
|
||||
filter_data[index] = value_rng();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::array<flatbuffers::Offset<tflite::Buffer>, 3> buffers{{
|
||||
CreateBuffer(builder, builder.CreateVector({})),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(float) * filter_data.size())),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(float) * bias_data.size())),
|
||||
}};
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(uint16_t) * filter_data.size())));
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder,
|
||||
builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(uint16_t) * bias_data.size())));
|
||||
|
||||
const std::array<int32_t, 1> dequantize_filter_inputs{{0}};
|
||||
const std::array<int32_t, 1> dequantize_filter_outputs{{3}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/1,
|
||||
builder.CreateVector<int32_t>(dequantize_filter_inputs.data(),
|
||||
dequantize_filter_inputs.size()),
|
||||
builder.CreateVector<int32_t>(dequantize_filter_outputs.data(),
|
||||
dequantize_filter_outputs.size())));
|
||||
const std::array<int32_t, 1> dequantize_bias_inputs{{1}};
|
||||
const std::array<int32_t, 1> dequantize_bias_outputs{{4}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/1,
|
||||
builder.CreateVector<int32_t>(dequantize_bias_inputs.data(),
|
||||
dequantize_bias_inputs.size()),
|
||||
builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
|
||||
dequantize_bias_outputs.size())));
|
||||
} else {
|
||||
std::vector<float> filter_data(KernelHeight() * KernelWidth() *
|
||||
OutputChannels());
|
||||
std::vector<float> bias_data(OutputChannels());
|
||||
for (int32_t ic = 0; ic < InputChannels(); ic++) {
|
||||
// Use the same range of all-positive or all-negative values to generate
|
||||
// all pixels within the same batch index & channel, but different ranges
|
||||
// for different channels or batches. This ensures that no catastrophic
|
||||
// cancellation occur, but test covers both positive and negative inputs.
|
||||
const float range = range_rng();
|
||||
auto value_rng =
|
||||
std::bind(std::uniform_real_distribution<float>(
|
||||
std::min(range, 0.0f), std::max(range, 0.0f)),
|
||||
std::ref(rng));
|
||||
for (int32_t m = 0; m < DepthMultiplier(); m++) {
|
||||
const int32_t oc = ic * DepthMultiplier() + m;
|
||||
bias_data[oc] = value_rng();
|
||||
for (int32_t y = 0; y < KernelHeight(); y++) {
|
||||
for (int32_t x = 0; x < KernelWidth(); x++) {
|
||||
const int32_t index =
|
||||
(y * KernelWidth() + x) * OutputChannels() + oc;
|
||||
filter_data[index] = value_rng();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(float) * filter_data.size())));
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder,
|
||||
builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(float) * bias_data.size())));
|
||||
}
|
||||
|
||||
const std::array<int32_t, 4> input_shape{
|
||||
{BatchSize(), InputHeight(), InputWidth(), InputChannels()}};
|
||||
@ -166,49 +221,69 @@ std::vector<char> DepthwiseConv2DTester::CreateTfLiteModel() const {
|
||||
{1, KernelHeight(), KernelWidth(), OutputChannels()}};
|
||||
const std::array<int32_t, 1> bias_shape{{OutputChannels()}};
|
||||
|
||||
const std::array<flatbuffers::Offset<tflite::Tensor>, 4> tensors{{
|
||||
CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(input_shape.data(), input_shape.size()),
|
||||
TensorType_FLOAT32),
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(filter_shape.data(),
|
||||
filter_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/1),
|
||||
CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/2),
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(output_shape.data(),
|
||||
output_shape.size()),
|
||||
TensorType_FLOAT32),
|
||||
}};
|
||||
std::vector<flatbuffers::Offset<tflite::Tensor>> tensors;
|
||||
if (FP16Weights()) {
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
|
||||
TensorType_FLOAT16, /*buffer=*/1));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
|
||||
TensorType_FLOAT16, /*buffer=*/2));
|
||||
}
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(input_shape.data(), input_shape.size()),
|
||||
TensorType_FLOAT32));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 1));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
|
||||
TensorType_FLOAT32));
|
||||
|
||||
const std::array<int32_t, 3> op_inputs{{0, 1, 2}};
|
||||
const std::array<int32_t, 1> op_outputs{{3}};
|
||||
const std::array<int32_t, 3> op_inputs{
|
||||
{static_cast<int>(tensors.size()) - 4,
|
||||
static_cast<int>(tensors.size()) - 3,
|
||||
static_cast<int>(tensors.size()) - 2}};
|
||||
const std::array<int32_t, 1> op_outputs{
|
||||
{static_cast<int>(tensors.size()) - 1}};
|
||||
|
||||
flatbuffers::Offset<tflite::Operator> op = CreateOperator(
|
||||
flatbuffers::Offset<DepthwiseConv2DOptions> depthwise_conv2d_options =
|
||||
CreateDepthwiseConv2DOptions(
|
||||
builder, Padding(), StrideWidth(), StrideHeight(), DepthMultiplier(),
|
||||
Activation(), DilationWidth(), DilationHeight());
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/0,
|
||||
builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
|
||||
builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
|
||||
BuiltinOptions_DepthwiseConv2DOptions, depthwise_conv2d_options.Union());
|
||||
BuiltinOptions_DepthwiseConv2DOptions, depthwise_conv2d_options.Union()));
|
||||
|
||||
const std::array<int32_t, 1> subgraph_inputs{{0}};
|
||||
const std::array<int32_t, 1> subgraph_outputs{{3}};
|
||||
const std::array<int32_t, 1> subgraph_inputs{
|
||||
{static_cast<int>(tensors.size()) - 4}};
|
||||
const std::array<int32_t, 1> subgraph_outputs{
|
||||
{static_cast<int>(tensors.size()) - 1}};
|
||||
flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
|
||||
builder, builder.CreateVector(tensors.data(), tensors.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_inputs.data(),
|
||||
subgraph_inputs.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_outputs.data(),
|
||||
subgraph_outputs.size()),
|
||||
builder.CreateVector(&op, 1));
|
||||
builder.CreateVector(operators.data(), operators.size()));
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> description =
|
||||
builder.CreateString("DepthwiseConv2D model");
|
||||
|
||||
flatbuffers::Offset<Model> model_buffer = CreateModel(
|
||||
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
|
||||
builder, TFLITE_SCHEMA_VERSION,
|
||||
builder.CreateVector(operator_codes.data(), operator_codes.size()),
|
||||
builder.CreateVector(&subgraph, 1), description,
|
||||
builder.CreateVector(buffers.data(), buffers.size()));
|
||||
|
||||
|
@ -152,6 +152,13 @@ class DepthwiseConv2DTester {
|
||||
return (KernelWidth() - 1) * DilationWidth() + 1;
|
||||
}
|
||||
|
||||
inline DepthwiseConv2DTester& FP16Weights() {
|
||||
fp16_weights_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline bool FP16Weights() const { return fp16_weights_; }
|
||||
|
||||
inline DepthwiseConv2DTester& SamePadding() {
|
||||
padding_ = ::tflite::Padding_SAME;
|
||||
return *this;
|
||||
@ -209,6 +216,7 @@ class DepthwiseConv2DTester {
|
||||
int32_t stride_width_ = 1;
|
||||
int32_t dilation_height_ = 1;
|
||||
int32_t dilation_width_ = 1;
|
||||
bool fp16_weights_ = false;
|
||||
::tflite::Padding padding_ = ::tflite::Padding_VALID;
|
||||
::tflite::ActivationFunctionType activation_ =
|
||||
::tflite::ActivationFunctionType_NONE;
|
||||
|
@ -228,6 +228,29 @@ TEST(FullyConnected, 4DKeepDims) {
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(FullyConnected, FP16Weights) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto batch_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
auto channels_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 9), std::ref(rng));
|
||||
const auto batch = batch_rng();
|
||||
const auto input_channels = channels_rng();
|
||||
const auto output_channels = channels_rng();
|
||||
|
||||
FullyConnectedTester()
|
||||
.InputShape({batch, input_channels})
|
||||
.InputChannels(input_channels)
|
||||
.OutputChannels(output_channels)
|
||||
.FP16Weights()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(FullyConnected, ReluActivation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <fp16.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
@ -109,98 +110,165 @@ void FullyConnectedTester::Test(TfLiteDelegate* delegate) const {
|
||||
std::vector<char> FullyConnectedTester::CreateTfLiteModel() const {
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
|
||||
auto range_rng = std::bind(
|
||||
std::uniform_real_distribution<float>(-25.0f, 25.0f), std::ref(rng));
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
flatbuffers::Offset<OperatorCode> operator_code =
|
||||
CreateOperatorCode(builder, BuiltinOperator_FULLY_CONNECTED);
|
||||
std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
|
||||
{CreateOperatorCode(builder, BuiltinOperator_FULLY_CONNECTED)}};
|
||||
std::vector<flatbuffers::Offset<Operator>> operators;
|
||||
std::vector<flatbuffers::Offset<Buffer>> buffers{
|
||||
{CreateBuffer(builder, builder.CreateVector({}))}};
|
||||
|
||||
std::vector<float> filter_data(InputChannels() * OutputChannels());
|
||||
std::vector<float> bias_data(OutputChannels());
|
||||
if (FP16Weights()) {
|
||||
operator_codes.emplace_back(
|
||||
CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
|
||||
|
||||
for (int32_t oc = 0; oc < OutputChannels(); oc++) {
|
||||
// Use the same range of all-positive or all-negative values to generate
|
||||
// all filter & bias weights within the same channel, but different ranges
|
||||
// for different output channels. This ensures that no catastrophic
|
||||
// cancellation occur, but test covers both positive and negative inputs.
|
||||
const float range = range_rng();
|
||||
auto value_rng =
|
||||
std::bind(std::uniform_real_distribution<float>(std::min(range, 0.0f),
|
||||
std::max(range, 0.0f)),
|
||||
std::ref(rng));
|
||||
std::vector<uint16_t> filter_data(InputChannels() * OutputChannels());
|
||||
std::vector<uint16_t> bias_data(OutputChannels());
|
||||
|
||||
bias_data[oc] = value_rng();
|
||||
for (int32_t ic = 0; ic < InputChannels(); ic++) {
|
||||
filter_data[oc * InputChannels() + ic] = value_rng();
|
||||
for (int32_t oc = 0; oc < OutputChannels(); oc++) {
|
||||
// Use the same range of all-positive or all-negative values to generate
|
||||
// all filter & bias weights within the same channel, but different ranges
|
||||
// for different output channels. This ensures that no catastrophic
|
||||
// cancellation occur, but test covers both positive and negative inputs.
|
||||
const float range = range_rng();
|
||||
auto value_rng =
|
||||
std::bind(fp16_ieee_from_fp32_value,
|
||||
std::bind(std::uniform_real_distribution<float>(
|
||||
std::min(range, 0.0f), std::max(range, 0.0f)),
|
||||
std::ref(rng)));
|
||||
|
||||
bias_data[oc] = value_rng();
|
||||
for (int32_t ic = 0; ic < InputChannels(); ic++) {
|
||||
filter_data[oc * InputChannels() + ic] = value_rng();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::array<flatbuffers::Offset<Buffer>, 3> buffers{{
|
||||
CreateBuffer(builder, builder.CreateVector({})),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(float) * filter_data.size())),
|
||||
CreateBuffer(builder,
|
||||
builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(float) * bias_data.size())),
|
||||
}};
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(uint16_t) * filter_data.size())));
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder,
|
||||
builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(uint16_t) * bias_data.size())));
|
||||
|
||||
const std::array<int32_t, 1> dequantize_filter_inputs{{0}};
|
||||
const std::array<int32_t, 1> dequantize_filter_outputs{{3}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/1,
|
||||
builder.CreateVector<int32_t>(dequantize_filter_inputs.data(),
|
||||
dequantize_filter_inputs.size()),
|
||||
builder.CreateVector<int32_t>(dequantize_filter_outputs.data(),
|
||||
dequantize_filter_outputs.size())));
|
||||
const std::array<int32_t, 1> dequantize_bias_inputs{{1}};
|
||||
const std::array<int32_t, 1> dequantize_bias_outputs{{4}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/1,
|
||||
builder.CreateVector<int32_t>(dequantize_bias_inputs.data(),
|
||||
dequantize_bias_inputs.size()),
|
||||
builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
|
||||
dequantize_bias_outputs.size())));
|
||||
} else {
|
||||
std::vector<float> filter_data(InputChannels() * OutputChannels());
|
||||
std::vector<float> bias_data(OutputChannels());
|
||||
|
||||
for (int32_t oc = 0; oc < OutputChannels(); oc++) {
|
||||
// Use the same range of all-positive or all-negative values to generate
|
||||
// all filter & bias weights within the same channel, but different ranges
|
||||
// for different output channels. This ensures that no catastrophic
|
||||
// cancellation occur, but test covers both positive and negative inputs.
|
||||
const float range = range_rng();
|
||||
auto value_rng =
|
||||
std::bind(std::uniform_real_distribution<float>(
|
||||
std::min(range, 0.0f), std::max(range, 0.0f)),
|
||||
std::ref(rng));
|
||||
|
||||
bias_data[oc] = value_rng();
|
||||
for (int32_t ic = 0; ic < InputChannels(); ic++) {
|
||||
filter_data[oc * InputChannels() + ic] = value_rng();
|
||||
}
|
||||
}
|
||||
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(filter_data.data()),
|
||||
sizeof(float) * filter_data.size())));
|
||||
buffers.emplace_back(CreateBuffer(
|
||||
builder,
|
||||
builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
|
||||
sizeof(float) * bias_data.size())));
|
||||
}
|
||||
|
||||
const std::array<int32_t, 2> filter_shape(
|
||||
{OutputChannels(), InputChannels()});
|
||||
const std::array<int32_t, 1> bias_shape({OutputChannels()});
|
||||
|
||||
const std::vector<int32_t> output_shape = OutputShape();
|
||||
const std::array<flatbuffers::Offset<Tensor>, 4> tensors{{
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(InputShape().data(),
|
||||
InputShape().size()),
|
||||
TensorType_FLOAT32),
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(filter_shape.data(),
|
||||
filter_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/1),
|
||||
CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/2),
|
||||
CreateTensor(builder,
|
||||
builder.CreateVector<int32_t>(output_shape.data(),
|
||||
output_shape.size()),
|
||||
TensorType_FLOAT32),
|
||||
}};
|
||||
std::vector<flatbuffers::Offset<Tensor>> tensors;
|
||||
if (FP16Weights()) {
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
|
||||
TensorType_FLOAT16, /*buffer=*/1));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
|
||||
TensorType_FLOAT16, /*buffer=*/2));
|
||||
}
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(InputShape().data(), InputShape().size()),
|
||||
TensorType_FLOAT32));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 1));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
|
||||
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
|
||||
TensorType_FLOAT32));
|
||||
|
||||
flatbuffers::Offset<FullyConnectedOptions> fully_connected_options =
|
||||
CreateFullyConnectedOptions(builder, Activation(),
|
||||
FullyConnectedOptionsWeightsFormat_DEFAULT,
|
||||
KeepDims());
|
||||
|
||||
const std::array<int32_t, 3> op_inputs{{0, 1, 2}};
|
||||
const std::array<int32_t, 1> op_outputs{{3}};
|
||||
flatbuffers::Offset<Operator> op = CreateOperator(
|
||||
const std::array<int32_t, 3> op_inputs{
|
||||
{static_cast<int>(tensors.size()) - 4,
|
||||
static_cast<int>(tensors.size()) - 3,
|
||||
static_cast<int>(tensors.size()) - 2}};
|
||||
const std::array<int32_t, 1> op_outputs{
|
||||
{static_cast<int>(tensors.size()) - 1}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/0,
|
||||
builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
|
||||
builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
|
||||
BuiltinOptions_FullyConnectedOptions, fully_connected_options.Union());
|
||||
BuiltinOptions_FullyConnectedOptions, fully_connected_options.Union()));
|
||||
|
||||
const std::array<int32_t, 1> subgraph_inputs{{0}};
|
||||
const std::array<int32_t, 1> subgraph_outputs{{3}};
|
||||
const std::array<int32_t, 1> subgraph_inputs{
|
||||
{static_cast<int>(tensors.size()) - 4}};
|
||||
const std::array<int32_t, 1> subgraph_outputs{
|
||||
{static_cast<int>(tensors.size()) - 1}};
|
||||
flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
|
||||
builder, builder.CreateVector(tensors.data(), tensors.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_inputs.data(),
|
||||
subgraph_inputs.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_outputs.data(),
|
||||
subgraph_outputs.size()),
|
||||
builder.CreateVector(&op, 1));
|
||||
builder.CreateVector(operators.data(), operators.size()));
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> description =
|
||||
builder.CreateString("Fully Connected model");
|
||||
|
||||
flatbuffers::Offset<Model> model_buffer = CreateModel(
|
||||
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
|
||||
builder, TFLITE_SCHEMA_VERSION,
|
||||
builder.CreateVector(operator_codes.data(), operator_codes.size()),
|
||||
builder.CreateVector(&subgraph, 1), description,
|
||||
builder.CreateVector(buffers.data(), buffers.size()));
|
||||
|
||||
|
@ -71,6 +71,13 @@ class FullyConnectedTester {
|
||||
|
||||
inline bool KeepDims() const { return keep_dims_; }
|
||||
|
||||
inline FullyConnectedTester& FP16Weights() {
|
||||
fp16_weights_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline bool FP16Weights() const { return fp16_weights_; }
|
||||
|
||||
inline FullyConnectedTester& ReluActivation() {
|
||||
activation_ = ::tflite::ActivationFunctionType_RELU;
|
||||
return *this;
|
||||
@ -102,6 +109,7 @@ class FullyConnectedTester {
|
||||
int32_t input_channels_ = 1;
|
||||
int32_t output_channels_ = 1;
|
||||
bool keep_dims_ = false;
|
||||
bool fp16_weights_ = false;
|
||||
::tflite::ActivationFunctionType activation_ =
|
||||
::tflite::ActivationFunctionType_NONE;
|
||||
};
|
||||
|
@ -679,6 +679,35 @@ TEST(Mul, 2DByStatic0D) {
|
||||
.Test(BuiltinOperator_MUL, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Mul, FP16Weights) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
BinaryElementwiseTester()
|
||||
.Input1Shape({batch, height, width, channels})
|
||||
.Input2Shape({batch, height, width, channels})
|
||||
.Input1Static(true)
|
||||
.FP16Weights()
|
||||
.Test(BuiltinOperator_MUL, xnnpack_delegate.get());
|
||||
|
||||
BinaryElementwiseTester()
|
||||
.Input1Shape({batch, height, width, channels})
|
||||
.Input2Shape({batch, height, width, channels})
|
||||
.Input2Static(true)
|
||||
.FP16Weights()
|
||||
.Test(BuiltinOperator_MUL, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Mul, ReluActivation) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
|
@ -22,10 +22,12 @@ limitations under the License.
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <fp16.h>
|
||||
#include <xnnpack.h>
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
@ -39,6 +41,8 @@ namespace {
|
||||
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
|
||||
|
||||
class Delegate {
|
||||
friend class Subgraph;
|
||||
|
||||
public:
|
||||
explicit Delegate(const TfLiteXNNPackDelegateOptions* options) {
|
||||
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
|
||||
@ -49,9 +53,10 @@ class Delegate {
|
||||
#endif
|
||||
}
|
||||
|
||||
TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context);
|
||||
TfLiteDelegate* tflite_delegate() { return &delegate_; }
|
||||
|
||||
pthreadpool_t threadpool() {
|
||||
pthreadpool_t threadpool() const {
|
||||
#if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__)
|
||||
return nullptr;
|
||||
#else
|
||||
@ -69,6 +74,17 @@ class Delegate {
|
||||
kTfLiteDelegateFlagsNone, // .flags
|
||||
};
|
||||
|
||||
// Unpacked data for quasi-static tensors, i.e. tensors produced by
|
||||
// dequantizing or unpacking static buffers.
|
||||
std::vector<char> static_unpacked_data_;
|
||||
// Mapping from a tensor index for a quasi-static tensor to the offset to
|
||||
// its unpacked data within static_unpacked_data_.
|
||||
std::unordered_map<int, size_t> static_unpacked_data_map_;
|
||||
// Set of indices of nodes which unpack static data, e.g. Dequantize
|
||||
// operators which convert FP16 static weights to FP32. These nodes are simply
|
||||
// ignored in the delegate implementation, because their outputs are
|
||||
// pre-unpacked in DelegatePrepare.
|
||||
std::unordered_set<int> static_unpack_nodes_;
|
||||
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
|
||||
// Thread pool with smart-pointer for lifetime management.
|
||||
std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)> threadpool_{
|
||||
@ -80,7 +96,7 @@ class Subgraph {
|
||||
public:
|
||||
static Subgraph* Create(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* params,
|
||||
pthreadpool_t threadpool) {
|
||||
const Delegate* delegate) {
|
||||
// Convert subgraph inputs and outputs to hash sets for faster lookup.
|
||||
const std::unordered_set<int> inputs(
|
||||
¶ms->input_tensors->data[0],
|
||||
@ -113,11 +129,17 @@ class Subgraph {
|
||||
// filtered out and removed later.
|
||||
std::vector<int> tensors(context->tensors_size, -1);
|
||||
for (int i = 0; i < params->nodes_to_replace->size; i++) {
|
||||
const int node_index = params->nodes_to_replace->data[i];
|
||||
if (delegate->static_unpack_nodes_.count(node_index)) {
|
||||
// The node unpacks static input and can be skipped because its input
|
||||
// was pre-unpacked in DelegatePrepare.
|
||||
continue;
|
||||
}
|
||||
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
if (context->GetNodeAndRegistration(context,
|
||||
params->nodes_to_replace->data[i],
|
||||
&node, ®istration) != kTfLiteOk) {
|
||||
if (context->GetNodeAndRegistration(context, node_index, &node,
|
||||
®istration) != kTfLiteOk) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -164,6 +186,12 @@ class Subgraph {
|
||||
const void* data = nullptr;
|
||||
if (context->tensors[t].allocation_type == kTfLiteMmapRo) {
|
||||
data = context->tensors[t].data.raw_const;
|
||||
} else {
|
||||
// Check for quasi-static data.
|
||||
const auto it = delegate->static_unpacked_data_map_.find(t);
|
||||
if (it != delegate->static_unpacked_data_map_.end()) {
|
||||
data = delegate->static_unpacked_data_.data() + it->second;
|
||||
}
|
||||
}
|
||||
if (inputs.count(t) != 0) {
|
||||
flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT;
|
||||
@ -189,25 +217,38 @@ class Subgraph {
|
||||
}
|
||||
}
|
||||
|
||||
// Create a set of quasi-static tensors for VisitNode function
|
||||
std::unordered_set<int> quasi_static_tensors;
|
||||
for (const std::pair<const int, size_t>& entry :
|
||||
delegate->static_unpacked_data_map_) {
|
||||
quasi_static_tensors.insert(entry.first);
|
||||
}
|
||||
|
||||
// Create XNNPACK nodes for TFLite delegate nodes
|
||||
for (int i = 0; i < params->nodes_to_replace->size; i++) {
|
||||
const int node_index = params->nodes_to_replace->data[i];
|
||||
if (delegate->static_unpack_nodes_.count(node_index)) {
|
||||
// The node unpacks static input and can be skipped because its input
|
||||
// was pre-unpacked in DelegatePrepare.
|
||||
continue;
|
||||
}
|
||||
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
if (context->GetNodeAndRegistration(context,
|
||||
params->nodes_to_replace->data[i],
|
||||
&node, ®istration) != kTfLiteOk) {
|
||||
if (context->GetNodeAndRegistration(context, node_index, &node,
|
||||
®istration) != kTfLiteOk) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (VisitNode(subgraph.get(), context, registration, node, i,
|
||||
xnnpack_tensors) != kTfLiteOk) {
|
||||
if (VisitNode(subgraph.get(), context, registration, node, node_index,
|
||||
quasi_static_tensors, xnnpack_tensors) != kTfLiteOk) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
xnn_runtime_t runtime_ptr = nullptr;
|
||||
status = xnn_create_runtime_v2(subgraph.get(), threadpool, /*flags=*/0,
|
||||
&runtime_ptr);
|
||||
status = xnn_create_runtime_v2(subgraph.get(), delegate->threadpool(),
|
||||
/*flags=*/0, &runtime_ptr);
|
||||
if (status != xnn_status_success) {
|
||||
TF_LITE_KERNEL_LOG(context, "failed to create XNNPACK runtime");
|
||||
return nullptr;
|
||||
@ -707,10 +748,11 @@ class Subgraph {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus VisitNode(xnn_subgraph_t subgraph, TfLiteContext* context,
|
||||
TfLiteRegistration* registration,
|
||||
TfLiteNode* node, int node_index,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
static TfLiteStatus VisitNode(
|
||||
xnn_subgraph_t subgraph, TfLiteContext* context,
|
||||
TfLiteRegistration* registration, TfLiteNode* node, int node_index,
|
||||
const std::unordered_set<int>& quasi_static_tensors,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
// TFLite context used for logging purposes. When we create a new node
|
||||
// (subgraph is non-null), logging context is the same as context, and error
|
||||
// messages are passed to TFLite. When we detect supported operations
|
||||
@ -738,7 +780,8 @@ class Subgraph {
|
||||
static_cast<const TfLiteConvParams*>(node->builtin_data);
|
||||
|
||||
return VisitConv2DNode(subgraph, logging_context, node_index, node,
|
||||
context->tensors, conv_params, xnnpack_tensors);
|
||||
context->tensors, conv_params,
|
||||
quasi_static_tensors, xnnpack_tensors);
|
||||
}
|
||||
case kTfLiteBuiltinDepthwiseConv2d: {
|
||||
const TfLiteDepthwiseConvParams* dwconv_params =
|
||||
@ -746,7 +789,7 @@ class Subgraph {
|
||||
|
||||
return VisitDepthwiseConv2DNode(subgraph, logging_context, node_index,
|
||||
node, context->tensors, dwconv_params,
|
||||
xnnpack_tensors);
|
||||
quasi_static_tensors, xnnpack_tensors);
|
||||
}
|
||||
case kTfLiteBuiltinFullyConnected: {
|
||||
const TfLiteFullyConnectedParams* fc_params =
|
||||
@ -754,7 +797,7 @@ class Subgraph {
|
||||
|
||||
return VisitFullyConnectedNode(subgraph, logging_context, node_index,
|
||||
node, context->tensors, fc_params,
|
||||
xnnpack_tensors);
|
||||
quasi_static_tensors, xnnpack_tensors);
|
||||
}
|
||||
case kTfLiteBuiltinHardSwish:
|
||||
return VisitHardSwishNode(subgraph, logging_context, node_index, node,
|
||||
@ -782,7 +825,8 @@ class Subgraph {
|
||||
context->tensors, xnnpack_tensors);
|
||||
case kTfLiteBuiltinPrelu:
|
||||
return VisitPreluNode(subgraph, logging_context, node_index, node,
|
||||
context->tensors, xnnpack_tensors);
|
||||
context->tensors, quasi_static_tensors,
|
||||
xnnpack_tensors);
|
||||
case kTfLiteBuiltinRelu:
|
||||
return VisitReluNode(
|
||||
subgraph, logging_context, node_index, node, context->tensors, 0.0f,
|
||||
@ -810,7 +854,7 @@ class Subgraph {
|
||||
|
||||
return VisitMediaPipeDeconvolutionNode(
|
||||
subgraph, context, node_index, node, context->tensors,
|
||||
&deconv_params, xnnpack_tensors);
|
||||
&deconv_params, quasi_static_tensors, xnnpack_tensors);
|
||||
} else if (strcmp(registration->custom_name,
|
||||
"MaxPoolingWithArgmax2D") == 0) {
|
||||
TfLitePoolParams pool_params = {kTfLitePaddingUnknown};
|
||||
@ -948,6 +992,7 @@ class Subgraph {
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
const TfLiteConvParams* conv_params,
|
||||
const std::unordered_set<int>& quasi_static_tensors,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckConvolutionParams(logging_context, conv_params, node_index));
|
||||
@ -968,16 +1013,20 @@ class Subgraph {
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
|
||||
node->inputs->data[1]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
}
|
||||
|
||||
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, filter_tensor, node->inputs->data[2], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
|
||||
node->inputs->data[2]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
}
|
||||
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
@ -1034,6 +1083,7 @@ class Subgraph {
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
const TfLiteDepthwiseConvParams* dwconv_params,
|
||||
const std::unordered_set<int>& quasi_static_tensors,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
|
||||
@ -1051,16 +1101,20 @@ class Subgraph {
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
|
||||
node->inputs->data[1]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
}
|
||||
|
||||
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, filter_tensor, node->inputs->data[2], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
|
||||
node->inputs->data[2]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
}
|
||||
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
@ -1123,6 +1177,7 @@ class Subgraph {
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
const TfLiteFullyConnectedParams* fc_params,
|
||||
const std::unordered_set<int>& quasi_static_tensors,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckFullyConnectedParams(logging_context, fc_params, node_index));
|
||||
@ -1141,16 +1196,20 @@ class Subgraph {
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 2,
|
||||
node->inputs->data[1]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
}
|
||||
|
||||
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, filter_tensor, node->inputs->data[2], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
|
||||
node->inputs->data[2]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
}
|
||||
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
@ -1387,6 +1446,7 @@ class Subgraph {
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
const TfLiteTransposeConvParams* deconv_params,
|
||||
const std::unordered_set<int>& quasi_static_tensors,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
|
||||
@ -1404,16 +1464,20 @@ class Subgraph {
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
|
||||
node->inputs->data[1]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
}
|
||||
|
||||
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, filter_tensor, node->inputs->data[2], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
|
||||
node->inputs->data[2]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
}
|
||||
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
@ -1735,6 +1799,7 @@ class Subgraph {
|
||||
static TfLiteStatus VisitPreluNode(
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
const std::unordered_set<int>& quasi_static_tensors,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
|
||||
@ -1752,8 +1817,10 @@ class Subgraph {
|
||||
logging_context, slope_tensor, node->inputs->data[1], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckSlopeTensorShape(
|
||||
logging_context, slope_tensor, node->inputs->data[1], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, slope_tensor, node->inputs->data[1], node_index));
|
||||
if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, slope_tensor, node->inputs->data[1], node_index));
|
||||
}
|
||||
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
@ -1869,15 +1936,29 @@ class Subgraph {
|
||||
bool first_run_{true};
|
||||
};
|
||||
|
||||
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
|
||||
TfLiteIntArray* Delegate::PrepareOpsToDelegate(TfLiteContext* context) {
|
||||
// Clear previous data, in case the delegate is reused without re-creation.
|
||||
static_unpacked_data_map_.clear();
|
||||
static_unpacked_data_.clear();
|
||||
static_unpack_nodes_.clear();
|
||||
|
||||
TfLiteIntArray* execution_plan = nullptr;
|
||||
if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
|
||||
TF_LITE_KERNEL_LOG(context, "Unable to get graph execution plan.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TfLiteIntArray* nodes_to_replace = TfLiteIntArrayCreate(execution_plan->size);
|
||||
nodes_to_replace->size = 0;
|
||||
// Mapping for quasi-static (unpacked from static) tensor index to the node
|
||||
// index that produced it.
|
||||
std::unordered_map<int, int> quasi_static_tensors_producers;
|
||||
// Set of all quasi-static tensors in the execution plan.
|
||||
std::unordered_set<int> quasi_static_tensors;
|
||||
// Set of quasi-static tensors consumed by the delegated nodes.
|
||||
std::unordered_set<int> quasi_static_tensors_to_unpack;
|
||||
|
||||
TfLiteIntArray* nodes_to_delegate =
|
||||
TfLiteIntArrayCreate(execution_plan->size);
|
||||
nodes_to_delegate->size = 0;
|
||||
for (int i = 0; i < execution_plan->size; ++i) {
|
||||
const int node_index = execution_plan->data[i];
|
||||
|
||||
@ -1892,15 +1973,142 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
|
||||
continue; // Soft error (skip this node).
|
||||
}
|
||||
|
||||
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
|
||||
node->inputs->size == 1 && node->outputs->size == 1) {
|
||||
const TfLiteTensor& input_tensor =
|
||||
context->tensors[node->inputs->data[0]];
|
||||
const TfLiteTensor& output_tensor =
|
||||
context->tensors[node->outputs->data[0]];
|
||||
if (input_tensor.allocation_type == kTfLiteMmapRo &&
|
||||
input_tensor.type == kTfLiteFloat16 &&
|
||||
output_tensor.type == kTfLiteFloat32) {
|
||||
static_unpack_nodes_.insert(i);
|
||||
quasi_static_tensors_producers[node->outputs->data[0]] = i;
|
||||
quasi_static_tensors.insert(node->outputs->data[0]);
|
||||
|
||||
// Skip this node for now. If output of the node is consumed only by
|
||||
// delegated nodes, it will be added to nodes_to_delegate in the end.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (Subgraph::VisitNode(/*subgraph=*/nullptr, context, registration, node,
|
||||
node_index, std::vector<uint32_t>()) != kTfLiteOk) {
|
||||
node_index, quasi_static_tensors,
|
||||
std::vector<uint32_t>()) != kTfLiteOk) {
|
||||
// If a non-delegated node consumes output of a node that unpacks static
|
||||
// data, that node shouldn't be delegated.
|
||||
for (int j = 0; j < node->inputs->size; j++) {
|
||||
const auto it =
|
||||
quasi_static_tensors_producers.find(node->inputs->data[j]);
|
||||
if (it != quasi_static_tensors_producers.end()) {
|
||||
static_unpack_nodes_.erase(it->second);
|
||||
}
|
||||
}
|
||||
|
||||
// Non-delegatable node is not an error.
|
||||
continue;
|
||||
}
|
||||
|
||||
nodes_to_replace->data[nodes_to_replace->size++] = node_index;
|
||||
for (int j = 0; j < node->inputs->size; j++) {
|
||||
if (quasi_static_tensors.count(node->inputs->data[j]) != 0) {
|
||||
quasi_static_tensors_to_unpack.insert(node->inputs->data[j]);
|
||||
}
|
||||
}
|
||||
|
||||
nodes_to_delegate->data[nodes_to_delegate->size++] = node_index;
|
||||
}
|
||||
|
||||
// Unpack static data of all tensors
|
||||
for (int t : quasi_static_tensors_to_unpack) {
|
||||
const int producer_index = quasi_static_tensors_producers[t];
|
||||
// Check if TFLite nodes can be delegated to XNNPACK
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
if (context->GetNodeAndRegistration(context, producer_index, &node,
|
||||
®istration) != kTfLiteOk) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Unable to get node and registration for node %d.",
|
||||
producer_index);
|
||||
TfLiteIntArrayFree(nodes_to_delegate);
|
||||
return nullptr; // Hard error.
|
||||
}
|
||||
|
||||
if (node->inputs->size != 1) {
|
||||
TF_LITE_KERNEL_LOG(context, "unexpected number of inputs (%d) in node %d",
|
||||
node->inputs->size, producer_index);
|
||||
TfLiteIntArrayFree(nodes_to_delegate);
|
||||
return nullptr; // Hard error.
|
||||
}
|
||||
|
||||
if (node->outputs->size != 1) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"unexpected number of outputs (%d) in node %d",
|
||||
node->outputs->size, producer_index);
|
||||
TfLiteIntArrayFree(nodes_to_delegate);
|
||||
return nullptr; // Hard error.
|
||||
}
|
||||
|
||||
const TfLiteTensor& input_tensor = context->tensors[node->inputs->data[0]];
|
||||
if (input_tensor.allocation_type != kTfLiteMmapRo) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"unexpected allocation type in tensor %d in node %d",
|
||||
node->inputs->data[0], producer_index);
|
||||
TfLiteIntArrayFree(nodes_to_delegate);
|
||||
return nullptr; // Hard error.
|
||||
}
|
||||
|
||||
const TfLiteTensor& output_tensor = context->tensors[t];
|
||||
if (output_tensor.type != kTfLiteFloat32) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"unexpected datatype (%s) in tensor %d in node %d",
|
||||
TfLiteTypeGetName(output_tensor.type),
|
||||
node->outputs->data[0], producer_index);
|
||||
TfLiteIntArrayFree(nodes_to_delegate);
|
||||
return nullptr; // Hard error.
|
||||
}
|
||||
const size_t tensor_elements = output_tensor.bytes / sizeof(float);
|
||||
|
||||
// Align to XNN_EXTRA_BYTES bytes
|
||||
while (static_unpacked_data_.size() % XNN_EXTRA_BYTES != 0) {
|
||||
static_unpacked_data_.push_back(0);
|
||||
}
|
||||
const size_t tensor_offset = static_unpacked_data_.size();
|
||||
static_unpacked_data_.resize(tensor_offset + context->tensors[t].bytes);
|
||||
|
||||
float* unpacked_data =
|
||||
reinterpret_cast<float*>(static_unpacked_data_.data() + tensor_offset);
|
||||
switch (input_tensor.type) {
|
||||
case kTfLiteFloat16: {
|
||||
const uint16_t* packed_data =
|
||||
static_cast<const uint16_t*>(input_tensor.data.data);
|
||||
for (size_t i = 0; i < tensor_elements; i++) {
|
||||
unpacked_data[i] = fp16_ieee_to_fp32_value(packed_data[i]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"unexpected datatype (%s) in tensor %d in node %d",
|
||||
TfLiteTypeGetName(output_tensor.type),
|
||||
node->outputs->data[0], producer_index);
|
||||
TfLiteIntArrayFree(nodes_to_delegate);
|
||||
return nullptr; // Hard error.
|
||||
}
|
||||
|
||||
static_unpacked_data_map_[t] = tensor_offset;
|
||||
}
|
||||
|
||||
// Add nodes that unpack static data consumed by delegated nodes.
|
||||
// Note: this is done purely to avoid the overhead of running these nodes
|
||||
// again in TFLite interpreter which would allocate memory for their outputs.
|
||||
// We mark them as delegated, but the delegate would simply ignore these nodes
|
||||
// as the static weights are already unpacked.
|
||||
for (int node_index : static_unpack_nodes_) {
|
||||
nodes_to_delegate->data[nodes_to_delegate->size++] = node_index;
|
||||
}
|
||||
std::sort(&nodes_to_delegate->data[0],
|
||||
&nodes_to_delegate->data[nodes_to_delegate->size]);
|
||||
|
||||
#ifdef XNNPACK_DELEGATE_TEST_MODE
|
||||
// In the test mode build (used by unit tests), XNNPACK delegate claims to
|
||||
// support all operators in the execution plan to disable fallback to the
|
||||
@ -1908,24 +2116,22 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
|
||||
// not supported by the delegate, they will cause a failure in
|
||||
// ::tflite::Interpreter::ModifyGraphWithDelegate, to be caught in the unit
|
||||
// tests.
|
||||
nodes_to_replace->size = execution_plan->size;
|
||||
nodes_to_delegate->size = execution_plan->size;
|
||||
std::copy(&execution_plan->data[0],
|
||||
&execution_plan->data[execution_plan->size],
|
||||
&nodes_to_replace->data[0]);
|
||||
&nodes_to_delegate->data[0]);
|
||||
#endif
|
||||
|
||||
return nodes_to_replace;
|
||||
return nodes_to_delegate;
|
||||
}
|
||||
|
||||
void* SubgraphInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
const TfLiteDelegateParams* params =
|
||||
reinterpret_cast<const TfLiteDelegateParams*>(buffer);
|
||||
|
||||
pthreadpool_t threadpool =
|
||||
static_cast<::tflite::xnnpack::Delegate*>(params->delegate->data_)
|
||||
->threadpool();
|
||||
|
||||
return static_cast<void*>(Subgraph::Create(context, params, threadpool));
|
||||
return static_cast<void*>(Subgraph::Create(
|
||||
context, params,
|
||||
static_cast<::tflite::xnnpack::Delegate*>(params->delegate->data_)));
|
||||
}
|
||||
|
||||
TfLiteStatus SubgraphPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
@ -1962,7 +2168,9 @@ const TfLiteRegistration kSubgraphRegistration = {
|
||||
};
|
||||
|
||||
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||
TfLiteIntArray* ops_to_replace =
|
||||
static_cast<::tflite::xnnpack::Delegate*>(delegate->data_)
|
||||
->PrepareOpsToDelegate(context);
|
||||
const TfLiteStatus status = context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||
context, kSubgraphRegistration, ops_to_replace, delegate);
|
||||
TfLiteIntArrayFree(ops_to_replace);
|
||||
|
Loading…
x
Reference in New Issue
Block a user