Supports QUANTIZE & L2 Norm with int8 tensors

PiperOrigin-RevId: 307491431
Change-Id: I9377bdb81008b9c18f93050f3c4bcafd4226266d
This commit is contained in:
Sachin Joglekar 2020-04-20 15:43:04 -07:00 committed by TensorFlower Gardener
parent 5a674e06a9
commit 711627bacc
11 changed files with 448 additions and 6 deletions

View File

@ -82,6 +82,7 @@ are verified in `IsNodeSupportedByHexagon`:
* Mul (without any activation) (b/129276536) * Mul (without any activation) (b/129276536)
* Neg * Neg
* Pad: Only supports 0 padding (b/139277813) * Pad: Only supports 0 padding (b/139277813)
* Quantize (8-bit inputs & outputs only)
* Relu * Relu
* Relu6 * Relu6
* Reshape * Reshape

View File

@ -21,6 +21,7 @@ cc_library(
"op_builder.cc", "op_builder.cc",
"pad_builder.cc", "pad_builder.cc",
"pool_2d_builder.cc", "pool_2d_builder.cc",
"quantize_builder.cc",
"reduce_builder.cc", "reduce_builder.cc",
"reshape_builder.cc", "reshape_builder.cc",
"resize_bilinear_builder.cc", "resize_bilinear_builder.cc",
@ -44,6 +45,7 @@ cc_library(
"op_builder.h", "op_builder.h",
"pad_builder.h", "pad_builder.h",
"pool_2d_builder.h", "pool_2d_builder.h",
"quantize_builder.h",
"reduce_builder.h", "reduce_builder.h",
"reshape_builder.h", "reshape_builder.h",
"resize_bilinear_builder.h", "resize_bilinear_builder.h",

View File

@ -36,9 +36,7 @@ TfLiteStatus L2NormalizationOpBuilder::PopulateSubGraph(
const auto& input_tensor = context->tensors[tensor_id]; const auto& input_tensor = context->tensors[tensor_id];
AddInput(graph_builder_->GetHexagonTensorId(tensor_id)); AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_, ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_));
std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max()));
auto* input_min_const = graph_builder_->AddConstNodeWithData( auto* input_min_const = graph_builder_->AddConstNodeWithData(
quant_bound_shape, reinterpret_cast<char*>(&input_min_), quant_bound_shape, reinterpret_cast<char*>(&input_min_),
sizeof(input_min_)); sizeof(input_min_));

View File

@ -87,6 +87,8 @@ OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type) {
return CreateSpaceToDepthBuilder(this, OP_SpaceToDepth_8); return CreateSpaceToDepthBuilder(this, OP_SpaceToDepth_8);
case kTfLiteBuiltinDepthToSpace: case kTfLiteBuiltinDepthToSpace:
return CreateSpaceToDepthBuilder(this, OP_DepthToSpace_8); return CreateSpaceToDepthBuilder(this, OP_DepthToSpace_8);
case kTfLiteBuiltinQuantize:
return CreateQuantizeBuilder(this, OP_Requantize_8to8);
default: default:
context_->ReportError(context_, "Op not supported: %d", op_type); context_->ReportError(context_, "Op not supported: %d", op_type);
return nullptr; return nullptr;

View File

@ -50,6 +50,7 @@ OpBuilder* CreateBatchSeqBuilder(GraphBuilder* graph_builder, int op_type,
int max_size_for_batch, int max_size_for_batch,
TfLiteIntArray* input_batch_dimensions, TfLiteIntArray* input_batch_dimensions,
TfLiteIntArray* output_batch_dimensions); TfLiteIntArray* output_batch_dimensions);
OpBuilder* CreateQuantizeBuilder(GraphBuilder* graph_builder, int op_type);
} // namespace hexagon } // namespace hexagon
} // namespace delegates } // namespace delegates

View File

