Chains fusion simplified.

Fusion done as iterative process of nodes fusion.

PiperOrigin-RevId: 350109962
Change-Id: I50ec1b9f97916829286026954e50fe920923d09a
This commit is contained in:
Raman Sarokin 2021-01-05 03:39:10 -08:00 committed by TensorFlower Gardener
parent 32449628f6
commit b08f45a554
11 changed files with 185 additions and 115 deletions

View File

@ -119,7 +119,6 @@ objc_library(
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
"@com_google_absl//absl/strings",
],
)

View File

@ -16,12 +16,10 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
#include <Availability.h>
#include <string>
#include <tuple>
#include "absl/strings/match.h"
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@ -36,33 +34,10 @@ namespace metal {
absl::Status ComputeTask::CompileWithDevice(id<MTLDevice> device,
const NodeDescriptor& desc,
CalculationsPrecision precision) {
std::string args_declarations;
int bind_index = 0;
for (const auto& dst_name : desc.task->dst_tensors_names) {
args_declarations += "device FLT4* " + dst_name + "[[buffer(" +
std::to_string(bind_index) + ")]],\n";
bind_index++;
}
for (const auto& src_name : desc.task->src_tensors_names) {
args_declarations += "device FLT4* " + src_name + "[[buffer(" +
std::to_string(bind_index) + ")]],\n";
bind_index++;
}
for (const auto& buffer : desc.task->immutable_buffers) {
args_declarations += buffer.declaration + "[[buffer(" +
std::to_string(bind_index) + ")]],\n";
bind_index++;
}
for (const auto& buffer : desc.task->uniform_buffers) {
args_declarations += buffer.declaration + "[[buffer(" +
std::to_string(bind_index) + ")]],\n";
bind_index++;
}
desc.task->shader_source = absl::Substitute(desc.task->shader_source, "$0",
args_declarations + "$1", "");
RETURN_IF_ERROR(metal_args_.Init(device, bind_index, &desc.task->args,
size_t offset = desc.src_tensors_ids.size() +
desc.task->uniform_buffers.size() +
desc.task->immutable_buffers.size() + 1;
RETURN_IF_ERROR(metal_args_.Init(device, offset, &desc.task->args,
&desc.task->shader_source));
NSString* barrier;
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language
@ -256,11 +231,10 @@ std::vector<ValueId> ComputeTask::GetInputIds() const {
void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
input_buffers_[index].metal_handle = tensor.GetBufferHandle();
if (tensors_as_args_ &&
absl::StrContains(src_tensors_names_[index], "_buffer")) {
if (tensors_as_args_ && index < src_tensors_names_.size()) {
auto name = src_tensors_names_[index];
// extracting tensor_name from "tensor_name_buffer";
name = name.substr(0, name.size() - 7);
// extracting tensor_name from "device FLT4* tensor_name_buffer";
name = name.substr(13, name.size() - 20);
auto status = metal_args_.SetObjectRef(name, tensor);
}
}
@ -269,8 +243,8 @@ void ComputeTask::SetDstTensor(const MetalSpatialTensor& tensor, int index) {
output_buffers_[index].metal_handle = tensor.GetBufferHandle();
if (tensors_as_args_) {
auto name = dst_tensors_names_[index];
// extracting tensor_name from "tensor_name_buffer";
name = name.substr(0, name.size() - 7);
// extracting tensor_name from "device FLT4* tensor_name_buffer";
name = name.substr(13, name.size() - 20);
auto status = metal_args_.SetObjectRef(name, tensor);
}
}

View File

@ -63,22 +63,22 @@ ComputeTaskDescriptor::ComputeTaskDescriptor(const OperationDef& def)
void ComputeTaskDescriptor::AddSrcTensor(const std::string& tensor_name,
const TensorDescriptor& desc) {
if (tensors_as_args) {
src_tensors_names.push_back(tensor_name + "_buffer");
src_tensors_names.push_back("device FLT4* " + tensor_name + "_buffer");
auto desc_new = absl::make_unique<TensorDescriptor>(desc);
args.AddObjectRef(tensor_name, AccessType::READ, std::move(desc_new));
} else {
src_tensors_names.push_back(tensor_name);
src_tensors_names.push_back("device FLT4* " + tensor_name);
}
}
void ComputeTaskDescriptor::AddDstTensor(const std::string& tensor_name,
const TensorDescriptor& desc) {
if (tensors_as_args) {
dst_tensors_names.push_back(tensor_name + "_buffer");
dst_tensors_names.push_back("device FLT4* " + tensor_name + "_buffer");
auto desc_new = absl::make_unique<TensorDescriptor>(desc);
args.AddObjectRef(tensor_name, AccessType::WRITE, std::move(desc_new));
} else {
dst_tensors_names.push_back(tensor_name);
dst_tensors_names.push_back("device FLT4* " + tensor_name);
}
}

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
#include <map>
#include <string>
#include <vector>
#include "absl/strings/substitute.h"
@ -423,15 +422,15 @@ void RemoveInputProxies(std::list<FusionSequence>* chains) {
}
NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
ValueId output_id,
const OperationDef& definition) {
ValueId output_id) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->is_linkable = false;
desc->shader_source = R"(
#include <metal_stdlib>
using namespace metal;
$0
kernel void ComputeFunction($1
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.x || int(gid.y) >= size.y) {
return;
@ -443,8 +442,8 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
}
)";
desc->AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc->AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc->AddSrcTensor("src_tensor", {});
desc->AddDstTensor("dst_tensor", {});
desc->uniform_buffers = {
{"constant int2& size",
@ -472,58 +471,119 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
return node_desc;
}
void MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst,
std::string link_name) {
std::string call_arguments;
dst->dst_tensors_ids[0] = src->dst_tensors_ids[0];
dst->description += " linked : " + src->description;
for (int i = 0; i < src->task->src_tensors_names.size(); ++i) {
std::string tensor_name = src->task->src_tensors_names[i] + link_name;
call_arguments += ", " + tensor_name;
dst->task->AddSrcTensor(tensor_name,
src->task->definition.src_tensors[i + 1]);
dst->src_tensors_ids.push_back(src->src_tensors_ids[i + 1]);
}
for (int i = 0; i < src->task->immutable_buffers.size(); ++i) {
auto buffer = src->task->immutable_buffers[i];
const std::string buffer_name = "ibuffer" + std::to_string(i) + link_name;
buffer.declaration += " " + buffer_name;
call_arguments += ", " + buffer_name;
dst->task->immutable_buffers.push_back(buffer);
}
for (int i = 0; i < src->task->uniform_buffers.size(); ++i) {
auto buffer = src->task->uniform_buffers[i];
const std::string buffer_name = "ubuffer" + std::to_string(i) + link_name;
buffer.declaration += " " + buffer_name;
call_arguments += ", " + buffer_name;
dst->task->uniform_buffers.push_back(buffer);
}
std::string function_code =
absl::Substitute(src->task->shader_source, link_name) + "\n";
std::string call_code =
absl::Substitute("value = linkable$0(value, linear_index, gid$1);\n",
link_name, call_arguments);
dst->task->shader_source = absl::Substitute(
dst->task->shader_source, function_code + "$0", "$1", call_code + "$2");
}
NodeDescriptor FuseChain(const FusionSequence& chain) {
NodeDescriptor node_desc;
auto fused_descriptor = std::make_shared<ComputeTaskDescriptor>();
FusionSequence sequence;
if (chain.front().task->is_linkable) {
node_desc = NonLinkableStub(
chain.front().id, chain.front().src_tensors_ids[0],
chain.front().dst_tensors_ids[0], chain.front().task->definition);
MergeNodes(&chain.front(), &node_desc, "_link0");
} else {
node_desc = chain.front();
// The first task is linkable so it contains only linkable code. Insert
// unlinkable meta-task with remaining shader code.
sequence.push_back(NonLinkableStub(-1, chain.front().src_tensors_ids[0],
chain.front().src_tensors_ids[0]));
}
for (int j = 1; j < chain.size(); ++j) {
MergeNodes(&chain[j], &node_desc, "_link" + std::to_string(j));
sequence.insert(sequence.end(), chain.begin(), chain.end());
// Count buffers to calculate proper indices then.
int num_outputs = 1;
int num_inputs = 0;
int num_immutables = 0;
bool invalid_id = true;
ValueId fused_id;
for (const auto& desc : sequence) {
for (const auto& id : desc.src_tensors_ids) {
if (invalid_id || id != fused_id) {
num_inputs++;
}
}
fused_id = desc.dst_tensors_ids[0];
invalid_id = false;
num_immutables += desc.task->immutable_buffers.size();
}
int output_index = 0;
int input_index = num_outputs;
int immutable_index = num_outputs + num_inputs;
int uniform_index = num_outputs + num_inputs + num_immutables;
int function_index = 0;
std::string function_code;
std::string buffer_declarations;
std::string call_code;
invalid_id = true;
for (const auto& desc : sequence) {
if (desc.task->is_linkable) {
function_code +=
absl::Substitute(desc.task->shader_source, function_index) + "\n";
} else {
// Declare output buffer only for the first unlinkable task.
buffer_declarations +=
desc.task->dst_tensors_names[0] + "[[buffer(0)]],\n";
output_index++;
}
std::string call_arguments;
for (int i = 0; i < desc.task->src_tensors_names.size(); ++i) {
if (invalid_id || desc.src_tensors_ids[i] != fused_id) {
std::string index = std::to_string(input_index);
std::string name = (desc.task->is_linkable ? (" buffer" + index) : "");
buffer_declarations += desc.task->src_tensors_names[i] + name +
"[[buffer(" + index + ")]],\n";
call_arguments += ", buffer" + index;
input_index++;
fused_descriptor->AddSrcTensor("", {});
node_desc.src_tensors_ids.push_back(desc.src_tensors_ids[i]);
}
}
// We have an output id that is the input for the next task.
fused_id = desc.dst_tensors_ids[0];
invalid_id = false;
for (const auto& buffer : desc.task->immutable_buffers) {
std::string index = std::to_string(immutable_index);
std::string name = (desc.task->is_linkable ? (" buffer" + index) : "");
buffer_declarations +=
buffer.declaration + name + "[[buffer(" + index + ")]],\n";
call_arguments += ", buffer" + index;
immutable_index++;
fused_descriptor->immutable_buffers.push_back(buffer);
}
for (const auto& buffer : desc.task->uniform_buffers) {
std::string index = std::to_string(uniform_index);
std::string name = (desc.task->is_linkable ? (" buffer" + index) : "");
buffer_declarations +=
buffer.declaration + name + "[[buffer(" + index + ")]],\n";
call_arguments += ", buffer" + index;
uniform_index++;
fused_descriptor->uniform_buffers.push_back({"", buffer.data_function});
}
if (desc.task->is_linkable) {
call_code +=
absl::Substitute("value = linkable$0(value, linear_index, gid$1);\n",
function_index, call_arguments);
function_index++;
}
}
fused_descriptor->args = std::move(sequence.front().task->args);
auto& non_linkable = sequence.front();
fused_descriptor->shader_source =
absl::Substitute(non_linkable.task->shader_source, function_code + "$0",
buffer_declarations + "$1", call_code);
fused_descriptor->AddDstTensor("", {});
fused_descriptor->src_tensors_names = non_linkable.task->src_tensors_names;
fused_descriptor->dst_tensors_names = non_linkable.task->dst_tensors_names;
fused_descriptor->tensors_as_args = non_linkable.task->tensors_as_args;
fused_descriptor->resize_function = non_linkable.task->resize_function;
node_desc.dst_tensors_ids = {fused_id};
node_desc.task = fused_descriptor;
// The id of fused descriptor is the id of the first descriptor in the list.
node_desc.id = chain.front().id;
for (const auto& desc : sequence) {
node_desc.description += desc.description + "_";
}
return node_desc;
}
} // namespace

View File

@ -41,8 +41,8 @@ std::string GetAddTableCodeFused(int src_count) {
code += ") {\n";
for (int i = 0; i < src_count; ++i) {
code += " value += src_buf" + std::to_string(i) + "[linear_index];\n";
code += " return value;\n";
}
code += " return value;\n";
code += "}\n";
return code;
}
@ -53,10 +53,10 @@ ComputeTaskDescriptor Add(const OperationDef& definition) {
desc.is_linkable = true;
desc.shader_source = GetAddTableCodeFused(definition.src_tensors.size() - 1);
for (int i = 1; i < definition.src_tensors.size(); ++i) {
desc.AddSrcTensor("src_tensor_" + std::to_string(i),
definition.src_tensors[i]);
for (int i = 0; i < definition.src_tensors.size(); ++i) {
desc.AddSrcTensor("", definition.src_tensors[i]);
}
desc.AddDstTensor("", definition.dst_tensors[0]);
return desc;
}

View File

@ -46,32 +46,27 @@ using ::tflite::gpu::metal::SingleOpModel;
[super setUp];
}
- (void)testThreeInputTensorsOfTheSameShape {
TensorRef<BHWC> a, b, c, output;
a.type = DataType::FLOAT32;
a.ref = 0;
a.shape = BHWC(1, 2, 2, 1);
- (void)testTwoInputTensorsOfTheSameShape {
TensorRef<BHWC> augend, addend, output;
augend.type = DataType::FLOAT32;
augend.ref = 0;
augend.shape = BHWC(1, 2, 2, 1);
b.type = DataType::FLOAT32;
b.ref = 1;
b.shape = BHWC(1, 2, 2, 1);
c.type = DataType::FLOAT32;
c.ref = 2;
c.shape = BHWC(1, 2, 2, 1);
addend.type = DataType::FLOAT32;
addend.ref = 1;
addend.shape = BHWC(1, 2, 2, 1);
output.type = DataType::FLOAT32;
output.ref = 3;
output.ref = 2;
output.shape = BHWC(1, 2, 2, 1);
ElementwiseAttributes attr;
SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {a, b, c}, {output});
SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {augend, addend}, {output});
XCTAssertTrue(model.PopulateTensor(0, {-2.0, 0.2, 0.7, 0.8}));
XCTAssertTrue(model.PopulateTensor(1, {0.1, 0.2, 0.3, 0.5}));
XCTAssertTrue(model.PopulateTensor(2, {2.1, 1.2, 3.3, 4.5}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({0.2, 1.6, 4.3, 5.8}, model.GetOutput(0), 1e-6f);
status = CompareVectors({-1.9, 0.4, 1.0, 1.3}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}

View File

@ -107,7 +107,9 @@ ComputeTaskDescriptor ElementwiseWithTwoInputs(const OperationDef& definition,
desc.shader_source = code;
desc.AddSrcTensor("second_tensor", definition.src_tensors[1]);
desc.AddSrcTensor("", definition.src_tensors[0]);
desc.AddSrcTensor("", definition.src_tensors[1]);
desc.AddDstTensor("", definition.dst_tensors[0]);
desc.uniform_buffers = {
{"constant int2&",
@ -132,6 +134,9 @@ ComputeTaskDescriptor ElementwiseWithOneInput(const OperationDef& definition,
desc.shader_source +=
" return " + OneInputFunctor(op_type, "value") + ";\n";
desc.shader_source += " }";
desc.AddSrcTensor("", definition.src_tensors[0]);
desc.AddDstTensor("", definition.dst_tensors[0]);
return desc;
}
@ -177,6 +182,8 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
desc.shader_source += " }";
desc.AddSrcTensor("", definition.src_tensors[0]);
desc.AddDstTensor("", definition.dst_tensors[0]);
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
if (scalar) {
std::vector<uint8_t> scalar_bits =

View File

@ -55,6 +55,8 @@ ComputeTaskDescriptor PReLU(const OperationDef& definition,
return FLT4(max(FLT4(0.0f), value) + alphas[gid.z] * min(FLT4(0.0f), value));
})";
}
desc.AddSrcTensor("", definition.src_tensors[0]);
desc.AddDstTensor("", definition.dst_tensors[0]);
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
desc.immutable_buffers = {
{"device FLT4* const",
@ -95,6 +97,8 @@ ComputeTaskDescriptor PReLUFull(const OperationDef& definition,
return FLT4(max(FLT4(0.0f), value) + alphas[linear_index] * min(FLT4(0.0f), value));
})";
}
desc.AddSrcTensor("", definition.src_tensors[0]);
desc.AddDstTensor("", definition.dst_tensors[0]);
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
desc.immutable_buffers = {
{"device FLT4* const",

View File

@ -35,6 +35,9 @@ ComputeTaskDescriptor QuantizeAndDequantize(
return round(value) * FLT4(params.z) + FLT4(params.x);
}
)";
desc.AddSrcTensor("", definition.src_tensors[0]);
desc.AddDstTensor("", definition.dst_tensors[0]);
desc.uniform_buffers = {
{"constant float3&",
[attr](const std::vector<BHWC>& src_shapes,

View File

@ -46,6 +46,8 @@ ComputeTaskDescriptor ReLU(const OperationDef& definition,
desc.shader_source =
parameters + " return FLT4(max(value, " + min_func + "));\n}";
}
desc.AddSrcTensor("", definition.src_tensors[0]);
desc.AddDstTensor("", definition.dst_tensors[0]);
desc.uniform_buffers = {
{"constant float2&",
[attr](const std::vector<BHWC>& src_shapes,

View File

@ -235,6 +235,32 @@ absl::Status MetalExecutionEnvironment::ExecuteGPUOperation(
metal_node.description = "test_op";
metal_node.id = 0;
std::string buffer_declarations;
int index = 0;
for (int i = 0; i < metal_node.task->dst_tensors_names.size(); ++i) {
buffer_declarations += metal_node.task->dst_tensors_names[i] + "[[buffer(" +
std::to_string(index) + ")]],\n";
index++;
}
for (int i = 0; i < metal_node.task->src_tensors_names.size(); ++i) {
buffer_declarations += metal_node.task->src_tensors_names[i] + "[[buffer(" +
std::to_string(index) + ")]],\n";
index++;
}
for (const auto& buffer : metal_node.task->immutable_buffers) {
buffer_declarations +=
buffer.declaration + "[[buffer(" + std::to_string(index) + ")]],\n";
index++;
}
for (const auto& buffer : metal_node.task->uniform_buffers) {
buffer_declarations +=
buffer.declaration + "[[buffer(" + std::to_string(index) + ")]],\n";
index++;
}
metal_node.task->shader_source = absl::Substitute(
metal_node.task->shader_source, "$0", buffer_declarations + "$1", "");
ComputeTask gpu_task;
RETURN_IF_ERROR(
gpu_task.CompileWithDevice(device_, metal_node, op_def.precision));