Add SquaredDifference op in hexagon delegate. The op is not supported natively so we lower the op to Mul(Sub(a,b), Sub(a,b))
PiperOrigin-RevId: 349301984 Change-Id: I19c184de2018ca6a417fd02da51885b1538f82bf
This commit is contained in:
parent
40dad0a283
commit
ea9f1bf016
@ -99,6 +99,7 @@ are verified in `IsNodeSupportedByHexagon`:
|
||||
* SoftMax
|
||||
* SpaceToDepth
|
||||
* Split
|
||||
* SquaredDifference
|
||||
* Strided Slice
|
||||
* Sub (Support relu activations)
|
||||
* Tanh
|
||||
|
@ -35,6 +35,7 @@ cc_library(
|
||||
"softmax_builder.cc",
|
||||
"space_to_depth_builder.cc",
|
||||
"split_builder.cc",
|
||||
"squared_difference.cc",
|
||||
"strided_slice_builder.cc",
|
||||
"transpose_builder.cc",
|
||||
"transpose_conv_2d_builder.cc",
|
||||
|
@ -159,6 +159,8 @@ OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type,
|
||||
return CreatePackBuilder(this, OP_QuantizedPack_8);
|
||||
case kTfLiteBuiltinStridedSlice:
|
||||
return CreateStridedSliceBuilder(this, OP_QuantizedStridedSlice_8);
|
||||
case kTfLiteBuiltinSquaredDifference:
|
||||
return CreateSquaredDifferenceOpBuilder(this, OP_QuantizedSub_8p8to8);
|
||||
default:
|
||||
context_->ReportError(context_, "Op not supported: %d", op_type);
|
||||
return nullptr;
|
||||
|
@ -60,6 +60,8 @@ OpBuilder* CreateSliceOpBuilder(GraphBuilder* graph_builder, int op_type);
|
||||
OpBuilder* CreatePackBuilder(GraphBuilder* graph_builder, int op_type);
|
||||
OpBuilder* CreateMatMulOpBuilder(GraphBuilder* graph_builder, int op_type);
|
||||
OpBuilder* CreateStridedSliceBuilder(GraphBuilder* graph_builder, int op_type);
|
||||
OpBuilder* CreateSquaredDifferenceOpBuilder(GraphBuilder* graph_builder,
|
||||
int op_type);
|
||||
|
||||
} // namespace hexagon
|
||||
} // namespace delegates
|
||||
|
104
tensorflow/lite/delegates/hexagon/builders/squared_difference.cc
Normal file
104
tensorflow/lite/delegates/hexagon/builders/squared_difference.cc
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/hexagon/builders/op_builder.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
namespace hexagon {
|
||||
// Builder for SquaredDifference op by computing Mul(Sub(A,B), Sub(A,B))
|
||||
class SquaredDifferenceOpBuilder : public OpBuilder {
|
||||
public:
|
||||
explicit SquaredDifferenceOpBuilder(GraphBuilder* graph_builder, int op_type)
|
||||
: OpBuilder(graph_builder, op_type) {}
|
||||
TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context) override;
|
||||
|
||||
TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context) override;
|
||||
|
||||
private:
|
||||
TensorID node_output_;
|
||||
};
|
||||
|
||||
TfLiteStatus SquaredDifferenceOpBuilder::PopulateSubGraph(
|
||||
const TfLiteIntArray* inputs, const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context) {
|
||||
// We model Squared Diff as Mul(Sub(a,b), Sub(a,b))
|
||||
|
||||
// Add first Sub op.
|
||||
const int tensor_a_index = inputs->data[0];
|
||||
const int tensor_b_index = inputs->data[1];
|
||||
const auto& tensor_a = context->tensors[tensor_a_index];
|
||||
const auto& tensor_b = context->tensors[tensor_b_index];
|
||||
AddInput(graph_builder_->GetHexagonTensorId(tensor_a_index));
|
||||
AddInput(graph_builder_->GetHexagonTensorId(tensor_b_index));
|
||||
// Inputs min/max
|
||||
TF_LITE_ENSURE_STATUS(ComputeAndAddMinAndMax(context, tensor_a));
|
||||
TF_LITE_ENSURE_STATUS(ComputeAndAddMinAndMax(context, tensor_b));
|
||||
// Output details.
|
||||
float output_min = -1, output_max = -1;
|
||||
TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
|
||||
context->tensors[outputs->data[0]], &output_min, &output_max));
|
||||
auto* output_min_const = graph_builder_->AddConstNodeWithData(
|
||||
kScalarShape, reinterpret_cast<char*>(&output_min), sizeof(output_min));
|
||||
auto* output_max_const = graph_builder_->AddConstNodeWithData(
|
||||
kScalarShape, reinterpret_cast<char*>(&output_max), sizeof(output_max));
|
||||
int output_batch_size, output_height_size, output_width_size,
|
||||
output_depth_size;
|
||||
GetDims(&output_batch_size, &output_height_size, &output_width_size,
|
||||
&output_depth_size, context->tensors[outputs->data[0]].dims);
|
||||
|
||||
auto sub_out = AddOutput(sizeof(uint8_t), 4,
|
||||
{output_batch_size, output_height_size,
|
||||
output_width_size, output_depth_size});
|
||||
auto sub_min = AddOutput(sizeof(float), 4, kScalarShape);
|
||||
auto sub_max = AddOutput(sizeof(float), 4, kScalarShape);
|
||||
|
||||
// Add Mul
|
||||
auto* mul_op = graph_builder_->AddNode(GetTFLiteNodeID());
|
||||
mul_op->SetOpType(OP_QuantizedMul_8x8to8);
|
||||
mul_op->AddInput(sub_out);
|
||||
mul_op->AddInput(sub_out);
|
||||
mul_op->AddInput(sub_min);
|
||||
mul_op->AddInput(sub_max);
|
||||
mul_op->AddInput(sub_min);
|
||||
mul_op->AddInput(sub_max);
|
||||
mul_op->AddInput(TensorID(output_min_const->GetID(), 0));
|
||||
mul_op->AddInput(TensorID(output_max_const->GetID(), 0));
|
||||
node_output_ = mul_op->AddOutput(sizeof(uint8_t), 4,
|
||||
{output_batch_size, output_height_size,
|
||||
output_width_size, output_depth_size});
|
||||
mul_op->AddOutput(sizeof(float), 4, kScalarShape);
|
||||
mul_op->AddOutput(sizeof(float), 4, kScalarShape);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SquaredDifferenceOpBuilder::RegisterOutputs(
|
||||
const TfLiteIntArray* outputs, TfLiteContext* context) {
|
||||
graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first,
|
||||
node_output_.second);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
OpBuilder* CreateSquaredDifferenceOpBuilder(GraphBuilder* graph_builder,
|
||||
int op_type) {
|
||||
return new SquaredDifferenceOpBuilder(graph_builder, op_type);
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
@ -49,6 +49,7 @@ hexagon_op_tests(
|
||||
"softmax_test.cc",
|
||||
"space_to_depth_test.cc",
|
||||
"split_test.cc",
|
||||
"squared_difference_test.cc",
|
||||
"strided_slice_test.cc",
|
||||
"transpose_conv_test.cc",
|
||||
"transpose_test.cc",
|
||||
|
@ -0,0 +1,106 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h"
|
||||
|
||||
namespace tflite {
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class SquaredDifferenceOpModel : public SingleOpModelWithHexagon {
|
||||
public:
|
||||
SquaredDifferenceOpModel(const TensorData& input1, const TensorData& input2,
|
||||
const TensorData& output) {
|
||||
input1_ = AddInput(input1);
|
||||
input2_ = AddInput(input2);
|
||||
output_ = AddOutput(output);
|
||||
SetBuiltinOp(BuiltinOperator_SQUARED_DIFFERENCE,
|
||||
BuiltinOptions_SquaredDifferenceOptions,
|
||||
CreateSquaredDifferenceOptions(builder_).Union());
|
||||
BuildInterpreter({GetShape(input1_), GetShape(input2_)});
|
||||
}
|
||||
|
||||
int input1() { return input1_; }
|
||||
int input2() { return input2_; }
|
||||
|
||||
template <typename integer_dtype>
|
||||
std::vector<float> GetDequantizedOutput() {
|
||||
return Dequantize<int8_t>(ExtractVector<int8_t>(output_), GetScale(output_),
|
||||
GetZeroPoint(output_));
|
||||
}
|
||||
|
||||
protected:
|
||||
int input1_;
|
||||
int input2_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
float GetTolerance(int min, int max) {
|
||||
float kQuantizedStep = (max - min) / 255.0;
|
||||
return kQuantizedStep;
|
||||
}
|
||||
|
||||
TEST(QuantizedSquaredDifferenceOpTest, Quantized_SameShape) {
|
||||
float kQuantizedTolerance = GetTolerance(0, 1);
|
||||
SquaredDifferenceOpModel m({TensorType_INT8, {1, 2, 2, 1}, -1.2, 0.8},
|
||||
{TensorType_INT8, {1, 2, 2, 1}, -1.5, 0.5},
|
||||
{TensorType_INT8, {}, 0.0, 0.5});
|
||||
m.QuantizeAndPopulate<int8_t>(m.input1(), {-0.2, 0.2, -1.2, 0.8});
|
||||
m.QuantizeAndPopulate<int8_t>(m.input2(), {0.5, 0.2, -1.5, 0.5});
|
||||
m.ApplyDelegateAndInvoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear({0.49, 0.0, 0.09, 0.09},
|
||||
kQuantizedTolerance)));
|
||||
}
|
||||
|
||||
TEST(QuantizedSquaredDifferenceOpTest, Quantized_VariousInputShapes) {
|
||||
// NOTE: the min/max are 0 and 9. We use larger threshold for accuracy
|
||||
// issue in Hexagon.
|
||||
float kQuantizedTolerance = GetTolerance(0, 10);
|
||||
std::vector<std::vector<int>> test_shapes = {
|
||||
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
|
||||
for (int i = 0; i < test_shapes.size(); ++i) {
|
||||
SquaredDifferenceOpModel m({TensorType_INT8, test_shapes[i], -2.0, 1.7},
|
||||
{TensorType_INT8, test_shapes[i], -1.0, 1.0},
|
||||
{TensorType_INT8, {}, 0.0, 9.0});
|
||||
m.QuantizeAndPopulate<int8_t>(m.input1(), {-2.0, 0.2, 0.3, 0.8, 1.1, -2.0});
|
||||
m.QuantizeAndPopulate<int8_t>(m.input2(), {1.0, 0.2, 0.6, 0.4, -1.0, -0.0});
|
||||
m.ApplyDelegateAndInvoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{9.0, 0.0, 0.09, 0.16, 4.41, 4.0}, kQuantizedTolerance)))
|
||||
<< "With shape number " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(QuantizedSquaredDifferenceOpTest, Quantized_WithBroadcast) {
|
||||
std::vector<std::vector<int>> test_shapes = {
|
||||
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
|
||||
float kQuantizedTolerance = GetTolerance(0, 1);
|
||||
for (int i = 0; i < test_shapes.size(); ++i) {
|
||||
SquaredDifferenceOpModel m({TensorType_INT8, test_shapes[i], -0.2, 1.1},
|
||||
{TensorType_INT8, {}, 0.0, 0.1},
|
||||
{TensorType_INT8, {}, 0.0, 1.0});
|
||||
m.QuantizeAndPopulate<int8_t>(m.input1(), {-0.2, 0.2, 0.5, 0.8, 0.11, 1.1});
|
||||
m.QuantizeAndPopulate<int8_t>(m.input2(), {0.1});
|
||||
m.ApplyDelegateAndInvoke();
|
||||
EXPECT_THAT(
|
||||
m.GetDequantizedOutput<int8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear({0.09, 0.01, 0.16, 0.49, 0.0001, 1.0},
|
||||
kQuantizedTolerance)))
|
||||
<< "With shape number " << i;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
@ -101,6 +101,7 @@ bool CheckOpVersion(const TfLiteRegistration* registration) {
|
||||
case kTfLiteBuiltinTanh:
|
||||
case kTfLiteBuiltinTranspose:
|
||||
return registration->version <= 2;
|
||||
case kTfLiteBuiltinSquaredDifference:
|
||||
case kTfLiteBuiltinRelu:
|
||||
return registration->version == 2;
|
||||
case kTfLiteBuiltinConv2d:
|
||||
@ -426,6 +427,10 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
|
||||
// Hexagon doesn't support ellipsis/new-axis masks.
|
||||
return (params->ellipsis_mask == 0 && params->new_axis_mask == 0);
|
||||
}
|
||||
case kTfLiteBuiltinSquaredDifference: {
|
||||
return InputsWithCorrectTypes(node, context,
|
||||
{{kTfLiteInt8}, {kTfLiteInt8}});
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user