Fusion of nodes simplified.
Logic of tasks fusion moved to compute_task_descriptor. PiperOrigin-RevId: 351205320 Change-Id: Iaa632b1b6bb7bfc302debc831ac43b7f9c6a9d0b
This commit is contained in:
parent
7618f357a8
commit
99d561b15e
@ -137,6 +137,7 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common/task:arguments",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
|
||||
"@FP16",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -36,13 +36,9 @@ namespace metal {
|
||||
absl::Status ComputeTask::CompileWithDevice(id<MTLDevice> device,
|
||||
const NodeDescriptor& desc,
|
||||
CalculationsPrecision precision) {
|
||||
std::string args_declarations;
|
||||
int bind_index = 0;
|
||||
desc.task->shader_source = absl::Substitute(desc.task->shader_source, "$0",
|
||||
args_declarations + "$1", "");
|
||||
|
||||
RETURN_IF_ERROR(metal_args_.Init(device, bind_index, &desc.task->args,
|
||||
&desc.task->shader_source));
|
||||
desc.task->AssembleCode();
|
||||
RETURN_IF_ERROR(
|
||||
metal_args_.Init(device, 0, &desc.task->args, &desc.task->shader_source));
|
||||
NSString* barrier;
|
||||
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language
|
||||
// version 2.0
|
||||
|
@ -16,9 +16,12 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <fp16.h>
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
@ -27,6 +30,28 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
std::string GetElementWiseCode(const OperationDef& op_def) {
|
||||
return R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
$0
|
||||
kernel void ComputeFunction($1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
int X = static_cast<int>(gid.x);
|
||||
int Y = static_cast<int>(gid.y);
|
||||
int Z = static_cast<int>(gid.z);
|
||||
if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) {
|
||||
return;
|
||||
}
|
||||
FLT4 value = args.src_tensor.Read(X, Y, Z);
|
||||
$2
|
||||
args.dst_tensor.Write(value, X, Y, Z);
|
||||
}
|
||||
)";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Converts float to destination type (if needed) and stores as bytes array.
|
||||
std::vector<uint8_t> GetByteBufferConverted(
|
||||
@ -74,6 +99,59 @@ void ComputeTaskDescriptor::AddDstTensor(const std::string& tensor_name,
|
||||
args.AddObjectRef(tensor_name, AccessType::WRITE, std::move(desc_new));
|
||||
}
|
||||
|
||||
absl::Status ComputeTaskDescriptor::AddTask(ComputeTaskDescriptor* task_desc) {
|
||||
linkable_count += 1;
|
||||
std::string code = task_desc->shader_source;
|
||||
std::string unique_postfix = absl::StrCat("_link", linkable_count);
|
||||
task_desc->args.RenameArgs(unique_postfix, &code);
|
||||
elementwise_code += "{\n" + code + "\n}\n";
|
||||
RETURN_IF_ERROR(args.Merge(std::move(task_desc->args), unique_postfix));
|
||||
for (int i = 0; i < task_desc->src_tensors_names.size(); ++i) {
|
||||
definition.src_tensors.push_back(task_desc->definition.src_tensors[i + 1]);
|
||||
src_tensors_names.push_back(task_desc->src_tensors_names[i] +
|
||||
unique_postfix);
|
||||
}
|
||||
for (int i = 0; i < task_desc->dst_tensors_names.size(); ++i) {
|
||||
dst_tensors_names.push_back(task_desc->dst_tensors_names[i] +
|
||||
unique_postfix);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ComputeTaskDescriptor::AssembleCode() {
|
||||
if (is_linkable) {
|
||||
auto src_desc =
|
||||
absl::make_unique<TensorDescriptor>(definition.src_tensors[0]);
|
||||
if (definition.IsBatchSupported()) {
|
||||
src_desc->SetStateVar("BatchedWidth", "true");
|
||||
}
|
||||
src_tensors_names.insert(src_tensors_names.begin(), "src_tensor");
|
||||
args.AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
|
||||
|
||||
auto dst_desc =
|
||||
absl::make_unique<TensorDescriptor>(definition.dst_tensors[0]);
|
||||
if (definition.IsBatchSupported()) {
|
||||
dst_desc->SetStateVar("BatchedWidth", "true");
|
||||
}
|
||||
dst_tensors_names.insert(dst_tensors_names.begin(), "dst_tensor");
|
||||
args.AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc));
|
||||
|
||||
elementwise_code = "{\n" + shader_source + "\n}\n" + elementwise_code;
|
||||
shader_source = GetElementWiseCode(definition);
|
||||
|
||||
resize_function = [](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
uint3 groups_size{8, 8, 1};
|
||||
uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x),
|
||||
DivideRoundUp(dst_shapes[0].h, groups_size.y),
|
||||
DivideRoundUp(dst_shapes[0].c, 4)};
|
||||
return std::make_pair(groups_size, groups_count);
|
||||
};
|
||||
}
|
||||
shader_source = absl::Substitute(shader_source, "$0", "$1",
|
||||
"{\n" + elementwise_code + "\n}\n");
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -91,6 +91,13 @@ struct ComputeTaskDescriptor {
|
||||
const TensorDescriptor& desc);
|
||||
void AddDstTensor(const std::string& tensor_name,
|
||||
const TensorDescriptor& desc);
|
||||
|
||||
absl::Status AddTask(ComputeTaskDescriptor* task_desc);
|
||||
void AssembleCode();
|
||||
|
||||
private:
|
||||
int linkable_count = 0; // temporary, used during op construction
|
||||
std::string elementwise_code; // temporary, used during op construction
|
||||
};
|
||||
|
||||
using ComputeTaskDescriptorPtr = std::shared_ptr<ComputeTaskDescriptor>;
|
||||
|
@ -422,81 +422,19 @@ void RemoveInputProxies(std::list<FusionSequence>* chains) {
|
||||
}
|
||||
}
|
||||
|
||||
NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const OperationDef& definition) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->shader_source = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
$0
|
||||
kernel void ComputeFunction($1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
|
||||
return;
|
||||
}
|
||||
FLT4 value = args.src_tensor.Read(gid.x, gid.y, gid.z);
|
||||
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
|
||||
$2
|
||||
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
|
||||
}
|
||||
)";
|
||||
|
||||
desc->AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||
desc->AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||
|
||||
desc->resize_function = [](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
uint3 groups_size{8, 8, 1};
|
||||
uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x),
|
||||
DivideRoundUp(dst_shapes[0].h, groups_size.y),
|
||||
DivideRoundUp(dst_shapes[0].c, 4)};
|
||||
return std::make_pair(groups_size, groups_count);
|
||||
};
|
||||
|
||||
NodeDescriptor node_desc;
|
||||
node_desc.task = desc;
|
||||
node_desc.id = operation_id;
|
||||
node_desc.src_tensors_ids = {input_id};
|
||||
node_desc.dst_tensors_ids = {output_id};
|
||||
return node_desc;
|
||||
}
|
||||
|
||||
absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst,
|
||||
std::string link_name) {
|
||||
absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst) {
|
||||
for (int j = 1; j < src->src_tensors_ids.size(); ++j) {
|
||||
dst->src_tensors_ids.push_back(src->src_tensors_ids[j]);
|
||||
}
|
||||
dst->dst_tensors_ids[0] = src->dst_tensors_ids[0];
|
||||
dst->description += " linked : " + src->description;
|
||||
for (int i = 0; i < src->task->src_tensors_names.size(); ++i) {
|
||||
std::string tensor_name = src->task->src_tensors_names[i];
|
||||
dst->task->src_tensors_names.push_back(tensor_name + link_name);
|
||||
dst->task->definition.src_tensors.push_back(
|
||||
src->task->definition.src_tensors[i + 1]);
|
||||
dst->src_tensors_ids.push_back(src->src_tensors_ids[i + 1]);
|
||||
}
|
||||
|
||||
std::string code = src->task->shader_source;
|
||||
src->task->args.RenameArgs(link_name, &code);
|
||||
|
||||
RETURN_IF_ERROR(dst->task->args.Merge(std::move(src->task->args), link_name));
|
||||
|
||||
dst->task->shader_source = absl::Substitute(dst->task->shader_source, "$0",
|
||||
"$1", "{\n" + code + "\n}\n$2");
|
||||
|
||||
return absl::OkStatus();
|
||||
return dst->task->AddTask(src->task.get());
|
||||
}
|
||||
|
||||
absl::Status FuseChain(const FusionSequence& chain, NodeDescriptor* node_desc) {
|
||||
if (chain.front().task->is_linkable) {
|
||||
*node_desc = NonLinkableStub(
|
||||
chain.front().id, chain.front().src_tensors_ids[0],
|
||||
chain.front().dst_tensors_ids[0], chain.front().task->definition);
|
||||
RETURN_IF_ERROR(MergeNodes(&chain.front(), node_desc, "_link0"));
|
||||
} else {
|
||||
*node_desc = chain.front();
|
||||
}
|
||||
*node_desc = chain.front();
|
||||
for (int j = 1; j < chain.size(); ++j) {
|
||||
RETURN_IF_ERROR(
|
||||
MergeNodes(&chain[j], node_desc, "_link" + std::to_string(j)));
|
||||
RETURN_IF_ERROR(MergeNodes(&chain[j], node_desc));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user