@ -0,0 +1,91 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/delegates/hexagon/builders/quantize_builder.h"
#include <stdint.h>
#include <limits>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace delegates {
namespace hexagon {
TfLiteStatus QuantizeOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
const TfLiteIntArray* outputs,
TfLiteContext* context) {
static int scalar_shape[] = {1, 1, 1, 1};
// Input.
float input_min = 0;
float input_max = 0;
const auto& input_tensor = context->tensors[inputs->data[0]];
ComputeMinAndMaxQuantValues(input_tensor, &input_min, &input_max);
auto* input_min_const = graph_builder_->AddConstNodeWithData(
scalar_shape, reinterpret_cast<char*>(&input_min), sizeof(input_min));
auto* input_max_const = graph_builder_->AddConstNodeWithData(
scalar_shape, reinterpret_cast<char*>(&input_max), sizeof(input_max));
// Output.
float output_min = 0;
float output_max = 0;
const auto& output_tensor = context->tensors[outputs->data[0]];
TF_LITE_ENSURE_STATUS(
ComputeMinAndMaxQuantValues(output_tensor, &output_min, &output_max));
int output_batch_size, output_height_size, output_width_size,
output_depth_size;
GetDims(&output_batch_size, &output_height_size, &output_width_size,
&output_depth_size, output_tensor.dims);
auto* requantized_min_const = graph_builder_->AddConstNodeWithData(
scalar_shape, reinterpret_cast<char*>(&output_min), sizeof(output_min));
auto* requantized_max_const = graph_builder_->AddConstNodeWithData(
scalar_shape, reinterpret_cast<char*>(&output_max), sizeof(output_max));
AddInput(graph_builder_->GetHexagonTensorId(inputs->data[0]));
AddInput(TensorID(input_min_const->GetID(), 0));
AddInput(TensorID(input_max_const->GetID(), 0));
AddInput(TensorID(requantized_min_const->GetID(), 0));
AddInput(TensorID(requantized_max_const->GetID(), 0));
// Hexagon outputs for this node.
node_output_ = AddOutput(sizeof(uint8_t), 4,
{output_batch_size, output_height_size,
output_width_size, output_depth_size});
AddOutput(sizeof(float), 4, {1, 1, 1, 1});
AddOutput(sizeof(float), 4, {1, 1, 1, 1});
return kTfLiteOk;
}
TfLiteStatus QuantizeOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
TfLiteContext* context) {
// Should be only 1 output.
graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first,
node_output_.second);
return kTfLiteOk;
}
QuantizeOpBuilder::~QuantizeOpBuilder() {}
OpBuilder* CreateQuantizeBuilder(GraphBuilder* graph_builder, int op_type) {
return new QuantizeOpBuilder(graph_builder, op_type);
}
} // namespace hexagon
} // namespace delegates
} // namespace tflite

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_QUANTIZE_BUILDER_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_QUANTIZE_BUILDER_H_
#include "tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h"
namespace tflite {
namespace delegates {
namespace hexagon {
class QuantizeOpBuilder : public OpBuilder {
public:
explicit QuantizeOpBuilder(GraphBuilder* graph_builder, int op_type)
: OpBuilder(graph_builder, op_type) {}
explicit QuantizeOpBuilder(GraphBuilder* graph_builder, int op_type,
int relu_value)
: OpBuilder(graph_builder, op_type) {}
TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
const TfLiteIntArray* outputs,
TfLiteContext* context) override;
TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
TfLiteContext* context) override;
~QuantizeOpBuilder() override;
private:
TensorID node_output_;
};
} // namespace hexagon
} // namespace delegates
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_QUANTIZE_BUILDER_H_

View File

@ -28,11 +28,13 @@ hexagon_op_tests(
"arg_min_max_test.cc", "arg_min_max_test.cc",
"concat_test.cc", "concat_test.cc",
"conv_test.cc", "conv_test.cc",
"l2_norm_test.cc",
"matmul_test.cc", "matmul_test.cc",
"mul_test.cc", "mul_test.cc",
"neg_test.cc", "neg_test.cc",
"pad_test.cc", "pad_test.cc",
"pool_test.cc", "pool_test.cc",
"quantize_test.cc",
"reduce_test.cc", "reduce_test.cc",
"reshape_test.cc", "reshape_test.cc",
"resize_test.cc", "resize_test.cc",

View File

@ -0,0 +1,122 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
#include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h"
namespace tflite {
using testing::ElementsAreArray;
class L2NormOpModel : public SingleOpModelWithHexagon {
public:
L2NormOpModel(const std::initializer_list<int> input_shape,
const TensorType tensor_type) {
TensorData data = TensorData{tensor_type};
data.min = -2.0;
data.max = 2.0;
data.scale = 2.0;
data.zero_point = 128;
input_ = AddInput(data);
data.min = -1.0;
data.max = 127.0 / 128.0;
output_ = AddOutput(data);
SetBuiltinOp(
BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
CreateL2NormOptions(builder_, ActivationFunctionType_NONE).Union());
BuildInterpreter({input_shape});
}
void SetInput(std::initializer_list<float> data) {
PopulateTensor(input_, data);
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
template <typename T>
std::vector<float> GetDequantizedOutput() {
return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
GetZeroPoint(output_));
}
int input() const { return input_; }
private:
int input_;
int output_;
};
TEST(L2NormOpTest, ZerosVectorUint8Test) {
L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8);
m.QuantizeAndPopulate<uint8_t>(m.input(), {0, 0, 0, 0, 0, 0});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1)));
}
TEST(L2NormOpTest, ZerosVectorInt8Test) {
L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8);
m.QuantizeAndPopulate<int8_t>(m.input(), {0, 0, 0, 0, 0, 0});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1)));
}
TEST(L2NormOpTest, MultipleBatchUint8Test) {
L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8);
m.QuantizeAndPopulate<uint8_t>(m.input(),
{
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3
});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3
},
0.1)));
}
TEST(L2NormOpTest, MultipleBatchInt8Test) {
L2NormOpModel m({3, 1, 1, 6}, TensorType_INT8);
m.QuantizeAndPopulate<int8_t>(m.input(),
{
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3
});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear(
{
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3
},
0.1)));
}
} // namespace tflite

