Elementwise ops in Metal converted to new style(used arguments and tensors).

PiperOrigin-RevId: 350402612
Change-Id: If770f4edff502c39886bf3c285d32f165cee69b3
This commit is contained in:
Raman Sarokin 2021-01-06 12:10:50 -08:00 committed by TensorFlower Gardener
parent bcd5dd0148
commit 7d841e13c4
10 changed files with 128 additions and 236 deletions

View File

@ -188,6 +188,7 @@ objc_library(
deps = [ deps = [
":buffer", ":buffer",
":gpu_object", ":gpu_object",
":metal_spatial_tensor",
"//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/task:arguments", "//tensorflow/lite/delegates/gpu/common/task:arguments",

View File

@ -256,8 +256,7 @@ std::vector<ValueId> ComputeTask::GetInputIds() const {
void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) { void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
input_buffers_[index].metal_handle = tensor.GetBufferHandle(); input_buffers_[index].metal_handle = tensor.GetBufferHandle();
if (tensors_as_args_ && if (absl::StrContains(src_tensors_names_[index], "_buffer")) {
absl::StrContains(src_tensors_names_[index], "_buffer")) {
auto name = src_tensors_names_[index]; auto name = src_tensors_names_[index];
// extracting tensor_name from "tensor_name_buffer"; // extracting tensor_name from "tensor_name_buffer";
name = name.substr(0, name.size() - 7); name = name.substr(0, name.size() - 7);
@ -267,7 +266,7 @@ void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
void ComputeTask::SetDstTensor(const MetalSpatialTensor& tensor, int index) { void ComputeTask::SetDstTensor(const MetalSpatialTensor& tensor, int index) {
output_buffers_[index].metal_handle = tensor.GetBufferHandle(); output_buffers_[index].metal_handle = tensor.GetBufferHandle();
if (tensors_as_args_) { if (absl::StrContains(dst_tensors_names_[index], "_buffer")) {
auto name = dst_tensors_names_[index]; auto name = dst_tensors_names_[index];
// extracting tensor_name from "tensor_name_buffer"; // extracting tensor_name from "tensor_name_buffer";
name = name.substr(0, name.size() - 7); name = name.substr(0, name.size() - 7);

View File

@ -426,38 +426,29 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
ValueId output_id, ValueId output_id,
const OperationDef& definition) { const OperationDef& definition) {
auto desc = std::make_shared<ComputeTaskDescriptor>(); auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->is_linkable = false; desc->tensors_as_args = true;
desc->shader_source = R"( desc->shader_source = R"(
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
$0 $0
kernel void ComputeFunction($1 kernel void ComputeFunction($1
uint3 gid[[thread_position_in_grid]]) { uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.x || int(gid.y) >= size.y) { if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
return; return;
} }
const int linear_index = (gid.z * size.y + gid.y) * size.x + gid.x; FLT4 value = args.src_tensor.Read(gid.x, gid.y, gid.z);
FLT4 value = src_tensor[linear_index]; args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
$2 $2
dst_tensor[linear_index] = value; args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
} }
)"; )";
desc->AddSrcTensor("src_tensor", definition.src_tensors[0]); desc->AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc->AddDstTensor("dst_tensor", definition.dst_tensors[0]); desc->AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc->uniform_buffers = {
{"constant int2& size",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return GetByteBuffer(
std::vector<int>{src_shapes[0].w, src_shapes[0].h});
}},
};
desc->resize_function = [](const std::vector<BHWC>& src_shapes, desc->resize_function = [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) { const std::vector<BHWC>& dst_shapes) {
uint3 groups_size{16, 16, 1}; uint3 groups_size{8, 8, 1};
uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x), uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x),
DivideRoundUp(dst_shapes[0].h, groups_size.y), DivideRoundUp(dst_shapes[0].h, groups_size.y),
DivideRoundUp(dst_shapes[0].c, 4)}; DivideRoundUp(dst_shapes[0].c, 4)};
@ -472,60 +463,47 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
return node_desc; return node_desc;
} }
void MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst, absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst,
std::string link_name) { std::string link_name) {
std::string call_arguments;
dst->dst_tensors_ids[0] = src->dst_tensors_ids[0]; dst->dst_tensors_ids[0] = src->dst_tensors_ids[0];
dst->description += " linked : " + src->description; dst->description += " linked : " + src->description;
for (int i = 0; i < src->task->src_tensors_names.size(); ++i) { for (int i = 0; i < src->task->src_tensors_names.size(); ++i) {
std::string tensor_name = src->task->src_tensors_names[i] + link_name; std::string tensor_name = src->task->src_tensors_names[i];
call_arguments += ", " + tensor_name; dst->task->src_tensors_names.push_back(tensor_name + link_name + "_buffer");
dst->task->src_tensors_names.push_back(tensor_name); auto desc_new = absl::make_unique<TensorDescriptor>(
// dst->task->AddSrcTensor(tensor_name, src->task->definition.src_tensors[i + 1]);
// src->task->definition.src_tensors[i + 1]); src->task->args.AddObjectRef(tensor_name, AccessType::READ,
std::move(desc_new));
dst->task->definition.src_tensors.push_back(
src->task->definition.src_tensors[i + 1]);
dst->src_tensors_ids.push_back(src->src_tensors_ids[i + 1]); dst->src_tensors_ids.push_back(src->src_tensors_ids[i + 1]);
} }
for (int i = 0; i < src->task->immutable_buffers.size(); ++i) { std::string code = src->task->shader_source;
auto buffer = src->task->immutable_buffers[i]; src->task->args.RenameArgs(link_name, &code);
const std::string buffer_name = "ibuffer" + std::to_string(i) + link_name;
buffer.declaration += " " + buffer_name; RETURN_IF_ERROR(dst->task->args.Merge(std::move(src->task->args), link_name));
call_arguments += ", " + buffer_name;
dst->task->immutable_buffers.push_back(buffer); dst->task->shader_source = absl::Substitute(dst->task->shader_source, "$0",
"$1", "{\n" + code + "\n}\n$2");
return absl::OkStatus();
} }
for (int i = 0; i < src->task->uniform_buffers.size(); ++i) { absl::Status FuseChain(const FusionSequence& chain, NodeDescriptor* node_desc) {
auto buffer = src->task->uniform_buffers[i];
const std::string buffer_name = "ubuffer" + std::to_string(i) + link_name;
buffer.declaration += " " + buffer_name;
call_arguments += ", " + buffer_name;
dst->task->uniform_buffers.push_back(buffer);
}
std::string function_code =
absl::Substitute(src->task->shader_source, link_name) + "\n";
std::string call_code =
absl::Substitute("value = linkable$0(value, linear_index, gid$1);\n",
link_name, call_arguments);
dst->task->shader_source = absl::Substitute(
dst->task->shader_source, function_code + "$0", "$1", call_code + "$2");
}
NodeDescriptor FuseChain(const FusionSequence& chain) {
NodeDescriptor node_desc;
if (chain.front().task->is_linkable) { if (chain.front().task->is_linkable) {
node_desc = NonLinkableStub( *node_desc = NonLinkableStub(
chain.front().id, chain.front().src_tensors_ids[0], chain.front().id, chain.front().src_tensors_ids[0],
chain.front().dst_tensors_ids[0], chain.front().task->definition); chain.front().dst_tensors_ids[0], chain.front().task->definition);
MergeNodes(&chain.front(), &node_desc, "_link0"); RETURN_IF_ERROR(MergeNodes(&chain.front(), node_desc, "_link0"));
} else { } else {
node_desc = chain.front(); *node_desc = chain.front();
} }
for (int j = 1; j < chain.size(); ++j) { for (int j = 1; j < chain.size(); ++j) {
MergeNodes(&chain[j], &node_desc, "_link" + std::to_string(j)); RETURN_IF_ERROR(
MergeNodes(&chain[j], node_desc, "_link" + std::to_string(j)));
} }
return node_desc; return absl::OkStatus();
} }
} // namespace } // namespace
@ -697,8 +675,11 @@ absl::Status InferenceContext::ValidateOptimizeModel(
std::to_string(info.missing_output_buffer_ids.size()); std::to_string(info.missing_output_buffer_ids.size());
return absl::InternalError(message); return absl::InternalError(message);
} }
for (const auto& chain : sorted_chains) for (const auto& chain : sorted_chains) {
output_model->nodes.push_back(FuseChain(chain)); NodeDescriptor fused_node;
RETURN_IF_ERROR(FuseChain(chain, &fused_node));
output_model->nodes.push_back(std::move(fused_node));
}
output_model->tensor_shapes = input_model.tensor_shapes; output_model->tensor_shapes = input_model.tensor_shapes;
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -876,6 +876,7 @@ objc_library(
"padding_test.mm", "padding_test.mm",
"pooling_test.mm", "pooling_test.mm",
"prelu_test.mm", "prelu_test.mm",
"quantize_and_dequantize_test.mm",
"relu_test.mm", "relu_test.mm",
"reshape_test.mm", "reshape_test.mm",
"resize_test.mm", "resize_test.mm",
@ -897,6 +898,7 @@ objc_library(
"//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal:common", "//tensorflow/lite/delegates/gpu/metal:common",
"//tensorflow/lite/delegates/gpu/metal:inference_context", "//tensorflow/lite/delegates/gpu/metal:inference_context",
"//tensorflow/lite/kernels/internal:quantization_util",
], ],
) )

