Elementwise ops in Metal converted to new style(used arguments and tensors).

PiperOrigin-RevId: 350402612
Change-Id: If770f4edff502c39886bf3c285d32f165cee69b3
This commit is contained in:
Raman Sarokin 2021-01-06 12:10:50 -08:00 committed by TensorFlower Gardener
parent bcd5dd0148
commit 7d841e13c4
10 changed files with 128 additions and 236 deletions

View File

@ -188,6 +188,7 @@ objc_library(
deps = [
":buffer",
":gpu_object",
":metal_spatial_tensor",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/task:arguments",

View File

@ -256,8 +256,7 @@ std::vector<ValueId> ComputeTask::GetInputIds() const {
void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
input_buffers_[index].metal_handle = tensor.GetBufferHandle();
if (tensors_as_args_ &&
absl::StrContains(src_tensors_names_[index], "_buffer")) {
if (absl::StrContains(src_tensors_names_[index], "_buffer")) {
auto name = src_tensors_names_[index];
// extracting tensor_name from "tensor_name_buffer";
name = name.substr(0, name.size() - 7);
@ -267,7 +266,7 @@ void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
void ComputeTask::SetDstTensor(const MetalSpatialTensor& tensor, int index) {
output_buffers_[index].metal_handle = tensor.GetBufferHandle();
if (tensors_as_args_) {
if (absl::StrContains(dst_tensors_names_[index], "_buffer")) {
auto name = dst_tensors_names_[index];
// extracting tensor_name from "tensor_name_buffer";
name = name.substr(0, name.size() - 7);

View File

@ -426,38 +426,29 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
ValueId output_id,
const OperationDef& definition) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->is_linkable = false;
desc->tensors_as_args = true;
desc->shader_source = R"(
#include <metal_stdlib>
using namespace metal;
$0
kernel void ComputeFunction($1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.x || int(gid.y) >= size.y) {
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
return;
}
const int linear_index = (gid.z * size.y + gid.y) * size.x + gid.x;
FLT4 value = src_tensor[linear_index];
FLT4 value = args.src_tensor.Read(gid.x, gid.y, gid.z);
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
$2
dst_tensor[linear_index] = value;
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
}
)";
desc->AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc->AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc->uniform_buffers = {
{"constant int2& size",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return GetByteBuffer(
std::vector<int>{src_shapes[0].w, src_shapes[0].h});
}},
};
desc->resize_function = [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
uint3 groups_size{16, 16, 1};
uint3 groups_size{8, 8, 1};
uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x),
DivideRoundUp(dst_shapes[0].h, groups_size.y),
DivideRoundUp(dst_shapes[0].c, 4)};
@ -472,60 +463,47 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
return node_desc;
}
void MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst,
std::string link_name) {
std::string call_arguments;
absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst,
std::string link_name) {
dst->dst_tensors_ids[0] = src->dst_tensors_ids[0];
dst->description += " linked : " + src->description;
for (int i = 0; i < src->task->src_tensors_names.size(); ++i) {
std::string tensor_name = src->task->src_tensors_names[i] + link_name;
call_arguments += ", " + tensor_name;
dst->task->src_tensors_names.push_back(tensor_name);
// dst->task->AddSrcTensor(tensor_name,
// src->task->definition.src_tensors[i + 1]);
std::string tensor_name = src->task->src_tensors_names[i];
dst->task->src_tensors_names.push_back(tensor_name + link_name + "_buffer");
auto desc_new = absl::make_unique<TensorDescriptor>(
src->task->definition.src_tensors[i + 1]);
src->task->args.AddObjectRef(tensor_name, AccessType::READ,
std::move(desc_new));
dst->task->definition.src_tensors.push_back(
src->task->definition.src_tensors[i + 1]);
dst->src_tensors_ids.push_back(src->src_tensors_ids[i + 1]);
}
for (int i = 0; i < src->task->immutable_buffers.size(); ++i) {
auto buffer = src->task->immutable_buffers[i];
const std::string buffer_name = "ibuffer" + std::to_string(i) + link_name;
buffer.declaration += " " + buffer_name;
call_arguments += ", " + buffer_name;
dst->task->immutable_buffers.push_back(buffer);
}
std::string code = src->task->shader_source;
src->task->args.RenameArgs(link_name, &code);
for (int i = 0; i < src->task->uniform_buffers.size(); ++i) {
auto buffer = src->task->uniform_buffers[i];
const std::string buffer_name = "ubuffer" + std::to_string(i) + link_name;
buffer.declaration += " " + buffer_name;
call_arguments += ", " + buffer_name;
dst->task->uniform_buffers.push_back(buffer);
}
RETURN_IF_ERROR(dst->task->args.Merge(std::move(src->task->args), link_name));
std::string function_code =
absl::Substitute(src->task->shader_source, link_name) + "\n";
std::string call_code =
absl::Substitute("value = linkable$0(value, linear_index, gid$1);\n",
link_name, call_arguments);
dst->task->shader_source = absl::Substitute(dst->task->shader_source, "$0",
"$1", "{\n" + code + "\n}\n$2");
dst->task->shader_source = absl::Substitute(
dst->task->shader_source, function_code + "$0", "$1", call_code + "$2");
return absl::OkStatus();
}
NodeDescriptor FuseChain(const FusionSequence& chain) {
NodeDescriptor node_desc;
absl::Status FuseChain(const FusionSequence& chain, NodeDescriptor* node_desc) {
if (chain.front().task->is_linkable) {
node_desc = NonLinkableStub(
*node_desc = NonLinkableStub(
chain.front().id, chain.front().src_tensors_ids[0],
chain.front().dst_tensors_ids[0], chain.front().task->definition);
MergeNodes(&chain.front(), &node_desc, "_link0");
RETURN_IF_ERROR(MergeNodes(&chain.front(), node_desc, "_link0"));
} else {
node_desc = chain.front();
*node_desc = chain.front();
}
for (int j = 1; j < chain.size(); ++j) {
MergeNodes(&chain[j], &node_desc, "_link" + std::to_string(j));
RETURN_IF_ERROR(
MergeNodes(&chain[j], node_desc, "_link" + std::to_string(j)));
}
return node_desc;
return absl::OkStatus();
}
} // namespace
@ -697,8 +675,11 @@ absl::Status InferenceContext::ValidateOptimizeModel(
std::to_string(info.missing_output_buffer_ids.size());
return absl::InternalError(message);
}
for (const auto& chain : sorted_chains)
output_model->nodes.push_back(FuseChain(chain));
for (const auto& chain : sorted_chains) {
NodeDescriptor fused_node;
RETURN_IF_ERROR(FuseChain(chain, &fused_node));
output_model->nodes.push_back(std::move(fused_node));
}
output_model->tensor_shapes = input_model.tensor_shapes;
return absl::OkStatus();
}

View File

@ -876,6 +876,7 @@ objc_library(
"padding_test.mm",
"pooling_test.mm",
"prelu_test.mm",
"quantize_and_dequantize_test.mm",
"relu_test.mm",
"reshape_test.mm",
"resize_test.mm",
@ -897,6 +898,7 @@ objc_library(
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal:common",
"//tensorflow/lite/delegates/gpu/metal:inference_context",
"//tensorflow/lite/kernels/internal:quantization_util",
],
)

View File

@ -31,31 +31,16 @@ limitations under the License.
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::string GetAddTableCodeFused(int src_count) {
std::string code = "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid";
for (int i = 0; i < src_count; ++i) {
code += ", device FLT4* const src_buf" + std::to_string(i);
}
code += ") {\n";
for (int i = 0; i < src_count; ++i) {
code += " value += src_buf" + std::to_string(i) + "[linear_index];\n";
}
code += " return value;\n";
code += "}\n";
return code;
}
} // namespace
ComputeTaskDescriptor Add(const OperationDef& definition) {
ComputeTaskDescriptor desc(definition);
desc.is_linkable = true;
desc.shader_source = GetAddTableCodeFused(definition.src_tensors.size() - 1);
for (int i = 1; i < definition.src_tensors.size(); ++i) {
desc.AddSrcTensor("src_tensor_" + std::to_string(i),
definition.src_tensors[i]);
const std::string tensor_name = "src_tensor_" + std::to_string(i);
desc.AddSrcTensor(tensor_name, definition.src_tensors[i]);
desc.shader_source +=
" value += args." + tensor_name + ".Read(gid.x, gid.y, gid.z);\n";
}
return desc;

View File

@ -88,38 +88,22 @@ ComputeTaskDescriptor ElementwiseWithTwoInputs(const OperationDef& definition,
OperationType op_type) {
ComputeTaskDescriptor desc(definition);
desc.is_linkable = true;
const std::string x_coord = second_shape.w == 1 ? "0" : "int(gid.x)";
const std::string y_coord = second_shape.h == 1 ? "0" : "int(gid.y)";
const std::string s_coord = second_shape.c == 1 ? "0" : "int(gid.z)";
std::string code =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* "
"const second_tensor, int2 second_size) {\n";
code += " int second_index = (" + s_coord + " * second_size.y + " + y_coord +
") * second_size.x + " + x_coord + ";\n";
code += " FLT4 src_1 = second_tensor[second_index];\n";
const std::string x_coord = second_shape.w == 1 ? "0" : "gid.x";
const std::string y_coord = second_shape.h == 1 ? "0" : "gid.y";
const std::string s_coord = second_shape.c == 1 ? "0" : "gid.z";
std::string code;
code = " FLT4 src_1 = args.second_tensor.Read(" + x_coord + ", " + y_coord +
", " + s_coord + ");\n";
if (second_shape.c == 1) {
code += " src_1.y = src_1.x;\n";
code += " src_1.z = src_1.x;\n";
code += " src_1.w = src_1.x;\n";
}
code += " return " + TwoInputFunctor(op_type, "value", "src_1") + ";\n";
code += "}\n";
code += " value = " + TwoInputFunctor(op_type, "value", "src_1") + ";\n";
desc.shader_source = code;
desc.AddSrcTensor("second_tensor", definition.src_tensors[1]);
desc.uniform_buffers = {
{"constant int2&",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
std::vector<int> uniform_params{
src_shapes[1].w,
src_shapes[1].h,
};
return GetByteBuffer(uniform_params);
}},
};
return desc;
}
@ -128,10 +112,7 @@ ComputeTaskDescriptor ElementwiseWithOneInput(const OperationDef& definition,
ComputeTaskDescriptor desc(definition);
desc.is_linkable = true;
desc.shader_source =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {\n";
desc.shader_source +=
" return " + OneInputFunctor(op_type, "value") + ";\n";
desc.shader_source += " }";
" value = " + OneInputFunctor(op_type, "value") + ";\n";
return desc;
}
@ -141,32 +122,35 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
auto scalar = absl::get_if<float>(&attr);
auto linear_buf = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr);
auto hwc_buf = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr);
std::string param_desc;
if (scalar) {
param_desc += ", float scalar_val";
}
if (linear_buf) {
param_desc += ", device FLT4* const linear_buf";
}
if (hwc_buf) {
param_desc += ", device FLT4* const hwc_buf, int2 hwc_size";
}
ComputeTaskDescriptor desc(definition);
desc.is_linkable = true;
desc.shader_source =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid" + param_desc +
") {\n";
if (scalar) {
desc.shader_source += " FLT4 second_arg = FLT4(scalar_val);\n";
desc.args.AddFloat("scalar_val", *scalar);
desc.shader_source += " FLT4 second_arg = FLT4(args.scalar_val);\n";
} else if (linear_buf) {
desc.shader_source += " FLT4 second_arg = linear_buf[gid.z];\n";
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
const int dst_channels_aligned = AlignByN(linear_buf->shape.v, 4);
BufferDescriptor linear_desc;
linear_desc.element_type = data_type;
linear_desc.element_size = 4;
linear_desc.data = GetByteBufferConvertedResized(
linear_buf->data, data_type, dst_channels_aligned);
linear_desc.size = linear_desc.data.size();
desc.args.AddObject(
"linear", absl::make_unique<BufferDescriptor>(std::move(linear_desc)));
desc.shader_source += " FLT4 second_arg = args.linear.Read(gid.z);\n";
} else if (hwc_buf) {
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "int(gid.x)";
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "int(gid.y)";
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "int(gid.z)";
std::string index = "(" + s_coord + " * hwc_size.y + " + y_coord +
") * hwc_size.x + " + x_coord;
desc.shader_source += " FLT4 second_arg = hwc_buf[" + index + "];\n";
TensorDescriptor hwc_desc{definition.GetDataType(),
TensorStorageType::BUFFER, Layout::HWC};
hwc_desc.UploadData(*hwc_buf);
desc.args.AddObject(
"hwc", absl::make_unique<TensorDescriptor>(std::move(hwc_desc)));
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "gid.x";
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "gid.y";
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "gid.z";
desc.shader_source += " FLT4 second_arg = args.hwc.Read(" + x_coord +
", " + y_coord + ", " + s_coord + ");\n";
if (hwc_buf->shape.c == 1) {
desc.shader_source += " second_arg.y = second_arg.x;\n";
desc.shader_source += " second_arg.z = second_arg.x;\n";
@ -174,40 +158,7 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
}
}
desc.shader_source +=
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
desc.shader_source += " }";
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
if (scalar) {
std::vector<uint8_t> scalar_bits =
GetByteBuffer(std::vector<float>{*scalar});
desc.uniform_buffers = {
{"constant float&",
[scalar_bits](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return scalar_bits;
}},
};
} else if (linear_buf) {
desc.immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(linear_buf->data, data_type)},
};
} else if (hwc_buf) {
std::vector<uint8_t> size_bits =
GetByteBuffer(std::vector<int>{hwc_buf->shape.w, hwc_buf->shape.h});
desc.uniform_buffers = {
{"constant int2&",
[size_bits](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return size_bits;
}},
};
desc.immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(ConvertToPHWC4(*hwc_buf), data_type)},
};
}
" value = " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
return desc;
}

