Added new optimizing transformation to Metal backend.
PiperOrigin-RevId: 352866499 Change-Id: I0681d78f0826b1fbe8934aafb8135d3d8cfb89fd
This commit is contained in:
parent
39916c9b96
commit
9a426abe81
@ -182,6 +182,9 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:storage_type_util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:tuning_type",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:add_bias",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:global_pooling_to_reduce_op",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
|
||||
"//tensorflow/lite/delegates/gpu/metal/selectors:operation_selector",
|
||||
"//tensorflow/lite/delegates/gpu/metal/selectors:subgraph",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -28,6 +28,9 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
@ -113,6 +116,14 @@ absl::Status MergeNodes(MetalNode* src, MetalNode* dst) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status InferenceContext::InitFromGraphWithTransforms(
|
||||
const CreateInferenceInfo& create_info, GraphFloat32* graph,
|
||||
id<MTLDevice> device_id) {
|
||||
RETURN_IF_ERROR(RunGraphTransforms(graph));
|
||||
RETURN_IF_ERROR(InitFromGraph(create_info, *graph, device_id));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceContext::InitFromGraph(
|
||||
const CreateInferenceInfo& create_info, const GraphFloat32& graph,
|
||||
id<MTLDevice> device_id) {
|
||||
@ -543,6 +554,24 @@ void InferenceContext::UpdatePreallocatedTensors(
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status RunGraphTransforms(GraphFloat32* graph) {
|
||||
auto merge_padding_transform = NewMergePaddingWithAdd();
|
||||
auto add_bias_transform = NewAddBias();
|
||||
auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp();
|
||||
ModelTransformer transformer(graph, /*reporter=*/nullptr);
|
||||
if (!transformer.Apply("add_bias", add_bias_transform.get())) {
|
||||
return absl::InternalError("Invalid add_bias transform");
|
||||
}
|
||||
if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
|
||||
return absl::InternalError("Invalid merge_padding transform");
|
||||
}
|
||||
if (!transformer.Apply("global pooling to mean",
|
||||
pooling_to_reduce_op.get())) {
|
||||
return absl::InternalError("Invalid global pooling to mean transform");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -64,10 +64,19 @@ class InferenceContext {
|
||||
|
||||
InferenceContext() = default;
|
||||
|
||||
// IMPORTANT: If InitFromGraph used, RunGraphTransforms must be applied for
|
||||
// this graph upfront, otherwise not guaranteed correct behavior
|
||||
absl::Status InitFromGraph(const CreateInferenceInfo& create_info,
|
||||
const GraphFloat32& graph,
|
||||
id<MTLDevice> device_id);
|
||||
|
||||
// Applies specific transformations to the graph before the
|
||||
// initialization. These transformations are either impossible or useless in
|
||||
// other backends.
|
||||
absl::Status InitFromGraphWithTransforms(
|
||||
const CreateInferenceInfo& create_info, GraphFloat32* graph,
|
||||
id<MTLDevice> device_id);
|
||||
|
||||
/// Inserts all GPU compute tasks into the command encoder.
|
||||
/// @param inputOutputBuffers Must be created and passed into the method
|
||||
/// with pairs ID:buffer
|
||||
@ -195,6 +204,9 @@ class InferenceContext {
|
||||
// from _sharedBuffers
|
||||
};
|
||||
|
||||
// Runs specific transforms for the graph.
|
||||
absl::Status RunGraphTransforms(GraphFloat32* graph);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -87,7 +87,8 @@ absl::Status SingleOpModel::Invoke() {
|
||||
create_info.storage_type = TensorStorageType::BUFFER;
|
||||
InferenceContext inference_context;
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
RETURN_IF_ERROR(inference_context.InitFromGraph(create_info, graph_, device));
|
||||
RETURN_IF_ERROR(inference_context.InitFromGraphWithTransforms(
|
||||
create_info, &graph_, device));
|
||||
|
||||
std::map<ValueId, BHWC> input_dimensions;
|
||||
std::map<ValueId, id<MTLBuffer>> input_buffers;
|
||||
|
@ -433,7 +433,8 @@ class Delegate {
|
||||
InferenceContext::CreateInferenceInfo create_info;
|
||||
create_info.precision = precision;
|
||||
create_info.storage_type = TensorStorageType::BUFFER;
|
||||
RETURN_IF_ERROR(inference_context_.InitFromGraph(create_info, graph, metal_device_));
|
||||
RETURN_IF_ERROR(
|
||||
inference_context_.InitFromGraphWithTransforms(create_info, &graph, metal_device_));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user