Refactor custom op parsers.
PiperOrigin-RevId: 351263261 Change-Id: I1bee6760fa596658e340387209f8530b9f5d7e35
This commit is contained in:
parent
d7f96c409b
commit
511c27580a
@ -150,6 +150,7 @@ cc_library(
|
|||||||
":model_builder_helper",
|
":model_builder_helper",
|
||||||
":model_transformer",
|
":model_transformer",
|
||||||
":object_reader",
|
":object_reader",
|
||||||
|
":operation_parser",
|
||||||
":operations",
|
":operations",
|
||||||
":shape",
|
":shape",
|
||||||
":status",
|
":status",
|
||||||
@ -372,6 +373,19 @@ cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "operation_parser",
|
||||||
|
hdrs = ["operation_parser.h"],
|
||||||
|
deps = [
|
||||||
|
":model",
|
||||||
|
":object_reader",
|
||||||
|
":status",
|
||||||
|
"//tensorflow/lite/c:common",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "util_test",
|
name = "util_test",
|
||||||
srcs = ["util_test.cc"],
|
srcs = ["util_test.cc"],
|
||||||
|
@ -15,16 +15,22 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
|
||||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/any.h"
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
|
// Returns a parser for the provided custom op.
|
||||||
|
std::unique_ptr<TFLiteOperationParser> NewCustomOperationParser(
|
||||||
|
absl::string_view op_name);
|
||||||
|
|
||||||
// Matches the custom operation by the string name and parses attributes stored
|
// Matches the custom operation by the string name and parses attributes stored
|
||||||
// as flexbuffers.
|
// as flexbuffers.
|
||||||
absl::Status ParseCustomAttributes(absl::string_view op_name, int version,
|
absl::Status ParseCustomAttributes(absl::string_view op_name, int version,
|
||||||
|
@ -8,6 +8,7 @@ cc_library(
|
|||||||
srcs = ["custom_parsers.cc"],
|
srcs = ["custom_parsers.cc"],
|
||||||
hdrs = ["//tensorflow/lite/delegates/gpu/common:custom_parsers.h"],
|
hdrs = ["//tensorflow/lite/delegates/gpu/common:custom_parsers.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operation_parser",
|
||||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -15,18 +15,49 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
|
#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/any.h"
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class UnimplementedCustomOperationParser : public TFLiteOperationParser {
|
||||||
|
public:
|
||||||
|
explicit UnimplementedCustomOperationParser(absl::string_view op_name)
|
||||||
|
: op_name_(op_name) {}
|
||||||
|
|
||||||
|
absl::Status IsSupported(const TfLiteContext* context,
|
||||||
|
const TfLiteNode* tflite_node,
|
||||||
|
const TfLiteRegistration* registration) final {
|
||||||
|
return absl::UnimplementedError(op_name_);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||||
|
const TfLiteRegistration* registration,
|
||||||
|
GraphFloat32* graph, ObjectReader* reader) final {
|
||||||
|
return absl::UnimplementedError(op_name_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string op_name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<TFLiteOperationParser> NewCustomOperationParser(
|
||||||
|
absl::string_view op_name) {
|
||||||
|
return absl::make_unique<UnimplementedCustomOperationParser>(op_name);
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status ParseCustomAttributes(absl::string_view op_name, int version,
|
absl::Status ParseCustomAttributes(absl::string_view op_name, int version,
|
||||||
const void* data, uint32_t data_size,
|
const void* data, uint32_t data_size,
|
||||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
|
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
|
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
@ -69,30 +70,6 @@ absl::Status CheckTensorIsAvailable(const TfLiteContext* context,
|
|||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// A parser responsible for parsing TFLite operation and adding it to a graph.
|
|
||||||
class TFLiteOperationParser {
|
|
||||||
public:
|
|
||||||
virtual ~TFLiteOperationParser() = default;
|
|
||||||
|
|
||||||
// Parses TFLite operation. This method allows expanding fused operations
|
|
||||||
// into more than one node.
|
|
||||||
virtual absl::Status Parse(const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) = 0;
|
|
||||||
|
|
||||||
// Verifies whether passed tflite node may be built by GPU delegate or not.
|
|
||||||
virtual absl::Status IsSupported(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration) = 0;
|
|
||||||
|
|
||||||
// Return the value ids in the graph that correspond to the updated values of
|
|
||||||
// the variable input tensor.
|
|
||||||
virtual absl::flat_hash_map<int, ValueId>
|
|
||||||
GetNewValueIdsForVariableInputNodes() {
|
|
||||||
return absl::flat_hash_map<int, ValueId>();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); }
|
HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); }
|
||||||
|
|
||||||
template <typename AttrT>
|
template <typename AttrT>
|
||||||
@ -1438,17 +1415,10 @@ class Pooling2DOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
||||||
const TfLitePoolParams* tf_options;
|
const TfLitePoolParams* tf_options;
|
||||||
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
|
|
||||||
if (status.ok()) { // custom case with indices as a second output
|
|
||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
|
||||||
/*runtime_inputs=*/1,
|
|
||||||
/*outputs=*/2));
|
|
||||||
} else { // common pooling with 1 output
|
|
||||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
/*runtime_inputs=*/1,
|
/*runtime_inputs=*/1,
|
||||||
/*outputs=*/1));
|
/*outputs=*/1));
|
||||||
}
|
|
||||||
RETURN_IF_ERROR(CheckKernelsAndStrides(
|
RETURN_IF_ERROR(CheckKernelsAndStrides(
|
||||||
tf_options->filter_height, tf_options->filter_width,
|
tf_options->filter_height, tf_options->filter_width,
|
||||||
tf_options->stride_height, tf_options->stride_width));
|
tf_options->stride_height, tf_options->stride_width));
|
||||||
@ -1471,28 +1441,12 @@ class Pooling2DOperationParser : public TFLiteOperationParser {
|
|||||||
|
|
||||||
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
||||||
|
|
||||||
// check whether there are custom options encoded. It happens if operation
|
|
||||||
// is MaxPoolingWithArgmax2D. There is no way to read
|
|
||||||
// tflite_node->builtin_code, so, simply check whether custom data is
|
|
||||||
// available.
|
|
||||||
const TfLitePoolParams* tf_options;
|
const TfLitePoolParams* tf_options;
|
||||||
if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
|
|
||||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
|
RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
|
||||||
// Second output is optional. It is not required, it but must be added after
|
|
||||||
// MaybeAddFusedActivation function is called
|
|
||||||
reader->AddOutput(node, 1).IgnoreError();
|
|
||||||
|
|
||||||
// First output is the result of pooling operation, while second output is
|
attr.output_indices = false;
|
||||||
// indices used for pooling.
|
|
||||||
auto outputs = graph->FindOutputs(node->id);
|
|
||||||
attr.output_indices = outputs.size() == 2;
|
|
||||||
if (attr.output_indices) {
|
|
||||||
// Fix data type for output indices. In the model it is set as float32.
|
|
||||||
outputs[1]->tensor.type = DataType::INT32;
|
|
||||||
}
|
|
||||||
RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
|
RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
|
||||||
node->operation.attributes = attr;
|
node->operation.attributes = attr;
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
@ -2205,45 +2159,6 @@ class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Custom op version of TRANSPOSE_CONV.
|
|
||||||
class TransposeConvCustomOperationParser : public TFLiteOperationParser {
|
|
||||||
public:
|
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration) final {
|
|
||||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
|
||||||
const TfLiteTransposeConvParams* tf_options;
|
|
||||||
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
|
|
||||||
RETURN_IF_ERROR(
|
|
||||||
CheckStrides(tf_options->stride_height, tf_options->stride_width));
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) final {
|
|
||||||
auto* node = graph->NewNode();
|
|
||||||
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
|
||||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
|
||||||
|
|
||||||
const TfLiteTransposeConvParams* tf_options;
|
|
||||||
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
|
|
||||||
|
|
||||||
ConvolutionTransposedAttributes attr;
|
|
||||||
attr.stride = status.ok()
|
|
||||||
? HW(tf_options->stride_height, tf_options->stride_width)
|
|
||||||
: HW(1, 1);
|
|
||||||
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
|
|
||||||
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
|
|
||||||
|
|
||||||
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
|
|
||||||
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
|
|
||||||
node->operation.attributes = std::move(attr);
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class TransposeOperationParser : public TFLiteOperationParser {
|
class TransposeOperationParser : public TFLiteOperationParser {
|
||||||
public:
|
public:
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
absl::Status IsSupported(const TfLiteContext* context,
|
||||||
@ -2295,47 +2210,6 @@ class TransposeOperationParser : public TFLiteOperationParser {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class Unpooling2DOperationParser : public TFLiteOperationParser {
|
|
||||||
public:
|
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration) final {
|
|
||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
|
||||||
/*runtime_inputs=*/2, /*outputs=*/1));
|
|
||||||
const TfLitePoolParams* tf_options;
|
|
||||||
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
|
|
||||||
RETURN_IF_ERROR(CheckKernelsAndStrides(
|
|
||||||
tf_options->filter_height, tf_options->filter_width,
|
|
||||||
tf_options->stride_height, tf_options->stride_width));
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) final {
|
|
||||||
Node* node = graph->NewNode();
|
|
||||||
node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 1));
|
|
||||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
|
||||||
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
|
||||||
MaxUnpooling2DAttributes attr;
|
|
||||||
|
|
||||||
const TfLitePoolParams* tf_options;
|
|
||||||
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
|
|
||||||
|
|
||||||
attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
|
|
||||||
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
|
|
||||||
UpdatePadding(tf_options->padding, input_shape, &attr);
|
|
||||||
|
|
||||||
node->operation.attributes = attr;
|
|
||||||
|
|
||||||
auto output_value = graph->FindOutputs(node->id)[0];
|
|
||||||
output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
|
// TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
|
||||||
class BatchToSpaceOperationParser : public TFLiteOperationParser {
|
class BatchToSpaceOperationParser : public TFLiteOperationParser {
|
||||||
public:
|
public:
|
||||||
@ -2423,171 +2297,6 @@ class SpaceToBatchOperationParser : public TFLiteOperationParser {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class RoIToTransformMatrixOperationParser : public TFLiteOperationParser {
|
|
||||||
public:
|
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration) final {
|
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
|
||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
|
||||||
/*runtime_inputs=*/1, /*outputs=*/1));
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) final {
|
|
||||||
Node* node = graph->NewNode();
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0)); // bbox
|
|
||||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
|
||||||
|
|
||||||
std::string op_name = "roi_to_transform_matrix";
|
|
||||||
node->operation.type = op_name;
|
|
||||||
BHWC output_shape;
|
|
||||||
RETURN_IF_ERROR(ParseCustomAttributes(
|
|
||||||
op_name, registration->version, tflite_node->custom_initial_data,
|
|
||||||
tflite_node->custom_initial_data_size, &(node->operation.attributes),
|
|
||||||
&output_shape));
|
|
||||||
|
|
||||||
auto output_value = graph->FindOutputs(node->id)[0];
|
|
||||||
output_value->tensor.shape = output_shape;
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class TransformTensorBilinearOperationParser : public TFLiteOperationParser {
|
|
||||||
public:
|
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration) final {
|
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
|
||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
|
||||||
/*runtime_inputs=*/2, /*outputs=*/1));
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) final {
|
|
||||||
Node* node = graph->NewNode();
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
|
|
||||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
|
||||||
|
|
||||||
std::string op_name = "transform_tensor_bilinear";
|
|
||||||
node->operation.type = op_name;
|
|
||||||
BHWC output_shape;
|
|
||||||
RETURN_IF_ERROR(ParseCustomAttributes(
|
|
||||||
op_name, registration->version, tflite_node->custom_initial_data,
|
|
||||||
tflite_node->custom_initial_data_size, &(node->operation.attributes),
|
|
||||||
&output_shape));
|
|
||||||
|
|
||||||
auto output_value = graph->FindOutputs(node->id)[0];
|
|
||||||
|
|
||||||
output_value->tensor.shape =
|
|
||||||
BHWC(1, output_shape.h, output_shape.w,
|
|
||||||
graph->FindInputs(node->id)[0]->tensor.shape.c);
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class TransformLandmarksOperationParser : public TFLiteOperationParser {
|
|
||||||
public:
|
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration) final {
|
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
|
||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
|
||||||
/*runtime_inputs=*/2, /*outputs=*/1));
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) final {
|
|
||||||
Node* node = graph->NewNode();
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
|
|
||||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
|
||||||
std::string op_name = "transform_landmarks";
|
|
||||||
node->operation.type = op_name;
|
|
||||||
BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
|
|
||||||
RETURN_IF_ERROR(ParseCustomAttributes(
|
|
||||||
op_name, registration->version, tflite_node->custom_initial_data,
|
|
||||||
tflite_node->custom_initial_data_size, &(node->operation.attributes),
|
|
||||||
&output_shape));
|
|
||||||
|
|
||||||
auto output_value = graph->FindOutputs(node->id)[0];
|
|
||||||
|
|
||||||
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser {
|
|
||||||
public:
|
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration) final {
|
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
|
||||||
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
|
|
||||||
/*outputs=*/1);
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) final {
|
|
||||||
Node* node = graph->NewNode();
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks
|
|
||||||
RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix
|
|
||||||
|
|
||||||
const std::string op_name = "landmarks_to_transform_matrix";
|
|
||||||
node->operation.type = op_name;
|
|
||||||
BHWC output_shape;
|
|
||||||
RETURN_IF_ERROR(ParseCustomAttributes(
|
|
||||||
op_name, registration->version, tflite_node->custom_initial_data,
|
|
||||||
tflite_node->custom_initial_data_size, &(node->operation.attributes),
|
|
||||||
&output_shape));
|
|
||||||
|
|
||||||
auto output_value = graph->FindOutputs(node->id)[0];
|
|
||||||
output_value->tensor.shape = output_shape;
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class AlignmentPointsToTransformMatrixOperationParser
|
|
||||||
: public TFLiteOperationParser {
|
|
||||||
public:
|
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration) final {
|
|
||||||
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
|
|
||||||
/*outputs=*/1);
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) final {
|
|
||||||
Node* node = graph->NewNode();
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0)); // alignment points
|
|
||||||
RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix
|
|
||||||
|
|
||||||
const std::string op_name = "alignment_points_to_transform_matrix";
|
|
||||||
node->operation.type = op_name;
|
|
||||||
BHWC output_shape;
|
|
||||||
RETURN_IF_ERROR(ParseCustomAttributes(
|
|
||||||
op_name, registration->version, tflite_node->custom_initial_data,
|
|
||||||
tflite_node->custom_initial_data_size, &(node->operation.attributes),
|
|
||||||
&output_shape));
|
|
||||||
|
|
||||||
auto output_value = graph->FindOutputs(node->id)[0];
|
|
||||||
output_value->tensor.shape = output_shape;
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
};
|
|
||||||
|
|
||||||
class MeanOperationParser : public TFLiteOperationParser {
|
class MeanOperationParser : public TFLiteOperationParser {
|
||||||
public:
|
public:
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
absl::Status IsSupported(const TfLiteContext* context,
|
||||||
@ -2763,37 +2472,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
|||||||
return std::make_unique<TransposeOperationParser>();
|
return std::make_unique<TransposeOperationParser>();
|
||||||
case kTfLiteBuiltinTransposeConv:
|
case kTfLiteBuiltinTransposeConv:
|
||||||
return std::make_unique<TransposeConvBuiltinOperationParser>();
|
return std::make_unique<TransposeConvBuiltinOperationParser>();
|
||||||
|
|
||||||
case kTfLiteBuiltinCustom:
|
case kTfLiteBuiltinCustom:
|
||||||
const absl::string_view custom_name = registration->custom_name;
|
return NewCustomOperationParser(registration->custom_name);
|
||||||
if (custom_name == "Convolution2DTransposeBias") {
|
|
||||||
return std::make_unique<TransposeConvCustomOperationParser>();
|
|
||||||
}
|
|
||||||
if (custom_name == "MaxPoolingWithArgmax2D") {
|
|
||||||
return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
|
|
||||||
}
|
|
||||||
if (custom_name == "MaxUnpooling2D") {
|
|
||||||
return std::make_unique<Unpooling2DOperationParser>();
|
|
||||||
}
|
|
||||||
if (custom_name == "RoIToTransformMatrix") {
|
|
||||||
return std::make_unique<RoIToTransformMatrixOperationParser>();
|
|
||||||
}
|
|
||||||
if (custom_name == "TransformTensor" /*for version 1*/ ||
|
|
||||||
custom_name == "TransformTensorBilinear" /*for version 2*/) {
|
|
||||||
return std::make_unique<TransformTensorBilinearOperationParser>();
|
|
||||||
}
|
|
||||||
if (custom_name == "TransformLandmarks") {
|
|
||||||
return std::make_unique<TransformLandmarksOperationParser>();
|
|
||||||
}
|
|
||||||
if (custom_name == "Landmarks2TransformMatrix" ||
|
|
||||||
custom_name == "Landmarks2TransformMatrixV2") {
|
|
||||||
return std::make_unique<Landmarks2TransformMatrixOperationParser>();
|
|
||||||
}
|
|
||||||
if (custom_name == "AlignmentPointsToTransformMatrix") {
|
|
||||||
return std::make_unique<
|
|
||||||
AlignmentPointsToTransformMatrixOperationParser>();
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
return std::make_unique<UnsupportedOperationParser>();
|
return std::make_unique<UnsupportedOperationParser>();
|
||||||
}
|
}
|
||||||
|
55
tensorflow/lite/delegates/gpu/common/operation_parser.h
Normal file
55
tensorflow/lite/delegates/gpu/common/operation_parser.h
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Parses TFLite operation and updates provided GraphFloat32.
|
||||||
|
class TFLiteOperationParser {
|
||||||
|
public:
|
||||||
|
virtual ~TFLiteOperationParser() = default;
|
||||||
|
|
||||||
|
// Parses TFLite operation. This method allows expanding fused operations
|
||||||
|
// into more than one node.
|
||||||
|
virtual absl::Status Parse(const TfLiteNode* tflite_node,
|
||||||
|
const TfLiteRegistration* registration,
|
||||||
|
GraphFloat32* graph, ObjectReader* reader) = 0;
|
||||||
|
|
||||||
|
// Verifies whether passed tflite node may be built by GPU delegate or not.
|
||||||
|
virtual absl::Status IsSupported(const TfLiteContext* context,
|
||||||
|
const TfLiteNode* tflite_node,
|
||||||
|
const TfLiteRegistration* registration) = 0;
|
||||||
|
|
||||||
|
// Returns the value IDs in the graph that correspond to the updated values of
|
||||||
|
// the variable input tensor.
|
||||||
|
virtual absl::flat_hash_map<int, ValueId>
|
||||||
|
GetNewValueIdsForVariableInputNodes() {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_
|
Loading…
x
Reference in New Issue
Block a user