Support ELU operator in XNNPACK delegate
PiperOrigin-RevId: 345181968 Change-Id: I4cd89a2871bb4087a9c1775f9d6e8f4c22aee4a3
This commit is contained in:
parent
70e4cec276
commit
3b8bbb8b11
@ -443,6 +443,21 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elu_test",
|
||||
srcs = ["elu_test.cc"],
|
||||
linkopts = select({
|
||||
"//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS,
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":test_main",
|
||||
":unary_elementwise_tester",
|
||||
":xnnpack_delegate_test_mode",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "fully_connected_test",
|
||||
srcs = ["fully_connected_test.cc"],
|
||||
|
@ -173,6 +173,10 @@ Below is the list of current operators and limitations:
|
||||
* Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported,
|
||||
but fused `TANH` and `SIGN_BIT` activations are not.
|
||||
|
||||
### `ELU`
|
||||
|
||||
* Inputs and outputs must be in 32-bit floating-point format.
|
||||
|
||||
### `FULLY_CONNECTED`
|
||||
|
||||
* Inputs and outputs must be in 32-bit floating-point format.
|
||||
|
120
tensorflow/lite/delegates/xnnpack/elu_test.cc
Normal file
120
tensorflow/lite/delegates/xnnpack/elu_test.cc
Normal file
@ -0,0 +1,120 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h"
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
||||
TEST(Elu, 4D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
UnaryElementwiseTester()
|
||||
.Shape({batch, height, width, channels})
|
||||
.Test(BuiltinOperator_ELU, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Elu, 3D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
UnaryElementwiseTester()
|
||||
.Shape({batch, width, channels})
|
||||
.Test(BuiltinOperator_ELU, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Elu, 2D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
UnaryElementwiseTester()
|
||||
.Shape({batch, channels})
|
||||
.Test(BuiltinOperator_ELU, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Elu, 1D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
|
||||
UnaryElementwiseTester().Shape({batch}).Test(BuiltinOperator_ELU,
|
||||
xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Elu, MultiThreading) {
|
||||
TfLiteXNNPackDelegateOptions delegate_options =
|
||||
TfLiteXNNPackDelegateOptionsDefault();
|
||||
delegate_options.num_threads = 2;
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
UnaryElementwiseTester()
|
||||
.Shape({batch, height, width, channels})
|
||||
.Test(BuiltinOperator_ELU, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
@ -858,6 +858,9 @@ class Subgraph {
|
||||
return VisitDivNode(subgraph, logging_context, node_index, node,
|
||||
context->tensors, div_params, xnnpack_tensors);
|
||||
}
|
||||
case kTfLiteBuiltinElu:
|
||||
return VisitEluNode(subgraph, logging_context, node_index, node,
|
||||
context->tensors, xnnpack_tensors);
|
||||
case kTfLiteBuiltinFullyConnected: {
|
||||
// FullyConnected with sparse weight has version 8, which cannot be
|
||||
// delegated to XNNPack.
|
||||
@ -1496,6 +1499,41 @@ class Subgraph {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus VisitEluNode(
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
|
||||
|
||||
const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
|
||||
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
|
||||
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||
|
||||
if (subgraph != nullptr) {
|
||||
const xnn_status status =
|
||||
xnn_define_elu(subgraph, /*alpha=*/1.0f,
|
||||
/*input_id=*/xnnpack_tensors[node->inputs->data[0]],
|
||||
/*output_id=*/xnnpack_tensors[node->outputs->data[0]],
|
||||
/*flags=*/0);
|
||||
if (status != xnn_status_success) {
|
||||
TF_LITE_KERNEL_LOG(logging_context, "failed to delegate ELU node #%d",
|
||||
node_index);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus VisitFullyConnectedNode(
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
|
@ -21,8 +21,8 @@ include(FetchContent)
|
||||
|
||||
OverridableFetchContent_Declare(
|
||||
xnnpack
|
||||
GIT_REPOSITORY https://github.com/google/xnnpack
|
||||
GIT_TAG 0a9c1200ccb49bba0170a46a62044b13714f39a3
|
||||
GIT_REPOSITORY https://github.com/google/XNNPACK
|
||||
GIT_TAG 1a803b6e9b48aad978b33d648b7db00ffc300f60
|
||||
GIT_PROGRESS TRUE
|
||||
PREFIX "${CMAKE_BINARY_DIR}"
|
||||
SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack"
|
||||
|
@ -135,11 +135,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
# and update the sha256 with the result.
|
||||
tf_http_archive(
|
||||
name = "XNNPACK",
|
||||
sha256 = "eb087959b684d2d3965f8914075032e3995e4726ac8ce9c09a367863ff184b99",
|
||||
strip_prefix = "XNNPACK-0a9c1200ccb49bba0170a46a62044b13714f39a3",
|
||||
sha256 = "b6badf61153584d28ee40c8f8c553b79a1ee4642008c28d953ffaea47e308511",
|
||||
strip_prefix = "XNNPACK-1a803b6e9b48aad978b33d648b7db00ffc300f60",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/0a9c1200ccb49bba0170a46a62044b13714f39a3.zip",
|
||||
"https://github.com/google/XNNPACK/archive/0a9c1200ccb49bba0170a46a62044b13714f39a3.zip",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/1a803b6e9b48aad978b33d648b7db00ffc300f60.zip",
|
||||
"https://github.com/google/XNNPACK/archive/1a803b6e9b48aad978b33d648b7db00ffc300f60.zip",
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user