View File

@ -0,0 +1,170 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
#include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h"
namespace tflite {
using testing::ElementsAreArray;
class QuantizeOpModel : public SingleOpModelWithHexagon {
public:
explicit QuantizeOpModel(const TensorData& input, const TensorData& output) {
input_ = AddInput(input);
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_QUANTIZE, BuiltinOptions_QuantizeOptions,
CreateQuantizeOptions(builder_).Union());
BuildInterpreter({GetShape(input_)});
}
template <typename T>
void SetInput(const std::vector<float>& data) {
QuantizeAndPopulate<T>(input_, data);
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
protected:
BuiltinOperator op_code_;
int input_;
int output_;
};
// Input scale 0.500000, output scale 0.500000, input zeropoint 127, output
// zeropoint 127
TEST(QuantizeOpTest, UInt8UInt8SameScale) {
QuantizeOpModel m({TensorType_UINT8, {1, 1, 2, 5}, -63.5, 64},
{TensorType_UINT8, {1, 1, 2, 5}, -63.5, 64});
// Input will quantized to {129,131,133,135,137,139,141,143,145,147}.
m.SetInput<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(
m.GetOutput<uint8_t>(),
ElementsAreArray({129, 131, 133, 135, 137, 139, 141, 143, 145, 147}));
}
// Input scale 0.500000, output scale 1.000000, input zeropoint 127, output
// zeropoint 127
TEST(QuantizeOpTest, Uint8Uint8LargerScale) {
QuantizeOpModel m({TensorType_UINT8, {1, 1, 2, 5}, -63.5, 64},
{TensorType_UINT8, {1, 1, 2, 5}, -127, 128});
// Input will quantized to {129,131,133,135,137,139,141,143,145,147}.
m.SetInput<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(
m.GetOutput<uint8_t>(),
ElementsAreArray({128, 129, 130, 131, 132, 133, 134, 135, 136, 137}));
}
// Input scale 1.000000, output scale 0.500000, input zeropoint 127, output
// zeropoint 127
TEST(QuantizeOpTest, Uint8Uint8SmallerScale) {
QuantizeOpModel m({TensorType_UINT8, {1, 1, 2, 5}, -127, 128},
{TensorType_UINT8, {1, 1, 2, 5}, -63.5, 64});
// Input will quantized to {128, 129, 130, 131, 132, 133, 134, 135, 136, 137}.
m.SetInput<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(
m.GetOutput<uint8_t>(),
ElementsAreArray({129, 131, 133, 135, 137, 139, 141, 143, 145, 147}));
}
// Input scale 1.000000, output scale 0.500000, input zeropoint -1, output
// zeropoint 127
TEST(QuantizeOpTest, Int8Uint8SmallerScale) {
QuantizeOpModel m({TensorType_INT8, {1, 1, 2, 5}, -127, 128},
{TensorType_UINT8, {1, 1, 2, 5}, -63.5, 64});
// Input will quantized to {0,1,2,3,4,5,6,7,8,9}.
m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(
m.GetOutput<uint8_t>(),
ElementsAreArray({129, 131, 133, 135, 137, 139, 141, 143, 145, 147}));
}
// Input scale 1.000000, output scale 2.000000, input zeropoint -1, output
// zeropoint 127
TEST(QuantizeOpTest, Int8Uint8LargerScale) {
QuantizeOpModel m({TensorType_INT8, {1, 1, 2, 5}, -127, 128},
{TensorType_UINT8, {1, 1, 2, 5}, -254, 256});
// Input will quantized to {0,1,2,3,4,5,6,7,8,9}.
m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(
m.GetOutput<uint8_t>(),
ElementsAreArray({128, 128, 129, 129, 130, 130, 131, 131, 132, 132}));
}
// input scale 0.500000, output scale 0.500000, input zeropoint 127, output
// zeropoint -1
TEST(QuantizeOpTest, UInt8Int8SameScale128Diff) {
QuantizeOpModel m({TensorType_UINT8, {1, 1, 2, 5}, -127, 128},
{TensorType_INT8, {1, 1, 2, 5}, -127, 128});
// Input will quantized to {128, 129, 130, 131, 132, 133, 134, 135, 136, 137}.
m.SetInput<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
}
// Input scale 0.500000, output scale 0.500000, input zeropoint -1, output
// zeropoint -1
TEST(QuantizeOpTest, Int8Int8SameScale) {
QuantizeOpModel m({TensorType_INT8, {1, 1, 2, 5}, -63.5, 64},
{TensorType_INT8, {1, 1, 2, 5}, -63.5, 64});
// Input will quantized to {1,3,5,7,9,11,13,15,17,19}.
m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray({1, 3, 5, 7, 9, 11, 13, 15, 17, 19}));
}
// Input scale 0.500000, output scale 1.000000, input zeropoint -1, output
// zeropoint -1
TEST(QuantizeOpTest, Int8Int8LargerScale) {
QuantizeOpModel m({TensorType_INT8, {1, 1, 2, 5}, -63.5, 64},
{TensorType_INT8, {1, 1, 2, 5}, -127, 128});
// Input will quantized to {1,3,5,7,9,11,13,15,17,19}.
m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
}
// Input scale 1.000000, output scale 0.500000, input zeropoint -1, output
// zeropoint -1
TEST(QuantizeOpTest, Int8Int8SmallerScale) {
QuantizeOpModel m({TensorType_INT8, {1, 1, 2, 5}, -127, 128},
{TensorType_INT8, {1, 1, 2, 5}, -63.5, 64});
// Input will quantized to {0,1,2,3,4,5,6,7,8,9}.
m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.ApplyDelegateAndInvoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray({1, 3, 5, 7, 9, 11, 13, 15, 17, 19}));
}
} // namespace tflite