View File

@ -31,31 +31,16 @@ limitations under the License.
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
namespace metal { namespace metal {
namespace {
std::string GetAddTableCodeFused(int src_count) {
std::string code = "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid";
for (int i = 0; i < src_count; ++i) {
code += ", device FLT4* const src_buf" + std::to_string(i);
}
code += ") {\n";
for (int i = 0; i < src_count; ++i) {
code += " value += src_buf" + std::to_string(i) + "[linear_index];\n";
}
code += " return value;\n";
code += "}\n";
return code;
}
} // namespace
ComputeTaskDescriptor Add(const OperationDef& definition) { ComputeTaskDescriptor Add(const OperationDef& definition) {
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.is_linkable = true; desc.is_linkable = true;
desc.shader_source = GetAddTableCodeFused(definition.src_tensors.size() - 1);
for (int i = 1; i < definition.src_tensors.size(); ++i) { for (int i = 1; i < definition.src_tensors.size(); ++i) {
desc.AddSrcTensor("src_tensor_" + std::to_string(i), const std::string tensor_name = "src_tensor_" + std::to_string(i);
definition.src_tensors[i]); desc.AddSrcTensor(tensor_name, definition.src_tensors[i]);
desc.shader_source +=
" value += args." + tensor_name + ".Read(gid.x, gid.y, gid.z);\n";
} }
return desc; return desc;

View File

@ -88,38 +88,22 @@ ComputeTaskDescriptor ElementwiseWithTwoInputs(const OperationDef& definition,
OperationType op_type) { OperationType op_type) {
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.is_linkable = true; desc.is_linkable = true;
const std::string x_coord = second_shape.w == 1 ? "0" : "int(gid.x)"; const std::string x_coord = second_shape.w == 1 ? "0" : "gid.x";
const std::string y_coord = second_shape.h == 1 ? "0" : "int(gid.y)"; const std::string y_coord = second_shape.h == 1 ? "0" : "gid.y";
const std::string s_coord = second_shape.c == 1 ? "0" : "int(gid.z)"; const std::string s_coord = second_shape.c == 1 ? "0" : "gid.z";
std::string code = std::string code;
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* " code = " FLT4 src_1 = args.second_tensor.Read(" + x_coord + ", " + y_coord +
"const second_tensor, int2 second_size) {\n"; ", " + s_coord + ");\n";
code += " int second_index = (" + s_coord + " * second_size.y + " + y_coord +
") * second_size.x + " + x_coord + ";\n";
code += " FLT4 src_1 = second_tensor[second_index];\n";
if (second_shape.c == 1) { if (second_shape.c == 1) {
code += " src_1.y = src_1.x;\n"; code += " src_1.y = src_1.x;\n";
code += " src_1.z = src_1.x;\n"; code += " src_1.z = src_1.x;\n";
code += " src_1.w = src_1.x;\n"; code += " src_1.w = src_1.x;\n";
} }
code += " return " + TwoInputFunctor(op_type, "value", "src_1") + ";\n"; code += " value = " + TwoInputFunctor(op_type, "value", "src_1") + ";\n";
code += "}\n";
desc.shader_source = code; desc.shader_source = code;
desc.AddSrcTensor("second_tensor", definition.src_tensors[1]); desc.AddSrcTensor("second_tensor", definition.src_tensors[1]);
desc.uniform_buffers = {
{"constant int2&",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
std::vector<int> uniform_params{
src_shapes[1].w,
src_shapes[1].h,
};
return GetByteBuffer(uniform_params);
}},
};
return desc; return desc;
} }
@ -128,10 +112,7 @@ ComputeTaskDescriptor ElementwiseWithOneInput(const OperationDef& definition,
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.is_linkable = true; desc.is_linkable = true;
desc.shader_source = desc.shader_source =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {\n"; " value = " + OneInputFunctor(op_type, "value") + ";\n";
desc.shader_source +=
" return " + OneInputFunctor(op_type, "value") + ";\n";
desc.shader_source += " }";
return desc; return desc;
} }
@ -141,32 +122,35 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
auto scalar = absl::get_if<float>(&attr); auto scalar = absl::get_if<float>(&attr);
auto linear_buf = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr); auto linear_buf = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr);
auto hwc_buf = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr); auto hwc_buf = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr);
std::string param_desc;
if (scalar) {
param_desc += ", float scalar_val";
}
if (linear_buf) {
param_desc += ", device FLT4* const linear_buf";
}
if (hwc_buf) {
param_desc += ", device FLT4* const hwc_buf, int2 hwc_size";
}
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.is_linkable = true; desc.is_linkable = true;
desc.shader_source =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid" + param_desc +
") {\n";
if (scalar) { if (scalar) {
desc.shader_source += " FLT4 second_arg = FLT4(scalar_val);\n"; desc.args.AddFloat("scalar_val", *scalar);
desc.shader_source += " FLT4 second_arg = FLT4(args.scalar_val);\n";
} else if (linear_buf) { } else if (linear_buf) {
desc.shader_source += " FLT4 second_arg = linear_buf[gid.z];\n"; auto data_type = DeduceDataTypeFromPrecision(definition.precision);
const int dst_channels_aligned = AlignByN(linear_buf->shape.v, 4);
BufferDescriptor linear_desc;
linear_desc.element_type = data_type;
linear_desc.element_size = 4;
linear_desc.data = GetByteBufferConvertedResized(
linear_buf->data, data_type, dst_channels_aligned);
linear_desc.size = linear_desc.data.size();
desc.args.AddObject(
"linear", absl::make_unique<BufferDescriptor>(std::move(linear_desc)));
desc.shader_source += " FLT4 second_arg = args.linear.Read(gid.z);\n";
} else if (hwc_buf) { } else if (hwc_buf) {
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "int(gid.x)"; TensorDescriptor hwc_desc{definition.GetDataType(),
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "int(gid.y)"; TensorStorageType::BUFFER, Layout::HWC};
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "int(gid.z)"; hwc_desc.UploadData(*hwc_buf);
std::string index = "(" + s_coord + " * hwc_size.y + " + y_coord + desc.args.AddObject(
") * hwc_size.x + " + x_coord; "hwc", absl::make_unique<TensorDescriptor>(std::move(hwc_desc)));
desc.shader_source += " FLT4 second_arg = hwc_buf[" + index + "];\n";
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "gid.x";
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "gid.y";
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "gid.z";
desc.shader_source += " FLT4 second_arg = args.hwc.Read(" + x_coord +
", " + y_coord + ", " + s_coord + ");\n";
if (hwc_buf->shape.c == 1) { if (hwc_buf->shape.c == 1) {
desc.shader_source += " second_arg.y = second_arg.x;\n"; desc.shader_source += " second_arg.y = second_arg.x;\n";
desc.shader_source += " second_arg.z = second_arg.x;\n"; desc.shader_source += " second_arg.z = second_arg.x;\n";
@ -174,40 +158,7 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
} }
} }
desc.shader_source += desc.shader_source +=
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n"; " value = " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
desc.shader_source += " }";
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
if (scalar) {
std::vector<uint8_t> scalar_bits =
GetByteBuffer(std::vector<float>{*scalar});
desc.uniform_buffers = {
{"constant float&",
[scalar_bits](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return scalar_bits;
}},
};
} else if (linear_buf) {
desc.immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(linear_buf->data, data_type)},
};
} else if (hwc_buf) {
std::vector<uint8_t> size_bits =
GetByteBuffer(std::vector<int>{hwc_buf->shape.w, hwc_buf->shape.h});
desc.uniform_buffers = {
{"constant int2&",
[size_bits](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return size_bits;
}},
};
desc.immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(ConvertToPHWC4(*hwc_buf), data_type)},
};
}
return desc; return desc;
} }

