Support models with FP16 weights in XNNPACK delegate

PiperOrigin-RevId: 313505742
Change-Id: Id21f7528741073e93a7132d529c3cd79957a73fb
This commit is contained in:
Marat Dukhan 2020-05-27 18:50:31 -07:00 committed by TensorFlower Gardener
parent fb86acf839
commit a7048d89a1
13 changed files with 942 additions and 283 deletions

View File

@ -24,6 +24,7 @@ cc_library(
"//tensorflow/lite:util",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
"@FP16",
"@XNNPACK",
],
)
@ -39,6 +40,7 @@ cc_library(
"//tensorflow/lite:util",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
"@FP16",
"@XNNPACK",
],
)
@ -56,6 +58,7 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/schema:schema_fbs",
"@FP16",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
@ -72,6 +75,7 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/schema:schema_fbs",
"@FP16",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
@ -88,6 +92,7 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/schema:schema_fbs",
"@FP16",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
@ -215,6 +220,7 @@ cc_test(
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/schema:schema_fbs",
"@FP16",
"@com_google_googletest//:gtest",
"@flatbuffers",
],

View File

@ -679,6 +679,35 @@ TEST(Add, 2DByStatic0D) {
.Test(BuiltinOperator_ADD, xnnpack_delegate.get());
}
TEST(Add, FP16Weights) {
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
TfLiteXNNPackDelegateDelete);
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto shape_rng =
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
const auto batch = shape_rng();
const auto height = shape_rng();
const auto width = shape_rng();
const auto channels = shape_rng();
BinaryElementwiseTester()
.Input1Shape({batch, height, width, channels})
.Input2Shape({batch, height, width, channels})
.Input1Static(true)
.FP16Weights()
.Test(BuiltinOperator_ADD, xnnpack_delegate.get());
BinaryElementwiseTester()
.Input1Shape({batch, height, width, channels})
.Input2Shape({batch, height, width, channels})
.Input2Static(true)
.FP16Weights()
.Test(BuiltinOperator_ADD, xnnpack_delegate.get());
}
TEST(Add, ReluActivation) {
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
#include <fp16.h>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
@ -62,6 +63,9 @@ void BinaryElementwiseTester::Test(tflite::BuiltinOperator binary_op,
if (Input1Static()) {
ASSERT_FALSE(Input2Static());
}
if (FP16Weights()) {
ASSERT_TRUE(Input1Static() || Input2Static());
}
std::random_device random_device;
auto rng = std::mt19937(random_device());
@ -180,8 +184,12 @@ std::vector<char> BinaryElementwiseTester::CreateTfLiteModel(
auto input2_rng = std::bind(input2_distribution, std::ref(rng));
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<OperatorCode> operator_code =
CreateOperatorCode(builder, binary_op);
std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
{CreateOperatorCode(builder, binary_op)}};
if (FP16Weights()) {
operator_codes.emplace_back(
CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
}
std::vector<flatbuffers::Offset<Buffer>> buffers{{
CreateBuffer(builder, builder.CreateVector({})),
@ -189,43 +197,89 @@ std::vector<char> BinaryElementwiseTester::CreateTfLiteModel(
int32_t input1_buffer = 0;
if (Input1Static()) {
std::vector<float> input1_data(ComputeSize(Input1Shape()));
std::generate(input1_data.begin(), input1_data.end(), input1_rng);
if (FP16Weights()) {
std::vector<uint16_t> input1_data(ComputeSize(Input1Shape()));
std::generate(input1_data.begin(), input1_data.end(),
std::bind(fp16_ieee_from_fp32_value, input1_rng));
input1_buffer = buffers.size();
buffers.push_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(input1_data.data()),
sizeof(float) * input1_data.size())));
buffers.push_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(input1_data.data()),
sizeof(uint16_t) * input1_data.size())));
} else {
std::vector<float> input1_data(ComputeSize(Input1Shape()));
std::generate(input1_data.begin(), input1_data.end(), input1_rng);
input1_buffer = buffers.size();
buffers.push_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(input1_data.data()),
sizeof(float) * input1_data.size())));
}
}
int32_t input2_buffer = 0;
if (Input2Static()) {
std::vector<float> input2_data(ComputeSize(Input2Shape()));
std::generate(input2_data.begin(), input2_data.end(), input2_rng);
if (FP16Weights()) {
std::vector<uint16_t> input2_data(ComputeSize(Input2Shape()));
std::generate(input2_data.begin(), input2_data.end(),
std::bind(fp16_ieee_from_fp32_value, input1_rng));
input2_buffer = buffers.size();
buffers.push_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(input2_data.data()),
sizeof(float) * input2_data.size())));
buffers.push_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(input2_data.data()),
sizeof(uint16_t) * input2_data.size())));
} else {
std::vector<float> input2_data(ComputeSize(Input2Shape()));
std::generate(input2_data.begin(), input2_data.end(), input2_rng);
input2_buffer = buffers.size();
buffers.push_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(input2_data.data()),
sizeof(float) * input2_data.size())));
}
}
const std::vector<int32_t> output_shape = OutputShape();
const std::array<flatbuffers::Offset<Tensor>, 3> tensors{{
CreateTensor(builder,
builder.CreateVector<int32_t>(Input1Shape().data(),
Input1Shape().size()),
TensorType_FLOAT32, input1_buffer),
CreateTensor(builder,
builder.CreateVector<int32_t>(Input2Shape().data(),
Input2Shape().size()),
TensorType_FLOAT32, input2_buffer),
CreateTensor(builder,
builder.CreateVector<int32_t>(output_shape.data(),
output_shape.size()),
TensorType_FLOAT32),
}};
std::vector<flatbuffers::Offset<Tensor>> tensors;
std::vector<flatbuffers::Offset<Operator>> operators;
if (FP16Weights() && Input1Static()) {
tensors.emplace_back(
CreateTensor(builder,
builder.CreateVector<int32_t>(Input1Shape().data(),
Input1Shape().size()),
TensorType_FLOAT16, 1));
}
if (FP16Weights() && Input2Static()) {
tensors.emplace_back(
CreateTensor(builder,
builder.CreateVector<int32_t>(Input2Shape().data(),
Input2Shape().size()),
TensorType_FLOAT16, 1));
}
if (FP16Weights()) {
const std::array<int32_t, 1> dequantize_inputs{{0}};
const std::array<int32_t, 1> dequantize_outputs{{Input1Static() ? 1 : 2}};
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/1,
builder.CreateVector<int32_t>(dequantize_inputs.data(),
dequantize_inputs.size()),
builder.CreateVector<int32_t>(dequantize_outputs.data(),
dequantize_outputs.size())));
}
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(Input1Shape().data(), Input1Shape().size()),
TensorType_FLOAT32, input1_buffer));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(Input2Shape().data(), Input2Shape().size()),
TensorType_FLOAT32, input2_buffer));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
TensorType_FLOAT32));
tflite::BuiltinOptions builtin_options_type = tflite::BuiltinOptions_NONE;
flatbuffers::Offset<void> builtin_options = 0;
@ -250,35 +304,40 @@ std::vector<char> BinaryElementwiseTester::CreateTfLiteModel(
EXPECT_EQ(Activation(), ActivationFunctionType_NONE);
}
const std::array<int32_t, 2> op_inputs{{0, 1}};
const std::array<int32_t, 1> op_outputs{{2}};
flatbuffers::Offset<Operator> op = CreateOperator(
const std::array<int32_t, 2> op_inputs{
{static_cast<int>(tensors.size()) - 3,
static_cast<int>(tensors.size()) - 2}};
const std::array<int32_t, 1> op_outputs{
{static_cast<int>(tensors.size()) - 1}};
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/0,
builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
builtin_options_type, builtin_options);
builtin_options_type, builtin_options));
std::vector<int32_t> subgraph_inputs;
if (!Input1Static()) {
subgraph_inputs.push_back(0);
subgraph_inputs.push_back(tensors.size() - 3);
}
if (!Input2Static()) {
subgraph_inputs.push_back(1);
subgraph_inputs.push_back(tensors.size() - 2);
}
const std::array<int32_t, 1> subgraph_outputs{{2}};
const std::array<int32_t, 1> subgraph_outputs{
{static_cast<int>(tensors.size()) - 1}};
flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
builder, builder.CreateVector(tensors.data(), tensors.size()),
builder.CreateVector<int32_t>(subgraph_inputs.data(),
subgraph_inputs.size()),
builder.CreateVector<int32_t>(subgraph_outputs.data(),
subgraph_outputs.size()),
builder.CreateVector(&op, 1));
builder.CreateVector(operators.data(), operators.size()));
flatbuffers::Offset<flatbuffers::String> description =
builder.CreateString("Binary operator model");
flatbuffers::Offset<Model> model_buffer = CreateModel(
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
builder, TFLITE_SCHEMA_VERSION,
builder.CreateVector(operator_codes.data(), operator_codes.size()),
builder.CreateVector(&subgraph, 1), description,
builder.CreateVector(buffers.data(), buffers.size()));

View File

@ -74,6 +74,13 @@ class BinaryElementwiseTester {
inline bool Input2Static() const { return input2_static_; }
inline BinaryElementwiseTester& FP16Weights() {
fp16_weights_ = true;
return *this;
}
inline bool FP16Weights() const { return fp16_weights_; }
inline BinaryElementwiseTester& ReluActivation() {
activation_ = ::tflite::ActivationFunctionType_RELU;
return *this;
@ -114,6 +121,7 @@ class BinaryElementwiseTester {
std::vector<int32_t> input2_shape_;
bool input1_static_ = false;
bool input2_static_ = false;
bool fp16_weights_ = false;
::tflite::ActivationFunctionType activation_ =
::tflite::ActivationFunctionType_NONE;
};

View File

@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <cstdint>
#include <functional>
#include <random>
#include <vector>
#include <gtest/gtest.h>
#include <fp16.h>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#include "tensorflow/lite/interpreter.h"
@ -146,6 +148,13 @@ class Conv2DTester {
int32_t DilationWidth() const { return dilation_width_; }
inline Conv2DTester& FP16Weights() {
fp16_weights_ = true;
return *this;
}
inline bool FP16Weights() const { return fp16_weights_; }
Conv2DTester& SamePadding(bool same_padding) {
same_padding_ = same_padding;
return *this;
@ -154,11 +163,7 @@ class Conv2DTester {
bool SamePadding() const { return same_padding_; }
void Test(TfLiteDelegate* delegate) const {
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
std::vector<char> buffer = CreateTfLiteModel(std::ref(f32rng));
std::vector<char> buffer = CreateTfLiteModel();
const Model* model = GetModel(buffer.data());
std::unique_ptr<Interpreter> delegate_interpreter;
@ -187,6 +192,10 @@ class Conv2DTester {
ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate),
kTfLiteOk);
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
float* default_input_data = default_interpreter->typed_tensor<float>(
default_interpreter->inputs()[0]);
std::generate(default_input_data,
@ -219,82 +228,149 @@ class Conv2DTester {
}
private:
std::vector<char> CreateTfLiteModel(std::function<float()> f32rng) const {
std::vector<char> CreateTfLiteModel() const {
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<OperatorCode> operator_code =
CreateOperatorCode(builder, BuiltinOperator_CONV_2D, 0);
std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
{CreateOperatorCode(builder, BuiltinOperator_CONV_2D, 0)}};
std::vector<flatbuffers::Offset<tflite::Operator>> operators;
std::vector<flatbuffers::Offset<tflite::Buffer>> buffers{
{CreateBuffer(builder, builder.CreateVector({}))}};
if (FP16Weights()) {
operator_codes.emplace_back(
CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
std::vector<uint16_t> filter_data(OutputChannels() * KernelHeight() *
KernelWidth() * InputChannels());
std::vector<uint16_t> bias_data(OutputChannels());
std::generate(filter_data.begin(), filter_data.end(), f16rng);
std::generate(bias_data.begin(), bias_data.end(), f16rng);
buffers.emplace_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(filter_data.data()),
sizeof(uint16_t) * filter_data.size())));
buffers.emplace_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(bias_data.data()),
sizeof(uint16_t) * bias_data.size())));
const std::array<int32_t, 1> dequantize_filter_inputs{{0}};
const std::array<int32_t, 1> dequantize_filter_outputs{{3}};
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/1,
builder.CreateVector<int32_t>(dequantize_filter_inputs.data(),
dequantize_filter_inputs.size()),
builder.CreateVector<int32_t>(dequantize_filter_outputs.data(),
dequantize_filter_outputs.size())));
const std::array<int32_t, 1> dequantize_bias_inputs{{1}};
const std::array<int32_t, 1> dequantize_bias_outputs{{4}};
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/1,
builder.CreateVector<int32_t>(dequantize_bias_inputs.data(),
dequantize_bias_inputs.size()),
builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
dequantize_bias_outputs.size())));
} else {
std::vector<float> filter_data(OutputChannels() * KernelHeight() *
KernelWidth() * InputChannels());
std::vector<float> bias_data(OutputChannels());
std::generate(filter_data.begin(), filter_data.end(), f32rng);
std::generate(bias_data.begin(), bias_data.end(), f32rng);
buffers.emplace_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(filter_data.data()),
sizeof(float) * filter_data.size())));
buffers.emplace_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(bias_data.data()),
sizeof(float) * bias_data.size())));
}
const std::array<int32_t, 4> input_shape{
{BatchSize(), InputHeight(), InputWidth(), InputChannels()}};
const std::array<int32_t, 4> output_shape{
{BatchSize(), OutputHeight(), OutputWidth(), OutputChannels()}};
const std::array<int32_t, 4> filter_shape{
{OutputChannels(), KernelHeight(), KernelWidth(), InputChannels()}};
const std::array<int32_t, 1> bias_shape{{OutputChannels()}};
std::vector<flatbuffers::Offset<tflite::Tensor>> tensors;
if (FP16Weights()) {
tensors.emplace_back(
CreateTensor(builder,
builder.CreateVector<int32_t>(filter_shape.data(),
filter_shape.size()),
TensorType_FLOAT16, /*buffer=*/1));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
TensorType_FLOAT16, /*buffer=*/2));
}
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(input_shape.data(), input_shape.size()),
TensorType_FLOAT32));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 1));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
TensorType_FLOAT32));
const std::array<int32_t, 3> op_inputs{
{static_cast<int>(tensors.size()) - 4,
static_cast<int>(tensors.size()) - 3,
static_cast<int>(tensors.size()) - 2}};
const std::array<int32_t, 1> op_outputs{
{static_cast<int>(tensors.size()) - 1}};
flatbuffers::Offset<Conv2DOptions> conv2d_options = CreateConv2DOptions(
builder, SamePadding() ? tflite::Padding_SAME : tflite::Padding_VALID,
StrideWidth(), StrideHeight(), ActivationFunctionType_NONE,
DilationWidth(), DilationHeight());
std::vector<float> filter_data(OutputChannels() * KernelHeight() *
KernelWidth() * InputChannels());
std::vector<float> bias_data(OutputChannels());
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/0,
builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
BuiltinOptions_Conv2DOptions, conv2d_options.Union()));
std::generate(filter_data.begin(), filter_data.end(), f32rng);
std::generate(bias_data.begin(), bias_data.end(), f32rng);
flatbuffers::Offset<Buffer> buffers[3] = {
CreateBuffer(builder, builder.CreateVector({})),
CreateBuffer(builder,
builder.CreateVector(
reinterpret_cast<const uint8_t*>(filter_data.data()),
sizeof(float) * filter_data.size())),
CreateBuffer(builder,
builder.CreateVector(
reinterpret_cast<const uint8_t*>(bias_data.data()),
sizeof(float) * bias_data.size())),
};
const int32_t input_shape[4] = {BatchSize(), InputHeight(), InputWidth(),
InputChannels()};
const int32_t output_shape[4] = {BatchSize(), OutputHeight(), OutputWidth(),
OutputChannels()};
const int32_t filter_shape[4] = {OutputChannels(), KernelHeight(),
KernelWidth(), InputChannels()};
const int32_t bias_shape[1] = {OutputChannels()};
flatbuffers::Offset<Tensor> tensors[4] = {
CreateTensor(builder, builder.CreateVector<int32_t>(input_shape, 4),
TensorType_FLOAT32, /*buffer=*/0,
builder.CreateString("X")),
CreateTensor(builder, builder.CreateVector<int32_t>(filter_shape, 4),
TensorType_FLOAT32, /*buffer=*/1,
builder.CreateString("W")),
CreateTensor(builder, builder.CreateVector<int32_t>(bias_shape, 1),
TensorType_FLOAT32, /*buffer=*/2,
builder.CreateString("b")),
CreateTensor(builder, builder.CreateVector<int32_t>(output_shape, 4),
TensorType_FLOAT32, /*buffer=*/0,
builder.CreateString("Y")),
};
const int32_t op_inputs[3] = {0, 1, 2};
const int32_t op_outputs[1] = {3};
flatbuffers::Offset<Operator> op =
CreateOperator(builder, /*opcode_index=*/0,
builder.CreateVector<int32_t>(op_inputs, 3),
builder.CreateVector<int32_t>(op_outputs, 1),
BuiltinOptions_Conv2DOptions, conv2d_options.Union());
int32_t subgraph_inputs[1] = {0};
int32_t subgraph_outputs[1] = {3};
flatbuffers::Offset<SubGraph> subgraph =
CreateSubGraph(builder, builder.CreateVector(tensors, 4),
builder.CreateVector<int32_t>(subgraph_inputs, 1),
builder.CreateVector<int32_t>(subgraph_outputs, 1),
builder.CreateVector(&op, 1), /*name=*/0);
const std::array<int32_t, 1> subgraph_inputs{
{static_cast<int>(tensors.size()) - 4}};
const std::array<int32_t, 1> subgraph_outputs{
{static_cast<int>(tensors.size()) - 1}};
flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
builder, builder.CreateVector(tensors.data(), tensors.size()),
builder.CreateVector<int32_t>(subgraph_inputs.data(),
subgraph_inputs.size()),
builder.CreateVector<int32_t>(subgraph_outputs.data(),
subgraph_outputs.size()),
builder.CreateVector(operators.data(), operators.size()));
flatbuffers::Offset<flatbuffers::String> description =
builder.CreateString("Conv2D model");
flatbuffers::Offset<Model> model_buffer = CreateModel(
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
builder, TFLITE_SCHEMA_VERSION,
builder.CreateVector(operator_codes.data(), operator_codes.size()),
builder.CreateVector(&subgraph, 1), description,
builder.CreateVector(buffers, 3));
builder.CreateVector(buffers.data(), buffers.size()));
builder.Finish(model_buffer);
@ -313,6 +389,7 @@ class Conv2DTester {
int32_t stride_width_ = 1;
int32_t dilation_height_ = 1;
int32_t dilation_width_ = 1;
bool fp16_weights_ = false;
bool same_padding_ = true;
};
@ -506,5 +583,35 @@ TEST(Conv2D, DilationWithValidPadding) {
.Test(xnnpack_delegate.get());
}
TEST(Conv2D, FP16Weights) {
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
TfLiteXNNPackDelegateDelete);
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto input_rng =
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
auto kernel_rng =
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
auto stride_rng =
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
auto channel_rng =
std::bind(std::uniform_int_distribution<int32_t>(1, 16), std::ref(rng));
Conv2DTester()
.InputHeight(input_rng())
.InputWidth(input_rng())
.InputChannels(channel_rng())
.OutputChannels(channel_rng())
.KernelHeight(kernel_rng())
.KernelWidth(kernel_rng())
.StrideHeight(stride_rng())
.StrideWidth(stride_rng())
.SamePadding(true)
.FP16Weights()
.Test(xnnpack_delegate.get());
}
} // namespace xnnpack
} // namespace tflite

