From a7048d89a11f7f7ef6234ca0d01b341b1e5780f7 Mon Sep 17 00:00:00 2001 From: Marat Dukhan Date: Wed, 27 May 2020 18:50:31 -0700 Subject: [PATCH] Support models with FP16 weights in XNNPACK delegate PiperOrigin-RevId: 313505742 Change-Id: Id21f7528741073e93a7132d529c3cd79957a73fb --- tensorflow/lite/delegates/xnnpack/BUILD | 6 + tensorflow/lite/delegates/xnnpack/add_test.cc | 29 ++ .../xnnpack/binary_elementwise_tester.cc | 137 +++++--- .../xnnpack/binary_elementwise_tester.h | 8 + .../lite/delegates/xnnpack/conv_2d_test.cc | 241 ++++++++++---- .../xnnpack/depthwise_conv_2d_test.cc | 31 ++ .../xnnpack/depthwise_conv_2d_tester.cc | 211 ++++++++---- .../xnnpack/depthwise_conv_2d_tester.h | 8 + .../delegates/xnnpack/fully_connected_test.cc | 23 ++ .../xnnpack/fully_connected_tester.cc | 180 ++++++---- .../xnnpack/fully_connected_tester.h | 8 + tensorflow/lite/delegates/xnnpack/mul_test.cc | 29 ++ .../delegates/xnnpack/xnnpack_delegate.cc | 314 +++++++++++++++--- 13 files changed, 942 insertions(+), 283 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index 1cdba72b615..df70a314308 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -24,6 +24,7 @@ cc_library( "//tensorflow/lite:util", "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", + "@FP16", "@XNNPACK", ], ) @@ -39,6 +40,7 @@ cc_library( "//tensorflow/lite:util", "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", + "@FP16", "@XNNPACK", ], ) @@ -56,6 +58,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "@FP16", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -72,6 +75,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "@FP16", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -88,6 +92,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "@FP16", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -215,6 +220,7 @@ cc_test( "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "@FP16", "@com_google_googletest//:gtest", "@flatbuffers", ], diff --git a/tensorflow/lite/delegates/xnnpack/add_test.cc b/tensorflow/lite/delegates/xnnpack/add_test.cc index dd2857e01ce..6bc8f8d6bca 100644 --- a/tensorflow/lite/delegates/xnnpack/add_test.cc +++ b/tensorflow/lite/delegates/xnnpack/add_test.cc @@ -679,6 +679,35 @@ TEST(Add, 2DByStatic0D) { .Test(BuiltinOperator_ADD, xnnpack_delegate.get()); } +TEST(Add, FP16Weights) { + std::unique_ptr + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto shape_rng = + std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); + const auto batch = shape_rng(); + const auto height = shape_rng(); + const auto width = shape_rng(); + const auto channels = shape_rng(); + + BinaryElementwiseTester() + .Input1Shape({batch, height, width, channels}) + .Input2Shape({batch, height, width, channels}) + .Input1Static(true) + .FP16Weights() + .Test(BuiltinOperator_ADD, xnnpack_delegate.get()); + + BinaryElementwiseTester() + .Input1Shape({batch, height, width, channels}) + .Input2Shape({batch, height, width, channels}) + .Input2Static(true) + .FP16Weights() + .Test(BuiltinOperator_ADD, xnnpack_delegate.get()); +} + TEST(Add, ReluActivation) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), diff --git a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc index e846cbeffe3..ad5b197d6fa 100644 --- a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" @@ -62,6 +63,9 @@ void BinaryElementwiseTester::Test(tflite::BuiltinOperator binary_op, if (Input1Static()) { ASSERT_FALSE(Input2Static()); } + if (FP16Weights()) { + ASSERT_TRUE(Input1Static() || Input2Static()); + } std::random_device random_device; auto rng = std::mt19937(random_device()); @@ -180,8 +184,12 @@ std::vector BinaryElementwiseTester::CreateTfLiteModel( auto input2_rng = std::bind(input2_distribution, std::ref(rng)); flatbuffers::FlatBufferBuilder builder; - flatbuffers::Offset operator_code = - CreateOperatorCode(builder, binary_op); + std::vector> operator_codes{ + {CreateOperatorCode(builder, binary_op)}}; + if (FP16Weights()) { + operator_codes.emplace_back( + CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE)); + } std::vector> buffers{{ CreateBuffer(builder, builder.CreateVector({})), @@ -189,43 +197,89 @@ std::vector BinaryElementwiseTester::CreateTfLiteModel( int32_t input1_buffer = 0; if (Input1Static()) { - std::vector input1_data(ComputeSize(Input1Shape())); - std::generate(input1_data.begin(), input1_data.end(), input1_rng); + if (FP16Weights()) { + std::vector input1_data(ComputeSize(Input1Shape())); + std::generate(input1_data.begin(), input1_data.end(), + std::bind(fp16_ieee_from_fp32_value, input1_rng)); - input1_buffer = buffers.size(); - buffers.push_back(CreateBuffer( - builder, builder.CreateVector( - reinterpret_cast(input1_data.data()), - sizeof(float) * input1_data.size()))); + buffers.push_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(input1_data.data()), + sizeof(uint16_t) * input1_data.size()))); + } else { + std::vector input1_data(ComputeSize(Input1Shape())); + std::generate(input1_data.begin(), input1_data.end(), input1_rng); + + input1_buffer = buffers.size(); + buffers.push_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(input1_data.data()), + sizeof(float) * input1_data.size()))); + } } int32_t input2_buffer = 0; if (Input2Static()) { - std::vector input2_data(ComputeSize(Input2Shape())); - std::generate(input2_data.begin(), input2_data.end(), input2_rng); + if (FP16Weights()) { + std::vector input2_data(ComputeSize(Input2Shape())); + std::generate(input2_data.begin(), input2_data.end(), + std::bind(fp16_ieee_from_fp32_value, input1_rng)); - input2_buffer = buffers.size(); - buffers.push_back(CreateBuffer( - builder, builder.CreateVector( - reinterpret_cast(input2_data.data()), - sizeof(float) * input2_data.size()))); + buffers.push_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(input2_data.data()), + sizeof(uint16_t) * input2_data.size()))); + } else { + std::vector input2_data(ComputeSize(Input2Shape())); + std::generate(input2_data.begin(), input2_data.end(), input2_rng); + + input2_buffer = buffers.size(); + buffers.push_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(input2_data.data()), + sizeof(float) * input2_data.size()))); + } } const std::vector output_shape = OutputShape(); - const std::array, 3> tensors{{ - CreateTensor(builder, - builder.CreateVector(Input1Shape().data(), - Input1Shape().size()), - TensorType_FLOAT32, input1_buffer), - CreateTensor(builder, - builder.CreateVector(Input2Shape().data(), - Input2Shape().size()), - TensorType_FLOAT32, input2_buffer), - CreateTensor(builder, - builder.CreateVector(output_shape.data(), - output_shape.size()), - TensorType_FLOAT32), - }}; + std::vector> tensors; + std::vector> operators; + if (FP16Weights() && Input1Static()) { + tensors.emplace_back( + CreateTensor(builder, + builder.CreateVector(Input1Shape().data(), + Input1Shape().size()), + TensorType_FLOAT16, 1)); + } + if (FP16Weights() && Input2Static()) { + tensors.emplace_back( + CreateTensor(builder, + builder.CreateVector(Input2Shape().data(), + Input2Shape().size()), + TensorType_FLOAT16, 1)); + } + if (FP16Weights()) { + const std::array dequantize_inputs{{0}}; + const std::array dequantize_outputs{{Input1Static() ? 1 : 2}}; + operators.emplace_back(CreateOperator( + builder, /*opcode_index=*/1, + builder.CreateVector(dequantize_inputs.data(), + dequantize_inputs.size()), + builder.CreateVector(dequantize_outputs.data(), + dequantize_outputs.size()))); + } + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(Input1Shape().data(), Input1Shape().size()), + TensorType_FLOAT32, input1_buffer)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(Input2Shape().data(), Input2Shape().size()), + TensorType_FLOAT32, input2_buffer)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(output_shape.data(), output_shape.size()), + TensorType_FLOAT32)); tflite::BuiltinOptions builtin_options_type = tflite::BuiltinOptions_NONE; flatbuffers::Offset builtin_options = 0; @@ -250,35 +304,40 @@ std::vector BinaryElementwiseTester::CreateTfLiteModel( EXPECT_EQ(Activation(), ActivationFunctionType_NONE); } - const std::array op_inputs{{0, 1}}; - const std::array op_outputs{{2}}; - flatbuffers::Offset op = CreateOperator( + const std::array op_inputs{ + {static_cast(tensors.size()) - 3, + static_cast(tensors.size()) - 2}}; + const std::array op_outputs{ + {static_cast(tensors.size()) - 1}}; + operators.emplace_back(CreateOperator( builder, /*opcode_index=*/0, builder.CreateVector(op_inputs.data(), op_inputs.size()), builder.CreateVector(op_outputs.data(), op_outputs.size()), - builtin_options_type, builtin_options); + builtin_options_type, builtin_options)); std::vector subgraph_inputs; if (!Input1Static()) { - subgraph_inputs.push_back(0); + subgraph_inputs.push_back(tensors.size() - 3); } if (!Input2Static()) { - subgraph_inputs.push_back(1); + subgraph_inputs.push_back(tensors.size() - 2); } - const std::array subgraph_outputs{{2}}; + const std::array subgraph_outputs{ + {static_cast(tensors.size()) - 1}}; flatbuffers::Offset subgraph = CreateSubGraph( builder, builder.CreateVector(tensors.data(), tensors.size()), builder.CreateVector(subgraph_inputs.data(), subgraph_inputs.size()), builder.CreateVector(subgraph_outputs.data(), subgraph_outputs.size()), - builder.CreateVector(&op, 1)); + builder.CreateVector(operators.data(), operators.size())); flatbuffers::Offset description = builder.CreateString("Binary operator model"); flatbuffers::Offset model_buffer = CreateModel( - builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1), + builder, TFLITE_SCHEMA_VERSION, + builder.CreateVector(operator_codes.data(), operator_codes.size()), builder.CreateVector(&subgraph, 1), description, builder.CreateVector(buffers.data(), buffers.size())); diff --git a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h index 15c99c3148d..a0c2440f59a 100644 --- a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h +++ b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h @@ -74,6 +74,13 @@ class BinaryElementwiseTester { inline bool Input2Static() const { return input2_static_; } + inline BinaryElementwiseTester& FP16Weights() { + fp16_weights_ = true; + return *this; + } + + inline bool FP16Weights() const { return fp16_weights_; } + inline BinaryElementwiseTester& ReluActivation() { activation_ = ::tflite::ActivationFunctionType_RELU; return *this; @@ -114,6 +121,7 @@ class BinaryElementwiseTester { std::vector input2_shape_; bool input1_static_ = false; bool input2_static_ = false; + bool fp16_weights_ = false; ::tflite::ActivationFunctionType activation_ = ::tflite::ActivationFunctionType_NONE; }; diff --git a/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc index 95a358d1b9c..a8c6a1956bc 100644 --- a/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include #include +#include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/interpreter.h" @@ -146,6 +148,13 @@ class Conv2DTester { int32_t DilationWidth() const { return dilation_width_; } + inline Conv2DTester& FP16Weights() { + fp16_weights_ = true; + return *this; + } + + inline bool FP16Weights() const { return fp16_weights_; } + Conv2DTester& SamePadding(bool same_padding) { same_padding_ = same_padding; return *this; @@ -154,11 +163,7 @@ class Conv2DTester { bool SamePadding() const { return same_padding_; } void Test(TfLiteDelegate* delegate) const { - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto f32rng = std::bind(std::uniform_real_distribution(), rng); - - std::vector buffer = CreateTfLiteModel(std::ref(f32rng)); + std::vector buffer = CreateTfLiteModel(); const Model* model = GetModel(buffer.data()); std::unique_ptr delegate_interpreter; @@ -187,6 +192,10 @@ class Conv2DTester { ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate), kTfLiteOk); + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto f32rng = std::bind(std::uniform_real_distribution(), rng); + float* default_input_data = default_interpreter->typed_tensor( default_interpreter->inputs()[0]); std::generate(default_input_data, @@ -219,82 +228,149 @@ class Conv2DTester { } private: - std::vector CreateTfLiteModel(std::function f32rng) const { + std::vector CreateTfLiteModel() const { + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto f32rng = std::bind(std::uniform_real_distribution(), rng); + flatbuffers::FlatBufferBuilder builder; - flatbuffers::Offset operator_code = - CreateOperatorCode(builder, BuiltinOperator_CONV_2D, 0); + std::vector> operator_codes{ + {CreateOperatorCode(builder, BuiltinOperator_CONV_2D, 0)}}; + std::vector> operators; + std::vector> buffers{ + {CreateBuffer(builder, builder.CreateVector({}))}}; + + if (FP16Weights()) { + operator_codes.emplace_back( + CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE)); + + auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng); + + std::vector filter_data(OutputChannels() * KernelHeight() * + KernelWidth() * InputChannels()); + std::vector bias_data(OutputChannels()); + + std::generate(filter_data.begin(), filter_data.end(), f16rng); + std::generate(bias_data.begin(), bias_data.end(), f16rng); + + buffers.emplace_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(filter_data.data()), + sizeof(uint16_t) * filter_data.size()))); + buffers.emplace_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(bias_data.data()), + sizeof(uint16_t) * bias_data.size()))); + + const std::array dequantize_filter_inputs{{0}}; + const std::array dequantize_filter_outputs{{3}}; + operators.emplace_back(CreateOperator( + builder, /*opcode_index=*/1, + builder.CreateVector(dequantize_filter_inputs.data(), + dequantize_filter_inputs.size()), + builder.CreateVector(dequantize_filter_outputs.data(), + dequantize_filter_outputs.size()))); + const std::array dequantize_bias_inputs{{1}}; + const std::array dequantize_bias_outputs{{4}}; + operators.emplace_back(CreateOperator( + builder, /*opcode_index=*/1, + builder.CreateVector(dequantize_bias_inputs.data(), + dequantize_bias_inputs.size()), + builder.CreateVector(dequantize_bias_outputs.data(), + dequantize_bias_outputs.size()))); + } else { + std::vector filter_data(OutputChannels() * KernelHeight() * + KernelWidth() * InputChannels()); + std::vector bias_data(OutputChannels()); + + std::generate(filter_data.begin(), filter_data.end(), f32rng); + std::generate(bias_data.begin(), bias_data.end(), f32rng); + + buffers.emplace_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(filter_data.data()), + sizeof(float) * filter_data.size()))); + buffers.emplace_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(bias_data.data()), + sizeof(float) * bias_data.size()))); + } + + const std::array input_shape{ + {BatchSize(), InputHeight(), InputWidth(), InputChannels()}}; + const std::array output_shape{ + {BatchSize(), OutputHeight(), OutputWidth(), OutputChannels()}}; + const std::array filter_shape{ + {OutputChannels(), KernelHeight(), KernelWidth(), InputChannels()}}; + const std::array bias_shape{{OutputChannels()}}; + + std::vector> tensors; + if (FP16Weights()) { + tensors.emplace_back( + CreateTensor(builder, + builder.CreateVector(filter_shape.data(), + filter_shape.size()), + TensorType_FLOAT16, /*buffer=*/1)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(bias_shape.data(), bias_shape.size()), + TensorType_FLOAT16, /*buffer=*/2)); + } + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(input_shape.data(), input_shape.size()), + TensorType_FLOAT32)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(filter_shape.data(), filter_shape.size()), + TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 1)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(bias_shape.data(), bias_shape.size()), + TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(output_shape.data(), output_shape.size()), + TensorType_FLOAT32)); + + const std::array op_inputs{ + {static_cast(tensors.size()) - 4, + static_cast(tensors.size()) - 3, + static_cast(tensors.size()) - 2}}; + const std::array op_outputs{ + {static_cast(tensors.size()) - 1}}; flatbuffers::Offset conv2d_options = CreateConv2DOptions( builder, SamePadding() ? tflite::Padding_SAME : tflite::Padding_VALID, StrideWidth(), StrideHeight(), ActivationFunctionType_NONE, DilationWidth(), DilationHeight()); - std::vector filter_data(OutputChannels() * KernelHeight() * - KernelWidth() * InputChannels()); - std::vector bias_data(OutputChannels()); + operators.emplace_back(CreateOperator( + builder, /*opcode_index=*/0, + builder.CreateVector(op_inputs.data(), op_inputs.size()), + builder.CreateVector(op_outputs.data(), op_outputs.size()), + BuiltinOptions_Conv2DOptions, conv2d_options.Union())); - std::generate(filter_data.begin(), filter_data.end(), f32rng); - std::generate(bias_data.begin(), bias_data.end(), f32rng); - - flatbuffers::Offset buffers[3] = { - CreateBuffer(builder, builder.CreateVector({})), - CreateBuffer(builder, - builder.CreateVector( - reinterpret_cast(filter_data.data()), - sizeof(float) * filter_data.size())), - CreateBuffer(builder, - builder.CreateVector( - reinterpret_cast(bias_data.data()), - sizeof(float) * bias_data.size())), - }; - - const int32_t input_shape[4] = {BatchSize(), InputHeight(), InputWidth(), - InputChannels()}; - const int32_t output_shape[4] = {BatchSize(), OutputHeight(), OutputWidth(), - OutputChannels()}; - const int32_t filter_shape[4] = {OutputChannels(), KernelHeight(), - KernelWidth(), InputChannels()}; - const int32_t bias_shape[1] = {OutputChannels()}; - - flatbuffers::Offset tensors[4] = { - CreateTensor(builder, builder.CreateVector(input_shape, 4), - TensorType_FLOAT32, /*buffer=*/0, - builder.CreateString("X")), - CreateTensor(builder, builder.CreateVector(filter_shape, 4), - TensorType_FLOAT32, /*buffer=*/1, - builder.CreateString("W")), - CreateTensor(builder, builder.CreateVector(bias_shape, 1), - TensorType_FLOAT32, /*buffer=*/2, - builder.CreateString("b")), - CreateTensor(builder, builder.CreateVector(output_shape, 4), - TensorType_FLOAT32, /*buffer=*/0, - builder.CreateString("Y")), - }; - - const int32_t op_inputs[3] = {0, 1, 2}; - const int32_t op_outputs[1] = {3}; - - flatbuffers::Offset op = - CreateOperator(builder, /*opcode_index=*/0, - builder.CreateVector(op_inputs, 3), - builder.CreateVector(op_outputs, 1), - BuiltinOptions_Conv2DOptions, conv2d_options.Union()); - - int32_t subgraph_inputs[1] = {0}; - int32_t subgraph_outputs[1] = {3}; - flatbuffers::Offset subgraph = - CreateSubGraph(builder, builder.CreateVector(tensors, 4), - builder.CreateVector(subgraph_inputs, 1), - builder.CreateVector(subgraph_outputs, 1), - builder.CreateVector(&op, 1), /*name=*/0); + const std::array subgraph_inputs{ + {static_cast(tensors.size()) - 4}}; + const std::array subgraph_outputs{ + {static_cast(tensors.size()) - 1}}; + flatbuffers::Offset subgraph = CreateSubGraph( + builder, builder.CreateVector(tensors.data(), tensors.size()), + builder.CreateVector(subgraph_inputs.data(), + subgraph_inputs.size()), + builder.CreateVector(subgraph_outputs.data(), + subgraph_outputs.size()), + builder.CreateVector(operators.data(), operators.size())); flatbuffers::Offset description = builder.CreateString("Conv2D model"); flatbuffers::Offset model_buffer = CreateModel( - builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1), + builder, TFLITE_SCHEMA_VERSION, + builder.CreateVector(operator_codes.data(), operator_codes.size()), builder.CreateVector(&subgraph, 1), description, - builder.CreateVector(buffers, 3)); + builder.CreateVector(buffers.data(), buffers.size())); builder.Finish(model_buffer); @@ -313,6 +389,7 @@ class Conv2DTester { int32_t stride_width_ = 1; int32_t dilation_height_ = 1; int32_t dilation_width_ = 1; + bool fp16_weights_ = false; bool same_padding_ = true; }; @@ -506,5 +583,35 @@ TEST(Conv2D, DilationWithValidPadding) { .Test(xnnpack_delegate.get()); } +TEST(Conv2D, FP16Weights) { + std::unique_ptr + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto input_rng = + std::bind(std::uniform_int_distribution(10, 25), std::ref(rng)); + auto kernel_rng = + std::bind(std::uniform_int_distribution(3, 5), std::ref(rng)); + auto stride_rng = + std::bind(std::uniform_int_distribution(2, 3), std::ref(rng)); + auto channel_rng = + std::bind(std::uniform_int_distribution(1, 16), std::ref(rng)); + + Conv2DTester() + .InputHeight(input_rng()) + .InputWidth(input_rng()) + .InputChannels(channel_rng()) + .OutputChannels(channel_rng()) + .KernelHeight(kernel_rng()) + .KernelWidth(kernel_rng()) + .StrideHeight(stride_rng()) + .StrideWidth(stride_rng()) + .SamePadding(true) + .FP16Weights() + .Test(xnnpack_delegate.get()); +} + } // namespace xnnpack } // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc index fd82e4fd83f..c9d274cbe01 100644 --- a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_test.cc @@ -371,6 +371,37 @@ TEST(DepthwiseConv2D, DepthMultiplier) { .Test(xnnpack_delegate.get()); } +TEST(DepthwiseConv2D, FP16Weights) { + std::unique_ptr + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto batch_rng = + std::bind(std::uniform_int_distribution(2, 4), std::ref(rng)); + auto input_rng = + std::bind(std::uniform_int_distribution(10, 25), std::ref(rng)); + auto kernel_rng = + std::bind(std::uniform_int_distribution(3, 5), std::ref(rng)); + auto stride_rng = + std::bind(std::uniform_int_distribution(2, 3), std::ref(rng)); + auto channel_rng = + std::bind(std::uniform_int_distribution(3, 32), std::ref(rng)); + + DepthwiseConv2DTester() + .BatchSize(batch_rng()) + .InputHeight(input_rng()) + .InputWidth(input_rng()) + .InputChannels(channel_rng()) + .KernelHeight(kernel_rng()) + .KernelWidth(kernel_rng()) + .StrideHeight(stride_rng()) + .StrideWidth(stride_rng()) + .FP16Weights() + .Test(xnnpack_delegate.get()); +} + TEST(DepthwiseConv2D, ReluActivation) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), diff --git a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc index b6d1dfec69b..9b6749e42f6 100644 --- a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" @@ -107,56 +108,110 @@ void DepthwiseConv2DTester::Test(TfLiteDelegate* delegate) const { } std::vector DepthwiseConv2DTester::CreateTfLiteModel() const { - flatbuffers::FlatBufferBuilder builder; - flatbuffers::Offset operator_code = - CreateOperatorCode(builder, BuiltinOperator_DEPTHWISE_CONV_2D); - - flatbuffers::Offset depthwise_conv2d_options = - CreateDepthwiseConv2DOptions( - builder, Padding(), StrideWidth(), StrideHeight(), DepthMultiplier(), - Activation(), DilationWidth(), DilationHeight()); - - std::vector filter_data(KernelHeight() * KernelWidth() * - OutputChannels()); - std::vector bias_data(OutputChannels()); - std::random_device random_device; auto rng = std::mt19937(random_device()); auto range_rng = std::bind( std::uniform_real_distribution(-25.0f, 25.0f), std::ref(rng)); - for (int32_t ic = 0; ic < InputChannels(); ic++) { - // Use the same range of all-positive or all-negative values to generate - // all pixels within the same batch index & channel, but different ranges - // for different channels or batches. This ensures that no catastrophic - // cancellation occur, but test covers both positive and negative inputs. - const float range = range_rng(); - auto value_rng = - std::bind(std::uniform_real_distribution(std::min(range, 0.0f), - std::max(range, 0.0f)), - std::ref(rng)); - for (int32_t m = 0; m < DepthMultiplier(); m++) { - const int32_t oc = ic * DepthMultiplier() + m; - bias_data[oc] = value_rng(); - for (int32_t y = 0; y < KernelHeight(); y++) { - for (int32_t x = 0; x < KernelWidth(); x++) { - const int32_t index = (y * KernelWidth() + x) * OutputChannels() + oc; - filter_data[index] = value_rng(); + + flatbuffers::FlatBufferBuilder builder; + std::vector> operator_codes{ + {CreateOperatorCode(builder, BuiltinOperator_DEPTHWISE_CONV_2D)}}; + std::vector> operators; + std::vector> buffers{ + {CreateBuffer(builder, builder.CreateVector({}))}}; + + if (FP16Weights()) { + operator_codes.emplace_back( + CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE)); + + std::vector filter_data(KernelHeight() * KernelWidth() * + OutputChannels()); + std::vector bias_data(OutputChannels()); + for (int32_t ic = 0; ic < InputChannels(); ic++) { + // Use the same range of all-positive or all-negative values to generate + // all pixels within the same batch index & channel, but different ranges + // for different channels or batches. This ensures that no catastrophic + // cancellation occur, but test covers both positive and negative inputs. + const float range = range_rng(); + auto value_rng = + std::bind(fp16_ieee_from_fp32_value, + std::bind(std::uniform_real_distribution( + std::min(range, 0.0f), std::max(range, 0.0f)), + std::ref(rng))); + for (int32_t m = 0; m < DepthMultiplier(); m++) { + const int32_t oc = ic * DepthMultiplier() + m; + bias_data[oc] = value_rng(); + for (int32_t y = 0; y < KernelHeight(); y++) { + for (int32_t x = 0; x < KernelWidth(); x++) { + const int32_t index = + (y * KernelWidth() + x) * OutputChannels() + oc; + filter_data[index] = value_rng(); + } } } } - } - const std::array, 3> buffers{{ - CreateBuffer(builder, builder.CreateVector({})), - CreateBuffer(builder, - builder.CreateVector( - reinterpret_cast(filter_data.data()), - sizeof(float) * filter_data.size())), - CreateBuffer(builder, - builder.CreateVector( - reinterpret_cast(bias_data.data()), - sizeof(float) * bias_data.size())), - }}; + buffers.emplace_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(filter_data.data()), + sizeof(uint16_t) * filter_data.size()))); + buffers.emplace_back(CreateBuffer( + builder, + builder.CreateVector(reinterpret_cast(bias_data.data()), + sizeof(uint16_t) * bias_data.size()))); + + const std::array dequantize_filter_inputs{{0}}; + const std::array dequantize_filter_outputs{{3}}; + operators.emplace_back(CreateOperator( + builder, /*opcode_index=*/1, + builder.CreateVector(dequantize_filter_inputs.data(), + dequantize_filter_inputs.size()), + builder.CreateVector(dequantize_filter_outputs.data(), + dequantize_filter_outputs.size()))); + const std::array dequantize_bias_inputs{{1}}; + const std::array dequantize_bias_outputs{{4}}; + operators.emplace_back(CreateOperator( + builder, /*opcode_index=*/1, + builder.CreateVector(dequantize_bias_inputs.data(), + dequantize_bias_inputs.size()), + builder.CreateVector(dequantize_bias_outputs.data(), + dequantize_bias_outputs.size()))); + } else { + std::vector filter_data(KernelHeight() * KernelWidth() * + OutputChannels()); + std::vector bias_data(OutputChannels()); + for (int32_t ic = 0; ic < InputChannels(); ic++) { + // Use the same range of all-positive or all-negative values to generate + // all pixels within the same batch index & channel, but different ranges + // for different channels or batches. This ensures that no catastrophic + // cancellation occur, but test covers both positive and negative inputs. + const float range = range_rng(); + auto value_rng = + std::bind(std::uniform_real_distribution( + std::min(range, 0.0f), std::max(range, 0.0f)), + std::ref(rng)); + for (int32_t m = 0; m < DepthMultiplier(); m++) { + const int32_t oc = ic * DepthMultiplier() + m; + bias_data[oc] = value_rng(); + for (int32_t y = 0; y < KernelHeight(); y++) { + for (int32_t x = 0; x < KernelWidth(); x++) { + const int32_t index = + (y * KernelWidth() + x) * OutputChannels() + oc; + filter_data[index] = value_rng(); + } + } + } + } + + buffers.emplace_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(filter_data.data()), + sizeof(float) * filter_data.size()))); + buffers.emplace_back(CreateBuffer( + builder, + builder.CreateVector(reinterpret_cast(bias_data.data()), + sizeof(float) * bias_data.size()))); + } const std::array input_shape{ {BatchSize(), InputHeight(), InputWidth(), InputChannels()}}; @@ -166,49 +221,69 @@ std::vector DepthwiseConv2DTester::CreateTfLiteModel() const { {1, KernelHeight(), KernelWidth(), OutputChannels()}}; const std::array bias_shape{{OutputChannels()}}; - const std::array, 4> tensors{{ - CreateTensor( - builder, - builder.CreateVector(input_shape.data(), input_shape.size()), - TensorType_FLOAT32), - CreateTensor(builder, - builder.CreateVector(filter_shape.data(), - filter_shape.size()), - TensorType_FLOAT32, /*buffer=*/1), - CreateTensor( - builder, - builder.CreateVector(bias_shape.data(), bias_shape.size()), - TensorType_FLOAT32, /*buffer=*/2), - CreateTensor(builder, - builder.CreateVector(output_shape.data(), - output_shape.size()), - TensorType_FLOAT32), - }}; + std::vector> tensors; + if (FP16Weights()) { + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(filter_shape.data(), filter_shape.size()), + TensorType_FLOAT16, /*buffer=*/1)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(bias_shape.data(), bias_shape.size()), + TensorType_FLOAT16, /*buffer=*/2)); + } + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(input_shape.data(), input_shape.size()), + TensorType_FLOAT32)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(filter_shape.data(), filter_shape.size()), + TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 1)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(bias_shape.data(), bias_shape.size()), + TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(output_shape.data(), output_shape.size()), + TensorType_FLOAT32)); - const std::array op_inputs{{0, 1, 2}}; - const std::array op_outputs{{3}}; + const std::array op_inputs{ + {static_cast(tensors.size()) - 4, + static_cast(tensors.size()) - 3, + static_cast(tensors.size()) - 2}}; + const std::array op_outputs{ + {static_cast(tensors.size()) - 1}}; - flatbuffers::Offset op = CreateOperator( + flatbuffers::Offset depthwise_conv2d_options = + CreateDepthwiseConv2DOptions( + builder, Padding(), StrideWidth(), StrideHeight(), DepthMultiplier(), + Activation(), DilationWidth(), DilationHeight()); + operators.emplace_back(CreateOperator( builder, /*opcode_index=*/0, builder.CreateVector(op_inputs.data(), op_inputs.size()), builder.CreateVector(op_outputs.data(), op_outputs.size()), - BuiltinOptions_DepthwiseConv2DOptions, depthwise_conv2d_options.Union()); + BuiltinOptions_DepthwiseConv2DOptions, depthwise_conv2d_options.Union())); - const std::array subgraph_inputs{{0}}; - const std::array subgraph_outputs{{3}}; + const std::array subgraph_inputs{ + {static_cast(tensors.size()) - 4}}; + const std::array subgraph_outputs{ + {static_cast(tensors.size()) - 1}}; flatbuffers::Offset subgraph = CreateSubGraph( builder, builder.CreateVector(tensors.data(), tensors.size()), builder.CreateVector(subgraph_inputs.data(), subgraph_inputs.size()), builder.CreateVector(subgraph_outputs.data(), subgraph_outputs.size()), - builder.CreateVector(&op, 1)); + builder.CreateVector(operators.data(), operators.size())); flatbuffers::Offset description = builder.CreateString("DepthwiseConv2D model"); flatbuffers::Offset model_buffer = CreateModel( - builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1), + builder, TFLITE_SCHEMA_VERSION, + builder.CreateVector(operator_codes.data(), operator_codes.size()), builder.CreateVector(&subgraph, 1), description, builder.CreateVector(buffers.data(), buffers.size())); diff --git a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h index 16dc5920229..102c66af340 100644 --- a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h +++ b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h @@ -152,6 +152,13 @@ class DepthwiseConv2DTester { return (KernelWidth() - 1) * DilationWidth() + 1; } + inline DepthwiseConv2DTester& FP16Weights() { + fp16_weights_ = true; + return *this; + } + + inline bool FP16Weights() const { return fp16_weights_; } + inline DepthwiseConv2DTester& SamePadding() { padding_ = ::tflite::Padding_SAME; return *this; @@ -209,6 +216,7 @@ class DepthwiseConv2DTester { int32_t stride_width_ = 1; int32_t dilation_height_ = 1; int32_t dilation_width_ = 1; + bool fp16_weights_ = false; ::tflite::Padding padding_ = ::tflite::Padding_VALID; ::tflite::ActivationFunctionType activation_ = ::tflite::ActivationFunctionType_NONE; diff --git a/tensorflow/lite/delegates/xnnpack/fully_connected_test.cc b/tensorflow/lite/delegates/xnnpack/fully_connected_test.cc index a801ce141ed..0dffd1dee19 100644 --- a/tensorflow/lite/delegates/xnnpack/fully_connected_test.cc +++ b/tensorflow/lite/delegates/xnnpack/fully_connected_test.cc @@ -228,6 +228,29 @@ TEST(FullyConnected, 4DKeepDims) { .Test(xnnpack_delegate.get()); } +TEST(FullyConnected, FP16Weights) { + std::unique_ptr + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto batch_rng = + std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); + auto channels_rng = + std::bind(std::uniform_int_distribution(2, 9), std::ref(rng)); + const auto batch = batch_rng(); + const auto input_channels = channels_rng(); + const auto output_channels = channels_rng(); + + FullyConnectedTester() + .InputShape({batch, input_channels}) + .InputChannels(input_channels) + .OutputChannels(output_channels) + .FP16Weights() + .Test(xnnpack_delegate.get()); +} + TEST(FullyConnected, ReluActivation) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), diff --git a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc index 05716bf18fb..8962b8ba7ba 100644 --- a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" @@ -109,98 +110,165 @@ void FullyConnectedTester::Test(TfLiteDelegate* delegate) const { std::vector FullyConnectedTester::CreateTfLiteModel() const { std::random_device random_device; auto rng = std::mt19937(random_device()); - auto range_rng = std::bind( std::uniform_real_distribution(-25.0f, 25.0f), std::ref(rng)); flatbuffers::FlatBufferBuilder builder; - flatbuffers::Offset operator_code = - CreateOperatorCode(builder, BuiltinOperator_FULLY_CONNECTED); + std::vector> operator_codes{ + {CreateOperatorCode(builder, BuiltinOperator_FULLY_CONNECTED)}}; + std::vector> operators; + std::vector> buffers{ + {CreateBuffer(builder, builder.CreateVector({}))}}; - std::vector filter_data(InputChannels() * OutputChannels()); - std::vector bias_data(OutputChannels()); + if (FP16Weights()) { + operator_codes.emplace_back( + CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE)); - for (int32_t oc = 0; oc < OutputChannels(); oc++) { - // Use the same range of all-positive or all-negative values to generate - // all filter & bias weights within the same channel, but different ranges - // for different output channels. This ensures that no catastrophic - // cancellation occur, but test covers both positive and negative inputs. - const float range = range_rng(); - auto value_rng = - std::bind(std::uniform_real_distribution(std::min(range, 0.0f), - std::max(range, 0.0f)), - std::ref(rng)); + std::vector filter_data(InputChannels() * OutputChannels()); + std::vector bias_data(OutputChannels()); - bias_data[oc] = value_rng(); - for (int32_t ic = 0; ic < InputChannels(); ic++) { - filter_data[oc * InputChannels() + ic] = value_rng(); + for (int32_t oc = 0; oc < OutputChannels(); oc++) { + // Use the same range of all-positive or all-negative values to generate + // all filter & bias weights within the same channel, but different ranges + // for different output channels. This ensures that no catastrophic + // cancellation occur, but test covers both positive and negative inputs. + const float range = range_rng(); + auto value_rng = + std::bind(fp16_ieee_from_fp32_value, + std::bind(std::uniform_real_distribution( + std::min(range, 0.0f), std::max(range, 0.0f)), + std::ref(rng))); + + bias_data[oc] = value_rng(); + for (int32_t ic = 0; ic < InputChannels(); ic++) { + filter_data[oc * InputChannels() + ic] = value_rng(); + } } - } - std::array, 3> buffers{{ - CreateBuffer(builder, builder.CreateVector({})), - CreateBuffer(builder, - builder.CreateVector( - reinterpret_cast(filter_data.data()), - sizeof(float) * filter_data.size())), - CreateBuffer(builder, - builder.CreateVector( - reinterpret_cast(bias_data.data()), - sizeof(float) * bias_data.size())), - }}; + buffers.emplace_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(filter_data.data()), + sizeof(uint16_t) * filter_data.size()))); + buffers.emplace_back(CreateBuffer( + builder, + builder.CreateVector(reinterpret_cast(bias_data.data()), + sizeof(uint16_t) * bias_data.size()))); + + const std::array dequantize_filter_inputs{{0}}; + const std::array dequantize_filter_outputs{{3}}; + operators.emplace_back(CreateOperator( + builder, /*opcode_index=*/1, + builder.CreateVector(dequantize_filter_inputs.data(), + dequantize_filter_inputs.size()), + builder.CreateVector(dequantize_filter_outputs.data(), + dequantize_filter_outputs.size()))); + const std::array dequantize_bias_inputs{{1}}; + const std::array dequantize_bias_outputs{{4}}; + operators.emplace_back(CreateOperator( + builder, /*opcode_index=*/1, + builder.CreateVector(dequantize_bias_inputs.data(), + dequantize_bias_inputs.size()), + builder.CreateVector(dequantize_bias_outputs.data(), + dequantize_bias_outputs.size()))); + } else { + std::vector filter_data(InputChannels() * OutputChannels()); + std::vector bias_data(OutputChannels()); + + for (int32_t oc = 0; oc < OutputChannels(); oc++) { + // Use the same range of all-positive or all-negative values to generate + // all filter & bias weights within the same channel, but different ranges + // for different output channels. This ensures that no catastrophic + // cancellation occur, but test covers both positive and negative inputs. + const float range = range_rng(); + auto value_rng = + std::bind(std::uniform_real_distribution( + std::min(range, 0.0f), std::max(range, 0.0f)), + std::ref(rng)); + + bias_data[oc] = value_rng(); + for (int32_t ic = 0; ic < InputChannels(); ic++) { + filter_data[oc * InputChannels() + ic] = value_rng(); + } + } + + buffers.emplace_back(CreateBuffer( + builder, builder.CreateVector( + reinterpret_cast(filter_data.data()), + sizeof(float) * filter_data.size()))); + buffers.emplace_back(CreateBuffer( + builder, + builder.CreateVector(reinterpret_cast(bias_data.data()), + sizeof(float) * bias_data.size()))); + } const std::array filter_shape( {OutputChannels(), InputChannels()}); const std::array bias_shape({OutputChannels()}); const std::vector output_shape = OutputShape(); - const std::array, 4> tensors{{ - CreateTensor(builder, - builder.CreateVector(InputShape().data(), - InputShape().size()), - TensorType_FLOAT32), - CreateTensor(builder, - builder.CreateVector(filter_shape.data(), - filter_shape.size()), - TensorType_FLOAT32, /*buffer=*/1), - CreateTensor( - builder, - builder.CreateVector(bias_shape.data(), bias_shape.size()), - TensorType_FLOAT32, /*buffer=*/2), - CreateTensor(builder, - builder.CreateVector(output_shape.data(), - output_shape.size()), - TensorType_FLOAT32), - }}; + std::vector> tensors; + if (FP16Weights()) { + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(filter_shape.data(), filter_shape.size()), + TensorType_FLOAT16, /*buffer=*/1)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(bias_shape.data(), bias_shape.size()), + TensorType_FLOAT16, /*buffer=*/2)); + } + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(InputShape().data(), InputShape().size()), + TensorType_FLOAT32)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(filter_shape.data(), filter_shape.size()), + TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 1)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(bias_shape.data(), bias_shape.size()), + TensorType_FLOAT32, /*buffer=*/FP16Weights() ? 0 : 2)); + tensors.emplace_back(CreateTensor( + builder, + builder.CreateVector(output_shape.data(), output_shape.size()), + TensorType_FLOAT32)); flatbuffers::Offset fully_connected_options = CreateFullyConnectedOptions(builder, Activation(), FullyConnectedOptionsWeightsFormat_DEFAULT, KeepDims()); - const std::array op_inputs{{0, 1, 2}}; - const std::array op_outputs{{3}}; - flatbuffers::Offset op = CreateOperator( + const std::array op_inputs{ + {static_cast(tensors.size()) - 4, + static_cast(tensors.size()) - 3, + static_cast(tensors.size()) - 2}}; + const std::array op_outputs{ + {static_cast(tensors.size()) - 1}}; + operators.emplace_back(CreateOperator( builder, /*opcode_index=*/0, builder.CreateVector(op_inputs.data(), op_inputs.size()), builder.CreateVector(op_outputs.data(), op_outputs.size()), - BuiltinOptions_FullyConnectedOptions, fully_connected_options.Union()); + BuiltinOptions_FullyConnectedOptions, fully_connected_options.Union())); - const std::array subgraph_inputs{{0}}; - const std::array subgraph_outputs{{3}}; + const std::array subgraph_inputs{ + {static_cast(tensors.size()) - 4}}; + const std::array subgraph_outputs{ + {static_cast(tensors.size()) - 1}}; flatbuffers::Offset subgraph = CreateSubGraph( builder, builder.CreateVector(tensors.data(), tensors.size()), builder.CreateVector(subgraph_inputs.data(), subgraph_inputs.size()), builder.CreateVector(subgraph_outputs.data(), subgraph_outputs.size()), - builder.CreateVector(&op, 1)); + builder.CreateVector(operators.data(), operators.size())); flatbuffers::Offset description = builder.CreateString("Fully Connected model"); flatbuffers::Offset model_buffer = CreateModel( - builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1), + builder, TFLITE_SCHEMA_VERSION, + builder.CreateVector(operator_codes.data(), operator_codes.size()), builder.CreateVector(&subgraph, 1), description, builder.CreateVector(buffers.data(), buffers.size())); diff --git a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h index cf1d5513d46..6350bc8d739 100644 --- a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h +++ b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h @@ -71,6 +71,13 @@ class FullyConnectedTester { inline bool KeepDims() const { return keep_dims_; } + inline FullyConnectedTester& FP16Weights() { + fp16_weights_ = true; + return *this; + } + + inline bool FP16Weights() const { return fp16_weights_; } + inline FullyConnectedTester& ReluActivation() { activation_ = ::tflite::ActivationFunctionType_RELU; return *this; @@ -102,6 +109,7 @@ class FullyConnectedTester { int32_t input_channels_ = 1; int32_t output_channels_ = 1; bool keep_dims_ = false; + bool fp16_weights_ = false; ::tflite::ActivationFunctionType activation_ = ::tflite::ActivationFunctionType_NONE; }; diff --git a/tensorflow/lite/delegates/xnnpack/mul_test.cc b/tensorflow/lite/delegates/xnnpack/mul_test.cc index 6c0475e2b64..2dbb2663b80 100644 --- a/tensorflow/lite/delegates/xnnpack/mul_test.cc +++ b/tensorflow/lite/delegates/xnnpack/mul_test.cc @@ -679,6 +679,35 @@ TEST(Mul, 2DByStatic0D) { .Test(BuiltinOperator_MUL, xnnpack_delegate.get()); } +TEST(Mul, FP16Weights) { + std::unique_ptr + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto shape_rng = + std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); + const auto batch = shape_rng(); + const auto height = shape_rng(); + const auto width = shape_rng(); + const auto channels = shape_rng(); + + BinaryElementwiseTester() + .Input1Shape({batch, height, width, channels}) + .Input2Shape({batch, height, width, channels}) + .Input1Static(true) + .FP16Weights() + .Test(BuiltinOperator_MUL, xnnpack_delegate.get()); + + BinaryElementwiseTester() + .Input1Shape({batch, height, width, channels}) + .Input2Shape({batch, height, width, channels}) + .Input2Static(true) + .FP16Weights() + .Test(BuiltinOperator_MUL, xnnpack_delegate.get()); +} + TEST(Mul, ReluActivation) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 2beaa16255d..32fcbee4c22 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -22,10 +22,12 @@ limitations under the License. #include #include #include +#include #include #include #include +#include #include #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/builtin_op_data.h" @@ -39,6 +41,8 @@ namespace { TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate); class Delegate { + friend class Subgraph; + public: explicit Delegate(const TfLiteXNNPackDelegateOptions* options) { #if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) @@ -49,9 +53,10 @@ class Delegate { #endif } + TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context); TfLiteDelegate* tflite_delegate() { return &delegate_; } - pthreadpool_t threadpool() { + pthreadpool_t threadpool() const { #if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__) return nullptr; #else @@ -69,6 +74,17 @@ class Delegate { kTfLiteDelegateFlagsNone, // .flags }; + // Unpacked data for quasi-static tensors, i.e. tensors produced by + // dequantizing or unpacking static buffers. + std::vector static_unpacked_data_; + // Mapping from a tensor index for a quasi-static tensor to the offset to + // its unpacked data within static_unpacked_data_. + std::unordered_map static_unpacked_data_map_; + // Set of indices of nodes which unpack static data, e.g. Dequantize + // operators which convert FP16 static weights to FP32. These nodes are simply + // ignored in the delegate implementation, because their outputs are + // pre-unpacked in DelegatePrepare. + std::unordered_set static_unpack_nodes_; #if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) // Thread pool with smart-pointer for lifetime management. std::unique_ptr threadpool_{ @@ -80,7 +96,7 @@ class Subgraph { public: static Subgraph* Create(TfLiteContext* context, const TfLiteDelegateParams* params, - pthreadpool_t threadpool) { + const Delegate* delegate) { // Convert subgraph inputs and outputs to hash sets for faster lookup. const std::unordered_set inputs( ¶ms->input_tensors->data[0], @@ -113,11 +129,17 @@ class Subgraph { // filtered out and removed later. std::vector tensors(context->tensors_size, -1); for (int i = 0; i < params->nodes_to_replace->size; i++) { + const int node_index = params->nodes_to_replace->data[i]; + if (delegate->static_unpack_nodes_.count(node_index)) { + // The node unpacks static input and can be skipped because its input + // was pre-unpacked in DelegatePrepare. + continue; + } + TfLiteNode* node = nullptr; TfLiteRegistration* registration = nullptr; - if (context->GetNodeAndRegistration(context, - params->nodes_to_replace->data[i], - &node, ®istration) != kTfLiteOk) { + if (context->GetNodeAndRegistration(context, node_index, &node, + ®istration) != kTfLiteOk) { return nullptr; } @@ -164,6 +186,12 @@ class Subgraph { const void* data = nullptr; if (context->tensors[t].allocation_type == kTfLiteMmapRo) { data = context->tensors[t].data.raw_const; + } else { + // Check for quasi-static data. + const auto it = delegate->static_unpacked_data_map_.find(t); + if (it != delegate->static_unpacked_data_map_.end()) { + data = delegate->static_unpacked_data_.data() + it->second; + } } if (inputs.count(t) != 0) { flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT; @@ -189,25 +217,38 @@ class Subgraph { } } + // Create a set of quasi-static tensors for VisitNode function + std::unordered_set quasi_static_tensors; + for (const std::pair& entry : + delegate->static_unpacked_data_map_) { + quasi_static_tensors.insert(entry.first); + } + // Create XNNPACK nodes for TFLite delegate nodes for (int i = 0; i < params->nodes_to_replace->size; i++) { + const int node_index = params->nodes_to_replace->data[i]; + if (delegate->static_unpack_nodes_.count(node_index)) { + // The node unpacks static input and can be skipped because its input + // was pre-unpacked in DelegatePrepare. + continue; + } + TfLiteNode* node = nullptr; TfLiteRegistration* registration = nullptr; - if (context->GetNodeAndRegistration(context, - params->nodes_to_replace->data[i], - &node, ®istration) != kTfLiteOk) { + if (context->GetNodeAndRegistration(context, node_index, &node, + ®istration) != kTfLiteOk) { return nullptr; } - if (VisitNode(subgraph.get(), context, registration, node, i, - xnnpack_tensors) != kTfLiteOk) { + if (VisitNode(subgraph.get(), context, registration, node, node_index, + quasi_static_tensors, xnnpack_tensors) != kTfLiteOk) { return nullptr; } } xnn_runtime_t runtime_ptr = nullptr; - status = xnn_create_runtime_v2(subgraph.get(), threadpool, /*flags=*/0, - &runtime_ptr); + status = xnn_create_runtime_v2(subgraph.get(), delegate->threadpool(), + /*flags=*/0, &runtime_ptr); if (status != xnn_status_success) { TF_LITE_KERNEL_LOG(context, "failed to create XNNPACK runtime"); return nullptr; @@ -707,10 +748,11 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitNode(xnn_subgraph_t subgraph, TfLiteContext* context, - TfLiteRegistration* registration, - TfLiteNode* node, int node_index, - const std::vector& xnnpack_tensors) { + static TfLiteStatus VisitNode( + xnn_subgraph_t subgraph, TfLiteContext* context, + TfLiteRegistration* registration, TfLiteNode* node, int node_index, + const std::unordered_set& quasi_static_tensors, + const std::vector& xnnpack_tensors) { // TFLite context used for logging purposes. When we create a new node // (subgraph is non-null), logging context is the same as context, and error // messages are passed to TFLite. When we detect supported operations @@ -738,7 +780,8 @@ class Subgraph { static_cast(node->builtin_data); return VisitConv2DNode(subgraph, logging_context, node_index, node, - context->tensors, conv_params, xnnpack_tensors); + context->tensors, conv_params, + quasi_static_tensors, xnnpack_tensors); } case kTfLiteBuiltinDepthwiseConv2d: { const TfLiteDepthwiseConvParams* dwconv_params = @@ -746,7 +789,7 @@ class Subgraph { return VisitDepthwiseConv2DNode(subgraph, logging_context, node_index, node, context->tensors, dwconv_params, - xnnpack_tensors); + quasi_static_tensors, xnnpack_tensors); } case kTfLiteBuiltinFullyConnected: { const TfLiteFullyConnectedParams* fc_params = @@ -754,7 +797,7 @@ class Subgraph { return VisitFullyConnectedNode(subgraph, logging_context, node_index, node, context->tensors, fc_params, - xnnpack_tensors); + quasi_static_tensors, xnnpack_tensors); } case kTfLiteBuiltinHardSwish: return VisitHardSwishNode(subgraph, logging_context, node_index, node, @@ -782,7 +825,8 @@ class Subgraph { context->tensors, xnnpack_tensors); case kTfLiteBuiltinPrelu: return VisitPreluNode(subgraph, logging_context, node_index, node, - context->tensors, xnnpack_tensors); + context->tensors, quasi_static_tensors, + xnnpack_tensors); case kTfLiteBuiltinRelu: return VisitReluNode( subgraph, logging_context, node_index, node, context->tensors, 0.0f, @@ -810,7 +854,7 @@ class Subgraph { return VisitMediaPipeDeconvolutionNode( subgraph, context, node_index, node, context->tensors, - &deconv_params, xnnpack_tensors); + &deconv_params, quasi_static_tensors, xnnpack_tensors); } else if (strcmp(registration->custom_name, "MaxPoolingWithArgmax2D") == 0) { TfLitePoolParams pool_params = {kTfLitePaddingUnknown}; @@ -948,6 +992,7 @@ class Subgraph { xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, const TfLiteConvParams* conv_params, + const std::unordered_set& quasi_static_tensors, const std::vector& xnnpack_tensors) { TF_LITE_ENSURE_STATUS( CheckConvolutionParams(logging_context, conv_params, node_index)); @@ -968,16 +1013,20 @@ class Subgraph { logging_context, filter_tensor, node->inputs->data[1], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4, node->inputs->data[1])); - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, filter_tensor, node->inputs->data[1], node_index)); + if (quasi_static_tensors.count(node->inputs->data[1]) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, filter_tensor, node->inputs->data[1], node_index)); + } const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]]; TF_LITE_ENSURE_STATUS(CheckTensorFloatType( logging_context, filter_tensor, node->inputs->data[2], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1, node->inputs->data[2])); - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, bias_tensor, node->inputs->data[2], node_index)); + if (quasi_static_tensors.count(node->inputs->data[2]) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, bias_tensor, node->inputs->data[2], node_index)); + } const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloatType( @@ -1034,6 +1083,7 @@ class Subgraph { xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, const TfLiteDepthwiseConvParams* dwconv_params, + const std::unordered_set& quasi_static_tensors, const std::vector& xnnpack_tensors) { TF_LITE_ENSURE_STATUS( CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index)); @@ -1051,16 +1101,20 @@ class Subgraph { logging_context, filter_tensor, node->inputs->data[1], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4, node->inputs->data[1])); - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, filter_tensor, node->inputs->data[1], node_index)); + if (quasi_static_tensors.count(node->inputs->data[1]) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, filter_tensor, node->inputs->data[1], node_index)); + } const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]]; TF_LITE_ENSURE_STATUS(CheckTensorFloatType( logging_context, filter_tensor, node->inputs->data[2], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1, node->inputs->data[2])); - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, bias_tensor, node->inputs->data[2], node_index)); + if (quasi_static_tensors.count(node->inputs->data[2]) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, bias_tensor, node->inputs->data[2], node_index)); + } const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloatType( @@ -1123,6 +1177,7 @@ class Subgraph { xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, const TfLiteFullyConnectedParams* fc_params, + const std::unordered_set& quasi_static_tensors, const std::vector& xnnpack_tensors) { TF_LITE_ENSURE_STATUS( CheckFullyConnectedParams(logging_context, fc_params, node_index)); @@ -1141,16 +1196,20 @@ class Subgraph { logging_context, filter_tensor, node->inputs->data[1], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 2, node->inputs->data[1])); - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, filter_tensor, node->inputs->data[1], node_index)); + if (quasi_static_tensors.count(node->inputs->data[1]) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, filter_tensor, node->inputs->data[1], node_index)); + } const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]]; TF_LITE_ENSURE_STATUS(CheckTensorFloatType( logging_context, filter_tensor, node->inputs->data[2], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1, node->inputs->data[2])); - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, bias_tensor, node->inputs->data[2], node_index)); + if (quasi_static_tensors.count(node->inputs->data[2]) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, bias_tensor, node->inputs->data[2], node_index)); + } const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloatType( @@ -1387,6 +1446,7 @@ class Subgraph { xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, const TfLiteTransposeConvParams* deconv_params, + const std::unordered_set& quasi_static_tensors, const std::vector& xnnpack_tensors) { TF_LITE_ENSURE_STATUS( CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index)); @@ -1404,16 +1464,20 @@ class Subgraph { logging_context, filter_tensor, node->inputs->data[1], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4, node->inputs->data[1])); - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, filter_tensor, node->inputs->data[1], node_index)); + if (quasi_static_tensors.count(node->inputs->data[1]) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, filter_tensor, node->inputs->data[1], node_index)); + } const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]]; TF_LITE_ENSURE_STATUS(CheckTensorFloatType( logging_context, filter_tensor, node->inputs->data[2], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1, node->inputs->data[2])); - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, bias_tensor, node->inputs->data[2], node_index)); + if (quasi_static_tensors.count(node->inputs->data[2]) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, bias_tensor, node->inputs->data[2], node_index)); + } const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloatType( @@ -1735,6 +1799,7 @@ class Subgraph { static TfLiteStatus VisitPreluNode( xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, + const std::unordered_set& quasi_static_tensors, const std::vector& xnnpack_tensors) { TF_LITE_ENSURE_STATUS( CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index)); @@ -1752,8 +1817,10 @@ class Subgraph { logging_context, slope_tensor, node->inputs->data[1], node_index)); TF_LITE_ENSURE_STATUS(CheckSlopeTensorShape( logging_context, slope_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, slope_tensor, node->inputs->data[1], node_index)); + if (quasi_static_tensors.count(node->inputs->data[1]) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, slope_tensor, node->inputs->data[1], node_index)); + } const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloatType( @@ -1869,15 +1936,29 @@ class Subgraph { bool first_run_{true}; }; -TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { +TfLiteIntArray* Delegate::PrepareOpsToDelegate(TfLiteContext* context) { + // Clear previous data, in case the delegate is reused without re-creation. + static_unpacked_data_map_.clear(); + static_unpacked_data_.clear(); + static_unpack_nodes_.clear(); + TfLiteIntArray* execution_plan = nullptr; if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) { TF_LITE_KERNEL_LOG(context, "Unable to get graph execution plan."); return nullptr; } - TfLiteIntArray* nodes_to_replace = TfLiteIntArrayCreate(execution_plan->size); - nodes_to_replace->size = 0; + // Mapping for quasi-static (unpacked from static) tensor index to the node + // index that produced it. + std::unordered_map quasi_static_tensors_producers; + // Set of all quasi-static tensors in the execution plan. + std::unordered_set quasi_static_tensors; + // Set of quasi-static tensors consumed by the delegated nodes. + std::unordered_set quasi_static_tensors_to_unpack; + + TfLiteIntArray* nodes_to_delegate = + TfLiteIntArrayCreate(execution_plan->size); + nodes_to_delegate->size = 0; for (int i = 0; i < execution_plan->size; ++i) { const int node_index = execution_plan->data[i]; @@ -1892,15 +1973,142 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { continue; // Soft error (skip this node). } + if (registration->builtin_code == kTfLiteBuiltinDequantize && + node->inputs->size == 1 && node->outputs->size == 1) { + const TfLiteTensor& input_tensor = + context->tensors[node->inputs->data[0]]; + const TfLiteTensor& output_tensor = + context->tensors[node->outputs->data[0]]; + if (input_tensor.allocation_type == kTfLiteMmapRo && + input_tensor.type == kTfLiteFloat16 && + output_tensor.type == kTfLiteFloat32) { + static_unpack_nodes_.insert(i); + quasi_static_tensors_producers[node->outputs->data[0]] = i; + quasi_static_tensors.insert(node->outputs->data[0]); + + // Skip this node for now. If output of the node is consumed only by + // delegated nodes, it will be added to nodes_to_delegate in the end. + continue; + } + } + if (Subgraph::VisitNode(/*subgraph=*/nullptr, context, registration, node, - node_index, std::vector()) != kTfLiteOk) { + node_index, quasi_static_tensors, + std::vector()) != kTfLiteOk) { + // If a non-delegated node consumes output of a node that unpacks static + // data, that node shouldn't be delegated. + for (int j = 0; j < node->inputs->size; j++) { + const auto it = + quasi_static_tensors_producers.find(node->inputs->data[j]); + if (it != quasi_static_tensors_producers.end()) { + static_unpack_nodes_.erase(it->second); + } + } + // Non-delegatable node is not an error. continue; } - nodes_to_replace->data[nodes_to_replace->size++] = node_index; + for (int j = 0; j < node->inputs->size; j++) { + if (quasi_static_tensors.count(node->inputs->data[j]) != 0) { + quasi_static_tensors_to_unpack.insert(node->inputs->data[j]); + } + } + + nodes_to_delegate->data[nodes_to_delegate->size++] = node_index; } + // Unpack static data of all tensors + for (int t : quasi_static_tensors_to_unpack) { + const int producer_index = quasi_static_tensors_producers[t]; + // Check if TFLite nodes can be delegated to XNNPACK + TfLiteNode* node = nullptr; + TfLiteRegistration* registration = nullptr; + if (context->GetNodeAndRegistration(context, producer_index, &node, + ®istration) != kTfLiteOk) { + TF_LITE_KERNEL_LOG(context, + "Unable to get node and registration for node %d.", + producer_index); + TfLiteIntArrayFree(nodes_to_delegate); + return nullptr; // Hard error. + } + + if (node->inputs->size != 1) { + TF_LITE_KERNEL_LOG(context, "unexpected number of inputs (%d) in node %d", + node->inputs->size, producer_index); + TfLiteIntArrayFree(nodes_to_delegate); + return nullptr; // Hard error. + } + + if (node->outputs->size != 1) { + TF_LITE_KERNEL_LOG(context, + "unexpected number of outputs (%d) in node %d", + node->outputs->size, producer_index); + TfLiteIntArrayFree(nodes_to_delegate); + return nullptr; // Hard error. + } + + const TfLiteTensor& input_tensor = context->tensors[node->inputs->data[0]]; + if (input_tensor.allocation_type != kTfLiteMmapRo) { + TF_LITE_KERNEL_LOG(context, + "unexpected allocation type in tensor %d in node %d", + node->inputs->data[0], producer_index); + TfLiteIntArrayFree(nodes_to_delegate); + return nullptr; // Hard error. + } + + const TfLiteTensor& output_tensor = context->tensors[t]; + if (output_tensor.type != kTfLiteFloat32) { + TF_LITE_KERNEL_LOG(context, + "unexpected datatype (%s) in tensor %d in node %d", + TfLiteTypeGetName(output_tensor.type), + node->outputs->data[0], producer_index); + TfLiteIntArrayFree(nodes_to_delegate); + return nullptr; // Hard error. + } + const size_t tensor_elements = output_tensor.bytes / sizeof(float); + + // Align to XNN_EXTRA_BYTES bytes + while (static_unpacked_data_.size() % XNN_EXTRA_BYTES != 0) { + static_unpacked_data_.push_back(0); + } + const size_t tensor_offset = static_unpacked_data_.size(); + static_unpacked_data_.resize(tensor_offset + context->tensors[t].bytes); + + float* unpacked_data = + reinterpret_cast(static_unpacked_data_.data() + tensor_offset); + switch (input_tensor.type) { + case kTfLiteFloat16: { + const uint16_t* packed_data = + static_cast(input_tensor.data.data); + for (size_t i = 0; i < tensor_elements; i++) { + unpacked_data[i] = fp16_ieee_to_fp32_value(packed_data[i]); + } + break; + } + default: + TF_LITE_KERNEL_LOG(context, + "unexpected datatype (%s) in tensor %d in node %d", + TfLiteTypeGetName(output_tensor.type), + node->outputs->data[0], producer_index); + TfLiteIntArrayFree(nodes_to_delegate); + return nullptr; // Hard error. + } + + static_unpacked_data_map_[t] = tensor_offset; + } + + // Add nodes that unpack static data consumed by delegated nodes. + // Note: this is done purely to avoid the overhead of running these nodes + // again in TFLite interpreter which would allocate memory for their outputs. + // We mark them as delegated, but the delegate would simply ignore these nodes + // as the static weights are already unpacked. + for (int node_index : static_unpack_nodes_) { + nodes_to_delegate->data[nodes_to_delegate->size++] = node_index; + } + std::sort(&nodes_to_delegate->data[0], + &nodes_to_delegate->data[nodes_to_delegate->size]); + #ifdef XNNPACK_DELEGATE_TEST_MODE // In the test mode build (used by unit tests), XNNPACK delegate claims to // support all operators in the execution plan to disable fallback to the @@ -1908,24 +2116,22 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { // not supported by the delegate, they will cause a failure in // ::tflite::Interpreter::ModifyGraphWithDelegate, to be caught in the unit // tests. - nodes_to_replace->size = execution_plan->size; + nodes_to_delegate->size = execution_plan->size; std::copy(&execution_plan->data[0], &execution_plan->data[execution_plan->size], - &nodes_to_replace->data[0]); + &nodes_to_delegate->data[0]); #endif - return nodes_to_replace; + return nodes_to_delegate; } void* SubgraphInit(TfLiteContext* context, const char* buffer, size_t length) { const TfLiteDelegateParams* params = reinterpret_cast(buffer); - pthreadpool_t threadpool = - static_cast<::tflite::xnnpack::Delegate*>(params->delegate->data_) - ->threadpool(); - - return static_cast(Subgraph::Create(context, params, threadpool)); + return static_cast(Subgraph::Create( + context, params, + static_cast<::tflite::xnnpack::Delegate*>(params->delegate->data_))); } TfLiteStatus SubgraphPrepare(TfLiteContext* context, TfLiteNode* node) { @@ -1962,7 +2168,9 @@ const TfLiteRegistration kSubgraphRegistration = { }; TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { - TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); + TfLiteIntArray* ops_to_replace = + static_cast<::tflite::xnnpack::Delegate*>(delegate->data_) + ->PrepareOpsToDelegate(context); const TfLiteStatus status = context->ReplaceNodeSubsetsWithDelegateKernels( context, kSubgraphRegistration, ops_to_replace, delegate); TfLiteIntArrayFree(ops_to_replace);