Refactor custom op parsers.

PiperOrigin-RevId: 351263261
Change-Id: I1bee6760fa596658e340387209f8530b9f5d7e35
This commit is contained in:
Juhyun Lee 2021-01-11 16:44:44 -08:00 committed by TensorFlower Gardener
parent d7f96c409b
commit 511c27580a
6 changed files with 118 additions and 331 deletions

View File

@ -150,6 +150,7 @@ cc_library(
":model_builder_helper",
":model_transformer",
":object_reader",
":operation_parser",
":operations",
":shape",
":status",
@ -372,6 +373,19 @@ cc_test(
],
)
cc_library(
name = "operation_parser",
hdrs = ["operation_parser.h"],
deps = [
":model",
":object_reader",
":status",
"//tensorflow/lite/c:common",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_test(
name = "util_test",
srcs = ["util_test.cc"],

View File

@ -15,16 +15,22 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
#include <stdint.h>
#include <cstdint>
#include <memory>
#include "absl/strings/string_view.h"
#include "absl/types/any.h"
#include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
// Returns a parser for the provided custom op.
std::unique_ptr<TFLiteOperationParser> NewCustomOperationParser(
absl::string_view op_name);
// Matches the custom operation by the string name and parses attributes stored
// as flexbuffers.
absl::Status ParseCustomAttributes(absl::string_view op_name, int version,

View File

@ -8,6 +8,7 @@ cc_library(
srcs = ["custom_parsers.cc"],
hdrs = ["//tensorflow/lite/delegates/gpu/common:custom_parsers.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:operation_parser",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:status",
"@com_google_absl//absl/strings",

View File

@ -15,18 +15,49 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
#include <stdint.h>
#include <cstdint>
#include <memory>
#include <string>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/any.h"
#include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
namespace {
class UnimplementedCustomOperationParser : public TFLiteOperationParser {
public:
explicit UnimplementedCustomOperationParser(absl::string_view op_name)
: op_name_(op_name) {}
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return absl::UnimplementedError(op_name_);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
return absl::UnimplementedError(op_name_);
}
private:
std::string op_name_;
};
} // namespace
std::unique_ptr<TFLiteOperationParser> NewCustomOperationParser(
absl::string_view op_name) {
return absl::make_unique<UnimplementedCustomOperationParser>(op_name);
}
absl::Status ParseCustomAttributes(absl::string_view op_name, int version,
const void* data, uint32_t data_size,

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
#include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@ -69,30 +70,6 @@ absl::Status CheckTensorIsAvailable(const TfLiteContext* context,
return absl::OkStatus();
}
// A parser responsible for parsing TFLite operation and adding it to a graph.
class TFLiteOperationParser {
public:
virtual ~TFLiteOperationParser() = default;
// Parses TFLite operation. This method allows expanding fused operations
// into more than one node.
virtual absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) = 0;
// Verifies whether passed tflite node may be built by GPU delegate or not.
virtual absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) = 0;
// Return the value ids in the graph that correspond to the updated values of
// the variable input tensor.
virtual absl::flat_hash_map<int, ValueId>
GetNewValueIdsForVariableInputNodes() {
return absl::flat_hash_map<int, ValueId>();
}
};
HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); }
template <typename AttrT>
@ -1438,17 +1415,10 @@ class Pooling2DOperationParser : public TFLiteOperationParser {
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
const TfLitePoolParams* tf_options;
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
if (status.ok()) { // custom case with indices as a second output
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/2));
} else { // common pooling with 1 output
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/1));
}
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/1));
RETURN_IF_ERROR(CheckKernelsAndStrides(
tf_options->filter_height, tf_options->filter_width,
tf_options->stride_height, tf_options->stride_width));
@ -1471,28 +1441,12 @@ class Pooling2DOperationParser : public TFLiteOperationParser {
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
// check whether there are custom options encoded. It happens if operation
// is MaxPoolingWithArgmax2D. There is no way to read
// tflite_node->builtin_code, so, simply check whether custom data is
// available.
const TfLitePoolParams* tf_options;
if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
}
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
// Second output is optional. It is not required, it but must be added after
// MaybeAddFusedActivation function is called
reader->AddOutput(node, 1).IgnoreError();
// First output is the result of pooling operation, while second output is
// indices used for pooling.
auto outputs = graph->FindOutputs(node->id);
attr.output_indices = outputs.size() == 2;
if (attr.output_indices) {
// Fix data type for output indices. In the model it is set as float32.
outputs[1]->tensor.type = DataType::INT32;
}
attr.output_indices = false;
RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
node->operation.attributes = attr;
return absl::OkStatus();
@ -2205,45 +2159,6 @@ class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
}
};
// Custom op version of TRANSPOSE_CONV.
class TransposeConvCustomOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
const TfLiteTransposeConvParams* tf_options;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(
CheckStrides(tf_options->stride_height, tf_options->stride_width));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const TfLiteTransposeConvParams* tf_options;
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
ConvolutionTransposedAttributes attr;
attr.stride = status.ok()
? HW(tf_options->stride_height, tf_options->stride_width)
: HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
};
class TransposeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
@ -2295,47 +2210,6 @@ class TransposeOperationParser : public TFLiteOperationParser {
}
};
class Unpooling2DOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
const TfLitePoolParams* tf_options;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckKernelsAndStrides(
tf_options->filter_height, tf_options->filter_width,
tf_options->stride_height, tf_options->stride_width));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
RETURN_IF_ERROR(reader->AddOutputs(node));
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
MaxUnpooling2DAttributes attr;
const TfLitePoolParams* tf_options;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
UpdatePadding(tf_options->padding, input_shape, &attr);
node->operation.attributes = attr;
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
return absl::OkStatus();
}
};
// TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
class BatchToSpaceOperationParser : public TFLiteOperationParser {
public:
@ -2423,171 +2297,6 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser {
}
};
class RoIToTransformMatrixOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "roi_to_transform_matrix";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(ParseCustomAttributes(
op_name, registration->version, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size, &(node->operation.attributes),
&output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
};
class TransformTensorBilinearOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_tensor_bilinear";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(ParseCustomAttributes(
op_name, registration->version, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size, &(node->operation.attributes),
&output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape =
BHWC(1, output_shape.h, output_shape.w,
graph->FindInputs(node->id)[0]->tensor.shape.c);
return absl::OkStatus();
}
};
class TransformLandmarksOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_landmarks";
node->operation.type = op_name;
BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
RETURN_IF_ERROR(ParseCustomAttributes(
op_name, registration->version, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size, &(node->operation.attributes),
&output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
return absl::OkStatus();
}
};
class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks
RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix
const std::string op_name = "landmarks_to_transform_matrix";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(ParseCustomAttributes(
op_name, registration->version, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size, &(node->operation.attributes),
&output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
};
class AlignmentPointsToTransformMatrixOperationParser
: public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // alignment points
RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix
const std::string op_name = "alignment_points_to_transform_matrix";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(ParseCustomAttributes(
op_name, registration->version, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size, &(node->operation.attributes),
&output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = output_shape;
return absl::OkStatus();
}
private:
};
class MeanOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
@ -2763,37 +2472,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
return std::make_unique<TransposeOperationParser>();
case kTfLiteBuiltinTransposeConv:
return std::make_unique<TransposeConvBuiltinOperationParser>();
case kTfLiteBuiltinCustom:
const absl::string_view custom_name = registration->custom_name;
if (custom_name == "Convolution2DTransposeBias") {
return std::make_unique<TransposeConvCustomOperationParser>();
}
if (custom_name == "MaxPoolingWithArgmax2D") {
return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
}
if (custom_name == "MaxUnpooling2D") {
return std::make_unique<Unpooling2DOperationParser>();
}
if (custom_name == "RoIToTransformMatrix") {
return std::make_unique<RoIToTransformMatrixOperationParser>();
}
if (custom_name == "TransformTensor" /*for version 1*/ ||
custom_name == "TransformTensorBilinear" /*for version 2*/) {
return std::make_unique<TransformTensorBilinearOperationParser>();
}
if (custom_name == "TransformLandmarks") {
return std::make_unique<TransformLandmarksOperationParser>();
}
if (custom_name == "Landmarks2TransformMatrix" ||
custom_name == "Landmarks2TransformMatrixV2") {
return std::make_unique<Landmarks2TransformMatrixOperationParser>();
}
if (custom_name == "AlignmentPointsToTransformMatrix") {
return std::make_unique<
AlignmentPointsToTransformMatrixOperationParser>();
}
break;
return NewCustomOperationParser(registration->custom_name);
}
return std::make_unique<UnsupportedOperationParser>();
}

View File

@ -0,0 +1,55 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
// Parses TFLite operation and updates provided GraphFloat32.
class TFLiteOperationParser {
public:
virtual ~TFLiteOperationParser() = default;
// Parses TFLite operation. This method allows expanding fused operations
// into more than one node.
virtual absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) = 0;
// Verifies whether passed tflite node may be built by GPU delegate or not.
virtual absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) = 0;
// Returns the value IDs in the graph that correspond to the updated values of
// the variable input tensor.
virtual absl::flat_hash_map<int, ValueId>
GetNewValueIdsForVariableInputNodes() {
return {};
}
};
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_