View File

@ -43,34 +43,27 @@ ComputeTaskDescriptor PReLU(const OperationDef& definition,
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.is_linkable = true; desc.is_linkable = true;
if (attr.clip != 0) { if (attr.clip != 0) {
desc.args.AddFloat("clip", attr.clip);
desc.shader_source = desc.shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, R"(
device FLT4* const alphas, float clip) { value = FLT4(clamp(value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(gid.z) * min(FLT4(0.0f), value));
return FLT4(clamp(value, FLT4(0.0f), FLT4(clip)) + alphas[gid.z] * min(FLT4(0.0f), value)); )";
})";
} else { } else {
desc.shader_source = desc.shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, R"(
device FLT4* const alphas) { value = FLT4(max(FLT4(0.0f), value) + args.alpha.Read(gid.z) * min(FLT4(0.0f), value));
return FLT4(max(FLT4(0.0f), value) + alphas[gid.z] * min(FLT4(0.0f), value)); )";
})";
} }
auto data_type = DeduceDataTypeFromPrecision(definition.precision); auto data_type = DeduceDataTypeFromPrecision(definition.precision);
desc.immutable_buffers = { const int dst_channels_aligned = AlignByN(alpha_buffer->shape.v, 4);
{"device FLT4* const", BufferDescriptor alpha_desc;
GetByteBufferConverted(alpha_buffer->data, data_type)}, alpha_desc.element_type = data_type;
}; alpha_desc.element_size = 4;
if (attr.clip != 0) { alpha_desc.data = GetByteBufferConvertedResized(alpha_buffer->data, data_type,
desc.uniform_buffers = { dst_channels_aligned);
{"constant float&", alpha_desc.size = alpha_desc.data.size();
[attr](const std::vector<BHWC>& src_shapes, desc.args.AddObject(
const std::vector<BHWC>& dst_shapes) { "alpha", absl::make_unique<BufferDescriptor>(std::move(alpha_desc)));
std::vector<uint8_t> attr_clip =
GetByteBuffer(std::vector<float>{attr.clip});
return attr_clip;
}},
};
}
return desc; return desc;
} }
@ -83,34 +76,22 @@ ComputeTaskDescriptor PReLUFull(const OperationDef& definition,
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.is_linkable = true; desc.is_linkable = true;
if (attr.clip != 0) { if (attr.clip != 0) {
desc.args.AddFloat("clip", attr.clip);
desc.shader_source = desc.shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, R"(
device FLT4* const alphas, float clip) { value = FLT4(clamp(value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(gid.x, gid.y, gid.z) * min(FLT4(0.0f), value));
return FLT4(clamp(value, FLT4(0.0f), FLT4(clip)) + alphas[linear_index] * min(FLT4(0.0f), value)); )";
})";
} else { } else {
desc.shader_source = desc.shader_source =
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, R"(
device FLT4* const alphas) { value = FLT4(max(FLT4(0.0f), value) + args.alpha.Read(gid.x, gid.y, gid.z) * min(FLT4(0.0f), value));
return FLT4(max(FLT4(0.0f), value) + alphas[linear_index] * min(FLT4(0.0f), value)); )";
})";
}
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
desc.immutable_buffers = {
{"device FLT4* const",
GetByteBufferConverted(ConvertToPHWC4(*alpha), data_type)},
};
if (attr.clip != 0) {
desc.uniform_buffers = {
{"constant float&",
[attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
std::vector<uint8_t> attr_clip =
GetByteBuffer(std::vector<float>{attr.clip});
return attr_clip;
}},
};
} }
TensorDescriptor alpha_desc{definition.GetDataType(),
TensorStorageType::BUFFER, Layout::HWC};
alpha_desc.UploadData(*alpha);
desc.args.AddObject(
"alpha", absl::make_unique<TensorDescriptor>(std::move(alpha_desc)));
return desc; return desc;
} }

View File

@ -29,20 +29,12 @@ ComputeTaskDescriptor QuantizeAndDequantize(
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.is_linkable = true; desc.is_linkable = true;
desc.shader_source = R"( desc.shader_source = R"(
FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, float3 params) { value = clamp(value, FLT4(args.qmin), FLT4(args.qmax));
value = clamp(value, FLT4(params.x), FLT4(params.y)); value = (value - FLT4(args.qmin)) / FLT4(args.qscale);
value = (value - FLT4(params.x)) / FLT4(params.z); value = round(value) * FLT4(args.qscale) + FLT4(args.qmin);)";
return round(value) * FLT4(params.z) + FLT4(params.x); desc.args.AddFloat("qmax", attr.max);
} desc.args.AddFloat("qmin", attr.min);
)"; desc.args.AddFloat("qscale", attr.scale);
desc.uniform_buffers = {
{"constant float3&",
[attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return GetByteBuffer(
std::vector<float>{attr.min, attr.max, attr.scale});
}},
};
return desc; return desc;
} }

View File

@ -35,24 +35,15 @@ ComputeTaskDescriptor ReLU(const OperationDef& definition,
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.is_linkable = true; desc.is_linkable = true;
const std::string min_func = const std::string min_func =
attr.alpha == 0 ? "FLT4(0.0f)" : "min(value * params.x, 0.0f)"; attr.alpha == 0 ? "FLT4(0.0f)" : "min(value * args.alpha, 0.0f)";
const std::string parameters =
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, float2 params) "
"{\n";
if (attr.clip != 0.0) { if (attr.clip != 0.0) {
desc.shader_source = parameters + " return FLT4(clamp(value, " + min_func +
", FLT4(params.y)));\n}";
} else {
desc.shader_source = desc.shader_source =
parameters + " return FLT4(max(value, " + min_func + "));\n}"; "value = FLT4(clamp(value, " + min_func + ", FLT4(args.clip)));";
} else {
desc.shader_source = "value = FLT4(max(value, " + min_func + "));";
} }
desc.uniform_buffers = { desc.args.AddFloat("alpha", attr.alpha);
{"constant float2&", desc.args.AddFloat("clip", attr.clip);
[attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return GetByteBuffer(std::vector<float>{attr.alpha, attr.clip});
}},
};
return desc; return desc;
} }

