Fusion of nodes simplified.

Logic of tasks fusion moved to compute_task_descriptor.

PiperOrigin-RevId: 351205320
Change-Id: Iaa632b1b6bb7bfc302debc831ac43b7f9c6a9d0b
This commit is contained in:
Raman Sarokin 2021-01-11 12:01:17 -08:00 committed by TensorFlower Gardener
parent 7618f357a8
commit 99d561b15e
5 changed files with 96 additions and 76 deletions

View File

@ -137,6 +137,7 @@ objc_library(
"//tensorflow/lite/delegates/gpu/common/task:arguments",
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
"@FP16",
"@com_google_absl//absl/strings",
],
)

View File

@ -36,13 +36,9 @@ namespace metal {
absl::Status ComputeTask::CompileWithDevice(id<MTLDevice> device,
const NodeDescriptor& desc,
CalculationsPrecision precision) {
std::string args_declarations;
int bind_index = 0;
desc.task->shader_source = absl::Substitute(desc.task->shader_source, "$0",
args_declarations + "$1", "");
RETURN_IF_ERROR(metal_args_.Init(device, bind_index, &desc.task->args,
&desc.task->shader_source));
desc.task->AssembleCode();
RETURN_IF_ERROR(
metal_args_.Init(device, 0, &desc.task->args, &desc.task->shader_source));
NSString* barrier;
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language
// version 2.0

View File

@ -16,9 +16,12 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include <cstdint>
#include <string>
#include <vector>
#include <fp16.h>
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
@ -27,6 +30,28 @@ limitations under the License.
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::string GetElementWiseCode(const OperationDef& op_def) {
return R"(
#include <metal_stdlib>
using namespace metal;
$0
kernel void ComputeFunction($1
uint3 gid[[thread_position_in_grid]]) {
int X = static_cast<int>(gid.x);
int Y = static_cast<int>(gid.y);
int Z = static_cast<int>(gid.z);
if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) {
return;
}
FLT4 value = args.src_tensor.Read(X, Y, Z);
$2
args.dst_tensor.Write(value, X, Y, Z);
}
)";
}
} // namespace
/// Converts float to destination type (if needed) and stores as bytes array.
std::vector<uint8_t> GetByteBufferConverted(
@ -74,6 +99,59 @@ void ComputeTaskDescriptor::AddDstTensor(const std::string& tensor_name,
args.AddObjectRef(tensor_name, AccessType::WRITE, std::move(desc_new));
}
absl::Status ComputeTaskDescriptor::AddTask(ComputeTaskDescriptor* task_desc) {
linkable_count += 1;
std::string code = task_desc->shader_source;
std::string unique_postfix = absl::StrCat("_link", linkable_count);
task_desc->args.RenameArgs(unique_postfix, &code);
elementwise_code += "{\n" + code + "\n}\n";
RETURN_IF_ERROR(args.Merge(std::move(task_desc->args), unique_postfix));
for (int i = 0; i < task_desc->src_tensors_names.size(); ++i) {
definition.src_tensors.push_back(task_desc->definition.src_tensors[i + 1]);
src_tensors_names.push_back(task_desc->src_tensors_names[i] +
unique_postfix);
}
for (int i = 0; i < task_desc->dst_tensors_names.size(); ++i) {
dst_tensors_names.push_back(task_desc->dst_tensors_names[i] +
unique_postfix);
}
return absl::OkStatus();
}
void ComputeTaskDescriptor::AssembleCode() {
if (is_linkable) {
auto src_desc =
absl::make_unique<TensorDescriptor>(definition.src_tensors[0]);
if (definition.IsBatchSupported()) {
src_desc->SetStateVar("BatchedWidth", "true");
}
src_tensors_names.insert(src_tensors_names.begin(), "src_tensor");
args.AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
auto dst_desc =
absl::make_unique<TensorDescriptor>(definition.dst_tensors[0]);
if (definition.IsBatchSupported()) {
dst_desc->SetStateVar("BatchedWidth", "true");
}
dst_tensors_names.insert(dst_tensors_names.begin(), "dst_tensor");
args.AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc));
elementwise_code = "{\n" + shader_source + "\n}\n" + elementwise_code;
shader_source = GetElementWiseCode(definition);
resize_function = [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
uint3 groups_size{8, 8, 1};
uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x),
DivideRoundUp(dst_shapes[0].h, groups_size.y),
DivideRoundUp(dst_shapes[0].c, 4)};
return std::make_pair(groups_size, groups_count);
};
}
shader_source = absl::Substitute(shader_source, "$0", "$1",
"{\n" + elementwise_code + "\n}\n");
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -91,6 +91,13 @@ struct ComputeTaskDescriptor {
const TensorDescriptor& desc);
void AddDstTensor(const std::string& tensor_name,
const TensorDescriptor& desc);
absl::Status AddTask(ComputeTaskDescriptor* task_desc);
void AssembleCode();
private:
int linkable_count = 0; // temporary, used during op construction
std::string elementwise_code; // temporary, used during op construction
};
using ComputeTaskDescriptorPtr = std::shared_ptr<ComputeTaskDescriptor>;

