Added new optimizing transformation to Metal backend.
PiperOrigin-RevId: 352866499 Change-Id: I0681d78f0826b1fbe8934aafb8135d3d8cfb89fd
This commit is contained in:
parent
39916c9b96
commit
9a426abe81
@ -182,6 +182,9 @@ objc_library(
|
|||||||
"//tensorflow/lite/delegates/gpu/common:util",
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
"//tensorflow/lite/delegates/gpu/common/task:storage_type_util",
|
"//tensorflow/lite/delegates/gpu/common/task:storage_type_util",
|
||||||
"//tensorflow/lite/delegates/gpu/common/task:tuning_type",
|
"//tensorflow/lite/delegates/gpu/common/task:tuning_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common/transformations:add_bias",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common/transformations:global_pooling_to_reduce_op",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
|
||||||
"//tensorflow/lite/delegates/gpu/metal/selectors:operation_selector",
|
"//tensorflow/lite/delegates/gpu/metal/selectors:operation_selector",
|
||||||
"//tensorflow/lite/delegates/gpu/metal/selectors:subgraph",
|
"//tensorflow/lite/delegates/gpu/metal/selectors:subgraph",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
@ -28,6 +28,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
|
#include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
|
#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||||
@ -113,6 +116,14 @@ absl::Status MergeNodes(MetalNode* src, MetalNode* dst) {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
absl::Status InferenceContext::InitFromGraphWithTransforms(
|
||||||
|
const CreateInferenceInfo& create_info, GraphFloat32* graph,
|
||||||
|
id<MTLDevice> device_id) {
|
||||||
|
RETURN_IF_ERROR(RunGraphTransforms(graph));
|
||||||
|
RETURN_IF_ERROR(InitFromGraph(create_info, *graph, device_id));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status InferenceContext::InitFromGraph(
|
absl::Status InferenceContext::InitFromGraph(
|
||||||
const CreateInferenceInfo& create_info, const GraphFloat32& graph,
|
const CreateInferenceInfo& create_info, const GraphFloat32& graph,
|
||||||
id<MTLDevice> device_id) {
|
id<MTLDevice> device_id) {
|
||||||
@ -543,6 +554,24 @@ void InferenceContext::UpdatePreallocatedTensors(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Status RunGraphTransforms(GraphFloat32* graph) {
|
||||||
|
auto merge_padding_transform = NewMergePaddingWithAdd();
|
||||||
|
auto add_bias_transform = NewAddBias();
|
||||||
|
auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp();
|
||||||
|
ModelTransformer transformer(graph, /*reporter=*/nullptr);
|
||||||
|
if (!transformer.Apply("add_bias", add_bias_transform.get())) {
|
||||||
|
return absl::InternalError("Invalid add_bias transform");
|
||||||
|
}
|
||||||
|
if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
|
||||||
|
return absl::InternalError("Invalid merge_padding transform");
|
||||||
|
}
|
||||||
|
if (!transformer.Apply("global pooling to mean",
|
||||||
|
pooling_to_reduce_op.get())) {
|
||||||
|
return absl::InternalError("Invalid global pooling to mean transform");
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -64,10 +64,19 @@ class InferenceContext {
|
|||||||
|
|
||||||
InferenceContext() = default;
|
InferenceContext() = default;
|
||||||
|
|
||||||
|
// IMPORTANT: If InitFromGraph used, RunGraphTransforms must be applied for
|
||||||
|
// this graph upfront, otherwise not guaranteed correct behavior
|
||||||
absl::Status InitFromGraph(const CreateInferenceInfo& create_info,
|
absl::Status InitFromGraph(const CreateInferenceInfo& create_info,
|
||||||
const GraphFloat32& graph,
|
const GraphFloat32& graph,
|
||||||
id<MTLDevice> device_id);
|
id<MTLDevice> device_id);
|
||||||
|
|
||||||
|
// Applies specific transformations to the graph before the
|
||||||
|
// initialization. These transformations are either impossible or useless in
|
||||||
|
// other backends.
|
||||||
|
absl::Status InitFromGraphWithTransforms(
|
||||||
|
const CreateInferenceInfo& create_info, GraphFloat32* graph,
|
||||||
|
id<MTLDevice> device_id);
|
||||||
|
|
||||||
/// Inserts all GPU compute tasks into the command encoder.
|
/// Inserts all GPU compute tasks into the command encoder.
|
||||||
/// @param inputOutputBuffers Must be created and passed into the method
|
/// @param inputOutputBuffers Must be created and passed into the method
|
||||||
/// with pairs ID:buffer
|
/// with pairs ID:buffer
|
||||||
@ -195,6 +204,9 @@ class InferenceContext {
|
|||||||
// from _sharedBuffers
|
// from _sharedBuffers
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Runs specific transforms for the graph.
|
||||||
|
absl::Status RunGraphTransforms(GraphFloat32* graph);
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -87,7 +87,8 @@ absl::Status SingleOpModel::Invoke() {
|
|||||||
create_info.storage_type = TensorStorageType::BUFFER;
|
create_info.storage_type = TensorStorageType::BUFFER;
|
||||||
InferenceContext inference_context;
|
InferenceContext inference_context;
|
||||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||||
RETURN_IF_ERROR(inference_context.InitFromGraph(create_info, graph_, device));
|
RETURN_IF_ERROR(inference_context.InitFromGraphWithTransforms(
|
||||||
|
create_info, &graph_, device));
|
||||||
|
|
||||||
std::map<ValueId, BHWC> input_dimensions;
|
std::map<ValueId, BHWC> input_dimensions;
|
||||||
std::map<ValueId, id<MTLBuffer>> input_buffers;
|
std::map<ValueId, id<MTLBuffer>> input_buffers;
|
||||||
|
@ -433,7 +433,8 @@ class Delegate {
|
|||||||
InferenceContext::CreateInferenceInfo create_info;
|
InferenceContext::CreateInferenceInfo create_info;
|
||||||
create_info.precision = precision;
|
create_info.precision = precision;
|
||||||
create_info.storage_type = TensorStorageType::BUFFER;
|
create_info.storage_type = TensorStorageType::BUFFER;
|
||||||
RETURN_IF_ERROR(inference_context_.InitFromGraph(create_info, graph, metal_device_));
|
RETURN_IF_ERROR(
|
||||||
|
inference_context_.InitFromGraphWithTransforms(create_info, &graph, metal_device_));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user