View File

@ -17,9 +17,10 @@ limitations under the License.
#include <string> #include <string>
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/common/task/util.h" #include "tensorflow/lite/delegates/gpu/common/task/util.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
@ -128,6 +129,14 @@ absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
return absl::OkStatus(); return absl::OkStatus();
} }
const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
if (tensor_desc) {
MetalSpatialTensor gpu_tensor;
RETURN_IF_ERROR(gpu_tensor.CreateFromDescriptor(*tensor_desc, device));
*result = absl::make_unique<MetalSpatialTensor>(std::move(gpu_tensor));
return absl::OkStatus();
}
return absl::InvalidArgumentError("Unknown GPU descriptor."); return absl::InvalidArgumentError("Unknown GPU descriptor.");
} }
} // namespace } // namespace
@ -140,10 +149,10 @@ absl::Status MetalArguments::Init(id<MTLDevice> device, int buffer_offset,
RETURN_IF_ERROR(AllocateObjects(*args, device)); RETURN_IF_ERROR(AllocateObjects(*args, device));
RETURN_IF_ERROR(AddObjectArgs(args)); RETURN_IF_ERROR(AddObjectArgs(args));
RETURN_IF_ERROR(ResolveSelectorsPass(*args, {}, code)); RETURN_IF_ERROR(ResolveSelectorsPass(*args, {}, code));
RETURN_IF_ERROR(SetObjectsResources(*args));
object_refs_ = std::move(args->object_refs_); object_refs_ = std::move(args->object_refs_);
args->GetActiveArguments(kArgsPrefix, *code); args->GetActiveArguments(kArgsPrefix, *code);
std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code); std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code);
RETURN_IF_ERROR(SetObjectsResources(*args));
ResolveArgsPass(code); ResolveArgsPass(code);
*code = absl::Substitute(*code, struct_desc, GetListOfArgs(buffer_offset)); *code = absl::Substitute(*code, struct_desc, GetListOfArgs(buffer_offset));
return absl::OkStatus(); return absl::OkStatus();