Chains fusion simplified.
Fusion done as iterative process of nodes fusion. PiperOrigin-RevId: 350109962 Change-Id: I50ec1b9f97916829286026954e50fe920923d09a
This commit is contained in:
parent
32449628f6
commit
b08f45a554
@ -119,7 +119,6 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,12 +16,10 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
|
||||
|
||||
#include <Availability.h>
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
@ -36,33 +34,10 @@ namespace metal {
|
||||
absl::Status ComputeTask::CompileWithDevice(id<MTLDevice> device,
|
||||
const NodeDescriptor& desc,
|
||||
CalculationsPrecision precision) {
|
||||
std::string args_declarations;
|
||||
int bind_index = 0;
|
||||
for (const auto& dst_name : desc.task->dst_tensors_names) {
|
||||
args_declarations += "device FLT4* " + dst_name + "[[buffer(" +
|
||||
std::to_string(bind_index) + ")]],\n";
|
||||
bind_index++;
|
||||
}
|
||||
for (const auto& src_name : desc.task->src_tensors_names) {
|
||||
args_declarations += "device FLT4* " + src_name + "[[buffer(" +
|
||||
std::to_string(bind_index) + ")]],\n";
|
||||
bind_index++;
|
||||
}
|
||||
for (const auto& buffer : desc.task->immutable_buffers) {
|
||||
args_declarations += buffer.declaration + "[[buffer(" +
|
||||
std::to_string(bind_index) + ")]],\n";
|
||||
bind_index++;
|
||||
}
|
||||
|
||||
for (const auto& buffer : desc.task->uniform_buffers) {
|
||||
args_declarations += buffer.declaration + "[[buffer(" +
|
||||
std::to_string(bind_index) + ")]],\n";
|
||||
bind_index++;
|
||||
}
|
||||
desc.task->shader_source = absl::Substitute(desc.task->shader_source, "$0",
|
||||
args_declarations + "$1", "");
|
||||
|
||||
RETURN_IF_ERROR(metal_args_.Init(device, bind_index, &desc.task->args,
|
||||
size_t offset = desc.src_tensors_ids.size() +
|
||||
desc.task->uniform_buffers.size() +
|
||||
desc.task->immutable_buffers.size() + 1;
|
||||
RETURN_IF_ERROR(metal_args_.Init(device, offset, &desc.task->args,
|
||||
&desc.task->shader_source));
|
||||
NSString* barrier;
|
||||
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language
|
||||
@ -256,11 +231,10 @@ std::vector<ValueId> ComputeTask::GetInputIds() const {
|
||||
|
||||
void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
|
||||
input_buffers_[index].metal_handle = tensor.GetBufferHandle();
|
||||
if (tensors_as_args_ &&
|
||||
absl::StrContains(src_tensors_names_[index], "_buffer")) {
|
||||
if (tensors_as_args_ && index < src_tensors_names_.size()) {
|
||||
auto name = src_tensors_names_[index];
|
||||
// extracting tensor_name from "tensor_name_buffer";
|
||||
name = name.substr(0, name.size() - 7);
|
||||
// extracting tensor_name from "device FLT4* tensor_name_buffer";
|
||||
name = name.substr(13, name.size() - 20);
|
||||
auto status = metal_args_.SetObjectRef(name, tensor);
|
||||
}
|
||||
}
|
||||
@ -269,8 +243,8 @@ void ComputeTask::SetDstTensor(const MetalSpatialTensor& tensor, int index) {
|
||||
output_buffers_[index].metal_handle = tensor.GetBufferHandle();
|
||||
if (tensors_as_args_) {
|
||||
auto name = dst_tensors_names_[index];
|
||||
// extracting tensor_name from "tensor_name_buffer";
|
||||
name = name.substr(0, name.size() - 7);
|
||||
// extracting tensor_name from "device FLT4* tensor_name_buffer";
|
||||
name = name.substr(13, name.size() - 20);
|
||||
auto status = metal_args_.SetObjectRef(name, tensor);
|
||||
}
|
||||
}
|
||||
|
@ -63,22 +63,22 @@ ComputeTaskDescriptor::ComputeTaskDescriptor(const OperationDef& def)
|
||||
void ComputeTaskDescriptor::AddSrcTensor(const std::string& tensor_name,
|
||||
const TensorDescriptor& desc) {
|
||||
if (tensors_as_args) {
|
||||
src_tensors_names.push_back(tensor_name + "_buffer");
|
||||
src_tensors_names.push_back("device FLT4* " + tensor_name + "_buffer");
|
||||
auto desc_new = absl::make_unique<TensorDescriptor>(desc);
|
||||
args.AddObjectRef(tensor_name, AccessType::READ, std::move(desc_new));
|
||||
} else {
|
||||
src_tensors_names.push_back(tensor_name);
|
||||
src_tensors_names.push_back("device FLT4* " + tensor_name);
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeTaskDescriptor::AddDstTensor(const std::string& tensor_name,
|
||||
const TensorDescriptor& desc) {
|
||||
if (tensors_as_args) {
|
||||
dst_tensors_names.push_back(tensor_name + "_buffer");
|
||||
dst_tensors_names.push_back("device FLT4* " + tensor_name + "_buffer");
|
||||
auto desc_new = absl::make_unique<TensorDescriptor>(desc);
|
||||
args.AddObjectRef(tensor_name, AccessType::WRITE, std::move(desc_new));
|
||||
} else {
|
||||
dst_tensors_names.push_back(tensor_name);
|
||||
dst_tensors_names.push_back("device FLT4* " + tensor_name);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
@ -423,15 +422,15 @@ void RemoveInputProxies(std::list<FusionSequence>* chains) {
|
||||
}
|
||||
|
||||
NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const OperationDef& definition) {
|
||||
ValueId output_id) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->is_linkable = false;
|
||||
desc->shader_source = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
$0
|
||||
kernel void ComputeFunction($1
|
||||
kernel void ComputeFunction(
|
||||
$1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
if (int(gid.x) >= size.x || int(gid.y) >= size.y) {
|
||||
return;
|
||||
@ -443,8 +442,8 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
|
||||
}
|
||||
)";
|
||||
|
||||
desc->AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||
desc->AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||
desc->AddSrcTensor("src_tensor", {});
|
||||
desc->AddDstTensor("dst_tensor", {});
|
||||
|
||||
desc->uniform_buffers = {
|
||||
{"constant int2& size",
|
||||
@ -472,58 +471,119 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
|
||||
return node_desc;
|
||||
}
|
||||
|
||||
void MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst,
|
||||
std::string link_name) {
|
||||
std::string call_arguments;
|
||||
dst->dst_tensors_ids[0] = src->dst_tensors_ids[0];
|
||||
dst->description += " linked : " + src->description;
|
||||
for (int i = 0; i < src->task->src_tensors_names.size(); ++i) {
|
||||
std::string tensor_name = src->task->src_tensors_names[i] + link_name;
|
||||
call_arguments += ", " + tensor_name;
|
||||
dst->task->AddSrcTensor(tensor_name,
|
||||
src->task->definition.src_tensors[i + 1]);
|
||||
dst->src_tensors_ids.push_back(src->src_tensors_ids[i + 1]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < src->task->immutable_buffers.size(); ++i) {
|
||||
auto buffer = src->task->immutable_buffers[i];
|
||||
const std::string buffer_name = "ibuffer" + std::to_string(i) + link_name;
|
||||
buffer.declaration += " " + buffer_name;
|
||||
call_arguments += ", " + buffer_name;
|
||||
dst->task->immutable_buffers.push_back(buffer);
|
||||
}
|
||||
|
||||
for (int i = 0; i < src->task->uniform_buffers.size(); ++i) {
|
||||
auto buffer = src->task->uniform_buffers[i];
|
||||
const std::string buffer_name = "ubuffer" + std::to_string(i) + link_name;
|
||||
buffer.declaration += " " + buffer_name;
|
||||
call_arguments += ", " + buffer_name;
|
||||
dst->task->uniform_buffers.push_back(buffer);
|
||||
}
|
||||
|
||||
std::string function_code =
|
||||
absl::Substitute(src->task->shader_source, link_name) + "\n";
|
||||
std::string call_code =
|
||||
absl::Substitute("value = linkable$0(value, linear_index, gid$1);\n",
|
||||
link_name, call_arguments);
|
||||
|
||||
dst->task->shader_source = absl::Substitute(
|
||||
dst->task->shader_source, function_code + "$0", "$1", call_code + "$2");
|
||||
}
|
||||
|
||||
NodeDescriptor FuseChain(const FusionSequence& chain) {
|
||||
NodeDescriptor node_desc;
|
||||
auto fused_descriptor = std::make_shared<ComputeTaskDescriptor>();
|
||||
FusionSequence sequence;
|
||||
if (chain.front().task->is_linkable) {
|
||||
node_desc = NonLinkableStub(
|
||||
chain.front().id, chain.front().src_tensors_ids[0],
|
||||
chain.front().dst_tensors_ids[0], chain.front().task->definition);
|
||||
MergeNodes(&chain.front(), &node_desc, "_link0");
|
||||
} else {
|
||||
node_desc = chain.front();
|
||||
// The first task is linkable so it contains only linkable code. Insert
|
||||
// unlinkable meta-task with remaining shader code.
|
||||
sequence.push_back(NonLinkableStub(-1, chain.front().src_tensors_ids[0],
|
||||
chain.front().src_tensors_ids[0]));
|
||||
}
|
||||
for (int j = 1; j < chain.size(); ++j) {
|
||||
MergeNodes(&chain[j], &node_desc, "_link" + std::to_string(j));
|
||||
sequence.insert(sequence.end(), chain.begin(), chain.end());
|
||||
|
||||
// Count buffers to calculate proper indices then.
|
||||
int num_outputs = 1;
|
||||
int num_inputs = 0;
|
||||
int num_immutables = 0;
|
||||
bool invalid_id = true;
|
||||
ValueId fused_id;
|
||||
for (const auto& desc : sequence) {
|
||||
for (const auto& id : desc.src_tensors_ids) {
|
||||
if (invalid_id || id != fused_id) {
|
||||
num_inputs++;
|
||||
}
|
||||
}
|
||||
fused_id = desc.dst_tensors_ids[0];
|
||||
invalid_id = false;
|
||||
num_immutables += desc.task->immutable_buffers.size();
|
||||
}
|
||||
|
||||
int output_index = 0;
|
||||
int input_index = num_outputs;
|
||||
int immutable_index = num_outputs + num_inputs;
|
||||
int uniform_index = num_outputs + num_inputs + num_immutables;
|
||||
|
||||
int function_index = 0;
|
||||
std::string function_code;
|
||||
std::string buffer_declarations;
|
||||
std::string call_code;
|
||||
invalid_id = true;
|
||||
for (const auto& desc : sequence) {
|
||||
if (desc.task->is_linkable) {
|
||||
function_code +=
|
||||
absl::Substitute(desc.task->shader_source, function_index) + "\n";
|
||||
} else {
|
||||
// Declare output buffer only for the first unlinkable task.
|
||||
buffer_declarations +=
|
||||
desc.task->dst_tensors_names[0] + "[[buffer(0)]],\n";
|
||||
output_index++;
|
||||
}
|
||||
|
||||
std::string call_arguments;
|
||||
for (int i = 0; i < desc.task->src_tensors_names.size(); ++i) {
|
||||
if (invalid_id || desc.src_tensors_ids[i] != fused_id) {
|
||||
std::string index = std::to_string(input_index);
|
||||
std::string name = (desc.task->is_linkable ? (" buffer" + index) : "");
|
||||
buffer_declarations += desc.task->src_tensors_names[i] + name +
|
||||
"[[buffer(" + index + ")]],\n";
|
||||
call_arguments += ", buffer" + index;
|
||||
input_index++;
|
||||
fused_descriptor->AddSrcTensor("", {});
|
||||
node_desc.src_tensors_ids.push_back(desc.src_tensors_ids[i]);
|
||||
}
|
||||
}
|
||||
// We have an output id that is the input for the next task.
|
||||
fused_id = desc.dst_tensors_ids[0];
|
||||
invalid_id = false;
|
||||
|
||||
for (const auto& buffer : desc.task->immutable_buffers) {
|
||||
std::string index = std::to_string(immutable_index);
|
||||
std::string name = (desc.task->is_linkable ? (" buffer" + index) : "");
|
||||
buffer_declarations +=
|
||||
buffer.declaration + name + "[[buffer(" + index + ")]],\n";
|
||||
call_arguments += ", buffer" + index;
|
||||
immutable_index++;
|
||||
fused_descriptor->immutable_buffers.push_back(buffer);
|
||||
}
|
||||
|
||||
for (const auto& buffer : desc.task->uniform_buffers) {
|
||||
std::string index = std::to_string(uniform_index);
|
||||
std::string name = (desc.task->is_linkable ? (" buffer" + index) : "");
|
||||
buffer_declarations +=
|
||||
buffer.declaration + name + "[[buffer(" + index + ")]],\n";
|
||||
call_arguments += ", buffer" + index;
|
||||
uniform_index++;
|
||||
fused_descriptor->uniform_buffers.push_back({"", buffer.data_function});
|
||||
}
|
||||
|
||||
if (desc.task->is_linkable) {
|
||||
call_code +=
|
||||
absl::Substitute("value = linkable$0(value, linear_index, gid$1);\n",
|
||||
function_index, call_arguments);
|
||||
function_index++;
|
||||
}
|
||||
}
|
||||
fused_descriptor->args = std::move(sequence.front().task->args);
|
||||
|
||||
auto& non_linkable = sequence.front();
|
||||
fused_descriptor->shader_source =
|
||||
absl::Substitute(non_linkable.task->shader_source, function_code + "$0",
|
||||
buffer_declarations + "$1", call_code);
|
||||
fused_descriptor->AddDstTensor("", {});
|
||||
fused_descriptor->src_tensors_names = non_linkable.task->src_tensors_names;
|
||||
fused_descriptor->dst_tensors_names = non_linkable.task->dst_tensors_names;
|
||||
fused_descriptor->tensors_as_args = non_linkable.task->tensors_as_args;
|
||||
fused_descriptor->resize_function = non_linkable.task->resize_function;
|
||||
node_desc.dst_tensors_ids = {fused_id};
|
||||
node_desc.task = fused_descriptor;
|
||||
// The id of fused descriptor is the id of the first descriptor in the list.
|
||||
node_desc.id = chain.front().id;
|
||||
for (const auto& desc : sequence) {
|
||||
node_desc.description += desc.description + "_";
|
||||
}
|
||||
|
||||
return node_desc;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -41,8 +41,8 @@ std::string GetAddTableCodeFused(int src_count) {
|
||||
code += ") {\n";
|
||||
for (int i = 0; i < src_count; ++i) {
|
||||
code += " value += src_buf" + std::to_string(i) + "[linear_index];\n";
|
||||
code += " return value;\n";
|
||||
}
|
||||
code += " return value;\n";
|
||||
code += "}\n";
|
||||
return code;
|
||||
}
|
||||
@ -53,10 +53,10 @@ ComputeTaskDescriptor Add(const OperationDef& definition) {
|
||||
desc.is_linkable = true;
|
||||
desc.shader_source = GetAddTableCodeFused(definition.src_tensors.size() - 1);
|
||||
|
||||
for (int i = 1; i < definition.src_tensors.size(); ++i) {
|
||||
desc.AddSrcTensor("src_tensor_" + std::to_string(i),
|
||||
definition.src_tensors[i]);
|
||||
for (int i = 0; i < definition.src_tensors.size(); ++i) {
|
||||
desc.AddSrcTensor("", definition.src_tensors[i]);
|
||||
}
|
||||
desc.AddDstTensor("", definition.dst_tensors[0]);
|
||||
|
||||
return desc;
|
||||
}
|
||||
|
@ -46,32 +46,27 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
[super setUp];
|
||||
}
|
||||
|
||||
- (void)testThreeInputTensorsOfTheSameShape {
|
||||
TensorRef<BHWC> a, b, c, output;
|
||||
a.type = DataType::FLOAT32;
|
||||
a.ref = 0;
|
||||
a.shape = BHWC(1, 2, 2, 1);
|
||||
- (void)testTwoInputTensorsOfTheSameShape {
|
||||
TensorRef<BHWC> augend, addend, output;
|
||||
augend.type = DataType::FLOAT32;
|
||||
augend.ref = 0;
|
||||
augend.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
b.type = DataType::FLOAT32;
|
||||
b.ref = 1;
|
||||
b.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
c.type = DataType::FLOAT32;
|
||||
c.ref = 2;
|
||||
c.shape = BHWC(1, 2, 2, 1);
|
||||
addend.type = DataType::FLOAT32;
|
||||
addend.ref = 1;
|
||||
addend.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
ElementwiseAttributes attr;
|
||||
SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {a, b, c}, {output});
|
||||
SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {augend, addend}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {-2.0, 0.2, 0.7, 0.8}));
|
||||
XCTAssertTrue(model.PopulateTensor(1, {0.1, 0.2, 0.3, 0.5}));
|
||||
XCTAssertTrue(model.PopulateTensor(2, {2.1, 1.2, 3.3, 4.5}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({0.2, 1.6, 4.3, 5.8}, model.GetOutput(0), 1e-6f);
|
||||
status = CompareVectors({-1.9, 0.4, 1.0, 1.3}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
|
@ -107,7 +107,9 @@ ComputeTaskDescriptor ElementwiseWithTwoInputs(const OperationDef& definition,
|
||||
|
||||
desc.shader_source = code;
|
||||
|
||||
desc.AddSrcTensor("second_tensor", definition.src_tensors[1]);
|
||||
desc.AddSrcTensor("", definition.src_tensors[0]);
|
||||
desc.AddSrcTensor("", definition.src_tensors[1]);
|
||||
desc.AddDstTensor("", definition.dst_tensors[0]);
|
||||
|
||||
desc.uniform_buffers = {
|
||||
{"constant int2&",
|
||||
@ -132,6 +134,9 @@ ComputeTaskDescriptor ElementwiseWithOneInput(const OperationDef& definition,
|
||||
desc.shader_source +=
|
||||
" return " + OneInputFunctor(op_type, "value") + ";\n";
|
||||
desc.shader_source += " }";
|
||||
|
||||
desc.AddSrcTensor("", definition.src_tensors[0]);
|
||||
desc.AddDstTensor("", definition.dst_tensors[0]);
|
||||
return desc;
|
||||
}
|
||||
|
||||
@ -177,6 +182,8 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
|
||||
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
|
||||
desc.shader_source += " }";
|
||||
|
||||
desc.AddSrcTensor("", definition.src_tensors[0]);
|
||||
desc.AddDstTensor("", definition.dst_tensors[0]);
|
||||
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
|
||||
if (scalar) {
|
||||
std::vector<uint8_t> scalar_bits =
|
||||
|
@ -55,6 +55,8 @@ ComputeTaskDescriptor PReLU(const OperationDef& definition,
|
||||
return FLT4(max(FLT4(0.0f), value) + alphas[gid.z] * min(FLT4(0.0f), value));
|
||||
})";
|
||||
}
|
||||
desc.AddSrcTensor("", definition.src_tensors[0]);
|
||||
desc.AddDstTensor("", definition.dst_tensors[0]);
|
||||
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
|
||||
desc.immutable_buffers = {
|
||||
{"device FLT4* const",
|
||||
@ -95,6 +97,8 @@ ComputeTaskDescriptor PReLUFull(const OperationDef& definition,
|
||||
return FLT4(max(FLT4(0.0f), value) + alphas[linear_index] * min(FLT4(0.0f), value));
|
||||
})";
|
||||
}
|
||||
desc.AddSrcTensor("", definition.src_tensors[0]);
|
||||
desc.AddDstTensor("", definition.dst_tensors[0]);
|
||||
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
|
||||
desc.immutable_buffers = {
|
||||
{"device FLT4* const",
|
||||
|
@ -35,6 +35,9 @@ ComputeTaskDescriptor QuantizeAndDequantize(
|
||||
return round(value) * FLT4(params.z) + FLT4(params.x);
|
||||
}
|
||||
)";
|
||||
|
||||
desc.AddSrcTensor("", definition.src_tensors[0]);
|
||||
desc.AddDstTensor("", definition.dst_tensors[0]);
|
||||
desc.uniform_buffers = {
|
||||
{"constant float3&",
|
||||
[attr](const std::vector<BHWC>& src_shapes,
|
||||
|
@ -46,6 +46,8 @@ ComputeTaskDescriptor ReLU(const OperationDef& definition,
|
||||
desc.shader_source =
|
||||
parameters + " return FLT4(max(value, " + min_func + "));\n}";
|
||||
}
|
||||
desc.AddSrcTensor("", definition.src_tensors[0]);
|
||||
desc.AddDstTensor("", definition.dst_tensors[0]);
|
||||
desc.uniform_buffers = {
|
||||
{"constant float2&",
|
||||
[attr](const std::vector<BHWC>& src_shapes,
|
||||
|
@ -235,6 +235,32 @@ absl::Status MetalExecutionEnvironment::ExecuteGPUOperation(
|
||||
metal_node.description = "test_op";
|
||||
metal_node.id = 0;
|
||||
|
||||
std::string buffer_declarations;
|
||||
int index = 0;
|
||||
for (int i = 0; i < metal_node.task->dst_tensors_names.size(); ++i) {
|
||||
buffer_declarations += metal_node.task->dst_tensors_names[i] + "[[buffer(" +
|
||||
std::to_string(index) + ")]],\n";
|
||||
index++;
|
||||
}
|
||||
for (int i = 0; i < metal_node.task->src_tensors_names.size(); ++i) {
|
||||
buffer_declarations += metal_node.task->src_tensors_names[i] + "[[buffer(" +
|
||||
std::to_string(index) + ")]],\n";
|
||||
index++;
|
||||
}
|
||||
for (const auto& buffer : metal_node.task->immutable_buffers) {
|
||||
buffer_declarations +=
|
||||
buffer.declaration + "[[buffer(" + std::to_string(index) + ")]],\n";
|
||||
index++;
|
||||
}
|
||||
for (const auto& buffer : metal_node.task->uniform_buffers) {
|
||||
buffer_declarations +=
|
||||
buffer.declaration + "[[buffer(" + std::to_string(index) + ")]],\n";
|
||||
index++;
|
||||
}
|
||||
|
||||
metal_node.task->shader_source = absl::Substitute(
|
||||
metal_node.task->shader_source, "$0", buffer_declarations + "$1", "");
|
||||
|
||||
ComputeTask gpu_task;
|
||||
RETURN_IF_ERROR(
|
||||
gpu_task.CompileWithDevice(device_, metal_node, op_def.precision));
|
||||
|
Loading…
x
Reference in New Issue
Block a user