From c3e25fd2b3b60ead7ce4695ee483bc41fad9fde9 Mon Sep 17 00:00:00 2001 From: Sachin Joglekar Date: Tue, 24 Mar 2020 08:46:10 -0700 Subject: [PATCH] Support int8 tensors with the Hexagon delegate. Validates using Softmax & AvgPool operators. PiperOrigin-RevId: 302673336 Change-Id: I5d373b7d974c030b19203dafc2ab809b032fe327 --- .../lite/experimental/delegates/hexagon/BUILD | 3 + .../delegates/hexagon/builders/op_builder.h | 15 ++ .../hexagon/builders/pool_2d_builder.cc | 10 +- .../hexagon/builders/softmax_builder.cc | 5 +- .../delegates/hexagon/builders/tests/BUILD | 1 + .../hexagon/builders/tests/pool_test.cc | 47 ++++++- .../hexagon/builders/tests/softmax_test.cc | 128 ++++++++++++++++++ .../hexagon/hexagon_delegate_kernel.cc | 79 ++++++++++- .../hexagon/hexagon_delegate_kernel.h | 6 + .../experimental/delegates/hexagon/utils.cc | 33 +++-- 10 files changed, 297 insertions(+), 30 deletions(-) create mode 100644 tensorflow/lite/experimental/delegates/hexagon/builders/tests/softmax_test.cc diff --git a/tensorflow/lite/experimental/delegates/hexagon/BUILD b/tensorflow/lite/experimental/delegates/hexagon/BUILD index 86c0f8cd39b..1028fd53582 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/BUILD +++ b/tensorflow/lite/experimental/delegates/hexagon/BUILD @@ -57,8 +57,11 @@ cc_library( "//tensorflow/lite:kernel_api", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/delegates:utils", "//tensorflow/lite/experimental/delegates/hexagon/builders:op_builder", "//tensorflow/lite/experimental/delegates/hexagon/hexagon_nn:hexagon_nn_header", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels/internal:optimized_base", "//tensorflow/lite/schema:schema_fbs", "@hexagon_nn//:hexagon_nn_ops", ], diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h index 30a92f1bc19..7c39d013d59 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_OP_BUILDER_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_OP_BUILDER_H_ +#include #include #include #include @@ -123,6 +124,20 @@ class OpBuilder { } } + TfLiteStatus ComputeMinAndMaxQuantValues(const TfLiteTensor& tensor, + float* min, float* max) { + if (tensor.type == kTfLiteUInt8) { + return ComputeMinAndMaxQuantValues(tensor, min, max, + std::numeric_limits::min(), + std::numeric_limits::max()); + } else if (tensor.type == kTfLiteInt8) { + return ComputeMinAndMaxQuantValues(tensor, min, max, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + return kTfLiteError; + } + template TfLiteStatus ComputeMinAndMaxQuantValues(const TfLiteTensor& tensor, float* min, float* max, T min_value, diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/pool_2d_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/pool_2d_builder.cc index 96ce0bbb900..d7ab6614714 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/pool_2d_builder.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/pool_2d_builder.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -34,9 +35,8 @@ TfLiteStatus Pool2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, int tensor_id = inputs->data[0]; const auto& data_tensor = context->tensors[tensor_id]; AddInput(graph_builder_->GetHexagonTensorId(tensor_id)); - TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues( - data_tensor, &data_min_, &data_max_, std::numeric_limits::min(), - std::numeric_limits::max())); + TF_LITE_ENSURE_STATUS( + ComputeMinAndMaxQuantValues(data_tensor, &data_min_, &data_max_)); auto* data_min_const = graph_builder_->AddConstNodeWithData( quant_bound_shape.data(), (char*)&data_min_, sizeof(data_min_)); auto* data_max_const = graph_builder_->AddConstNodeWithData( @@ -89,9 +89,7 @@ TfLiteStatus Pool2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, // Output min/max for requantization. TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues( - context->tensors[outputs->data[0]], &output_min_, &output_max_, - std::numeric_limits::min(), - std::numeric_limits::max())); + context->tensors[outputs->data[0]], &output_min_, &output_max_)); auto* output_min_const = graph_builder_->AddConstNodeWithData( quant_bound_shape.data(), (char*)&output_min_, sizeof(output_min_)); auto* output_max_const = graph_builder_->AddConstNodeWithData( diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/softmax_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/softmax_builder.cc index d3c7f45199e..f02fd05da77 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/softmax_builder.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/softmax_builder.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -36,9 +37,7 @@ TfLiteStatus SoftmaxOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, const auto& input_tensor = context->tensors[tensor_id]; AddInput(graph_builder_->GetHexagonTensorId(tensor_id)); TF_LITE_ENSURE_STATUS( - ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_, - std::numeric_limits::min(), - std::numeric_limits::max())); + ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_)); auto* input_min_const = graph_builder_->AddConstNodeWithData( quant_bound_shape.data(), (char*)&input_min_, sizeof(input_min_)); auto* input_max_const = graph_builder_->AddConstNodeWithData( diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD index 7270cfd06b0..b8c207ad8b6 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD @@ -34,6 +34,7 @@ hexagon_op_tests( "pool_test.cc", "reduce_test.cc", "resize_bilinear_test.cc", + "softmax_test.cc", "space_to_depth_test.cc", "split_test.cc", "transpose_conv_test.cc", diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/pool_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/pool_test.cc index 60dac2c6304..6b0edd0f12d 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/pool_test.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/pool_test.cc @@ -34,13 +34,15 @@ class AveragePoolingOpModel : public SingleOpModelWithHexagon { BuildInterpreter({GetShape(input_)}); } - void SetInput(std::initializer_list data) { - QuantizeAndPopulate(input_, data); + template + void SetInput(const std::vector& data) { + QuantizeAndPopulate(input_, data); } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } private: @@ -53,7 +55,7 @@ TEST(QuantizedPoolingOpTest, AveragePool) { /*input=*/{TensorType_UINT8, {1, 16, 8, 1}, 0, 10}, /*filter_width=*/8, /*filter_height=*/8, /*output=*/{TensorType_UINT8, {}, 0, 10}); - m.SetInput({ + m.SetInput({ 0, 6, 2, 4, 0, 6, 2, 4, // 3, 2, 10, 7, 3, 2, 10, 7, // 0, 6, 2, 4, 0, 6, 2, 4, // @@ -73,9 +75,42 @@ TEST(QuantizedPoolingOpTest, AveragePool) { }); m.ApplyDelegateAndInvoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( {4.58824, 4.58824, 4.90196, 4.58824, 4.27451}))); } +TEST(QuantizedPoolingOpTest, AveragePool_Int8) { + AveragePoolingOpModel m( + /*input=*/{TensorType_INT8, {1, 16, 8, 1}, 0, 10}, + /*filter_width=*/8, /*filter_height=*/8, + /*output=*/{TensorType_INT8, {}, 0, 10}); + m.SetInput({ + 0, 6, 2, 4, 0, 6, 2, 4, // + 3, 2, 10, 7, 3, 2, 10, 7, // + 0, 6, 2, 4, 0, 6, 2, 4, // + 3, 2, 10, 7, 3, 2, 10, 7, // + 0, 6, 2, 4, 0, 6, 2, 4, // + 3, 2, 10, 7, 3, 2, 10, 7, // + 3, 2, 10, 7, 3, 2, 10, 7, // + 3, 2, 10, 7, 3, 2, 10, 7, // + 0, 6, 2, 4, 0, 6, 2, 4, // + 3, 2, 10, 7, 3, 2, 10, 7, // + 3, 2, 10, 7, 3, 2, 10, 7, // + 3, 2, 10, 7, 3, 2, 10, 7, // + 0, 6, 2, 4, 0, 6, 2, 4, // + 0, 6, 2, 4, 0, 6, 2, 4, // + 0, 6, 2, 4, 0, 6, 2, 4, // + 3, 2, 10, 7, 3, 2, 10, 7, // + }); + + // Reference data. + m.Invoke(); + auto reference_output = m.GetDequantizedOutput(); + + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(reference_output))); +} + } // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/softmax_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/softmax_test.cc new file mode 100644 index 00000000000..094c5478997 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/softmax_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h" + +namespace tflite { +using testing::ElementsAreArray; + +const float kTolerance = 2 * (1. / 256); + +class SoftmaxOpModel : public SingleOpModelWithHexagon { + public: + SoftmaxOpModel(float softmax_beta, const TensorData& input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else if (input.type == TensorType_INT8) { + output_ = AddOutput({TensorType_INT8, {}, 0, 0, 1. / 256, -128}); + } + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, softmax_beta).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(const std::vector& data) { + QuantizeAndPopulate(input_, data); + } + + template + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); + } + + protected: + int input_; + int output_; +}; + +TEST(SoftmaxOpModel, Softmax4DUint8) { + SoftmaxOpModel m(0.1, + /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kTolerance))); +} + +TEST(SoftmaxOpModel, Softmax4DUint8_MultipleBatch) { + SoftmaxOpModel m(0.1, + /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10}); + m.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kTolerance))); +} + +TEST(SoftmaxOpModel, Softmax4DInt8) { + SoftmaxOpModel m(0.1, + /*input=*/{TensorType_INT8, {1, 2, 1, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kTolerance))); +} + +TEST(SoftmaxOpModel, Softmax4DInt8_MultipleBatch) { + SoftmaxOpModel m(0.1, + /*input=*/{TensorType_INT8, {4, 1, 1, 2}, -10, 10}); + m.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kTolerance))); +} + +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate_kernel.cc b/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate_kernel.cc index 3ec765ae888..529b8c59dd2 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate_kernel.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate_kernel.cc @@ -20,12 +20,21 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/delegates/utils.h" #include "tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.h" #include "tensorflow/lite/experimental/delegates/hexagon/utils.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace { + +// Used to convert int8 <-> uint8. +constexpr int kSameScaleEffectiveMultiplier = 1 << 30; +constexpr int kSameScaleEffectiveShift = 1; +constexpr int kInt8Uint8ZeroPointDiff = 128; + inline const char* StateToString( HexagonDelegateKernel::HexagonKernelState state) { switch (state) { @@ -126,13 +135,34 @@ TfLiteStatus HexagonDelegateKernel::Invoke(TfLiteContext* context, } // Allocate inputs. std::vector input_tensors; - for (auto tensor_index : TfLiteIntArrayView(node->inputs)) { + for (int input_idx = 0; input_idx < node->inputs->size; ++input_idx) { + const auto tensor_index = node->inputs->data[input_idx]; if (tensor_index == kTfLiteOptionalTensor) { continue; } TfLiteTensor* tensor = &context->tensors[tensor_index]; - // Const tensors should be added as const nodes during graph construction. + // Const tensors should have been handled at delegation time.. if (tensor->allocation_type != kTfLiteMmapRo) { + char* data_ptr = tensor->data.raw; + if (tensor->type == kTfLiteInt8) { + // If input is int8, we first re-quantize it to uint8 for Hexagon. + if (int8_to_uint8_tensors_.size() <= input_idx || + !int8_to_uint8_tensors_[input_idx]) { + TF_LITE_KERNEL_LOG(context, + "Found int8 input %d with no uint8 version", + tensor_index); + return kTfLiteError; + } + TfLiteTensor* uint8_tensor = int8_to_uint8_tensors_[input_idx]; + optimized_ops::Requantize( + tensor->data.int8, NumElements(tensor), + kSameScaleEffectiveMultiplier, kSameScaleEffectiveShift, + tensor->params.zero_point, + tensor->params.zero_point + kInt8Uint8ZeroPointDiff, + uint8_tensor->data.uint8); + data_ptr = uint8_tensor->data.raw; + } + if (tensor->dims->size > 4) { ReportError(context, HexagonKernelState::INPUT_RANK_NOT_SUPPORTED, "Only up to 4d tensor are supported."); @@ -140,7 +170,7 @@ TfLiteStatus HexagonDelegateKernel::Invoke(TfLiteContext* context, } input_tensors.emplace_back(); auto& input_tensor = input_tensors.back(); - input_tensor.data = reinterpret_cast(tensor->data.raw); + input_tensor.data = reinterpret_cast(data_ptr); input_tensor.dataLen = tensor->bytes; input_tensor.data_valid_len = tensor->bytes; TF_LITE_ENSURE_STATUS( @@ -182,6 +212,20 @@ TfLiteStatus HexagonDelegateKernel::Invoke(TfLiteContext* context, "Failed to execute graph."); return kTfLiteError; } + + // Requantize uint8->int8 for eligible output tensors. + for (auto tensor_index : TfLiteIntArrayView(node->outputs)) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + if (tensor->allocation_type != kTfLiteMmapRo && + tensor->type == kTfLiteInt8) { + optimized_ops::Requantize( + tensor->data.uint8, NumElements(tensor), + kSameScaleEffectiveMultiplier, kSameScaleEffectiveShift, + tensor->params.zero_point + kInt8Uint8ZeroPointDiff, + tensor->params.zero_point, tensor->data.int8); + } + } + if (params_.print_graph_profile) { PrintPerformanceData(reinterpret_cast(context->profiler)); } @@ -222,6 +266,35 @@ TfLiteStatus HexagonDelegateKernel::Prepare(TfLiteContext* context, } } + // Assign temporary tensors for any input int8 tensors. + std::vector temporary_tensors; + int8_to_uint8_tensors_.clear(); + int8_to_uint8_tensors_.reserve(node->inputs->size); + for (auto tensor_index : TfLiteIntArrayView(node->inputs)) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + // For every int8 tensor, we need to create a new temporary uint8 tensor. + if (tensor->allocation_type != kTfLiteMmapRo && + tensor->type == kTfLiteInt8) { + TfLiteTensor* uint8_tensor; + int uint8_tensor_index; + TF_LITE_ENSURE_STATUS(delegates::CreateNewTensorWithDifferentType( + context, tensor_index, kTfLiteUInt8, &uint8_tensor, + &uint8_tensor_index)); + int8_to_uint8_tensors_.push_back(uint8_tensor); + temporary_tensors.push_back(uint8_tensor_index); + } else { + int8_to_uint8_tensors_.push_back(nullptr); + } + } + if (!temporary_tensors.empty()) { + // This ensures the runtime allocates memory for every required temporary + // tensor. + node->temporaries = TfLiteIntArrayCreate(temporary_tensors.size()); + for (int i = 0; i < temporary_tensors.size(); ++i) { + node->temporaries->data[i] = temporary_tensors[i]; + } + } + if (params_.print_graph_debug) { PrintDebuggingGraph(); } diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate_kernel.h b/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate_kernel.h index 1f56bcee4ea..91e36303574 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate_kernel.h +++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate_kernel.h @@ -95,6 +95,12 @@ class HexagonDelegateKernel { // Indices of nodes in the delegated TfLite subgraph. std::vector nodes_; ::TfLiteHexagonDelegateOptions params_; + + // Used to support int8 TFLite *input* tensors. + // This vector, for every node-input, contains: + // 1. Pointer to Uint8 version if tensor is non-constant & type is Int8. + // 2. nullptr otherwise. + std::vector int8_to_uint8_tensors_; }; } // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/utils.cc b/tensorflow/lite/experimental/delegates/hexagon/utils.cc index feff2080eaa..f55890b1d7a 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/utils.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/utils.cc @@ -59,6 +59,19 @@ TfLiteStatus Get4DShape(unsigned int* batch_size, unsigned int* height_size, return kTfLiteOk; } +// We maintain an op-version whitelist here to ensure we don't accept unintended +// ops. +bool CheckOpVersion(const TfLiteRegistration* registration) { + switch (registration->builtin_code) { + case kTfLiteBuiltinAveragePool2d: + case kTfLiteBuiltinDepthwiseConv2d: + case kTfLiteBuiltinSoftmax: + return registration->version <= 2; + default: + return registration->version == 1; + } +} + bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, const TfLiteNode* node, TfLiteContext* context) { // Ensure all inputs & outputs have dim <= 4. @@ -74,15 +87,7 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, if (tensor.dims->size > 4) return false; } - // Most hexagon kernels are not compatible with op versions > 1. - // We maintain a 'whitelist' here to ensure we don't accept unintended nodes. - if (registration->version > 1) { - if (registration->builtin_code == kTfLiteBuiltinDepthwiseConv2d && - registration->version == 2) { - return true; - } - return false; - } + if (!CheckOpVersion(registration)) return false; switch (registration->builtin_code) { case kTfLiteBuiltinAdd: { @@ -154,8 +159,9 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, return pool_params->activation == kTfLiteActNone; } case kTfLiteBuiltinAveragePool2d: { - if (!InputsWithCorrectTypes(node, context, {kTfLiteUInt8})) return false; - // AvgPool works fine for filter dim <=7. + if (!InputsWithCorrectTypes(node, context, {kTfLiteUInt8}) && + !InputsWithCorrectTypes(node, context, {kTfLiteInt8})) + return false; const TfLitePoolParams* pool_params = reinterpret_cast(node->builtin_data); return (node->inputs->size == 1 && @@ -220,7 +226,10 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, return false; return true; } - case kTfLiteBuiltinSoftmax: + case kTfLiteBuiltinSoftmax: { + return (InputsWithCorrectTypes(node, context, {kTfLiteUInt8}) || + InputsWithCorrectTypes(node, context, {kTfLiteInt8})); + } case kTfLiteBuiltinRelu: case kTfLiteBuiltinRelu6: case kTfLiteBuiltinTanh: