Elementwise ops in Metal converted to new style(used arguments and tensors).
PiperOrigin-RevId: 350402612 Change-Id: If770f4edff502c39886bf3c285d32f165cee69b3
This commit is contained in:
parent
bcd5dd0148
commit
7d841e13c4
tensorflow/lite/delegates/gpu/metal
@ -188,6 +188,7 @@ objc_library(
|
||||
deps = [
|
||||
":buffer",
|
||||
":gpu_object",
|
||||
":metal_spatial_tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:arguments",
|
||||
|
@ -256,8 +256,7 @@ std::vector<ValueId> ComputeTask::GetInputIds() const {
|
||||
|
||||
void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
|
||||
input_buffers_[index].metal_handle = tensor.GetBufferHandle();
|
||||
if (tensors_as_args_ &&
|
||||
absl::StrContains(src_tensors_names_[index], "_buffer")) {
|
||||
if (absl::StrContains(src_tensors_names_[index], "_buffer")) {
|
||||
auto name = src_tensors_names_[index];
|
||||
// extracting tensor_name from "tensor_name_buffer";
|
||||
name = name.substr(0, name.size() - 7);
|
||||
@ -267,7 +266,7 @@ void ComputeTask::SetSrcTensor(const MetalSpatialTensor& tensor, int index) {
|
||||
|
||||
void ComputeTask::SetDstTensor(const MetalSpatialTensor& tensor, int index) {
|
||||
output_buffers_[index].metal_handle = tensor.GetBufferHandle();
|
||||
if (tensors_as_args_) {
|
||||
if (absl::StrContains(dst_tensors_names_[index], "_buffer")) {
|
||||
auto name = dst_tensors_names_[index];
|
||||
// extracting tensor_name from "tensor_name_buffer";
|
||||
name = name.substr(0, name.size() - 7);
|
||||
|
@ -426,38 +426,29 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const OperationDef& definition) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->is_linkable = false;
|
||||
desc->tensors_as_args = true;
|
||||
desc->shader_source = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
$0
|
||||
kernel void ComputeFunction($1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
if (int(gid.x) >= size.x || int(gid.y) >= size.y) {
|
||||
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
|
||||
return;
|
||||
}
|
||||
const int linear_index = (gid.z * size.y + gid.y) * size.x + gid.x;
|
||||
FLT4 value = src_tensor[linear_index];
|
||||
FLT4 value = args.src_tensor.Read(gid.x, gid.y, gid.z);
|
||||
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
|
||||
$2
|
||||
dst_tensor[linear_index] = value;
|
||||
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
|
||||
}
|
||||
)";
|
||||
|
||||
desc->AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||
desc->AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||
|
||||
desc->uniform_buffers = {
|
||||
{"constant int2& size",
|
||||
[](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
return GetByteBuffer(
|
||||
std::vector<int>{src_shapes[0].w, src_shapes[0].h});
|
||||
}},
|
||||
};
|
||||
|
||||
desc->resize_function = [](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
uint3 groups_size{16, 16, 1};
|
||||
uint3 groups_size{8, 8, 1};
|
||||
uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x),
|
||||
DivideRoundUp(dst_shapes[0].h, groups_size.y),
|
||||
DivideRoundUp(dst_shapes[0].c, 4)};
|
||||
@ -472,60 +463,47 @@ NodeDescriptor NonLinkableStub(int operation_id, ValueId input_id,
|
||||
return node_desc;
|
||||
}
|
||||
|
||||
void MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst,
|
||||
std::string link_name) {
|
||||
std::string call_arguments;
|
||||
absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst,
|
||||
std::string link_name) {
|
||||
dst->dst_tensors_ids[0] = src->dst_tensors_ids[0];
|
||||
dst->description += " linked : " + src->description;
|
||||
for (int i = 0; i < src->task->src_tensors_names.size(); ++i) {
|
||||
std::string tensor_name = src->task->src_tensors_names[i] + link_name;
|
||||
call_arguments += ", " + tensor_name;
|
||||
dst->task->src_tensors_names.push_back(tensor_name);
|
||||
// dst->task->AddSrcTensor(tensor_name,
|
||||
// src->task->definition.src_tensors[i + 1]);
|
||||
std::string tensor_name = src->task->src_tensors_names[i];
|
||||
dst->task->src_tensors_names.push_back(tensor_name + link_name + "_buffer");
|
||||
auto desc_new = absl::make_unique<TensorDescriptor>(
|
||||
src->task->definition.src_tensors[i + 1]);
|
||||
src->task->args.AddObjectRef(tensor_name, AccessType::READ,
|
||||
std::move(desc_new));
|
||||
dst->task->definition.src_tensors.push_back(
|
||||
src->task->definition.src_tensors[i + 1]);
|
||||
dst->src_tensors_ids.push_back(src->src_tensors_ids[i + 1]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < src->task->immutable_buffers.size(); ++i) {
|
||||
auto buffer = src->task->immutable_buffers[i];
|
||||
const std::string buffer_name = "ibuffer" + std::to_string(i) + link_name;
|
||||
buffer.declaration += " " + buffer_name;
|
||||
call_arguments += ", " + buffer_name;
|
||||
dst->task->immutable_buffers.push_back(buffer);
|
||||
}
|
||||
std::string code = src->task->shader_source;
|
||||
src->task->args.RenameArgs(link_name, &code);
|
||||
|
||||
for (int i = 0; i < src->task->uniform_buffers.size(); ++i) {
|
||||
auto buffer = src->task->uniform_buffers[i];
|
||||
const std::string buffer_name = "ubuffer" + std::to_string(i) + link_name;
|
||||
buffer.declaration += " " + buffer_name;
|
||||
call_arguments += ", " + buffer_name;
|
||||
dst->task->uniform_buffers.push_back(buffer);
|
||||
}
|
||||
RETURN_IF_ERROR(dst->task->args.Merge(std::move(src->task->args), link_name));
|
||||
|
||||
std::string function_code =
|
||||
absl::Substitute(src->task->shader_source, link_name) + "\n";
|
||||
std::string call_code =
|
||||
absl::Substitute("value = linkable$0(value, linear_index, gid$1);\n",
|
||||
link_name, call_arguments);
|
||||
dst->task->shader_source = absl::Substitute(dst->task->shader_source, "$0",
|
||||
"$1", "{\n" + code + "\n}\n$2");
|
||||
|
||||
dst->task->shader_source = absl::Substitute(
|
||||
dst->task->shader_source, function_code + "$0", "$1", call_code + "$2");
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
NodeDescriptor FuseChain(const FusionSequence& chain) {
|
||||
NodeDescriptor node_desc;
|
||||
absl::Status FuseChain(const FusionSequence& chain, NodeDescriptor* node_desc) {
|
||||
if (chain.front().task->is_linkable) {
|
||||
node_desc = NonLinkableStub(
|
||||
*node_desc = NonLinkableStub(
|
||||
chain.front().id, chain.front().src_tensors_ids[0],
|
||||
chain.front().dst_tensors_ids[0], chain.front().task->definition);
|
||||
MergeNodes(&chain.front(), &node_desc, "_link0");
|
||||
RETURN_IF_ERROR(MergeNodes(&chain.front(), node_desc, "_link0"));
|
||||
} else {
|
||||
node_desc = chain.front();
|
||||
*node_desc = chain.front();
|
||||
}
|
||||
for (int j = 1; j < chain.size(); ++j) {
|
||||
MergeNodes(&chain[j], &node_desc, "_link" + std::to_string(j));
|
||||
RETURN_IF_ERROR(
|
||||
MergeNodes(&chain[j], node_desc, "_link" + std::to_string(j)));
|
||||
}
|
||||
return node_desc;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -697,8 +675,11 @@ absl::Status InferenceContext::ValidateOptimizeModel(
|
||||
std::to_string(info.missing_output_buffer_ids.size());
|
||||
return absl::InternalError(message);
|
||||
}
|
||||
for (const auto& chain : sorted_chains)
|
||||
output_model->nodes.push_back(FuseChain(chain));
|
||||
for (const auto& chain : sorted_chains) {
|
||||
NodeDescriptor fused_node;
|
||||
RETURN_IF_ERROR(FuseChain(chain, &fused_node));
|
||||
output_model->nodes.push_back(std::move(fused_node));
|
||||
}
|
||||
output_model->tensor_shapes = input_model.tensor_shapes;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -876,6 +876,7 @@ objc_library(
|
||||
"padding_test.mm",
|
||||
"pooling_test.mm",
|
||||
"prelu_test.mm",
|
||||
"quantize_and_dequantize_test.mm",
|
||||
"relu_test.mm",
|
||||
"reshape_test.mm",
|
||||
"resize_test.mm",
|
||||
@ -897,6 +898,7 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/metal:common",
|
||||
"//tensorflow/lite/delegates/gpu/metal:inference_context",
|
||||
"//tensorflow/lite/kernels/internal:quantization_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -31,31 +31,16 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
|
||||
std::string GetAddTableCodeFused(int src_count) {
|
||||
std::string code = "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid";
|
||||
for (int i = 0; i < src_count; ++i) {
|
||||
code += ", device FLT4* const src_buf" + std::to_string(i);
|
||||
}
|
||||
code += ") {\n";
|
||||
for (int i = 0; i < src_count; ++i) {
|
||||
code += " value += src_buf" + std::to_string(i) + "[linear_index];\n";
|
||||
}
|
||||
code += " return value;\n";
|
||||
code += "}\n";
|
||||
return code;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ComputeTaskDescriptor Add(const OperationDef& definition) {
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
desc.shader_source = GetAddTableCodeFused(definition.src_tensors.size() - 1);
|
||||
|
||||
for (int i = 1; i < definition.src_tensors.size(); ++i) {
|
||||
desc.AddSrcTensor("src_tensor_" + std::to_string(i),
|
||||
definition.src_tensors[i]);
|
||||
const std::string tensor_name = "src_tensor_" + std::to_string(i);
|
||||
desc.AddSrcTensor(tensor_name, definition.src_tensors[i]);
|
||||
desc.shader_source +=
|
||||
" value += args." + tensor_name + ".Read(gid.x, gid.y, gid.z);\n";
|
||||
}
|
||||
|
||||
return desc;
|
||||
|
@ -88,38 +88,22 @@ ComputeTaskDescriptor ElementwiseWithTwoInputs(const OperationDef& definition,
|
||||
OperationType op_type) {
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
const std::string x_coord = second_shape.w == 1 ? "0" : "int(gid.x)";
|
||||
const std::string y_coord = second_shape.h == 1 ? "0" : "int(gid.y)";
|
||||
const std::string s_coord = second_shape.c == 1 ? "0" : "int(gid.z)";
|
||||
std::string code =
|
||||
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* "
|
||||
"const second_tensor, int2 second_size) {\n";
|
||||
code += " int second_index = (" + s_coord + " * second_size.y + " + y_coord +
|
||||
") * second_size.x + " + x_coord + ";\n";
|
||||
code += " FLT4 src_1 = second_tensor[second_index];\n";
|
||||
const std::string x_coord = second_shape.w == 1 ? "0" : "gid.x";
|
||||
const std::string y_coord = second_shape.h == 1 ? "0" : "gid.y";
|
||||
const std::string s_coord = second_shape.c == 1 ? "0" : "gid.z";
|
||||
std::string code;
|
||||
code = " FLT4 src_1 = args.second_tensor.Read(" + x_coord + ", " + y_coord +
|
||||
", " + s_coord + ");\n";
|
||||
if (second_shape.c == 1) {
|
||||
code += " src_1.y = src_1.x;\n";
|
||||
code += " src_1.z = src_1.x;\n";
|
||||
code += " src_1.w = src_1.x;\n";
|
||||
}
|
||||
code += " return " + TwoInputFunctor(op_type, "value", "src_1") + ";\n";
|
||||
code += "}\n";
|
||||
code += " value = " + TwoInputFunctor(op_type, "value", "src_1") + ";\n";
|
||||
|
||||
desc.shader_source = code;
|
||||
|
||||
desc.AddSrcTensor("second_tensor", definition.src_tensors[1]);
|
||||
|
||||
desc.uniform_buffers = {
|
||||
{"constant int2&",
|
||||
[](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
std::vector<int> uniform_params{
|
||||
src_shapes[1].w,
|
||||
src_shapes[1].h,
|
||||
};
|
||||
return GetByteBuffer(uniform_params);
|
||||
}},
|
||||
};
|
||||
return desc;
|
||||
}
|
||||
|
||||
@ -128,10 +112,7 @@ ComputeTaskDescriptor ElementwiseWithOneInput(const OperationDef& definition,
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
desc.shader_source =
|
||||
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {\n";
|
||||
desc.shader_source +=
|
||||
" return " + OneInputFunctor(op_type, "value") + ";\n";
|
||||
desc.shader_source += " }";
|
||||
" value = " + OneInputFunctor(op_type, "value") + ";\n";
|
||||
return desc;
|
||||
}
|
||||
|
||||
@ -141,32 +122,35 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
|
||||
auto scalar = absl::get_if<float>(&attr);
|
||||
auto linear_buf = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr);
|
||||
auto hwc_buf = absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr);
|
||||
std::string param_desc;
|
||||
if (scalar) {
|
||||
param_desc += ", float scalar_val";
|
||||
}
|
||||
if (linear_buf) {
|
||||
param_desc += ", device FLT4* const linear_buf";
|
||||
}
|
||||
if (hwc_buf) {
|
||||
param_desc += ", device FLT4* const hwc_buf, int2 hwc_size";
|
||||
}
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
desc.shader_source =
|
||||
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid" + param_desc +
|
||||
") {\n";
|
||||
if (scalar) {
|
||||
desc.shader_source += " FLT4 second_arg = FLT4(scalar_val);\n";
|
||||
desc.args.AddFloat("scalar_val", *scalar);
|
||||
desc.shader_source += " FLT4 second_arg = FLT4(args.scalar_val);\n";
|
||||
} else if (linear_buf) {
|
||||
desc.shader_source += " FLT4 second_arg = linear_buf[gid.z];\n";
|
||||
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
|
||||
const int dst_channels_aligned = AlignByN(linear_buf->shape.v, 4);
|
||||
BufferDescriptor linear_desc;
|
||||
linear_desc.element_type = data_type;
|
||||
linear_desc.element_size = 4;
|
||||
linear_desc.data = GetByteBufferConvertedResized(
|
||||
linear_buf->data, data_type, dst_channels_aligned);
|
||||
linear_desc.size = linear_desc.data.size();
|
||||
desc.args.AddObject(
|
||||
"linear", absl::make_unique<BufferDescriptor>(std::move(linear_desc)));
|
||||
desc.shader_source += " FLT4 second_arg = args.linear.Read(gid.z);\n";
|
||||
} else if (hwc_buf) {
|
||||
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "int(gid.x)";
|
||||
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "int(gid.y)";
|
||||
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "int(gid.z)";
|
||||
std::string index = "(" + s_coord + " * hwc_size.y + " + y_coord +
|
||||
") * hwc_size.x + " + x_coord;
|
||||
desc.shader_source += " FLT4 second_arg = hwc_buf[" + index + "];\n";
|
||||
TensorDescriptor hwc_desc{definition.GetDataType(),
|
||||
TensorStorageType::BUFFER, Layout::HWC};
|
||||
hwc_desc.UploadData(*hwc_buf);
|
||||
desc.args.AddObject(
|
||||
"hwc", absl::make_unique<TensorDescriptor>(std::move(hwc_desc)));
|
||||
|
||||
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "gid.x";
|
||||
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "gid.y";
|
||||
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "gid.z";
|
||||
desc.shader_source += " FLT4 second_arg = args.hwc.Read(" + x_coord +
|
||||
", " + y_coord + ", " + s_coord + ");\n";
|
||||
if (hwc_buf->shape.c == 1) {
|
||||
desc.shader_source += " second_arg.y = second_arg.x;\n";
|
||||
desc.shader_source += " second_arg.z = second_arg.x;\n";
|
||||
@ -174,40 +158,7 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
|
||||
}
|
||||
}
|
||||
desc.shader_source +=
|
||||
" return " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
|
||||
desc.shader_source += " }";
|
||||
|
||||
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
|
||||
if (scalar) {
|
||||
std::vector<uint8_t> scalar_bits =
|
||||
GetByteBuffer(std::vector<float>{*scalar});
|
||||
desc.uniform_buffers = {
|
||||
{"constant float&",
|
||||
[scalar_bits](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
return scalar_bits;
|
||||
}},
|
||||
};
|
||||
} else if (linear_buf) {
|
||||
desc.immutable_buffers = {
|
||||
{"device FLT4* const",
|
||||
GetByteBufferConverted(linear_buf->data, data_type)},
|
||||
};
|
||||
} else if (hwc_buf) {
|
||||
std::vector<uint8_t> size_bits =
|
||||
GetByteBuffer(std::vector<int>{hwc_buf->shape.w, hwc_buf->shape.h});
|
||||
desc.uniform_buffers = {
|
||||
{"constant int2&",
|
||||
[size_bits](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
return size_bits;
|
||||
}},
|
||||
};
|
||||
desc.immutable_buffers = {
|
||||
{"device FLT4* const",
|
||||
GetByteBufferConverted(ConvertToPHWC4(*hwc_buf), data_type)},
|
||||
};
|
||||
}
|
||||
" value = " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
|
||||
return desc;
|
||||
}
|
||||
|
||||
|
@ -43,34 +43,27 @@ ComputeTaskDescriptor PReLU(const OperationDef& definition,
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
if (attr.clip != 0) {
|
||||
desc.args.AddFloat("clip", attr.clip);
|
||||
desc.shader_source =
|
||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
|
||||
device FLT4* const alphas, float clip) {
|
||||
return FLT4(clamp(value, FLT4(0.0f), FLT4(clip)) + alphas[gid.z] * min(FLT4(0.0f), value));
|
||||
})";
|
||||
R"(
|
||||
value = FLT4(clamp(value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(gid.z) * min(FLT4(0.0f), value));
|
||||
)";
|
||||
} else {
|
||||
desc.shader_source =
|
||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
|
||||
device FLT4* const alphas) {
|
||||
return FLT4(max(FLT4(0.0f), value) + alphas[gid.z] * min(FLT4(0.0f), value));
|
||||
})";
|
||||
R"(
|
||||
value = FLT4(max(FLT4(0.0f), value) + args.alpha.Read(gid.z) * min(FLT4(0.0f), value));
|
||||
)";
|
||||
}
|
||||
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
|
||||
desc.immutable_buffers = {
|
||||
{"device FLT4* const",
|
||||
GetByteBufferConverted(alpha_buffer->data, data_type)},
|
||||
};
|
||||
if (attr.clip != 0) {
|
||||
desc.uniform_buffers = {
|
||||
{"constant float&",
|
||||
[attr](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
std::vector<uint8_t> attr_clip =
|
||||
GetByteBuffer(std::vector<float>{attr.clip});
|
||||
return attr_clip;
|
||||
}},
|
||||
};
|
||||
}
|
||||
const int dst_channels_aligned = AlignByN(alpha_buffer->shape.v, 4);
|
||||
BufferDescriptor alpha_desc;
|
||||
alpha_desc.element_type = data_type;
|
||||
alpha_desc.element_size = 4;
|
||||
alpha_desc.data = GetByteBufferConvertedResized(alpha_buffer->data, data_type,
|
||||
dst_channels_aligned);
|
||||
alpha_desc.size = alpha_desc.data.size();
|
||||
desc.args.AddObject(
|
||||
"alpha", absl::make_unique<BufferDescriptor>(std::move(alpha_desc)));
|
||||
return desc;
|
||||
}
|
||||
|
||||
@ -83,34 +76,22 @@ ComputeTaskDescriptor PReLUFull(const OperationDef& definition,
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
if (attr.clip != 0) {
|
||||
desc.args.AddFloat("clip", attr.clip);
|
||||
desc.shader_source =
|
||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
|
||||
device FLT4* const alphas, float clip) {
|
||||
return FLT4(clamp(value, FLT4(0.0f), FLT4(clip)) + alphas[linear_index] * min(FLT4(0.0f), value));
|
||||
})";
|
||||
R"(
|
||||
value = FLT4(clamp(value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(gid.x, gid.y, gid.z) * min(FLT4(0.0f), value));
|
||||
)";
|
||||
} else {
|
||||
desc.shader_source =
|
||||
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid,
|
||||
device FLT4* const alphas) {
|
||||
return FLT4(max(FLT4(0.0f), value) + alphas[linear_index] * min(FLT4(0.0f), value));
|
||||
})";
|
||||
}
|
||||
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
|
||||
desc.immutable_buffers = {
|
||||
{"device FLT4* const",
|
||||
GetByteBufferConverted(ConvertToPHWC4(*alpha), data_type)},
|
||||
};
|
||||
if (attr.clip != 0) {
|
||||
desc.uniform_buffers = {
|
||||
{"constant float&",
|
||||
[attr](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
std::vector<uint8_t> attr_clip =
|
||||
GetByteBuffer(std::vector<float>{attr.clip});
|
||||
return attr_clip;
|
||||
}},
|
||||
};
|
||||
R"(
|
||||
value = FLT4(max(FLT4(0.0f), value) + args.alpha.Read(gid.x, gid.y, gid.z) * min(FLT4(0.0f), value));
|
||||
)";
|
||||
}
|
||||
TensorDescriptor alpha_desc{definition.GetDataType(),
|
||||
TensorStorageType::BUFFER, Layout::HWC};
|
||||
alpha_desc.UploadData(*alpha);
|
||||
desc.args.AddObject(
|
||||
"alpha", absl::make_unique<TensorDescriptor>(std::move(alpha_desc)));
|
||||
return desc;
|
||||
}
|
||||
|
||||
|
@ -29,20 +29,12 @@ ComputeTaskDescriptor QuantizeAndDequantize(
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
desc.shader_source = R"(
|
||||
FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, float3 params) {
|
||||
value = clamp(value, FLT4(params.x), FLT4(params.y));
|
||||
value = (value - FLT4(params.x)) / FLT4(params.z);
|
||||
return round(value) * FLT4(params.z) + FLT4(params.x);
|
||||
}
|
||||
)";
|
||||
desc.uniform_buffers = {
|
||||
{"constant float3&",
|
||||
[attr](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
return GetByteBuffer(
|
||||
std::vector<float>{attr.min, attr.max, attr.scale});
|
||||
}},
|
||||
};
|
||||
value = clamp(value, FLT4(args.qmin), FLT4(args.qmax));
|
||||
value = (value - FLT4(args.qmin)) / FLT4(args.qscale);
|
||||
value = round(value) * FLT4(args.qscale) + FLT4(args.qmin);)";
|
||||
desc.args.AddFloat("qmax", attr.max);
|
||||
desc.args.AddFloat("qmin", attr.min);
|
||||
desc.args.AddFloat("qscale", attr.scale);
|
||||
return desc;
|
||||
}
|
||||
|
||||
|
@ -35,24 +35,15 @@ ComputeTaskDescriptor ReLU(const OperationDef& definition,
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
const std::string min_func =
|
||||
attr.alpha == 0 ? "FLT4(0.0f)" : "min(value * params.x, 0.0f)";
|
||||
const std::string parameters =
|
||||
"FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, float2 params) "
|
||||
"{\n";
|
||||
attr.alpha == 0 ? "FLT4(0.0f)" : "min(value * args.alpha, 0.0f)";
|
||||
if (attr.clip != 0.0) {
|
||||
desc.shader_source = parameters + " return FLT4(clamp(value, " + min_func +
|
||||
", FLT4(params.y)));\n}";
|
||||
} else {
|
||||
desc.shader_source =
|
||||
parameters + " return FLT4(max(value, " + min_func + "));\n}";
|
||||
"value = FLT4(clamp(value, " + min_func + ", FLT4(args.clip)));";
|
||||
} else {
|
||||
desc.shader_source = "value = FLT4(max(value, " + min_func + "));";
|
||||
}
|
||||
desc.uniform_buffers = {
|
||||
{"constant float2&",
|
||||
[attr](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
return GetByteBuffer(std::vector<float>{attr.alpha, attr.clip});
|
||||
}},
|
||||
};
|
||||
desc.args.AddFloat("alpha", attr.alpha);
|
||||
desc.args.AddFloat("clip", attr.clip);
|
||||
return desc;
|
||||
}
|
||||
|
||||
|
@ -17,9 +17,10 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
@ -128,6 +129,14 @@ absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
|
||||
if (tensor_desc) {
|
||||
MetalSpatialTensor gpu_tensor;
|
||||
RETURN_IF_ERROR(gpu_tensor.CreateFromDescriptor(*tensor_desc, device));
|
||||
*result = absl::make_unique<MetalSpatialTensor>(std::move(gpu_tensor));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
return absl::InvalidArgumentError("Unknown GPU descriptor.");
|
||||
}
|
||||
} // namespace
|
||||
@ -140,10 +149,10 @@ absl::Status MetalArguments::Init(id<MTLDevice> device, int buffer_offset,
|
||||
RETURN_IF_ERROR(AllocateObjects(*args, device));
|
||||
RETURN_IF_ERROR(AddObjectArgs(args));
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(*args, {}, code));
|
||||
RETURN_IF_ERROR(SetObjectsResources(*args));
|
||||
object_refs_ = std::move(args->object_refs_);
|
||||
args->GetActiveArguments(kArgsPrefix, *code);
|
||||
std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code);
|
||||
RETURN_IF_ERROR(SetObjectsResources(*args));
|
||||
ResolveArgsPass(code);
|
||||
*code = absl::Substitute(*code, struct_desc, GetListOfArgs(buffer_offset));
|
||||
return absl::OkStatus();
|
||||
|
Loading…
Reference in New Issue
Block a user