Implement V2->v1 conversion for TransformLandmarks operation.
PiperOrigin-RevId: 329787506 Change-Id: I1a5737171b81a3fbbfacb239edb47e425a1e737c
This commit is contained in:
parent
b129d3dc1c
commit
98635e2d5b
@ -2265,6 +2265,7 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser {
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||
/*runtime_inputs=*/2, /*outputs=*/1));
|
||||
return absl::OkStatus();
|
||||
@ -2290,42 +2291,6 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser {
|
||||
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
class TransformLandmarksV2OperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||
/*runtime_inputs=*/2, /*outputs=*/1));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader) final {
|
||||
Node* node = graph->NewNode();
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
std::string op_name = "transform_landmarks_v2";
|
||||
node->operation.type = op_name;
|
||||
|
||||
auto output_value = graph->FindOutputs(node->id)[0];
|
||||
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
||||
BHWC output_shape = output_value->tensor.shape;
|
||||
RETURN_IF_ERROR(ParseCustomAttributes(
|
||||
op_name, registration->version, tflite_node->custom_initial_data,
|
||||
tflite_node->custom_initial_data_size, &(node->operation.attributes),
|
||||
&output_shape));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser {
|
||||
@ -2596,9 +2561,6 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
if (custom_name == "TransformLandmarks") {
|
||||
return std::make_unique<TransformLandmarksOperationParser>();
|
||||
}
|
||||
if (custom_name == "TransformLandmarksV2") {
|
||||
return std::make_unique<TransformLandmarksV2OperationParser>();
|
||||
}
|
||||
if (custom_name == "Landmarks2TransformMatrix" ||
|
||||
custom_name == "Landmarks2TransformMatrixV2") {
|
||||
return std::make_unique<Landmarks2TransformMatrixOperationParser>();
|
||||
|
Loading…
Reference in New Issue
Block a user