Branch TransformLandmarks operation to TransformLandmarksV2.

PiperOrigin-RevId: 310424827
Change-Id: I4e8771dfb8060438ed23ba97d7177633ad6aba76
This commit is contained in:
A. Unique TensorFlower 2020-05-07 13:23:53 -07:00 committed by TensorFlower Gardener
parent 3d0dac26f5
commit 2acaff3d89

View File

@ -2421,6 +2421,40 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser {
private:
};
class TransformLandmarksV2OperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
RETURN_IF_ERROR(reader->AddOutputs(node));
std::string op_name = "transform_landmarks_v2";
node->operation.type = op_name;
BHWC output_shape;
RETURN_IF_ERROR(
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
tflite_node->custom_initial_data_size,
&(node->operation.attributes), &output_shape));
auto output_value = graph->FindOutputs(node->id)[0];
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
return absl::OkStatus();
}
private:
};
class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
@ -2672,6 +2706,9 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
if (custom_name == "TransformLandmarks") {
return std::make_unique<TransformLandmarksOperationParser>();
}
if (custom_name == "TransformLandmarksV2") {
return std::make_unique<TransformLandmarksV2OperationParser>();
}
if (custom_name == "Landmarks2TransformMatrix") {
return std::make_unique<Landmarks2TransformMatrixOperationParser>();
}