View File

@ -422,81 +422,19 @@ void RemoveInputProxies(std::list<FusionSequence>* chains) {
}
}
NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
ValueId output_id,
const OperationDef& definition) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->shader_source = R"(
#include <metal_stdlib>
using namespace metal;
$0
kernel void ComputeFunction($1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
return;
}
FLT4 value = args.src_tensor.Read(gid.x, gid.y, gid.z);
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
$2
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
}
)";
desc->AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc->AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc->resize_function = [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
uint3 groups_size{8, 8, 1};
uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x),
DivideRoundUp(dst_shapes[0].h, groups_size.y),
DivideRoundUp(dst_shapes[0].c, 4)};
return std::make_pair(groups_size, groups_count);
};
NodeDescriptor node_desc;
node_desc.task = desc;
node_desc.id = operation_id;
node_desc.src_tensors_ids = {input_id};
node_desc.dst_tensors_ids = {output_id};
return node_desc;
}
absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst,
std::string link_name) {
absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst) {
for (int j = 1; j < src->src_tensors_ids.size(); ++j) {
dst->src_tensors_ids.push_back(src->src_tensors_ids[j]);
}
dst->dst_tensors_ids[0] = src->dst_tensors_ids[0];
dst->description += " linked : " + src->description;
for (int i = 0; i < src->task->src_tensors_names.size(); ++i) {
std::string tensor_name = src->task->src_tensors_names[i];
dst->task->src_tensors_names.push_back(tensor_name + link_name);
dst->task->definition.src_tensors.push_back(
src->task->definition.src_tensors[i + 1]);
dst->src_tensors_ids.push_back(src->src_tensors_ids[i + 1]);
}
std::string code = src->task->shader_source;
src->task->args.RenameArgs(link_name, &code);
RETURN_IF_ERROR(dst->task->args.Merge(std::move(src->task->args), link_name));
dst->task->shader_source = absl::Substitute(dst->task->shader_source, "$0",
"$1", "{\n" + code + "\n}\n$2");
return absl::OkStatus();
return dst->task->AddTask(src->task.get());
}
absl::Status FuseChain(const FusionSequence& chain, NodeDescriptor* node_desc) {
if (chain.front().task->is_linkable) {
*node_desc = NonLinkableStub(
chain.front().id, chain.front().src_tensors_ids[0],
chain.front().dst_tensors_ids[0], chain.front().task->definition);
RETURN_IF_ERROR(MergeNodes(&chain.front(), node_desc, "_link0"));
} else {
*node_desc = chain.front();
}
*node_desc = chain.front();
for (int j = 1; j < chain.size(); ++j) {
RETURN_IF_ERROR(
MergeNodes(&chain[j], node_desc, "_link" + std::to_string(j)));
RETURN_IF_ERROR(MergeNodes(&chain[j], node_desc));
}
return absl::OkStatus();
}