View File

@ -371,6 +371,37 @@ TEST(DepthwiseConv2D, DepthMultiplier) {
.Test(xnnpack_delegate.get());
}
TEST(DepthwiseConv2D, FP16Weights) {
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
TfLiteXNNPackDelegateDelete);
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto batch_rng =
std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
auto input_rng =
std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
auto kernel_rng =
std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
auto stride_rng =
std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
auto channel_rng =
std::bind(std::uniform_int_distribution<int32_t>(3, 32), std::ref(rng));
DepthwiseConv2DTester()
.BatchSize(batch_rng())
.InputHeight(input_rng())
.InputWidth(input_rng())
.InputChannels(channel_rng())
.KernelHeight(kernel_rng())
.KernelWidth(kernel_rng())
.StrideHeight(stride_rng())
.StrideWidth(stride_rng())
.FP16Weights()
.Test(xnnpack_delegate.get());
}
TEST(DepthwiseConv2D, ReluActivation) {
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
#include <fp16.h>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
@ -107,56 +108,110 @@ void DepthwiseConv2DTester::Test(TfLiteDelegate* delegate) const {
}
std::vector<char> DepthwiseConv2DTester::CreateTfLiteModel() const {
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<OperatorCode> operator_code =
CreateOperatorCode(builder, BuiltinOperator_DEPTHWISE_CONV_2D);
flatbuffers::Offset<DepthwiseConv2DOptions> depthwise_conv2d_options =
CreateDepthwiseConv2DOptions(
builder, Padding(), StrideWidth(), StrideHeight(), DepthMultiplier(),
Activation(), DilationWidth(), DilationHeight());
std::vector<float> filter_data(KernelHeight() * KernelWidth() *
OutputChannels());
std::vector<float> bias_data(OutputChannels());
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto range_rng = std::bind(
std::uniform_real_distribution<float>(-25.0f, 25.0f), std::ref(rng));
for (int32_t ic = 0; ic < InputChannels(); ic++) {
// Use the same range of all-positive or all-negative values to generate
// all pixels within the same batch index & channel, but different ranges
// for different channels or batches. This ensures that no catastrophic
// cancellation occur, but test covers both positive and negative inputs.
const float range = range_rng();
auto value_rng =
std::bind(std::uniform_real_distribution<float>(std::min(range, 0.0f),
std::max(range, 0.0f)),
std::ref(rng));
for (int32_t m = 0; m < DepthMultiplier(); m++) {
const int32_t oc = ic * DepthMultiplier() + m;
bias_data[oc] = value_rng();
for (int32_t y = 0; y < KernelHeight(); y++) {
for (int32_t x = 0; x < KernelWidth(); x++) {
const int32_t index = (y * KernelWidth() + x) * OutputChannels() + oc;
filter_data[index] = value_rng();
flatbuffers::FlatBufferBuilder builder;
std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
{CreateOperatorCode(builder, BuiltinOperator_DEPTHWISE_CONV_2D)}};
std::vector<flatbuffers::Offset<tflite::Operator>> operators;
std::vector<flatbuffers::Offset<tflite::Buffer>> buffers{
{CreateBuffer(builder, builder.CreateVector({}))}};
if (FP16Weights()) {
operator_codes.emplace_back(
CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
std::vector<uint16_t> filter_data(KernelHeight() * KernelWidth() *
OutputChannels());
std::vector<uint16_t> bias_data(OutputChannels());
for (int32_t ic = 0; ic < InputChannels(); ic++) {
// Use the same range of all-positive or all-negative values to generate
// all pixels within the same batch index & channel, but different ranges
// for different channels or batches. This ensures that no catastrophic
// cancellation occur, but test covers both positive and negative inputs.
const float range = range_rng();
auto value_rng =
std::bind(fp16_ieee_from_fp32_value,
std::bind(std::uniform_real_distribution<float>(
std::min(range, 0.0f), std::max(range, 0.0f)),
std::ref(rng)));
for (int32_t m = 0; m < DepthMultiplier(); m++) {
const int32_t oc = ic * DepthMultiplier() + m;
bias_data[oc] = value_rng();
for (int32_t y = 0; y < KernelHeight(); y++) {
for (int32_t x = 0; x < KernelWidth(); x++) {
const int32_t index =
(y * KernelWidth() + x) * OutputChannels() + oc;
filter_data[index] = value_rng();
}
}
}
}
}
const std::array<flatbuffers::Offset<tflite::Buffer>, 3> buffers{{
CreateBuffer(builder, builder.CreateVector({})),
CreateBuffer(builder,
builder.CreateVector(
reinterpret_cast<const uint8_t*>(filter_data.data()),
sizeof(float) * filter_data.size())),
CreateBuffer(builder,
builder.CreateVector(
reinterpret_cast<const uint8_t*>(bias_data.data()),
sizeof(float) * bias_data.size())),
}};
buffers.emplace_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(filter_data.data()),
sizeof(uint16_t) * filter_data.size())));
buffers.emplace_back(CreateBuffer(
builder,
builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
sizeof(uint16_t) * bias_data.size())));
const std::array<int32_t, 1> dequantize_filter_inputs{{0}};
const std::array<int32_t, 1> dequantize_filter_outputs{{3}};
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/1,
builder.CreateVector<int32_t>(dequantize_filter_inputs.data(),
dequantize_filter_inputs.size()),
builder.CreateVector<int32_t>(dequantize_filter_outputs.data(),
dequantize_filter_outputs.size())));
const std::array<int32_t, 1> dequantize_bias_inputs{{1}};
const std::array<int32_t, 1> dequantize_bias_outputs{{4}};
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/1,
builder.CreateVector<int32_t>(dequantize_bias_inputs.data(),
dequantize_bias_inputs.size()),
builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
dequantize_bias_outputs.size())));
} else {
std::vector<float> filter_data(KernelHeight() * KernelWidth() *
OutputChannels());
std::vector<float> bias_data(OutputChannels());
for (int32_t ic = 0; ic < InputChannels(); ic++) {
// Use the same range of all-positive or all-negative values to generate
// all pixels within the same batch index & channel, but different ranges
// for different channels or batches. This ensures that no catastrophic
// cancellation occur, but test covers both positive and negative inputs.
const float range = range_rng();
auto value_rng =
std::bind(std::uniform_real_distribution<float>(
std::min(range, 0.0f), std::max(range, 0.0f)),
std::ref(rng));
for (int32_t m = 0; m < DepthMultiplier(); m++) {
const int32_t oc = ic * DepthMultiplier() + m;
bias_data[oc] = value_rng();
for (int32_t y = 0; y < KernelHeight(); y++) {
for (int32_t x = 0; x < KernelWidth(); x++) {
const int32_t index =
(y * KernelWidth() + x) * OutputChannels() + oc;
filter_data[index] = value_rng();
}
}
}
}
buffers.emplace_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(filter_data.data()),
sizeof(float) * filter_data.size())));
buffers.emplace_back(CreateBuffer(
builder,
builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
sizeof(float) * bias_data.size())));
}
const std::array<int32_t, 4> input_shape{
{BatchSize(), InputHeight(), InputWidth(), InputChannels()}};
@ -166,49 +221,69 @@ std::vector<char> DepthwiseConv2DTester::CreateTfLiteModel() const {
{1, KernelHeight(), KernelWidth(), OutputChannels()}};
const std::array<int32_t, 1> bias_shape{{OutputChannels()}};
const std::array<flatbuffers::Offset<tflite::Tensor>, 4> tensors{{
CreateTensor(
builder,
builder.CreateVector<int32_t>(input_shape.data(), input_shape.size()),
TensorType_FLOAT32),
CreateTensor(builder,
builder.CreateVector<int32_t>(filter_shape.data(),
filter_shape.size()),
TensorType_FLOAT32, /*buffer=*/1),
CreateTensor(
builder,
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
TensorType_FLOAT32, /*buffer=*/2),
CreateTensor(builder,
builder.CreateVector<int32_t>(output_shape.data(),
output_shape.size()),
TensorType_FLOAT32),
}};
std::vector<flatbuffers::Offset<tflite::Tensor>> tensors;
if (FP16Weights()) {
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
TensorType_FLOAT16, /*buffer=*/1));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
TensorType_FLOAT16, /*buffer=*/2));
}
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(input_shape.data(), input_shape.size()),
TensorType_FLOAT32));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 1));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
TensorType_FLOAT32));
const std::array<int32_t, 3> op_inputs{{0, 1, 2}};
const std::array<int32_t, 1> op_outputs{{3}};
const std::array<int32_t, 3> op_inputs{
{static_cast<int>(tensors.size()) - 4,
static_cast<int>(tensors.size()) - 3,
static_cast<int>(tensors.size()) - 2}};
const std::array<int32_t, 1> op_outputs{
{static_cast<int>(tensors.size()) - 1}};
flatbuffers::Offset<tflite::Operator> op = CreateOperator(
flatbuffers::Offset<DepthwiseConv2DOptions> depthwise_conv2d_options =
CreateDepthwiseConv2DOptions(
builder, Padding(), StrideWidth(), StrideHeight(), DepthMultiplier(),
Activation(), DilationWidth(), DilationHeight());
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/0,
builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
BuiltinOptions_DepthwiseConv2DOptions, depthwise_conv2d_options.Union());
BuiltinOptions_DepthwiseConv2DOptions, depthwise_conv2d_options.Union()));
const std::array<int32_t, 1> subgraph_inputs{{0}};
const std::array<int32_t, 1> subgraph_outputs{{3}};
const std::array<int32_t, 1> subgraph_inputs{
{static_cast<int>(tensors.size()) - 4}};
const std::array<int32_t, 1> subgraph_outputs{
{static_cast<int>(tensors.size()) - 1}};
flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
builder, builder.CreateVector(tensors.data(), tensors.size()),
builder.CreateVector<int32_t>(subgraph_inputs.data(),
subgraph_inputs.size()),
builder.CreateVector<int32_t>(subgraph_outputs.data(),
subgraph_outputs.size()),
builder.CreateVector(&op, 1));
builder.CreateVector(operators.data(), operators.size()));
flatbuffers::Offset<flatbuffers::String> description =
builder.CreateString("DepthwiseConv2D model");
flatbuffers::Offset<Model> model_buffer = CreateModel(
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
builder, TFLITE_SCHEMA_VERSION,
builder.CreateVector(operator_codes.data(), operator_codes.size()),
builder.CreateVector(&subgraph, 1), description,
builder.CreateVector(buffers.data(), buffers.size()));

