From 99d561b15ef47071b5efc14c00b8da7eeb5b9856 Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Mon, 11 Jan 2021 12:01:17 -0800 Subject: [PATCH] Fusion of nodes simplified. Logic of tasks fusion moved to compute_task_descriptor. PiperOrigin-RevId: 351205320 Change-Id: Iaa632b1b6bb7bfc302debc831ac43b7f9c6a9d0b --- tensorflow/lite/delegates/gpu/metal/BUILD | 1 + .../lite/delegates/gpu/metal/compute_task.cc | 10 +-- .../gpu/metal/compute_task_descriptor.cc | 78 +++++++++++++++++++ .../gpu/metal/compute_task_descriptor.h | 7 ++ .../delegates/gpu/metal/inference_context.cc | 76 ++---------------- 5 files changed, 96 insertions(+), 76 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD index 080414f68f6..0a5a0c774c4 100644 --- a/tensorflow/lite/delegates/gpu/metal/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/BUILD @@ -137,6 +137,7 @@ objc_library( "//tensorflow/lite/delegates/gpu/common/task:arguments", "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@FP16", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.cc b/tensorflow/lite/delegates/gpu/metal/compute_task.cc index a716fb45d43..f0152643934 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.cc +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.cc @@ -36,13 +36,9 @@ namespace metal { absl::Status ComputeTask::CompileWithDevice(id device, const NodeDescriptor& desc, CalculationsPrecision precision) { - std::string args_declarations; - int bind_index = 0; - desc.task->shader_source = absl::Substitute(desc.task->shader_source, "$0", - args_declarations + "$1", ""); - - RETURN_IF_ERROR(metal_args_.Init(device, bind_index, &desc.task->args, - &desc.task->shader_source)); + desc.task->AssembleCode(); + RETURN_IF_ERROR( + metal_args_.Init(device, 0, &desc.task->args, &desc.task->shader_source)); NSString* barrier; // simdgroup_barrier is supported on macOS 10.13+ and Metal shading language // version 2.0 diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc index 86c1e3ed730..e1858afc4c8 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc +++ b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc @@ -16,9 +16,12 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include +#include #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -27,6 +30,28 @@ limitations under the License. namespace tflite { namespace gpu { namespace metal { +namespace { +std::string GetElementWiseCode(const OperationDef& op_def) { + return R"( +#include +using namespace metal; +$0 +kernel void ComputeFunction($1 + uint3 gid[[thread_position_in_grid]]) { + int X = static_cast(gid.x); + int Y = static_cast(gid.y); + int Z = static_cast(gid.z); + if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) { + return; + } + FLT4 value = args.src_tensor.Read(X, Y, Z); + $2 + args.dst_tensor.Write(value, X, Y, Z); +} +)"; +} + +} // namespace /// Converts float to destination type (if needed) and stores as bytes array. std::vector GetByteBufferConverted( @@ -74,6 +99,59 @@ void ComputeTaskDescriptor::AddDstTensor(const std::string& tensor_name, args.AddObjectRef(tensor_name, AccessType::WRITE, std::move(desc_new)); } +absl::Status ComputeTaskDescriptor::AddTask(ComputeTaskDescriptor* task_desc) { + linkable_count += 1; + std::string code = task_desc->shader_source; + std::string unique_postfix = absl::StrCat("_link", linkable_count); + task_desc->args.RenameArgs(unique_postfix, &code); + elementwise_code += "{\n" + code + "\n}\n"; + RETURN_IF_ERROR(args.Merge(std::move(task_desc->args), unique_postfix)); + for (int i = 0; i < task_desc->src_tensors_names.size(); ++i) { + definition.src_tensors.push_back(task_desc->definition.src_tensors[i + 1]); + src_tensors_names.push_back(task_desc->src_tensors_names[i] + + unique_postfix); + } + for (int i = 0; i < task_desc->dst_tensors_names.size(); ++i) { + dst_tensors_names.push_back(task_desc->dst_tensors_names[i] + + unique_postfix); + } + return absl::OkStatus(); +} + +void ComputeTaskDescriptor::AssembleCode() { + if (is_linkable) { + auto src_desc = + absl::make_unique(definition.src_tensors[0]); + if (definition.IsBatchSupported()) { + src_desc->SetStateVar("BatchedWidth", "true"); + } + src_tensors_names.insert(src_tensors_names.begin(), "src_tensor"); + args.AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc)); + + auto dst_desc = + absl::make_unique(definition.dst_tensors[0]); + if (definition.IsBatchSupported()) { + dst_desc->SetStateVar("BatchedWidth", "true"); + } + dst_tensors_names.insert(dst_tensors_names.begin(), "dst_tensor"); + args.AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc)); + + elementwise_code = "{\n" + shader_source + "\n}\n" + elementwise_code; + shader_source = GetElementWiseCode(definition); + + resize_function = [](const std::vector& src_shapes, + const std::vector& dst_shapes) { + uint3 groups_size{8, 8, 1}; + uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x), + DivideRoundUp(dst_shapes[0].h, groups_size.y), + DivideRoundUp(dst_shapes[0].c, 4)}; + return std::make_pair(groups_size, groups_count); + }; + } + shader_source = absl::Substitute(shader_source, "$0", "$1", + "{\n" + elementwise_code + "\n}\n"); +} + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h index 0a12c19e7b2..6827ab4073d 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h +++ b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h @@ -91,6 +91,13 @@ struct ComputeTaskDescriptor { const TensorDescriptor& desc); void AddDstTensor(const std::string& tensor_name, const TensorDescriptor& desc); + + absl::Status AddTask(ComputeTaskDescriptor* task_desc); + void AssembleCode(); + + private: + int linkable_count = 0; // temporary, used during op construction + std::string elementwise_code; // temporary, used during op construction }; using ComputeTaskDescriptorPtr = std::shared_ptr; diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.cc b/tensorflow/lite/delegates/gpu/metal/inference_context.cc index a4627f221d5..0e9081b9fe5 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.cc @@ -422,81 +422,19 @@ void RemoveInputProxies(std::list* chains) { } } -NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id, - ValueId output_id, - const OperationDef& definition) { - auto desc = std::make_shared(); - desc->shader_source = R"( - #include - using namespace metal; - $0 - kernel void ComputeFunction($1 - uint3 gid[[thread_position_in_grid]]) { - if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) { - return; - } - FLT4 value = args.src_tensor.Read(gid.x, gid.y, gid.z); - args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z); - $2 - args.dst_tensor.Write(value, gid.x, gid.y, gid.z); - } - )"; - - desc->AddSrcTensor("src_tensor", definition.src_tensors[0]); - desc->AddDstTensor("dst_tensor", definition.dst_tensors[0]); - - desc->resize_function = [](const std::vector& src_shapes, - const std::vector& dst_shapes) { - uint3 groups_size{8, 8, 1}; - uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x), - DivideRoundUp(dst_shapes[0].h, groups_size.y), - DivideRoundUp(dst_shapes[0].c, 4)}; - return std::make_pair(groups_size, groups_count); - }; - - NodeDescriptor node_desc; - node_desc.task = desc; - node_desc.id = operation_id; - node_desc.src_tensors_ids = {input_id}; - node_desc.dst_tensors_ids = {output_id}; - return node_desc; -} - -absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst, - std::string link_name) { +absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst) { + for (int j = 1; j < src->src_tensors_ids.size(); ++j) { + dst->src_tensors_ids.push_back(src->src_tensors_ids[j]); + } dst->dst_tensors_ids[0] = src->dst_tensors_ids[0]; dst->description += " linked : " + src->description; - for (int i = 0; i < src->task->src_tensors_names.size(); ++i) { - std::string tensor_name = src->task->src_tensors_names[i]; - dst->task->src_tensors_names.push_back(tensor_name + link_name); - dst->task->definition.src_tensors.push_back( - src->task->definition.src_tensors[i + 1]); - dst->src_tensors_ids.push_back(src->src_tensors_ids[i + 1]); - } - - std::string code = src->task->shader_source; - src->task->args.RenameArgs(link_name, &code); - - RETURN_IF_ERROR(dst->task->args.Merge(std::move(src->task->args), link_name)); - - dst->task->shader_source = absl::Substitute(dst->task->shader_source, "$0", - "$1", "{\n" + code + "\n}\n$2"); - - return absl::OkStatus(); + return dst->task->AddTask(src->task.get()); } absl::Status FuseChain(const FusionSequence& chain, NodeDescriptor* node_desc) { - if (chain.front().task->is_linkable) { - *node_desc = NonLinkableStub( - chain.front().id, chain.front().src_tensors_ids[0], - chain.front().dst_tensors_ids[0], chain.front().task->definition); - RETURN_IF_ERROR(MergeNodes(&chain.front(), node_desc, "_link0")); - } else { - *node_desc = chain.front(); - } + *node_desc = chain.front(); for (int j = 1; j < chain.size(); ++j) { - RETURN_IF_ERROR( - MergeNodes(&chain[j], node_desc, "_link" + std::to_string(j))); + RETURN_IF_ERROR(MergeNodes(&chain[j], node_desc)); } return absl::OkStatus(); }