View File

@ -43,34 +43,27 @@ ComputeTaskDescriptor PReLU(const OperationDef& definition,
ComputeTaskDescriptor desc(definition);
desc.is_linkable = true;
if (attr.clip != 0) {
desc.args.AddFloat("clip", attr.clip);
desc.shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
device FLT4* const alphas, float clip) {
return FLT4(clamp(value, FLT4(0.0f), FLT4(clip)) + alphas[gid.z] * min(FLT4(0.0f), value));
})";
R"(
value = FLT4(clamp(value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(gid.z) * min(FLT4(0.0f), value));
)";
} else {
desc.shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
device FLT4* const alphas) {
return FLT4(max(FLT4(0.0f), value) + alphas[gid.z] * min(FLT4(0.0f), value));
})";
R"(
value = FLT4(max(FLT4(0.0f), value) + args.alpha.Read(gid.z) * min(FLT4(0.0f), value));
)";
}
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
desc.immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(alpha_buffer->data, data_type)},
};
if (attr.clip != 0) {
desc.uniform_buffers = {
{"constant float&",
[attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
std::vector<uint8_t> attr_clip =
GetByteBuffer(std::vector<float>{attr.clip});
return attr_clip;
}},
};
}
const int dst_channels_aligned = AlignByN(alpha_buffer->shape.v, 4);
BufferDescriptor alpha_desc;
alpha_desc.element_type = data_type;
alpha_desc.element_size = 4;
alpha_desc.data = GetByteBufferConvertedResized(alpha_buffer->data, data_type,
dst_channels_aligned);
alpha_desc.size = alpha_desc.data.size();
desc.args.AddObject(
"alpha", absl::make_unique<BufferDescriptor>(std::move(alpha_desc)));
return desc;
}
@ -83,34 +76,22 @@ ComputeTaskDescriptor PReLUFull(const OperationDef& definition,
ComputeTaskDescriptor desc(definition);
desc.is_linkable = true;
if (attr.clip != 0) {
desc.args.AddFloat("clip", attr.clip);
desc.shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
device FLT4* const alphas, float clip) {
return FLT4(clamp(value, FLT4(0.0f), FLT4(clip)) + alphas[linear_index] * min(FLT4(0.0f), value));
})";
R"(
value = FLT4(clamp(value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(gid.x, gid.y, gid.z) * min(FLT4(0.0f), value));
)";
} else {
desc.shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
device FLT4* const alphas) {
return FLT4(max(FLT4(0.0f), value) + alphas[linear_index] * min(FLT4(0.0f), value));
})";
}
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
desc.immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(ConvertToPHWC4(*alpha), data_type)},
};
if (attr.clip != 0) {
desc.uniform_buffers = {
{"constant float&",
[attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
std::vector<uint8_t> attr_clip =
GetByteBuffer(std::vector<float>{attr.clip});
return attr_clip;
}},
};
R"(
value = FLT4(max(FLT4(0.0f), value) + args.alpha.Read(gid.x, gid.y, gid.z) * min(FLT4(0.0f), value));
)";
}
TensorDescriptor alpha_desc{definition.GetDataType(),
TensorStorageType::BUFFER, Layout::HWC};
alpha_desc.UploadData(*alpha);
desc.args.AddObject(
"alpha", absl::make_unique<TensorDescriptor>(std::move(alpha_desc)));
return desc;
}

View File

@ -29,20 +29,12 @@ ComputeTaskDescriptor QuantizeAndDequantize(
ComputeTaskDescriptor desc(definition);
desc.is_linkable = true;
desc.shader_source = R"(
FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, float3 params) {
value = clamp(value, FLT4(params.x), FLT4(params.y));
value = (value - FLT4(params.x)) / FLT4(params.z);
return round(value) * FLT4(params.z) + FLT4(params.x);
}
)";
desc.uniform_buffers = {
{"constant float3&",
[attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return GetByteBuffer(
std::vector<float>{attr.min, attr.max, attr.scale});
}},
};
value = clamp(value, FLT4(args.qmin), FLT4(args.qmax));
value = (value - FLT4(args.qmin)) / FLT4(args.qscale);
value = round(value) * FLT4(args.qscale) + FLT4(args.qmin);)";
desc.args.AddFloat("qmax", attr.max);
desc.args.AddFloat("qmin", attr.min);
desc.args.AddFloat("qscale", attr.scale);
return desc;
}

View File

@ -35,24 +35,15 @@ ComputeTaskDescriptor ReLU(const OperationDef& definition,
ComputeTaskDescriptor desc(definition);
desc.is_linkable = true;
const std::string min_func =
attr.alpha == 0 ? "FLT4(0.0f)" : "min(value * params.x, 0.0f)";
const std::string parameters =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, float2 params) "
"{\n";
attr.alpha == 0 ? "FLT4(0.0f)" : "min(value * args.alpha, 0.0f)";
if (attr.clip != 0.0) {
desc.shader_source = parameters + " return FLT4(clamp(value, " + min_func +
", FLT4(params.y)));\n}";
} else {
desc.shader_source =
parameters + " return FLT4(max(value, " + min_func + "));\n}";
"value = FLT4(clamp(value, " + min_func + ", FLT4(args.clip)));";
} else {
desc.shader_source = "value = FLT4(max(value, " + min_func + "));";
}
desc.uniform_buffers = {
{"constant float2&",
[attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return GetByteBuffer(std::vector<float>{attr.alpha, attr.clip});
}},
};
desc.args.AddFloat("alpha", attr.alpha);
desc.args.AddFloat("clip", attr.clip);
return desc;
}

