From 696f2a8bd7dc0b80d94382c4cb42aa4ccdc2527c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 8 May 2020 12:46:14 -0700 Subject: [PATCH] Branch LandmarksToTransformMatrix operation to LandmarksToTransformMatrixV2. PiperOrigin-RevId: 310610035 Change-Id: I5a284b3b539ddff85ba15c6c08967cab600d3dc8 --- .../delegates/gpu/common/model_builder.cc | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index c536e09d6b5..46856a70a7c 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -2485,6 +2485,37 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { } }; +class Landmarks2TransformMatrixV2OperationParser + : public TFLiteOperationParser { + public: + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, + /*outputs=*/1); + } + + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + Node* node = graph->NewNode(); + RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks + RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix + + const std::string op_name = "landmarks_to_transform_matrix_v2"; + node->operation.type = op_name; + BHWC output_shape; + RETURN_IF_ERROR( + ParseCustomAttributes(op_name, tflite_node->custom_initial_data, + tflite_node->custom_initial_data_size, + &(node->operation.attributes), &output_shape)); + + auto output_value = graph->FindOutputs(node->id)[0]; + output_value->tensor.shape = output_shape; + return absl::OkStatus(); + } +}; + class AlignmentPointsToTransformMatrixOperationParser : public TFLiteOperationParser { public: @@ -2712,6 +2743,9 @@ std::unique_ptr NewOperationParser( if (custom_name == "Landmarks2TransformMatrix") { return std::make_unique(); } + if (custom_name == "Landmarks2TransformMatrixV2") { + return std::make_unique(); + } if (custom_name == "AlignmentPointsToTransformMatrix") { return std::make_unique< AlignmentPointsToTransformMatrixOperationParser>();