diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index ef43c76310b..918e5f6b827 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -150,6 +150,7 @@ cc_library( ":model_builder_helper", ":model_transformer", ":object_reader", + ":operation_parser", ":operations", ":shape", ":status", @@ -372,6 +373,19 @@ cc_test( ], ) +cc_library( + name = "operation_parser", + hdrs = ["operation_parser.h"], + deps = [ + ":model", + ":object_reader", + ":status", + "//tensorflow/lite/c:common", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + cc_test( name = "util_test", srcs = ["util_test.cc"], diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.h b/tensorflow/lite/delegates/gpu/common/custom_parsers.h index 2644864cb58..2d80fe39c59 100644 --- a/tensorflow/lite/delegates/gpu/common/custom_parsers.h +++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.h @@ -15,16 +15,22 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_ -#include +#include +#include #include "absl/strings/string_view.h" #include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/operation_parser.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { namespace gpu { +// Returns a parser for the provided custom op. +std::unique_ptr NewCustomOperationParser( + absl::string_view op_name); + // Matches the custom operation by the string name and parses attributes stored // as flexbuffers. absl::Status ParseCustomAttributes(absl::string_view op_name, int version, diff --git a/tensorflow/lite/delegates/gpu/common/default/BUILD b/tensorflow/lite/delegates/gpu/common/default/BUILD index 91ce7e6c028..eccb99ce179 100644 --- a/tensorflow/lite/delegates/gpu/common/default/BUILD +++ b/tensorflow/lite/delegates/gpu/common/default/BUILD @@ -8,6 +8,7 @@ cc_library( srcs = ["custom_parsers.cc"], hdrs = ["//tensorflow/lite/delegates/gpu/common:custom_parsers.h"], deps = [ + "//tensorflow/lite/delegates/gpu/common:operation_parser", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "@com_google_absl//absl/strings", diff --git a/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc b/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc index a4981a9d459..e70802e16bb 100644 --- a/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc +++ b/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc @@ -15,18 +15,49 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h" -#include - +#include +#include #include +#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/operation_parser.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { namespace gpu { +namespace { + +class UnimplementedCustomOperationParser : public TFLiteOperationParser { + public: + explicit UnimplementedCustomOperationParser(absl::string_view op_name) + : op_name_(op_name) {} + + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return absl::UnimplementedError(op_name_); + } + + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + return absl::UnimplementedError(op_name_); + } + + private: + std::string op_name_; +}; + +} // namespace + +std::unique_ptr NewCustomOperationParser( + absl::string_view op_name) { + return absl::make_unique(op_name); +} absl::Status ParseCustomAttributes(absl::string_view op_name, int version, const void* data, uint32_t data_size, diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 5fca6297171..fb158d13cdc 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/object_reader.h" +#include "tensorflow/lite/delegates/gpu/common/operation_parser.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -69,30 +70,6 @@ absl::Status CheckTensorIsAvailable(const TfLiteContext* context, return absl::OkStatus(); } -// A parser responsible for parsing TFLite operation and adding it to a graph. -class TFLiteOperationParser { - public: - virtual ~TFLiteOperationParser() = default; - - // Parses TFLite operation. This method allows expanding fused operations - // into more than one node. - virtual absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) = 0; - - // Verifies whether passed tflite node may be built by GPU delegate or not. - virtual absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) = 0; - - // Return the value ids in the graph that correspond to the updated values of - // the variable input tensor. - virtual absl::flat_hash_map - GetNewValueIdsForVariableInputNodes() { - return absl::flat_hash_map(); - } -}; - HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); } template @@ -1438,17 +1415,10 @@ class Pooling2DOperationParser : public TFLiteOperationParser { const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); const TfLitePoolParams* tf_options; - auto status = RetrieveCustomInitialData(tflite_node, &tf_options); - if (status.ok()) { // custom case with indices as a second output - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/1, - /*outputs=*/2)); - } else { // common pooling with 1 output - RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/1, - /*outputs=*/1)); - } + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, + /*outputs=*/1)); RETURN_IF_ERROR(CheckKernelsAndStrides( tf_options->filter_height, tf_options->filter_width, tf_options->stride_height, tf_options->stride_width)); @@ -1471,28 +1441,12 @@ class Pooling2DOperationParser : public TFLiteOperationParser { auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; - // check whether there are custom options encoded. It happens if operation - // is MaxPoolingWithArgmax2D. There is no way to read - // tflite_node->builtin_code, so, simply check whether custom data is - // available. const TfLitePoolParams* tf_options; - if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) { - RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); - } + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node)); - // Second output is optional. It is not required, it but must be added after - // MaybeAddFusedActivation function is called - reader->AddOutput(node, 1).IgnoreError(); - // First output is the result of pooling operation, while second output is - // indices used for pooling. - auto outputs = graph->FindOutputs(node->id); - attr.output_indices = outputs.size() == 2; - if (attr.output_indices) { - // Fix data type for output indices. In the model it is set as float32. - outputs[1]->tensor.type = DataType::INT32; - } + attr.output_indices = false; RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr)); node->operation.attributes = attr; return absl::OkStatus(); @@ -2205,45 +2159,6 @@ class TransposeConvBuiltinOperationParser : public TFLiteOperationParser { } }; -// Custom op version of TRANSPOSE_CONV. -class TransposeConvCustomOperationParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); - const TfLiteTransposeConvParams* tf_options; - RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); - RETURN_IF_ERROR( - CheckStrides(tf_options->stride_height, tf_options->stride_width)); - return absl::OkStatus(); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - auto* node = graph->NewNode(); - node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); - RETURN_IF_ERROR(reader->AddInput(node, 0)); - RETURN_IF_ERROR(reader->AddOutputs(node)); - - const TfLiteTransposeConvParams* tf_options; - auto status = RetrieveCustomInitialData(tflite_node, &tf_options); - - ConvolutionTransposedAttributes attr; - attr.stride = status.ok() - ? HW(tf_options->stride_height, tf_options->stride_width) - : HW(1, 1); - RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); - reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional - - UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown, - graph->FindInputs(node->id)[0]->tensor.shape, &attr); - node->operation.attributes = std::move(attr); - return absl::OkStatus(); - } -}; - class TransposeOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, @@ -2295,47 +2210,6 @@ class TransposeOperationParser : public TFLiteOperationParser { } }; -class Unpooling2DOperationParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/2, /*outputs=*/1)); - const TfLitePoolParams* tf_options; - RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); - RETURN_IF_ERROR(CheckKernelsAndStrides( - tf_options->filter_height, tf_options->filter_width, - tf_options->stride_height, tf_options->stride_width)); - return absl::OkStatus(); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D); - RETURN_IF_ERROR(reader->AddInput(node, 0)); - RETURN_IF_ERROR(reader->AddInput(node, 1)); - RETURN_IF_ERROR(reader->AddOutputs(node)); - auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; - MaxUnpooling2DAttributes attr; - - const TfLitePoolParams* tf_options; - RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); - - attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width); - attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); - UpdatePadding(tf_options->padding, input_shape, &attr); - - node->operation.attributes = attr; - - auto output_value = graph->FindOutputs(node->id)[0]; - output_value->tensor.shape = CalculateOutputShape(input_shape, attr); - return absl::OkStatus(); - } -}; - // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported. class BatchToSpaceOperationParser : public TFLiteOperationParser { public: @@ -2423,171 +2297,6 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser { } }; -class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/1, /*outputs=*/1)); - return absl::OkStatus(); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - RETURN_IF_ERROR(reader->AddInput(node, 0)); // bbox - RETURN_IF_ERROR(reader->AddOutputs(node)); - - std::string op_name = "roi_to_transform_matrix"; - node->operation.type = op_name; - BHWC output_shape; - RETURN_IF_ERROR(ParseCustomAttributes( - op_name, registration->version, tflite_node->custom_initial_data, - tflite_node->custom_initial_data_size, &(node->operation.attributes), - &output_shape)); - - auto output_value = graph->FindOutputs(node->id)[0]; - output_value->tensor.shape = output_shape; - return absl::OkStatus(); - } -}; - -class TransformTensorBilinearOperationParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/2, /*outputs=*/1)); - return absl::OkStatus(); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - RETURN_IF_ERROR(reader->AddInput(node, 0)); // data - RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox - RETURN_IF_ERROR(reader->AddOutputs(node)); - - std::string op_name = "transform_tensor_bilinear"; - node->operation.type = op_name; - BHWC output_shape; - RETURN_IF_ERROR(ParseCustomAttributes( - op_name, registration->version, tflite_node->custom_initial_data, - tflite_node->custom_initial_data_size, &(node->operation.attributes), - &output_shape)); - - auto output_value = graph->FindOutputs(node->id)[0]; - - output_value->tensor.shape = - BHWC(1, output_shape.h, output_shape.w, - graph->FindInputs(node->id)[0]->tensor.shape.c); - return absl::OkStatus(); - } -}; - -class TransformLandmarksOperationParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/2, /*outputs=*/1)); - return absl::OkStatus(); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - RETURN_IF_ERROR(reader->AddInput(node, 0)); // data - RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox - RETURN_IF_ERROR(reader->AddOutputs(node)); - std::string op_name = "transform_landmarks"; - node->operation.type = op_name; - BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape; - RETURN_IF_ERROR(ParseCustomAttributes( - op_name, registration->version, tflite_node->custom_initial_data, - tflite_node->custom_initial_data_size, &(node->operation.attributes), - &output_shape)); - - auto output_value = graph->FindOutputs(node->id)[0]; - - output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape; - return absl::OkStatus(); - } -}; - -class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); - return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, - /*outputs=*/1); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks - RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix - - const std::string op_name = "landmarks_to_transform_matrix"; - node->operation.type = op_name; - BHWC output_shape; - RETURN_IF_ERROR(ParseCustomAttributes( - op_name, registration->version, tflite_node->custom_initial_data, - tflite_node->custom_initial_data_size, &(node->operation.attributes), - &output_shape)); - - auto output_value = graph->FindOutputs(node->id)[0]; - output_value->tensor.shape = output_shape; - return absl::OkStatus(); - } -}; - -class AlignmentPointsToTransformMatrixOperationParser - : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, - /*outputs=*/1); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - Node* node = graph->NewNode(); - RETURN_IF_ERROR(reader->AddInput(node, 0)); // alignment points - RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix - - const std::string op_name = "alignment_points_to_transform_matrix"; - node->operation.type = op_name; - BHWC output_shape; - RETURN_IF_ERROR(ParseCustomAttributes( - op_name, registration->version, tflite_node->custom_initial_data, - tflite_node->custom_initial_data_size, &(node->operation.attributes), - &output_shape)); - - auto output_value = graph->FindOutputs(node->id)[0]; - output_value->tensor.shape = output_shape; - return absl::OkStatus(); - } - - private: -}; - class MeanOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, @@ -2763,37 +2472,8 @@ std::unique_ptr NewOperationParser( return std::make_unique(); case kTfLiteBuiltinTransposeConv: return std::make_unique(); - case kTfLiteBuiltinCustom: - const absl::string_view custom_name = registration->custom_name; - if (custom_name == "Convolution2DTransposeBias") { - return std::make_unique(); - } - if (custom_name == "MaxPoolingWithArgmax2D") { - return std::make_unique(PoolingType::MAX); - } - if (custom_name == "MaxUnpooling2D") { - return std::make_unique(); - } - if (custom_name == "RoIToTransformMatrix") { - return std::make_unique(); - } - if (custom_name == "TransformTensor" /*for version 1*/ || - custom_name == "TransformTensorBilinear" /*for version 2*/) { - return std::make_unique(); - } - if (custom_name == "TransformLandmarks") { - return std::make_unique(); - } - if (custom_name == "Landmarks2TransformMatrix" || - custom_name == "Landmarks2TransformMatrixV2") { - return std::make_unique(); - } - if (custom_name == "AlignmentPointsToTransformMatrix") { - return std::make_unique< - AlignmentPointsToTransformMatrixOperationParser>(); - } - break; + return NewCustomOperationParser(registration->custom_name); } return std::make_unique(); } diff --git a/tensorflow/lite/delegates/gpu/common/operation_parser.h b/tensorflow/lite/delegates/gpu/common/operation_parser.h new file mode 100644 index 00000000000..29244068e16 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/operation_parser.h @@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/object_reader.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { + +// Parses TFLite operation and updates provided GraphFloat32. +class TFLiteOperationParser { + public: + virtual ~TFLiteOperationParser() = default; + + // Parses TFLite operation. This method allows expanding fused operations + // into more than one node. + virtual absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) = 0; + + // Verifies whether passed tflite node may be built by GPU delegate or not. + virtual absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) = 0; + + // Returns the value IDs in the graph that correspond to the updated values of + // the variable input tensor. + virtual absl::flat_hash_map + GetNewValueIdsForVariableInputNodes() { + return {}; + } +}; + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_