View File

@ -17,9 +17,10 @@ limitations under the License.
#include <string>
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/common/task/util.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
namespace tflite {
namespace gpu {
@ -128,6 +129,14 @@ absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
return absl::OkStatus();
}
const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
if (tensor_desc) {
MetalSpatialTensor gpu_tensor;
RETURN_IF_ERROR(gpu_tensor.CreateFromDescriptor(*tensor_desc, device));
*result = absl::make_unique<MetalSpatialTensor>(std::move(gpu_tensor));
return absl::OkStatus();
}
return absl::InvalidArgumentError("Unknown GPU descriptor.");
}
} // namespace
@ -140,10 +149,10 @@ absl::Status MetalArguments::Init(id<MTLDevice> device, int buffer_offset,
RETURN_IF_ERROR(AllocateObjects(*args, device));
RETURN_IF_ERROR(AddObjectArgs(args));
RETURN_IF_ERROR(ResolveSelectorsPass(*args, {}, code));
RETURN_IF_ERROR(SetObjectsResources(*args));
object_refs_ = std::move(args->object_refs_);
args->GetActiveArguments(kArgsPrefix, *code);
std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code);
RETURN_IF_ERROR(SetObjectsResources(*args));
ResolveArgsPass(code);
*code = absl::Substitute(*code, struct_desc, GetListOfArgs(buffer_offset));
return absl::OkStatus();