Support int8 tensors with the Hexagon delegate. Validates using Softmax & AvgPool operators.
PiperOrigin-RevId: 302673336 Change-Id: I5d373b7d974c030b19203dafc2ab809b032fe327
This commit is contained in:
parent
99e10c109f
commit
c3e25fd2b3
@ -57,8 +57,11 @@ cc_library(
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/delegates:utils",
|
||||
"//tensorflow/lite/experimental/delegates/hexagon/builders:op_builder",
|
||||
"//tensorflow/lite/experimental/delegates/hexagon/hexagon_nn:hexagon_nn_header",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels/internal:optimized_base",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@hexagon_nn//:hexagon_nn_ops",
|
||||
],
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_OP_BUILDER_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_OP_BUILDER_H_
|
||||
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
@ -123,6 +124,20 @@ class OpBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus ComputeMinAndMaxQuantValues(const TfLiteTensor& tensor,
|
||||
float* min, float* max) {
|
||||
if (tensor.type == kTfLiteUInt8) {
|
||||
return ComputeMinAndMaxQuantValues(tensor, min, max,
|
||||
std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
} else if (tensor.type == kTfLiteInt8) {
|
||||
return ComputeMinAndMaxQuantValues(tensor, min, max,
|
||||
std::numeric_limits<int8_t>::min(),
|
||||
std::numeric_limits<int8_t>::max());
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TfLiteStatus ComputeMinAndMaxQuantValues(const TfLiteTensor& tensor,
|
||||
float* min, float* max, T min_value,
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
@ -34,9 +35,8 @@ TfLiteStatus Pool2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
int tensor_id = inputs->data[0];
|
||||
const auto& data_tensor = context->tensors[tensor_id];
|
||||
AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
|
||||
TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
|
||||
data_tensor, &data_min_, &data_max_, std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max()));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ComputeMinAndMaxQuantValues(data_tensor, &data_min_, &data_max_));
|
||||
auto* data_min_const = graph_builder_->AddConstNodeWithData(
|
||||
quant_bound_shape.data(), (char*)&data_min_, sizeof(data_min_));
|
||||
auto* data_max_const = graph_builder_->AddConstNodeWithData(
|
||||
@ -89,9 +89,7 @@ TfLiteStatus Pool2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
|
||||
// Output min/max for requantization.
|
||||
TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
|
||||
context->tensors[outputs->data[0]], &output_min_, &output_max_,
|
||||
std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max()));
|
||||
context->tensors[outputs->data[0]], &output_min_, &output_max_));
|
||||
auto* output_min_const = graph_builder_->AddConstNodeWithData(
|
||||
quant_bound_shape.data(), (char*)&output_min_, sizeof(output_min_));
|
||||
auto* output_max_const = graph_builder_->AddConstNodeWithData(
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
@ -36,9 +37,7 @@ TfLiteStatus SoftmaxOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
const auto& input_tensor = context->tensors[tensor_id];
|
||||
AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_,
|
||||
std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max()));
|
||||
ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_));
|
||||
auto* input_min_const = graph_builder_->AddConstNodeWithData(
|
||||
quant_bound_shape.data(), (char*)&input_min_, sizeof(input_min_));
|
||||
auto* input_max_const = graph_builder_->AddConstNodeWithData(
|
||||
|
@ -34,6 +34,7 @@ hexagon_op_tests(
|
||||
"pool_test.cc",
|
||||
"reduce_test.cc",
|
||||
"resize_bilinear_test.cc",
|
||||
"softmax_test.cc",
|
||||
"space_to_depth_test.cc",
|
||||
"split_test.cc",
|
||||
"transpose_conv_test.cc",
|
||||
|
@ -34,13 +34,15 @@ class AveragePoolingOpModel : public SingleOpModelWithHexagon {
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
QuantizeAndPopulate<uint8_t>(input_, data);
|
||||
template <typename T>
|
||||
void SetInput(const std::vector<float>& data) {
|
||||
QuantizeAndPopulate<T>(input_, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<float> GetDequantizedOutput() {
|
||||
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
|
||||
GetScale(output_), GetZeroPoint(output_));
|
||||
return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
|
||||
GetZeroPoint(output_));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -53,7 +55,7 @@ TEST(QuantizedPoolingOpTest, AveragePool) {
|
||||
/*input=*/{TensorType_UINT8, {1, 16, 8, 1}, 0, 10},
|
||||
/*filter_width=*/8, /*filter_height=*/8,
|
||||
/*output=*/{TensorType_UINT8, {}, 0, 10});
|
||||
m.SetInput({
|
||||
m.SetInput<uint8_t>({
|
||||
0, 6, 2, 4, 0, 6, 2, 4, //
|
||||
3, 2, 10, 7, 3, 2, 10, 7, //
|
||||
0, 6, 2, 4, 0, 6, 2, 4, //
|
||||
@ -73,9 +75,42 @@ TEST(QuantizedPoolingOpTest, AveragePool) {
|
||||
});
|
||||
m.ApplyDelegateAndInvoke();
|
||||
|
||||
EXPECT_THAT(m.GetDequantizedOutput(),
|
||||
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{4.58824, 4.58824, 4.90196, 4.58824, 4.27451})));
|
||||
}
|
||||
|
||||
TEST(QuantizedPoolingOpTest, AveragePool_Int8) {
|
||||
AveragePoolingOpModel m(
|
||||
/*input=*/{TensorType_INT8, {1, 16, 8, 1}, 0, 10},
|
||||
/*filter_width=*/8, /*filter_height=*/8,
|
||||
/*output=*/{TensorType_INT8, {}, 0, 10});
|
||||
m.SetInput<int8_t>({
|
||||
0, 6, 2, 4, 0, 6, 2, 4, //
|
||||
3, 2, 10, 7, 3, 2, 10, 7, //
|
||||
0, 6, 2, 4, 0, 6, 2, 4, //
|
||||
3, 2, 10, 7, 3, 2, 10, 7, //
|
||||
0, 6, 2, 4, 0, 6, 2, 4, //
|
||||
3, 2, 10, 7, 3, 2, 10, 7, //
|
||||
3, 2, 10, 7, 3, 2, 10, 7, //
|
||||
3, 2, 10, 7, 3, 2, 10, 7, //
|
||||
0, 6, 2, 4, 0, 6, 2, 4, //
|
||||
3, 2, 10, 7, 3, 2, 10, 7, //
|
||||
3, 2, 10, 7, 3, 2, 10, 7, //
|
||||
3, 2, 10, 7, 3, 2, 10, 7, //
|
||||
0, 6, 2, 4, 0, 6, 2, 4, //
|
||||
0, 6, 2, 4, 0, 6, 2, 4, //
|
||||
0, 6, 2, 4, 0, 6, 2, 4, //
|
||||
3, 2, 10, 7, 3, 2, 10, 7, //
|
||||
});
|
||||
|
||||
// Reference data.
|
||||
m.Invoke();
|
||||
auto reference_output = m.GetDequantizedOutput<int8_t>();
|
||||
|
||||
m.ApplyDelegateAndInvoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(reference_output)));
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -0,0 +1,128 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h"
|
||||
|
||||
namespace tflite {
|
||||
using testing::ElementsAreArray;
|
||||
|
||||
const float kTolerance = 2 * (1. / 256);
|
||||
|
||||
class SoftmaxOpModel : public SingleOpModelWithHexagon {
|
||||
public:
|
||||
SoftmaxOpModel(float softmax_beta, const TensorData& input) {
|
||||
input_ = AddInput(input);
|
||||
if (input.type == TensorType_UINT8) {
|
||||
output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
|
||||
} else if (input.type == TensorType_INT8) {
|
||||
output_ = AddOutput({TensorType_INT8, {}, 0, 0, 1. / 256, -128});
|
||||
}
|
||||
SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
|
||||
CreateSoftmaxOptions(builder_, softmax_beta).Union());
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SetInput(const std::vector<float>& data) {
|
||||
QuantizeAndPopulate<T>(input_, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<float> GetDequantizedOutput() {
|
||||
return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
|
||||
GetZeroPoint(output_));
|
||||
}
|
||||
|
||||
protected:
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(SoftmaxOpModel, Softmax4DUint8) {
|
||||
SoftmaxOpModel m(0.1,
|
||||
/*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10});
|
||||
m.SetInput<uint8_t>({
|
||||
0, -6, 2, 4, // depth = 0
|
||||
3, -2, 10, 1, // depth = 1
|
||||
});
|
||||
m.ApplyDelegateAndInvoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
.23463, .12877, .28658, .35003, //
|
||||
.22528, .13664, .45365, .18443, //
|
||||
},
|
||||
kTolerance)));
|
||||
}
|
||||
|
||||
TEST(SoftmaxOpModel, Softmax4DUint8_MultipleBatch) {
|
||||
SoftmaxOpModel m(0.1,
|
||||
/*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10});
|
||||
m.SetInput<uint8_t>({
|
||||
0, -6, //
|
||||
2, 4, //
|
||||
3, -2, //
|
||||
10, 1, //
|
||||
});
|
||||
m.ApplyDelegateAndInvoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0.645656, 0.354344, //
|
||||
0.450166, 0.549834, //
|
||||
0.622459, 0.377541, //
|
||||
0.710949, 0.28905, //
|
||||
},
|
||||
kTolerance)));
|
||||
}
|
||||
|
||||
TEST(SoftmaxOpModel, Softmax4DInt8) {
|
||||
SoftmaxOpModel m(0.1,
|
||||
/*input=*/{TensorType_INT8, {1, 2, 1, 4}, -10, 10});
|
||||
m.SetInput<int8_t>({
|
||||
0, -6, 2, 4, // depth = 0
|
||||
3, -2, 10, 1, // depth = 1
|
||||
});
|
||||
m.ApplyDelegateAndInvoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
.23463, .12877, .28658, .35003, //
|
||||
.22528, .13664, .45365, .18443, //
|
||||
},
|
||||
kTolerance)));
|
||||
}
|
||||
|
||||
TEST(SoftmaxOpModel, Softmax4DInt8_MultipleBatch) {
|
||||
SoftmaxOpModel m(0.1,
|
||||
/*input=*/{TensorType_INT8, {4, 1, 1, 2}, -10, 10});
|
||||
m.SetInput<int8_t>({
|
||||
0, -6, //
|
||||
2, 4, //
|
||||
3, -2, //
|
||||
10, 1, //
|
||||
});
|
||||
m.ApplyDelegateAndInvoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0.645656, 0.354344, //
|
||||
0.450166, 0.549834, //
|
||||
0.622459, 0.377541, //
|
||||
0.710949, 0.28905, //
|
||||
},
|
||||
kTolerance)));
|
||||
}
|
||||
|
||||
} // namespace tflite
|
@ -20,12 +20,21 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/delegates/utils.h"
|
||||
#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.h"
|
||||
#include "tensorflow/lite/experimental/delegates/hexagon/utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
|
||||
// Used to convert int8 <-> uint8.
|
||||
constexpr int kSameScaleEffectiveMultiplier = 1 << 30;
|
||||
constexpr int kSameScaleEffectiveShift = 1;
|
||||
constexpr int kInt8Uint8ZeroPointDiff = 128;
|
||||
|
||||
inline const char* StateToString(
|
||||
HexagonDelegateKernel::HexagonKernelState state) {
|
||||
switch (state) {
|
||||
@ -126,13 +135,34 @@ TfLiteStatus HexagonDelegateKernel::Invoke(TfLiteContext* context,
|
||||
}
|
||||
// Allocate inputs.
|
||||
std::vector<hexagon_nn_tensordef> input_tensors;
|
||||
for (auto tensor_index : TfLiteIntArrayView(node->inputs)) {
|
||||
for (int input_idx = 0; input_idx < node->inputs->size; ++input_idx) {
|
||||
const auto tensor_index = node->inputs->data[input_idx];
|
||||
if (tensor_index == kTfLiteOptionalTensor) {
|
||||
continue;
|
||||
}
|
||||
TfLiteTensor* tensor = &context->tensors[tensor_index];
|
||||
// Const tensors should be added as const nodes during graph construction.
|
||||
// Const tensors should have been handled at delegation time..
|
||||
if (tensor->allocation_type != kTfLiteMmapRo) {
|
||||
char* data_ptr = tensor->data.raw;
|
||||
if (tensor->type == kTfLiteInt8) {
|
||||
// If input is int8, we first re-quantize it to uint8 for Hexagon.
|
||||
if (int8_to_uint8_tensors_.size() <= input_idx ||
|
||||
!int8_to_uint8_tensors_[input_idx]) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Found int8 input %d with no uint8 version",
|
||||
tensor_index);
|
||||
return kTfLiteError;
|
||||
}
|
||||
TfLiteTensor* uint8_tensor = int8_to_uint8_tensors_[input_idx];
|
||||
optimized_ops::Requantize(
|
||||
tensor->data.int8, NumElements(tensor),
|
||||
kSameScaleEffectiveMultiplier, kSameScaleEffectiveShift,
|
||||
tensor->params.zero_point,
|
||||
tensor->params.zero_point + kInt8Uint8ZeroPointDiff,
|
||||
uint8_tensor->data.uint8);
|
||||
data_ptr = uint8_tensor->data.raw;
|
||||
}
|
||||
|
||||
if (tensor->dims->size > 4) {
|
||||
ReportError(context, HexagonKernelState::INPUT_RANK_NOT_SUPPORTED,
|
||||
"Only up to 4d tensor are supported.");
|
||||
@ -140,7 +170,7 @@ TfLiteStatus HexagonDelegateKernel::Invoke(TfLiteContext* context,
|
||||
}
|
||||
input_tensors.emplace_back();
|
||||
auto& input_tensor = input_tensors.back();
|
||||
input_tensor.data = reinterpret_cast<unsigned char*>(tensor->data.raw);
|
||||
input_tensor.data = reinterpret_cast<unsigned char*>(data_ptr);
|
||||
input_tensor.dataLen = tensor->bytes;
|
||||
input_tensor.data_valid_len = tensor->bytes;
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
@ -182,6 +212,20 @@ TfLiteStatus HexagonDelegateKernel::Invoke(TfLiteContext* context,
|
||||
"Failed to execute graph.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Requantize uint8->int8 for eligible output tensors.
|
||||
for (auto tensor_index : TfLiteIntArrayView(node->outputs)) {
|
||||
TfLiteTensor* tensor = &context->tensors[tensor_index];
|
||||
if (tensor->allocation_type != kTfLiteMmapRo &&
|
||||
tensor->type == kTfLiteInt8) {
|
||||
optimized_ops::Requantize(
|
||||
tensor->data.uint8, NumElements(tensor),
|
||||
kSameScaleEffectiveMultiplier, kSameScaleEffectiveShift,
|
||||
tensor->params.zero_point + kInt8Uint8ZeroPointDiff,
|
||||
tensor->params.zero_point, tensor->data.int8);
|
||||
}
|
||||
}
|
||||
|
||||
if (params_.print_graph_profile) {
|
||||
PrintPerformanceData(reinterpret_cast<Profiler*>(context->profiler));
|
||||
}
|
||||
@ -222,6 +266,35 @@ TfLiteStatus HexagonDelegateKernel::Prepare(TfLiteContext* context,
|
||||
}
|
||||
}
|
||||
|
||||
// Assign temporary tensors for any input int8 tensors.
|
||||
std::vector<int> temporary_tensors;
|
||||
int8_to_uint8_tensors_.clear();
|
||||
int8_to_uint8_tensors_.reserve(node->inputs->size);
|
||||
for (auto tensor_index : TfLiteIntArrayView(node->inputs)) {
|
||||
TfLiteTensor* tensor = &context->tensors[tensor_index];
|
||||
// For every int8 tensor, we need to create a new temporary uint8 tensor.
|
||||
if (tensor->allocation_type != kTfLiteMmapRo &&
|
||||
tensor->type == kTfLiteInt8) {
|
||||
TfLiteTensor* uint8_tensor;
|
||||
int uint8_tensor_index;
|
||||
TF_LITE_ENSURE_STATUS(delegates::CreateNewTensorWithDifferentType(
|
||||
context, tensor_index, kTfLiteUInt8, &uint8_tensor,
|
||||
&uint8_tensor_index));
|
||||
int8_to_uint8_tensors_.push_back(uint8_tensor);
|
||||
temporary_tensors.push_back(uint8_tensor_index);
|
||||
} else {
|
||||
int8_to_uint8_tensors_.push_back(nullptr);
|
||||
}
|
||||
}
|
||||
if (!temporary_tensors.empty()) {
|
||||
// This ensures the runtime allocates memory for every required temporary
|
||||
// tensor.
|
||||
node->temporaries = TfLiteIntArrayCreate(temporary_tensors.size());
|
||||
for (int i = 0; i < temporary_tensors.size(); ++i) {
|
||||
node->temporaries->data[i] = temporary_tensors[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (params_.print_graph_debug) {
|
||||
PrintDebuggingGraph();
|
||||
}
|
||||
|
@ -95,6 +95,12 @@ class HexagonDelegateKernel {
|
||||
// Indices of nodes in the delegated TfLite subgraph.
|
||||
std::vector<int> nodes_;
|
||||
::TfLiteHexagonDelegateOptions params_;
|
||||
|
||||
// Used to support int8 TFLite *input* tensors.
|
||||
// This vector, for every node-input, contains:
|
||||
// 1. Pointer to Uint8 version if tensor is non-constant & type is Int8.
|
||||
// 2. nullptr otherwise.
|
||||
std::vector<TfLiteTensor*> int8_to_uint8_tensors_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -59,6 +59,19 @@ TfLiteStatus Get4DShape(unsigned int* batch_size, unsigned int* height_size,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We maintain an op-version whitelist here to ensure we don't accept unintended
|
||||
// ops.
|
||||
bool CheckOpVersion(const TfLiteRegistration* registration) {
|
||||
switch (registration->builtin_code) {
|
||||
case kTfLiteBuiltinAveragePool2d:
|
||||
case kTfLiteBuiltinDepthwiseConv2d:
|
||||
case kTfLiteBuiltinSoftmax:
|
||||
return registration->version <= 2;
|
||||
default:
|
||||
return registration->version == 1;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
|
||||
const TfLiteNode* node, TfLiteContext* context) {
|
||||
// Ensure all inputs & outputs have dim <= 4.
|
||||
@ -74,15 +87,7 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
|
||||
if (tensor.dims->size > 4) return false;
|
||||
}
|
||||
|
||||
// Most hexagon kernels are not compatible with op versions > 1.
|
||||
// We maintain a 'whitelist' here to ensure we don't accept unintended nodes.
|
||||
if (registration->version > 1) {
|
||||
if (registration->builtin_code == kTfLiteBuiltinDepthwiseConv2d &&
|
||||
registration->version == 2) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!CheckOpVersion(registration)) return false;
|
||||
|
||||
switch (registration->builtin_code) {
|
||||
case kTfLiteBuiltinAdd: {
|
||||
@ -154,8 +159,9 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
|
||||
return pool_params->activation == kTfLiteActNone;
|
||||
}
|
||||
case kTfLiteBuiltinAveragePool2d: {
|
||||
if (!InputsWithCorrectTypes(node, context, {kTfLiteUInt8})) return false;
|
||||
// AvgPool works fine for filter dim <=7.
|
||||
if (!InputsWithCorrectTypes(node, context, {kTfLiteUInt8}) &&
|
||||
!InputsWithCorrectTypes(node, context, {kTfLiteInt8}))
|
||||
return false;
|
||||
const TfLitePoolParams* pool_params =
|
||||
reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
|
||||
return (node->inputs->size == 1 &&
|
||||
@ -220,7 +226,10 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
case kTfLiteBuiltinSoftmax:
|
||||
case kTfLiteBuiltinSoftmax: {
|
||||
return (InputsWithCorrectTypes(node, context, {kTfLiteUInt8}) ||
|
||||
InputsWithCorrectTypes(node, context, {kTfLiteInt8}));
|
||||
}
|
||||
case kTfLiteBuiltinRelu:
|
||||
case kTfLiteBuiltinRelu6:
|
||||
case kTfLiteBuiltinTanh:
|
||||
|
Loading…
x
Reference in New Issue
Block a user