Added new optimizing transformation to Metal backend.

PiperOrigin-RevId: 352866499
Change-Id: I0681d78f0826b1fbe8934aafb8135d3d8cfb89fd
This commit is contained in:
Raman Sarokin 2021-01-20 13:44:13 -08:00 committed by TensorFlower Gardener
parent 39916c9b96
commit 9a426abe81
5 changed files with 48 additions and 2 deletions

View File

@ -182,6 +182,9 @@ objc_library(
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/task:storage_type_util",
"//tensorflow/lite/delegates/gpu/common/task:tuning_type",
"//tensorflow/lite/delegates/gpu/common/transformations:add_bias",
"//tensorflow/lite/delegates/gpu/common/transformations:global_pooling_to_reduce_op",
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
"//tensorflow/lite/delegates/gpu/metal/selectors:operation_selector",
"//tensorflow/lite/delegates/gpu/metal/selectors:subgraph",
"@com_google_absl//absl/container:flat_hash_map",

View File

@ -28,6 +28,9 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@ -113,6 +116,14 @@ absl::Status MergeNodes(MetalNode* src, MetalNode* dst) {
}
} // namespace
absl::Status InferenceContext::InitFromGraphWithTransforms(
const CreateInferenceInfo& create_info, GraphFloat32* graph,
id<MTLDevice> device_id) {
RETURN_IF_ERROR(RunGraphTransforms(graph));
RETURN_IF_ERROR(InitFromGraph(create_info, *graph, device_id));
return absl::OkStatus();
}
absl::Status InferenceContext::InitFromGraph(
const CreateInferenceInfo& create_info, const GraphFloat32& graph,
id<MTLDevice> device_id) {
@ -543,6 +554,24 @@ void InferenceContext::UpdatePreallocatedTensors(
}
}
absl::Status RunGraphTransforms(GraphFloat32* graph) {
auto merge_padding_transform = NewMergePaddingWithAdd();
auto add_bias_transform = NewAddBias();
auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp();
ModelTransformer transformer(graph, /*reporter=*/nullptr);
if (!transformer.Apply("add_bias", add_bias_transform.get())) {
return absl::InternalError("Invalid add_bias transform");
}
if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
return absl::InternalError("Invalid merge_padding transform");
}
if (!transformer.Apply("global pooling to mean",
pooling_to_reduce_op.get())) {
return absl::InternalError("Invalid global pooling to mean transform");
}
return absl::OkStatus();
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -64,10 +64,19 @@ class InferenceContext {
InferenceContext() = default;
// IMPORTANT: If InitFromGraph used, RunGraphTransforms must be applied for
// this graph upfront, otherwise not guaranteed correct behavior
absl::Status InitFromGraph(const CreateInferenceInfo& create_info,
const GraphFloat32& graph,
id<MTLDevice> device_id);
// Applies specific transformations to the graph before the
// initialization. These transformations are either impossible or useless in
// other backends.
absl::Status InitFromGraphWithTransforms(
const CreateInferenceInfo& create_info, GraphFloat32* graph,
id<MTLDevice> device_id);
/// Inserts all GPU compute tasks into the command encoder.
/// @param inputOutputBuffers Must be created and passed into the method
/// with pairs ID:buffer
@ -195,6 +204,9 @@ class InferenceContext {
// from _sharedBuffers
};
// Runs specific transforms for the graph.
absl::Status RunGraphTransforms(GraphFloat32* graph);
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -87,7 +87,8 @@ absl::Status SingleOpModel::Invoke() {
create_info.storage_type = TensorStorageType::BUFFER;
InferenceContext inference_context;
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
RETURN_IF_ERROR(inference_context.InitFromGraph(create_info, graph_, device));
RETURN_IF_ERROR(inference_context.InitFromGraphWithTransforms(
create_info, &graph_, device));
std::map<ValueId, BHWC> input_dimensions;
std::map<ValueId, id<MTLBuffer>> input_buffers;

View File

@ -433,7 +433,8 @@ class Delegate {
InferenceContext::CreateInferenceInfo create_info;
create_info.precision = precision;
create_info.storage_type = TensorStorageType::BUFFER;
RETURN_IF_ERROR(inference_context_.InitFromGraph(create_info, graph, metal_device_));
RETURN_IF_ERROR(
inference_context_.InitFromGraphWithTransforms(create_info, &graph, metal_device_));
return absl::OkStatus();
}