View File

@ -152,6 +152,13 @@ class DepthwiseConv2DTester {
return (KernelWidth() - 1) * DilationWidth() + 1;
}
inline DepthwiseConv2DTester& FP16Weights() {
fp16_weights_ = true;
return *this;
}
inline bool FP16Weights() const { return fp16_weights_; }
inline DepthwiseConv2DTester& SamePadding() {
padding_ = ::tflite::Padding_SAME;
return *this;
@ -209,6 +216,7 @@ class DepthwiseConv2DTester {
int32_t stride_width_ = 1;
int32_t dilation_height_ = 1;
int32_t dilation_width_ = 1;
bool fp16_weights_ = false;
::tflite::Padding padding_ = ::tflite::Padding_VALID;
::tflite::ActivationFunctionType activation_ =
::tflite::ActivationFunctionType_NONE;

View File

@ -228,6 +228,29 @@ TEST(FullyConnected, 4DKeepDims) {
.Test(xnnpack_delegate.get());
}
TEST(FullyConnected, FP16Weights) {
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
TfLiteXNNPackDelegateDelete);
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto batch_rng =
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
auto channels_rng =
std::bind(std::uniform_int_distribution<int32_t>(2, 9), std::ref(rng));
const auto batch = batch_rng();
const auto input_channels = channels_rng();
const auto output_channels = channels_rng();
FullyConnectedTester()
.InputShape({batch, input_channels})
.InputChannels(input_channels)
.OutputChannels(output_channels)
.FP16Weights()
.Test(xnnpack_delegate.get());
}
TEST(FullyConnected, ReluActivation) {
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
#include <fp16.h>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
@ -109,98 +110,165 @@ void FullyConnectedTester::Test(TfLiteDelegate* delegate) const {
std::vector<char> FullyConnectedTester::CreateTfLiteModel() const {
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto range_rng = std::bind(
std::uniform_real_distribution<float>(-25.0f, 25.0f), std::ref(rng));
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<OperatorCode> operator_code =
CreateOperatorCode(builder, BuiltinOperator_FULLY_CONNECTED);
std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
{CreateOperatorCode(builder, BuiltinOperator_FULLY_CONNECTED)}};
std::vector<flatbuffers::Offset<Operator>> operators;
std::vector<flatbuffers::Offset<Buffer>> buffers{
{CreateBuffer(builder, builder.CreateVector({}))}};
std::vector<float> filter_data(InputChannels() * OutputChannels());
std::vector<float> bias_data(OutputChannels());
if (FP16Weights()) {
operator_codes.emplace_back(
CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
for (int32_t oc = 0; oc < OutputChannels(); oc++) {
// Use the same range of all-positive or all-negative values to generate
// all filter & bias weights within the same channel, but different ranges
// for different output channels. This ensures that no catastrophic
// cancellation occur, but test covers both positive and negative inputs.
const float range = range_rng();
auto value_rng =
std::bind(std::uniform_real_distribution<float>(std::min(range, 0.0f),
std::max(range, 0.0f)),
std::ref(rng));
std::vector<uint16_t> filter_data(InputChannels() * OutputChannels());
std::vector<uint16_t> bias_data(OutputChannels());
bias_data[oc] = value_rng();
for (int32_t ic = 0; ic < InputChannels(); ic++) {
filter_data[oc * InputChannels() + ic] = value_rng();
for (int32_t oc = 0; oc < OutputChannels(); oc++) {
// Use the same range of all-positive or all-negative values to generate
// all filter & bias weights within the same channel, but different ranges
// for different output channels. This ensures that no catastrophic
// cancellation occur, but test covers both positive and negative inputs.
const float range = range_rng();
auto value_rng =
std::bind(fp16_ieee_from_fp32_value,
std::bind(std::uniform_real_distribution<float>(
std::min(range, 0.0f), std::max(range, 0.0f)),
std::ref(rng)));
bias_data[oc] = value_rng();
for (int32_t ic = 0; ic < InputChannels(); ic++) {
filter_data[oc * InputChannels() + ic] = value_rng();
}
}
}
std::array<flatbuffers::Offset<Buffer>, 3> buffers{{
CreateBuffer(builder, builder.CreateVector({})),
CreateBuffer(builder,
builder.CreateVector(
reinterpret_cast<const uint8_t*>(filter_data.data()),
sizeof(float) * filter_data.size())),
CreateBuffer(builder,
builder.CreateVector(
reinterpret_cast<const uint8_t*>(bias_data.data()),
sizeof(float) * bias_data.size())),
}};
buffers.emplace_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(filter_data.data()),
sizeof(uint16_t) * filter_data.size())));
buffers.emplace_back(CreateBuffer(
builder,
builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
sizeof(uint16_t) * bias_data.size())));
const std::array<int32_t, 1> dequantize_filter_inputs{{0}};
const std::array<int32_t, 1> dequantize_filter_outputs{{3}};
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/1,
builder.CreateVector<int32_t>(dequantize_filter_inputs.data(),
dequantize_filter_inputs.size()),
builder.CreateVector<int32_t>(dequantize_filter_outputs.data(),
dequantize_filter_outputs.size())));
const std::array<int32_t, 1> dequantize_bias_inputs{{1}};
const std::array<int32_t, 1> dequantize_bias_outputs{{4}};
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/1,
builder.CreateVector<int32_t>(dequantize_bias_inputs.data(),
dequantize_bias_inputs.size()),
builder.CreateVector<int32_t>(dequantize_bias_outputs.data(),
dequantize_bias_outputs.size())));
} else {
std::vector<float> filter_data(InputChannels() * OutputChannels());
std::vector<float> bias_data(OutputChannels());
for (int32_t oc = 0; oc < OutputChannels(); oc++) {
// Use the same range of all-positive or all-negative values to generate
// all filter & bias weights within the same channel, but different ranges
// for different output channels. This ensures that no catastrophic
// cancellation occur, but test covers both positive and negative inputs.
const float range = range_rng();
auto value_rng =
std::bind(std::uniform_real_distribution<float>(
std::min(range, 0.0f), std::max(range, 0.0f)),
std::ref(rng));
bias_data[oc] = value_rng();
for (int32_t ic = 0; ic < InputChannels(); ic++) {
filter_data[oc * InputChannels() + ic] = value_rng();
}
}
buffers.emplace_back(CreateBuffer(
builder, builder.CreateVector(
reinterpret_cast<const uint8_t*>(filter_data.data()),
sizeof(float) * filter_data.size())));
buffers.emplace_back(CreateBuffer(
builder,
builder.CreateVector(reinterpret_cast<const uint8_t*>(bias_data.data()),
sizeof(float) * bias_data.size())));
}
const std::array<int32_t, 2> filter_shape(
{OutputChannels(), InputChannels()});
const std::array<int32_t, 1> bias_shape({OutputChannels()});
const std::vector<int32_t> output_shape = OutputShape();
const std::array<flatbuffers::Offset<Tensor>, 4> tensors{{
CreateTensor(builder,
builder.CreateVector<int32_t>(InputShape().data(),
InputShape().size()),
TensorType_FLOAT32),
CreateTensor(builder,
builder.CreateVector<int32_t>(filter_shape.data(),
filter_shape.size()),
TensorType_FLOAT32, /*buffer=*/1),
CreateTensor(
builder,
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
TensorType_FLOAT32, /*buffer=*/2),
CreateTensor(builder,
builder.CreateVector<int32_t>(output_shape.data(),
output_shape.size()),
TensorType_FLOAT32),
}};
std::vector<flatbuffers::Offset<Tensor>> tensors;
if (FP16Weights()) {
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
TensorType_FLOAT16, /*buffer=*/1));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
TensorType_FLOAT16, /*buffer=*/2));
}
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(InputShape().data(), InputShape().size()),
TensorType_FLOAT32));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 1));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2));
tensors.emplace_back(CreateTensor(
builder,
builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
TensorType_FLOAT32));
flatbuffers::Offset<FullyConnectedOptions> fully_connected_options =
CreateFullyConnectedOptions(builder, Activation(),
FullyConnectedOptionsWeightsFormat_DEFAULT,
KeepDims());
const std::array<int32_t, 3> op_inputs{{0, 1, 2}};
const std::array<int32_t, 1> op_outputs{{3}};
flatbuffers::Offset<Operator> op = CreateOperator(
const std::array<int32_t, 3> op_inputs{
{static_cast<int>(tensors.size()) - 4,
static_cast<int>(tensors.size()) - 3,
static_cast<int>(tensors.size()) - 2}};
const std::array<int32_t, 1> op_outputs{
{static_cast<int>(tensors.size()) - 1}};
operators.emplace_back(CreateOperator(
builder, /*opcode_index=*/0,
builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
BuiltinOptions_FullyConnectedOptions, fully_connected_options.Union());
BuiltinOptions_FullyConnectedOptions, fully_connected_options.Union()));
const std::array<int32_t, 1> subgraph_inputs{{0}};
const std::array<int32_t, 1> subgraph_outputs{{3}};
const std::array<int32_t, 1> subgraph_inputs{
{static_cast<int>(tensors.size()) - 4}};
const std::array<int32_t, 1> subgraph_outputs{
{static_cast<int>(tensors.size()) - 1}};
flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
builder, builder.CreateVector(tensors.data(), tensors.size()),
builder.CreateVector<int32_t>(subgraph_inputs.data(),
subgraph_inputs.size()),
builder.CreateVector<int32_t>(subgraph_outputs.data(),
subgraph_outputs.size()),
builder.CreateVector(&op, 1));
builder.CreateVector(operators.data(), operators.size()));
flatbuffers::Offset<flatbuffers::String> description =
builder.CreateString("Fully Connected model");
flatbuffers::Offset<Model> model_buffer = CreateModel(
builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
builder, TFLITE_SCHEMA_VERSION,
builder.CreateVector(operator_codes.data(), operator_codes.size()),
builder.CreateVector(&subgraph, 1), description,
builder.CreateVector(buffers.data(), buffers.size()));

