diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD index 4d116105a7d..8a061452127 100644 --- a/tensorflow/lite/delegates/gpu/metal/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/BUILD @@ -182,6 +182,9 @@ objc_library( "//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common/task:storage_type_util", "//tensorflow/lite/delegates/gpu/common/task:tuning_type", + "//tensorflow/lite/delegates/gpu/common/transformations:add_bias", + "//tensorflow/lite/delegates/gpu/common/transformations:global_pooling_to_reduce_op", + "//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with", "//tensorflow/lite/delegates/gpu/metal/selectors:operation_selector", "//tensorflow/lite/delegates/gpu/metal/selectors:subgraph", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.cc b/tensorflow/lite/delegates/gpu/metal/inference_context.cc index d7d1953a0e4..0a05184aaa0 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.cc @@ -28,6 +28,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h" #include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" @@ -113,6 +116,14 @@ absl::Status MergeNodes(MetalNode* src, MetalNode* dst) { } } // namespace +absl::Status InferenceContext::InitFromGraphWithTransforms( + const CreateInferenceInfo& create_info, GraphFloat32* graph, + id device_id) { + RETURN_IF_ERROR(RunGraphTransforms(graph)); + RETURN_IF_ERROR(InitFromGraph(create_info, *graph, device_id)); + return absl::OkStatus(); +} + absl::Status InferenceContext::InitFromGraph( const CreateInferenceInfo& create_info, const GraphFloat32& graph, id device_id) { @@ -543,6 +554,24 @@ void InferenceContext::UpdatePreallocatedTensors( } } +absl::Status RunGraphTransforms(GraphFloat32* graph) { + auto merge_padding_transform = NewMergePaddingWithAdd(); + auto add_bias_transform = NewAddBias(); + auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp(); + ModelTransformer transformer(graph, /*reporter=*/nullptr); + if (!transformer.Apply("add_bias", add_bias_transform.get())) { + return absl::InternalError("Invalid add_bias transform"); + } + if (!transformer.Apply("merge_padding", merge_padding_transform.get())) { + return absl::InternalError("Invalid merge_padding transform"); + } + if (!transformer.Apply("global pooling to mean", + pooling_to_reduce_op.get())) { + return absl::InternalError("Invalid global pooling to mean transform"); + } + return absl::OkStatus(); +} + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.h b/tensorflow/lite/delegates/gpu/metal/inference_context.h index 3ad65818855..2fdb54d6fed 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.h +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.h @@ -64,10 +64,19 @@ class InferenceContext { InferenceContext() = default; + // IMPORTANT: If InitFromGraph used, RunGraphTransforms must be applied for + // this graph upfront, otherwise not guaranteed correct behavior absl::Status InitFromGraph(const CreateInferenceInfo& create_info, const GraphFloat32& graph, id device_id); + // Applies specific transformations to the graph before the + // initialization. These transformations are either impossible or useless in + // other backends. + absl::Status InitFromGraphWithTransforms( + const CreateInferenceInfo& create_info, GraphFloat32* graph, + id device_id); + /// Inserts all GPU compute tasks into the command encoder. /// @param inputOutputBuffers Must be created and passed into the method /// with pairs ID:buffer @@ -195,6 +204,9 @@ class InferenceContext { // from _sharedBuffers }; +// Runs specific transforms for the graph. +absl::Status RunGraphTransforms(GraphFloat32* graph); + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.cc b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.cc index 3c27a84c476..d49067d56a0 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/test_util.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/test_util.cc @@ -87,7 +87,8 @@ absl::Status SingleOpModel::Invoke() { create_info.storage_type = TensorStorageType::BUFFER; InferenceContext inference_context; id device = MTLCreateSystemDefaultDevice(); - RETURN_IF_ERROR(inference_context.InitFromGraph(create_info, graph_, device)); + RETURN_IF_ERROR(inference_context.InitFromGraphWithTransforms( + create_info, &graph_, device)); std::map input_dimensions; std::map> input_buffers; diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm index 85a5354d203..6e059719b09 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate.mm +++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm @@ -433,7 +433,8 @@ class Delegate { InferenceContext::CreateInferenceInfo create_info; create_info.precision = precision; create_info.storage_type = TensorStorageType::BUFFER; - RETURN_IF_ERROR(inference_context_.InitFromGraph(create_info, graph, metal_device_)); + RETURN_IF_ERROR( + inference_context_.InitFromGraphWithTransforms(create_info, &graph, metal_device_)); return absl::OkStatus(); }