View File

@ -77,17 +77,19 @@ bool CheckOpVersion(const TfLiteRegistration* registration) {
case kTfLiteBuiltinArgMin: case kTfLiteBuiltinArgMin:
case kTfLiteBuiltinAveragePool2d: case kTfLiteBuiltinAveragePool2d:
case kTfLiteBuiltinConcatenation: case kTfLiteBuiltinConcatenation:
case kTfLiteBuiltinL2Normalization:
case kTfLiteBuiltinLogistic: case kTfLiteBuiltinLogistic:
case kTfLiteBuiltinMaxPool2d: case kTfLiteBuiltinMaxPool2d:
case kTfLiteBuiltinMul: case kTfLiteBuiltinMul:
case kTfLiteBuiltinPad: case kTfLiteBuiltinPad:
case kTfLiteBuiltinSub: case kTfLiteBuiltinQuantize:
case kTfLiteBuiltinRelu6: case kTfLiteBuiltinRelu6:
case kTfLiteBuiltinResizeBilinear: case kTfLiteBuiltinResizeBilinear:
case kTfLiteBuiltinResizeNearestNeighbor: case kTfLiteBuiltinResizeNearestNeighbor:
case kTfLiteBuiltinSoftmax: case kTfLiteBuiltinSoftmax:
case kTfLiteBuiltinSpaceToDepth: case kTfLiteBuiltinSpaceToDepth:
case kTfLiteBuiltinSplit: case kTfLiteBuiltinSplit:
case kTfLiteBuiltinSub:
case kTfLiteBuiltinTanh: case kTfLiteBuiltinTanh:
case kTfLiteBuiltinTranspose: case kTfLiteBuiltinTranspose:
case kTfLiteBuiltinTransposeConv: case kTfLiteBuiltinTransposeConv:
@ -301,8 +303,7 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
IsConstantTensor(GetInput(context, node, 1)); IsConstantTensor(GetInput(context, node, 1));
} }
case kTfLiteBuiltinL2Normalization: { case kTfLiteBuiltinL2Normalization: {
// TODO(b/142009955): Support int8. if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}))
if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8}}))
return false; return false;
const TfLiteL2NormParams* norm_params = const TfLiteL2NormParams* norm_params =
reinterpret_cast<const TfLiteL2NormParams*>(node->builtin_data); reinterpret_cast<const TfLiteL2NormParams*>(node->builtin_data);
@ -347,6 +348,10 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
return InputsWithCorrectTypes(node, context, return InputsWithCorrectTypes(node, context,
{{kTfLiteUInt8, kTfLiteInt8}}); {{kTfLiteUInt8, kTfLiteInt8}});
} }
case kTfLiteBuiltinQuantize: {
return InputsWithCorrectTypes(node, context,
{{kTfLiteUInt8, kTfLiteInt8}});
}
default: default:
return false; return false;
} }