View File

@ -71,6 +71,13 @@ class FullyConnectedTester {
inline bool KeepDims() const { return keep_dims_; }
inline FullyConnectedTester& FP16Weights() {
fp16_weights_ = true;
return *this;
}
inline bool FP16Weights() const { return fp16_weights_; }
inline FullyConnectedTester& ReluActivation() {
activation_ = ::tflite::ActivationFunctionType_RELU;
return *this;
@ -102,6 +109,7 @@ class FullyConnectedTester {
int32_t input_channels_ = 1;
int32_t output_channels_ = 1;
bool keep_dims_ = false;
bool fp16_weights_ = false;
::tflite::ActivationFunctionType activation_ =
::tflite::ActivationFunctionType_NONE;
};

View File

@ -679,6 +679,35 @@ TEST(Mul, 2DByStatic0D) {
.Test(BuiltinOperator_MUL, xnnpack_delegate.get());
}
TEST(Mul, FP16Weights) {
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
TfLiteXNNPackDelegateDelete);
std::random_device random_device;
auto rng = std::mt19937(random_device());
auto shape_rng =
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
const auto batch = shape_rng();
const auto height = shape_rng();
const auto width = shape_rng();
const auto channels = shape_rng();
BinaryElementwiseTester()
.Input1Shape({batch, height, width, channels})
.Input2Shape({batch, height, width, channels})
.Input1Static(true)
.FP16Weights()
.Test(BuiltinOperator_MUL, xnnpack_delegate.get());
BinaryElementwiseTester()
.Input1Shape({batch, height, width, channels})
.Input2Shape({batch, height, width, channels})
.Input2Static(true)
.FP16Weights()
.Test(BuiltinOperator_MUL, xnnpack_delegate.get());
}
TEST(Mul, ReluActivation) {
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),

View File

@ -22,10 +22,12 @@ limitations under the License.
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <fp16.h>
#include <xnnpack.h>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/builtin_op_data.h"
@ -39,6 +41,8 @@ namespace {
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
class Delegate {
friend class Subgraph;
public:
explicit Delegate(const TfLiteXNNPackDelegateOptions* options) {
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
@ -49,9 +53,10 @@ class Delegate {
#endif
}
TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context);
TfLiteDelegate* tflite_delegate() { return &delegate_; }
pthreadpool_t threadpool() {
pthreadpool_t threadpool() const {
#if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__)
return nullptr;
#else
@ -69,6 +74,17 @@ class Delegate {
kTfLiteDelegateFlagsNone, // .flags
};
// Unpacked data for quasi-static tensors, i.e. tensors produced by
// dequantizing or unpacking static buffers.
std::vector<char> static_unpacked_data_;
// Mapping from a tensor index for a quasi-static tensor to the offset to
// its unpacked data within static_unpacked_data_.
std::unordered_map<int, size_t> static_unpacked_data_map_;
// Set of indices of nodes which unpack static data, e.g. Dequantize
// operators which convert FP16 static weights to FP32. These nodes are simply
// ignored in the delegate implementation, because their outputs are
// pre-unpacked in DelegatePrepare.
std::unordered_set<int> static_unpack_nodes_;
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
// Thread pool with smart-pointer for lifetime management.
std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)> threadpool_{
@ -80,7 +96,7 @@ class Subgraph {
public:
static Subgraph* Create(TfLiteContext* context,
const TfLiteDelegateParams* params,
pthreadpool_t threadpool) {
const Delegate* delegate) {
// Convert subgraph inputs and outputs to hash sets for faster lookup.
const std::unordered_set<int> inputs(
&params->input_tensors->data[0],
@ -113,11 +129,17 @@ class Subgraph {
// filtered out and removed later.
std::vector<int> tensors(context->tensors_size, -1);
for (int i = 0; i < params->nodes_to_replace->size; i++) {
const int node_index = params->nodes_to_replace->data[i];
if (delegate->static_unpack_nodes_.count(node_index)) {
// The node unpacks static input and can be skipped because its input
// was pre-unpacked in DelegatePrepare.
continue;
}
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
if (context->GetNodeAndRegistration(context,
params->nodes_to_replace->data[i],
&node, &registration) != kTfLiteOk) {
if (context->GetNodeAndRegistration(context, node_index, &node,
&registration) != kTfLiteOk) {
return nullptr;
}
@ -164,6 +186,12 @@ class Subgraph {
const void* data = nullptr;
if (context->tensors[t].allocation_type == kTfLiteMmapRo) {
data = context->tensors[t].data.raw_const;
} else {
// Check for quasi-static data.
const auto it = delegate->static_unpacked_data_map_.find(t);
if (it != delegate->static_unpacked_data_map_.end()) {
data = delegate->static_unpacked_data_.data() + it->second;
}
}
if (inputs.count(t) != 0) {
flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT;
@ -189,25 +217,38 @@ class Subgraph {
}
}
// Create a set of quasi-static tensors for VisitNode function
std::unordered_set<int> quasi_static_tensors;
for (const std::pair<const int, size_t>& entry :
delegate->static_unpacked_data_map_) {
quasi_static_tensors.insert(entry.first);
}
// Create XNNPACK nodes for TFLite delegate nodes
for (int i = 0; i < params->nodes_to_replace->size; i++) {
const int node_index = params->nodes_to_replace->data[i];
if (delegate->static_unpack_nodes_.count(node_index)) {
// The node unpacks static input and can be skipped because its input
// was pre-unpacked in DelegatePrepare.
continue;
}
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
if (context->GetNodeAndRegistration(context,
params->nodes_to_replace->data[i],
&node, &registration) != kTfLiteOk) {
if (context->GetNodeAndRegistration(context, node_index, &node,
&registration) != kTfLiteOk) {
return nullptr;
}
if (VisitNode(subgraph.get(), context, registration, node, i,
xnnpack_tensors) != kTfLiteOk) {
if (VisitNode(subgraph.get(), context, registration, node, node_index,
quasi_static_tensors, xnnpack_tensors) != kTfLiteOk) {
return nullptr;
}
}
xnn_runtime_t runtime_ptr = nullptr;
status = xnn_create_runtime_v2(subgraph.get(), threadpool, /*flags=*/0,
&runtime_ptr);
status = xnn_create_runtime_v2(subgraph.get(), delegate->threadpool(),
/*flags=*/0, &runtime_ptr);
if (status != xnn_status_success) {
TF_LITE_KERNEL_LOG(context, "failed to create XNNPACK runtime");
return nullptr;
@ -707,10 +748,11 @@ class Subgraph {
return kTfLiteOk;
}
static TfLiteStatus VisitNode(xnn_subgraph_t subgraph, TfLiteContext* context,
TfLiteRegistration* registration,
TfLiteNode* node, int node_index,
const std::vector<uint32_t>& xnnpack_tensors) {
static TfLiteStatus VisitNode(
xnn_subgraph_t subgraph, TfLiteContext* context,
TfLiteRegistration* registration, TfLiteNode* node, int node_index,
const std::unordered_set<int>& quasi_static_tensors,
const std::vector<uint32_t>& xnnpack_tensors) {
// TFLite context used for logging purposes. When we create a new node
// (subgraph is non-null), logging context is the same as context, and error
// messages are passed to TFLite. When we detect supported operations
@ -738,7 +780,8 @@ class Subgraph {
static_cast<const TfLiteConvParams*>(node->builtin_data);
return VisitConv2DNode(subgraph, logging_context, node_index, node,
context->tensors, conv_params, xnnpack_tensors);
context->tensors, conv_params,
quasi_static_tensors, xnnpack_tensors);
}
case kTfLiteBuiltinDepthwiseConv2d: {
const TfLiteDepthwiseConvParams* dwconv_params =
@ -746,7 +789,7 @@ class Subgraph {
return VisitDepthwiseConv2DNode(subgraph, logging_context, node_index,
node, context->tensors, dwconv_params,
xnnpack_tensors);
quasi_static_tensors, xnnpack_tensors);
}
case kTfLiteBuiltinFullyConnected: {
const TfLiteFullyConnectedParams* fc_params =
@ -754,7 +797,7 @@ class Subgraph {
return VisitFullyConnectedNode(subgraph, logging_context, node_index,
node, context->tensors, fc_params,
xnnpack_tensors);
quasi_static_tensors, xnnpack_tensors);
}
case kTfLiteBuiltinHardSwish:
return VisitHardSwishNode(subgraph, logging_context, node_index, node,
@ -782,7 +825,8 @@ class Subgraph {
context->tensors, xnnpack_tensors);
case kTfLiteBuiltinPrelu:
return VisitPreluNode(subgraph, logging_context, node_index, node,
context->tensors, xnnpack_tensors);
context->tensors, quasi_static_tensors,
xnnpack_tensors);
case kTfLiteBuiltinRelu:
return VisitReluNode(
subgraph, logging_context, node_index, node, context->tensors, 0.0f,
@ -810,7 +854,7 @@ class Subgraph {
return VisitMediaPipeDeconvolutionNode(
subgraph, context, node_index, node, context->tensors,
&deconv_params, xnnpack_tensors);
&deconv_params, quasi_static_tensors, xnnpack_tensors);
} else if (strcmp(registration->custom_name,
"MaxPoolingWithArgmax2D") == 0) {
TfLitePoolParams pool_params = {kTfLitePaddingUnknown};
@ -948,6 +992,7 @@ class Subgraph {
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
TfLiteNode* node, const TfLiteTensor* tensors,
const TfLiteConvParams* conv_params,
const std::unordered_set<int>& quasi_static_tensors,
const std::vector<uint32_t>& xnnpack_tensors) {
TF_LITE_ENSURE_STATUS(
CheckConvolutionParams(logging_context, conv_params, node_index));
@ -968,16 +1013,20 @@ class Subgraph {
logging_context, filter_tensor, node->inputs->data[1], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
node->inputs->data[1]));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, filter_tensor, node->inputs->data[1], node_index));
if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, filter_tensor, node->inputs->data[1], node_index));
}
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
logging_context, filter_tensor, node->inputs->data[2], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
node->inputs->data[2]));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, bias_tensor, node->inputs->data[2], node_index));
if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, bias_tensor, node->inputs->data[2], node_index));
}
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
@ -1034,6 +1083,7 @@ class Subgraph {
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
TfLiteNode* node, const TfLiteTensor* tensors,
const TfLiteDepthwiseConvParams* dwconv_params,
const std::unordered_set<int>& quasi_static_tensors,
const std::vector<uint32_t>& xnnpack_tensors) {
TF_LITE_ENSURE_STATUS(
CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
@ -1051,16 +1101,20 @@ class Subgraph {
logging_context, filter_tensor, node->inputs->data[1], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
node->inputs->data[1]));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, filter_tensor, node->inputs->data[1], node_index));
if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, filter_tensor, node->inputs->data[1], node_index));
}
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
logging_context, filter_tensor, node->inputs->data[2], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
node->inputs->data[2]));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, bias_tensor, node->inputs->data[2], node_index));
if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, bias_tensor, node->inputs->data[2], node_index));
}
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
@ -1123,6 +1177,7 @@ class Subgraph {
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
TfLiteNode* node, const TfLiteTensor* tensors,
const TfLiteFullyConnectedParams* fc_params,
const std::unordered_set<int>& quasi_static_tensors,
const std::vector<uint32_t>& xnnpack_tensors) {
TF_LITE_ENSURE_STATUS(
CheckFullyConnectedParams(logging_context, fc_params, node_index));
@ -1141,16 +1196,20 @@ class Subgraph {
logging_context, filter_tensor, node->inputs->data[1], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 2,
node->inputs->data[1]));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, filter_tensor, node->inputs->data[1], node_index));
if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, filter_tensor, node->inputs->data[1], node_index));
}
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
logging_context, filter_tensor, node->inputs->data[2], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
node->inputs->data[2]));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, bias_tensor, node->inputs->data[2], node_index));
if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, bias_tensor, node->inputs->data[2], node_index));
}
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
@ -1387,6 +1446,7 @@ class Subgraph {
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
TfLiteNode* node, const TfLiteTensor* tensors,
const TfLiteTransposeConvParams* deconv_params,
const std::unordered_set<int>& quasi_static_tensors,
const std::vector<uint32_t>& xnnpack_tensors) {
TF_LITE_ENSURE_STATUS(
CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
@ -1404,16 +1464,20 @@ class Subgraph {
logging_context, filter_tensor, node->inputs->data[1], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
node->inputs->data[1]));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, filter_tensor, node->inputs->data[1], node_index));
if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, filter_tensor, node->inputs->data[1], node_index));
}
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
logging_context, filter_tensor, node->inputs->data[2], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
node->inputs->data[2]));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, bias_tensor, node->inputs->data[2], node_index));
if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, bias_tensor, node->inputs->data[2], node_index));
}
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
@ -1735,6 +1799,7 @@ class Subgraph {
static TfLiteStatus VisitPreluNode(
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
TfLiteNode* node, const TfLiteTensor* tensors,
const std::unordered_set<int>& quasi_static_tensors,
const std::vector<uint32_t>& xnnpack_tensors) {
TF_LITE_ENSURE_STATUS(
CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
@ -1752,8 +1817,10 @@ class Subgraph {
logging_context, slope_tensor, node->inputs->data[1], node_index));
TF_LITE_ENSURE_STATUS(CheckSlopeTensorShape(
logging_context, slope_tensor, node->inputs->data[1], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, slope_tensor, node->inputs->data[1], node_index));
if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, slope_tensor, node->inputs->data[1], node_index));
}
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
@ -1869,15 +1936,29 @@ class Subgraph {
bool first_run_{true};
};
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
TfLiteIntArray* Delegate::PrepareOpsToDelegate(TfLiteContext* context) {
// Clear previous data, in case the delegate is reused without re-creation.
static_unpacked_data_map_.clear();
static_unpacked_data_.clear();
static_unpack_nodes_.clear();
TfLiteIntArray* execution_plan = nullptr;
if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context, "Unable to get graph execution plan.");
return nullptr;
}
TfLiteIntArray* nodes_to_replace = TfLiteIntArrayCreate(execution_plan->size);
nodes_to_replace->size = 0;
// Mapping for quasi-static (unpacked from static) tensor index to the node
// index that produced it.
std::unordered_map<int, int> quasi_static_tensors_producers;
// Set of all quasi-static tensors in the execution plan.
std::unordered_set<int> quasi_static_tensors;
// Set of quasi-static tensors consumed by the delegated nodes.
std::unordered_set<int> quasi_static_tensors_to_unpack;
TfLiteIntArray* nodes_to_delegate =
TfLiteIntArrayCreate(execution_plan->size);
nodes_to_delegate->size = 0;
for (int i = 0; i < execution_plan->size; ++i) {
const int node_index = execution_plan->data[i];
@ -1892,15 +1973,142 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
continue; // Soft error (skip this node).
}
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
node->inputs->size == 1 && node->outputs->size == 1) {
const TfLiteTensor& input_tensor =
context->tensors[node->inputs->data[0]];
const TfLiteTensor& output_tensor =
context->tensors[node->outputs->data[0]];
if (input_tensor.allocation_type == kTfLiteMmapRo &&
input_tensor.type == kTfLiteFloat16 &&
output_tensor.type == kTfLiteFloat32) {
static_unpack_nodes_.insert(i);
quasi_static_tensors_producers[node->outputs->data[0]] = i;
quasi_static_tensors.insert(node->outputs->data[0]);
// Skip this node for now. If output of the node is consumed only by
// delegated nodes, it will be added to nodes_to_delegate in the end.
continue;
}
}
if (Subgraph::VisitNode(/*subgraph=*/nullptr, context, registration, node,
node_index, std::vector<uint32_t>()) != kTfLiteOk) {
node_index, quasi_static_tensors,
std::vector<uint32_t>()) != kTfLiteOk) {
// If a non-delegated node consumes output of a node that unpacks static
// data, that node shouldn't be delegated.
for (int j = 0; j < node->inputs->size; j++) {
const auto it =
quasi_static_tensors_producers.find(node->inputs->data[j]);
if (it != quasi_static_tensors_producers.end()) {
static_unpack_nodes_.erase(it->second);
}
}
// Non-delegatable node is not an error.
continue;
}
nodes_to_replace->data[nodes_to_replace->size++] = node_index;
for (int j = 0; j < node->inputs->size; j++) {
if (quasi_static_tensors.count(node->inputs->data[j]) != 0) {
quasi_static_tensors_to_unpack.insert(node->inputs->data[j]);
}
}
nodes_to_delegate->data[nodes_to_delegate->size++] = node_index;
}
// Unpack static data of all tensors
for (int t : quasi_static_tensors_to_unpack) {
const int producer_index = quasi_static_tensors_producers[t];
// Check if TFLite nodes can be delegated to XNNPACK
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
if (context->GetNodeAndRegistration(context, producer_index, &node,
&registration) != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context,
"Unable to get node and registration for node %d.",
producer_index);
TfLiteIntArrayFree(nodes_to_delegate);
return nullptr; // Hard error.
}
if (node->inputs->size != 1) {
TF_LITE_KERNEL_LOG(context, "unexpected number of inputs (%d) in node %d",
node->inputs->size, producer_index);
TfLiteIntArrayFree(nodes_to_delegate);
return nullptr; // Hard error.
}
if (node->outputs->size != 1) {
TF_LITE_KERNEL_LOG(context,
"unexpected number of outputs (%d) in node %d",
node->outputs->size, producer_index);
TfLiteIntArrayFree(nodes_to_delegate);
return nullptr; // Hard error.
}
const TfLiteTensor& input_tensor = context->tensors[node->inputs->data[0]];
if (input_tensor.allocation_type != kTfLiteMmapRo) {
TF_LITE_KERNEL_LOG(context,
"unexpected allocation type in tensor %d in node %d",
node->inputs->data[0], producer_index);
TfLiteIntArrayFree(nodes_to_delegate);
return nullptr; // Hard error.
}
const TfLiteTensor& output_tensor = context->tensors[t];
if (output_tensor.type != kTfLiteFloat32) {
TF_LITE_KERNEL_LOG(context,
"unexpected datatype (%s) in tensor %d in node %d",
TfLiteTypeGetName(output_tensor.type),
node->outputs->data[0], producer_index);
TfLiteIntArrayFree(nodes_to_delegate);
return nullptr; // Hard error.
}
const size_t tensor_elements = output_tensor.bytes / sizeof(float);
// Align to XNN_EXTRA_BYTES bytes
while (static_unpacked_data_.size() % XNN_EXTRA_BYTES != 0) {
static_unpacked_data_.push_back(0);
}
const size_t tensor_offset = static_unpacked_data_.size();
static_unpacked_data_.resize(tensor_offset + context->tensors[t].bytes);
float* unpacked_data =
reinterpret_cast<float*>(static_unpacked_data_.data() + tensor_offset);
switch (input_tensor.type) {
case kTfLiteFloat16: {
const uint16_t* packed_data =
static_cast<const uint16_t*>(input_tensor.data.data);
for (size_t i = 0; i < tensor_elements; i++) {
unpacked_data[i] = fp16_ieee_to_fp32_value(packed_data[i]);
}
break;
}
default:
TF_LITE_KERNEL_LOG(context,
"unexpected datatype (%s) in tensor %d in node %d",
TfLiteTypeGetName(output_tensor.type),
node->outputs->data[0], producer_index);
TfLiteIntArrayFree(nodes_to_delegate);
return nullptr; // Hard error.
}
static_unpacked_data_map_[t] = tensor_offset;
}
// Add nodes that unpack static data consumed by delegated nodes.
// Note: this is done purely to avoid the overhead of running these nodes
// again in TFLite interpreter which would allocate memory for their outputs.
// We mark them as delegated, but the delegate would simply ignore these nodes
// as the static weights are already unpacked.
for (int node_index : static_unpack_nodes_) {
nodes_to_delegate->data[nodes_to_delegate->size++] = node_index;
}
std::sort(&nodes_to_delegate->data[0],
&nodes_to_delegate->data[nodes_to_delegate->size]);
#ifdef XNNPACK_DELEGATE_TEST_MODE
// In the test mode build (used by unit tests), XNNPACK delegate claims to
// support all operators in the execution plan to disable fallback to the
@ -1908,24 +2116,22 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
// not supported by the delegate, they will cause a failure in
// ::tflite::Interpreter::ModifyGraphWithDelegate, to be caught in the unit
// tests.
nodes_to_replace->size = execution_plan->size;
nodes_to_delegate->size = execution_plan->size;
std::copy(&execution_plan->data[0],
&execution_plan->data[execution_plan->size],
&nodes_to_replace->data[0]);
&nodes_to_delegate->data[0]);
#endif
return nodes_to_replace;
return nodes_to_delegate;
}
void* SubgraphInit(TfLiteContext* context, const char* buffer, size_t length) {
const TfLiteDelegateParams* params =
reinterpret_cast<const TfLiteDelegateParams*>(buffer);
pthreadpool_t threadpool =
static_cast<::tflite::xnnpack::Delegate*>(params->delegate->data_)
->threadpool();
return static_cast<void*>(Subgraph::Create(context, params, threadpool));
return static_cast<void*>(Subgraph::Create(
context, params,
static_cast<::tflite::xnnpack::Delegate*>(params->delegate->data_)));
}
TfLiteStatus SubgraphPrepare(TfLiteContext* context, TfLiteNode* node) {
@ -1962,7 +2168,9 @@ const TfLiteRegistration kSubgraphRegistration = {
};
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
TfLiteIntArray* ops_to_replace =
static_cast<::tflite::xnnpack::Delegate*>(delegate->data_)
->PrepareOpsToDelegate(context);
const TfLiteStatus status = context->ReplaceNodeSubsetsWithDelegateKernels(
context, kSubgraphRegistration, ops_to_replace, delegate);
TfLiteIntArrayFree(ops_to_replace);