diff --git a/tensorflow/lite/experimental/delegates/hexagon/README.md b/tensorflow/lite/experimental/delegates/hexagon/README.md index 6ad7d302bcc..198326d41de 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/README.md +++ b/tensorflow/lite/experimental/delegates/hexagon/README.md @@ -82,6 +82,7 @@ are verified in `IsNodeSupportedByHexagon`: * Mul (without any activation) (b/129276536) * Neg * Pad: Only supports 0 padding (b/139277813) +* Quantize (8-bit inputs & outputs only) * Relu * Relu6 * Reshape diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD b/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD index 3c666e6a4fe..550748e9961 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD @@ -21,6 +21,7 @@ cc_library( "op_builder.cc", "pad_builder.cc", "pool_2d_builder.cc", + "quantize_builder.cc", "reduce_builder.cc", "reshape_builder.cc", "resize_bilinear_builder.cc", @@ -44,6 +45,7 @@ cc_library( "op_builder.h", "pad_builder.h", "pool_2d_builder.h", + "quantize_builder.h", "reduce_builder.h", "reshape_builder.h", "resize_bilinear_builder.h", diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/l2_normalization_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/l2_normalization_builder.cc index ab91f65e18c..924408912f9 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/l2_normalization_builder.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/l2_normalization_builder.cc @@ -36,9 +36,7 @@ TfLiteStatus L2NormalizationOpBuilder::PopulateSubGraph( const auto& input_tensor = context->tensors[tensor_id]; AddInput(graph_builder_->GetHexagonTensorId(tensor_id)); TF_LITE_ENSURE_STATUS( - ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_, - std::numeric_limits::min(), - std::numeric_limits::max())); + ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_)); auto* input_min_const = graph_builder_->AddConstNodeWithData( quant_bound_shape, reinterpret_cast(&input_min_), sizeof(input_min_)); diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc index d737d4ab2fa..0cfe99994a2 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc @@ -87,6 +87,8 @@ OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type) { return CreateSpaceToDepthBuilder(this, OP_SpaceToDepth_8); case kTfLiteBuiltinDepthToSpace: return CreateSpaceToDepthBuilder(this, OP_DepthToSpace_8); + case kTfLiteBuiltinQuantize: + return CreateQuantizeBuilder(this, OP_Requantize_8to8); default: context_->ReportError(context_, "Op not supported: %d", op_type); return nullptr; diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h b/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h index 47e63f5d468..e2a4ef9a0a3 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h @@ -50,6 +50,7 @@ OpBuilder* CreateBatchSeqBuilder(GraphBuilder* graph_builder, int op_type, int max_size_for_batch, TfLiteIntArray* input_batch_dimensions, TfLiteIntArray* output_batch_dimensions); +OpBuilder* CreateQuantizeBuilder(GraphBuilder* graph_builder, int op_type); } // namespace hexagon } // namespace delegates diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/quantize_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/quantize_builder.cc new file mode 100644 index 00000000000..e4258642bb1 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/quantize_builder.cc @@ -0,0 +1,91 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/hexagon/builders/quantize_builder.h" + +#include + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace delegates { +namespace hexagon { +TfLiteStatus QuantizeOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context) { + static int scalar_shape[] = {1, 1, 1, 1}; + + // Input. + float input_min = 0; + float input_max = 0; + const auto& input_tensor = context->tensors[inputs->data[0]]; + ComputeMinAndMaxQuantValues(input_tensor, &input_min, &input_max); + auto* input_min_const = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast(&input_min), sizeof(input_min)); + auto* input_max_const = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast(&input_max), sizeof(input_max)); + + // Output. + float output_min = 0; + float output_max = 0; + const auto& output_tensor = context->tensors[outputs->data[0]]; + TF_LITE_ENSURE_STATUS( + ComputeMinAndMaxQuantValues(output_tensor, &output_min, &output_max)); + int output_batch_size, output_height_size, output_width_size, + output_depth_size; + GetDims(&output_batch_size, &output_height_size, &output_width_size, + &output_depth_size, output_tensor.dims); + auto* requantized_min_const = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast(&output_min), sizeof(output_min)); + auto* requantized_max_const = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast(&output_max), sizeof(output_max)); + + AddInput(graph_builder_->GetHexagonTensorId(inputs->data[0])); + AddInput(TensorID(input_min_const->GetID(), 0)); + AddInput(TensorID(input_max_const->GetID(), 0)); + AddInput(TensorID(requantized_min_const->GetID(), 0)); + AddInput(TensorID(requantized_max_const->GetID(), 0)); + + // Hexagon outputs for this node. + node_output_ = AddOutput(sizeof(uint8_t), 4, + {output_batch_size, output_height_size, + output_width_size, output_depth_size}); + AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + + return kTfLiteOk; +} + +TfLiteStatus QuantizeOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) { + // Should be only 1 output. + graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first, + node_output_.second); + + return kTfLiteOk; +} + +QuantizeOpBuilder::~QuantizeOpBuilder() {} + +OpBuilder* CreateQuantizeBuilder(GraphBuilder* graph_builder, int op_type) { + return new QuantizeOpBuilder(graph_builder, op_type); +} + +} // namespace hexagon +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/quantize_builder.h b/tensorflow/lite/experimental/delegates/hexagon/builders/quantize_builder.h new file mode 100644 index 00000000000..9851ce46f00 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/quantize_builder.h @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_QUANTIZE_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_QUANTIZE_BUILDER_H_ + +#include "tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace hexagon { + +class QuantizeOpBuilder : public OpBuilder { + public: + explicit QuantizeOpBuilder(GraphBuilder* graph_builder, int op_type) + : OpBuilder(graph_builder, op_type) {} + explicit QuantizeOpBuilder(GraphBuilder* graph_builder, int op_type, + int relu_value) + : OpBuilder(graph_builder, op_type) {} + TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + ~QuantizeOpBuilder() override; + + private: + TensorID node_output_; +}; + +} // namespace hexagon +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_QUANTIZE_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD index 3bf9120b56a..b1df59c4098 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD @@ -28,11 +28,13 @@ hexagon_op_tests( "arg_min_max_test.cc", "concat_test.cc", "conv_test.cc", + "l2_norm_test.cc", "matmul_test.cc", "mul_test.cc", "neg_test.cc", "pad_test.cc", "pool_test.cc", + "quantize_test.cc", "reduce_test.cc", "reshape_test.cc", "resize_test.cc", diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/l2_norm_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/l2_norm_test.cc new file mode 100644 index 00000000000..34d53d6e68f --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/l2_norm_test.cc @@ -0,0 +1,122 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h" + +namespace tflite { +using testing::ElementsAreArray; + +class L2NormOpModel : public SingleOpModelWithHexagon { + public: + L2NormOpModel(const std::initializer_list input_shape, + const TensorType tensor_type) { + TensorData data = TensorData{tensor_type}; + data.min = -2.0; + data.max = 2.0; + data.scale = 2.0; + data.zero_point = 128; + input_ = AddInput(data); + + data.min = -1.0; + data.max = 127.0 / 128.0; + output_ = AddOutput(data); + + SetBuiltinOp( + BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions, + CreateL2NormOptions(builder_, ActivationFunctionType_NONE).Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + template + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); + } + + int input() const { return input_; } + + private: + int input_; + int output_; +}; + +TEST(L2NormOpTest, ZerosVectorUint8Test) { + L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8); + + m.QuantizeAndPopulate(m.input(), {0, 0, 0, 0, 0, 0}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1))); +} + +TEST(L2NormOpTest, ZerosVectorInt8Test) { + L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8); + + m.QuantizeAndPopulate(m.input(), {0, 0, 0, 0, 0, 0}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1))); +} + +TEST(L2NormOpTest, MultipleBatchUint8Test) { + L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8); + + m.QuantizeAndPopulate(m.input(), + { + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + }); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + }, + 0.1))); +} + +TEST(L2NormOpTest, MultipleBatchInt8Test) { + L2NormOpModel m({3, 1, 1, 6}, TensorType_INT8); + + m.QuantizeAndPopulate(m.input(), + { + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + }); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + }, + 0.1))); +} + +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/quantize_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/quantize_test.cc new file mode 100644 index 00000000000..93cd138f014 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/quantize_test.cc @@ -0,0 +1,170 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h" + +namespace tflite { +using testing::ElementsAreArray; + +class QuantizeOpModel : public SingleOpModelWithHexagon { + public: + explicit QuantizeOpModel(const TensorData& input, const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_QUANTIZE, BuiltinOptions_QuantizeOptions, + CreateQuantizeOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(const std::vector& data) { + QuantizeAndPopulate(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + protected: + BuiltinOperator op_code_; + + int input_; + int output_; +}; + +// Input scale 0.500000, output scale 0.500000, input zeropoint 127, output +// zeropoint 127 +TEST(QuantizeOpTest, UInt8UInt8SameScale) { + QuantizeOpModel m({TensorType_UINT8, {1, 1, 2, 5}, -63.5, 64}, + {TensorType_UINT8, {1, 1, 2, 5}, -63.5, 64}); + + // Input will quantized to {129,131,133,135,137,139,141,143,145,147}. + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({129, 131, 133, 135, 137, 139, 141, 143, 145, 147})); +} + +// Input scale 0.500000, output scale 1.000000, input zeropoint 127, output +// zeropoint 127 +TEST(QuantizeOpTest, Uint8Uint8LargerScale) { + QuantizeOpModel m({TensorType_UINT8, {1, 1, 2, 5}, -63.5, 64}, + {TensorType_UINT8, {1, 1, 2, 5}, -127, 128}); + + // Input will quantized to {129,131,133,135,137,139,141,143,145,147}. + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({128, 129, 130, 131, 132, 133, 134, 135, 136, 137})); +} + +// Input scale 1.000000, output scale 0.500000, input zeropoint 127, output +// zeropoint 127 +TEST(QuantizeOpTest, Uint8Uint8SmallerScale) { + QuantizeOpModel m({TensorType_UINT8, {1, 1, 2, 5}, -127, 128}, + {TensorType_UINT8, {1, 1, 2, 5}, -63.5, 64}); + + // Input will quantized to {128, 129, 130, 131, 132, 133, 134, 135, 136, 137}. + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({129, 131, 133, 135, 137, 139, 141, 143, 145, 147})); +} + +// Input scale 1.000000, output scale 0.500000, input zeropoint -1, output +// zeropoint 127 +TEST(QuantizeOpTest, Int8Uint8SmallerScale) { + QuantizeOpModel m({TensorType_INT8, {1, 1, 2, 5}, -127, 128}, + {TensorType_UINT8, {1, 1, 2, 5}, -63.5, 64}); + + // Input will quantized to {0,1,2,3,4,5,6,7,8,9}. + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({129, 131, 133, 135, 137, 139, 141, 143, 145, 147})); +} + +// Input scale 1.000000, output scale 2.000000, input zeropoint -1, output +// zeropoint 127 +TEST(QuantizeOpTest, Int8Uint8LargerScale) { + QuantizeOpModel m({TensorType_INT8, {1, 1, 2, 5}, -127, 128}, + {TensorType_UINT8, {1, 1, 2, 5}, -254, 256}); + + // Input will quantized to {0,1,2,3,4,5,6,7,8,9}. + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({128, 128, 129, 129, 130, 130, 131, 131, 132, 132})); +} + +// input scale 0.500000, output scale 0.500000, input zeropoint 127, output +// zeropoint -1 +TEST(QuantizeOpTest, UInt8Int8SameScale128Diff) { + QuantizeOpModel m({TensorType_UINT8, {1, 1, 2, 5}, -127, 128}, + {TensorType_INT8, {1, 1, 2, 5}, -127, 128}); + + // Input will quantized to {128, 129, 130, 131, 132, 133, 134, 135, 136, 137}. + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, 9})); +} + +// Input scale 0.500000, output scale 0.500000, input zeropoint -1, output +// zeropoint -1 +TEST(QuantizeOpTest, Int8Int8SameScale) { + QuantizeOpModel m({TensorType_INT8, {1, 1, 2, 5}, -63.5, 64}, + {TensorType_INT8, {1, 1, 2, 5}, -63.5, 64}); + + // Input will quantized to {1,3,5,7,9,11,13,15,17,19}. + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 3, 5, 7, 9, 11, 13, 15, 17, 19})); +} + +// Input scale 0.500000, output scale 1.000000, input zeropoint -1, output +// zeropoint -1 +TEST(QuantizeOpTest, Int8Int8LargerScale) { + QuantizeOpModel m({TensorType_INT8, {1, 1, 2, 5}, -63.5, 64}, + {TensorType_INT8, {1, 1, 2, 5}, -127, 128}); + + // Input will quantized to {1,3,5,7,9,11,13,15,17,19}. + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, 9})); +} + +// Input scale 1.000000, output scale 0.500000, input zeropoint -1, output +// zeropoint -1 +TEST(QuantizeOpTest, Int8Int8SmallerScale) { + QuantizeOpModel m({TensorType_INT8, {1, 1, 2, 5}, -127, 128}, + {TensorType_INT8, {1, 1, 2, 5}, -63.5, 64}); + + // Input will quantized to {0,1,2,3,4,5,6,7,8,9}. + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 3, 5, 7, 9, 11, 13, 15, 17, 19})); +} + +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/utils.cc b/tensorflow/lite/experimental/delegates/hexagon/utils.cc index 35fce1bfebd..6ba1279e01d 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/utils.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/utils.cc @@ -77,17 +77,19 @@ bool CheckOpVersion(const TfLiteRegistration* registration) { case kTfLiteBuiltinArgMin: case kTfLiteBuiltinAveragePool2d: case kTfLiteBuiltinConcatenation: + case kTfLiteBuiltinL2Normalization: case kTfLiteBuiltinLogistic: case kTfLiteBuiltinMaxPool2d: case kTfLiteBuiltinMul: case kTfLiteBuiltinPad: - case kTfLiteBuiltinSub: + case kTfLiteBuiltinQuantize: case kTfLiteBuiltinRelu6: case kTfLiteBuiltinResizeBilinear: case kTfLiteBuiltinResizeNearestNeighbor: case kTfLiteBuiltinSoftmax: case kTfLiteBuiltinSpaceToDepth: case kTfLiteBuiltinSplit: + case kTfLiteBuiltinSub: case kTfLiteBuiltinTanh: case kTfLiteBuiltinTranspose: case kTfLiteBuiltinTransposeConv: @@ -301,8 +303,7 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, IsConstantTensor(GetInput(context, node, 1)); } case kTfLiteBuiltinL2Normalization: { - // TODO(b/142009955): Support int8. - if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8}})) + if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}})) return false; const TfLiteL2NormParams* norm_params = reinterpret_cast(node->builtin_data); @@ -347,6 +348,10 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, return InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}}); } + case kTfLiteBuiltinQuantize: { + return InputsWithCorrectTypes(node, context, + {{kTfLiteUInt8, kTfLiteInt8}}); + } default: return false; }