Added syntax for template arguments for gpu objects selectors.
Now we can write args.src_tensor.Read<float>(...). Demonstrated in Winograd kernel. PiperOrigin-RevId: 314576804 Change-Id: I27cb6e7d251b05d489c9c1dbbe4df8640df05651
This commit is contained in:
parent
f805f48644
commit
929398ef01
@ -45,6 +45,7 @@ size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
|
||||
{'(', ')'},
|
||||
{'{', '}'},
|
||||
{'[', ']'},
|
||||
{'<', '>'},
|
||||
};
|
||||
char b_open = bracket;
|
||||
auto it = brackets.find(b_open);
|
||||
@ -70,6 +71,28 @@ size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status ParseArgsInsideBrackets(const std::string& text,
|
||||
size_t open_bracket_pos,
|
||||
size_t* close_bracket_pos,
|
||||
std::vector<std::string>* args) {
|
||||
*close_bracket_pos =
|
||||
FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]);
|
||||
if (*close_bracket_pos == -1) {
|
||||
return absl::NotFoundError("Not found enclosing bracket");
|
||||
}
|
||||
std::string str_args = text.substr(open_bracket_pos + 1,
|
||||
*close_bracket_pos - open_bracket_pos - 2);
|
||||
std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
|
||||
args->reserve(words.size());
|
||||
for (const auto& word : words) {
|
||||
absl::string_view arg = absl::StripAsciiWhitespace(word);
|
||||
if (!arg.empty()) {
|
||||
args->push_back(std::string(arg));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
|
||||
std::string* str) {
|
||||
size_t position = str->find(old_word);
|
||||
@ -534,10 +557,10 @@ void Arguments::ResolveObjectNames(const std::string& object_name,
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status Arguments::ResolveSelector(const std::string& object_name,
|
||||
const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
std::string* result) {
|
||||
absl::Status Arguments::ResolveSelector(
|
||||
const std::string& object_name, const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) {
|
||||
const GPUObjectDescriptor* desc_ptr;
|
||||
AccessType access_type;
|
||||
if (auto it = object_refs_.find(object_name); it != object_refs_.end()) {
|
||||
@ -550,7 +573,8 @@ absl::Status Arguments::ResolveSelector(const std::string& object_name,
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No object with name - ", object_name));
|
||||
}
|
||||
RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, args, result));
|
||||
RETURN_IF_ERROR(
|
||||
desc_ptr->PerformSelector(selector, args, template_args, result));
|
||||
auto names = desc_ptr->GetGPUResources(access_type).GetNames();
|
||||
ResolveObjectNames(object_name, names, result);
|
||||
return absl::OkStatus();
|
||||
@ -570,32 +594,26 @@ absl::Status Arguments::ResolveSelectorsPass(std::string* code) {
|
||||
std::string selector_name = GetNextWord(*code, next_position);
|
||||
next_position += selector_name.size();
|
||||
next = (*code)[next_position];
|
||||
std::vector<std::string> template_args;
|
||||
if (next == '<') {
|
||||
size_t close_bracket_pos;
|
||||
RETURN_IF_ERROR(ParseArgsInsideBrackets(
|
||||
*code, next_position, &close_bracket_pos, &template_args));
|
||||
next_position = close_bracket_pos;
|
||||
next = (*code)[next_position];
|
||||
}
|
||||
if (next != '(') {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("Expected ( after function ", selector_name, " call"));
|
||||
}
|
||||
next_position += 1;
|
||||
size_t bracket_pos = FindEnclosingBracket(*code, next_position, '(');
|
||||
if (bracket_pos == -1) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("Not found enclosing bracket for function ",
|
||||
selector_name, " call"));
|
||||
}
|
||||
std::string str_args =
|
||||
code->substr(next_position, bracket_pos - next_position - 1);
|
||||
std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
|
||||
std::vector<std::string> args;
|
||||
args.reserve(words.size());
|
||||
for (const auto& word : words) {
|
||||
absl::string_view arg = absl::StripAsciiWhitespace(word);
|
||||
if (!arg.empty()) {
|
||||
args.push_back(std::string(arg));
|
||||
}
|
||||
}
|
||||
size_t close_bracket_pos;
|
||||
RETURN_IF_ERROR(ParseArgsInsideBrackets(*code, next_position,
|
||||
&close_bracket_pos, &args));
|
||||
std::string patch;
|
||||
RETURN_IF_ERROR(
|
||||
ResolveSelector(object_name, selector_name, args, &patch));
|
||||
code->replace(arg_pos, bracket_pos - arg_pos, patch);
|
||||
RETURN_IF_ERROR(ResolveSelector(object_name, selector_name, args,
|
||||
template_args, &patch));
|
||||
code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
|
||||
position = arg_pos + patch.size();
|
||||
} else {
|
||||
position = arg_pos + strlen(kArgsPrefix);
|
||||
|
@ -88,6 +88,7 @@ class Arguments {
|
||||
absl::Status ResolveSelector(const std::string& object_name,
|
||||
const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result);
|
||||
|
||||
void ResolveObjectNames(const std::string& object_name,
|
||||
|
@ -123,9 +123,10 @@ class GPUObjectDescriptor {
|
||||
return "";
|
||||
}
|
||||
|
||||
virtual absl::Status PerformSelector(const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
std::string* result) const {
|
||||
virtual absl::Status PerformSelector(
|
||||
const std::string& selector, const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result) const {
|
||||
*result = "";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -38,16 +38,6 @@ std::string GetWinograd4x4To36Code(
|
||||
const OperationDef& op_def,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations,
|
||||
Arguments* args) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data",
|
||||
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data",
|
||||
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
op_def.dst_tensors[0]);
|
||||
|
||||
const std::string batch_id = op_def.IsBatchSupported() ? "batch_id" : "";
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
|
||||
const auto src_tensor_type = op_def.src_tensors[0].storage_type;
|
||||
@ -80,23 +70,30 @@ std::string GetWinograd4x4To36Code(
|
||||
}
|
||||
c += "};\n";
|
||||
|
||||
std::string cl_type = accum_type == DataType::FLOAT16 ? "half" : "float";
|
||||
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
|
||||
src_desc->SetStateVar("ACCUM_FLT", cl_type);
|
||||
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
args->AddInt("padding_x");
|
||||
args->AddInt("padding_y");
|
||||
args->AddInt("tiles_total");
|
||||
args->AddInt("tiles_x");
|
||||
|
||||
std::string linked_args = GetArgsDeclaration(linked_operations);
|
||||
if (linked_args[0] == ',') {
|
||||
linked_args[0] = ' ';
|
||||
}
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size,\n ";
|
||||
c += linked_args;
|
||||
c += "$0) {\n";
|
||||
c += " int DST_X = get_global_id(0);\n";
|
||||
c += " int DST_Y = get_global_id(1);\n";
|
||||
c += " int DST_Z = get_global_id(2);\n";
|
||||
c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= dst_size.z) "
|
||||
"{\n";
|
||||
c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= "
|
||||
"args.dst_tensor.Slices()) {\n";
|
||||
c += " return; \n";
|
||||
c += " }\n";
|
||||
c += " int tile_x = (DST_X % args.tiles_x) * 4;\n";
|
||||
@ -114,19 +111,16 @@ std::string GetWinograd4x4To36Code(
|
||||
c += " bt_ar[5] = t1.y;\n";
|
||||
auto read_src = [&](const std::string& src, const std::string& xs) {
|
||||
if (is_image_buffer) {
|
||||
c += " ACCUM_FLT4 " + src + " = " +
|
||||
src_tensor.ReadAsType(accum_type, "src_a_" + xs + " + offset") +
|
||||
";\n";
|
||||
c += " ACCUM_FLT4 " + src +
|
||||
" = args.src_tensor.Read<ACCUM_FLT>(src_a_" + xs + " + offset);\n";
|
||||
} else if (is_buffer) {
|
||||
c += " ACCUM_FLT4 " + src + " = " +
|
||||
src_tensor.ReadAsType(accum_type, "src_a_" + xs + " + offset") +
|
||||
" * m" + xs + "_x;\n";
|
||||
c += " ACCUM_FLT4 " + src +
|
||||
" = args.src_tensor.Read<ACCUM_FLT>(src_a_" + xs + " + offset) * m" +
|
||||
xs + "_x;\n";
|
||||
} else {
|
||||
c += " ACCUM_FLT4 " + src + " = " +
|
||||
src_tensor.ReadAsTypeWHSB(accum_type,
|
||||
"tile_x + args.padding_x + " + xs, "yc",
|
||||
"DST_Z", batch_id) +
|
||||
";\n";
|
||||
c += " ACCUM_FLT4 " + src +
|
||||
" = args.src_tensor.Read<ACCUM_FLT>(tile_x + args.padding_x + " +
|
||||
xs + ", yc, DST_Z);\n";
|
||||
}
|
||||
};
|
||||
if (is_buffer || is_image_buffer) {
|
||||
@ -134,14 +128,17 @@ std::string GetWinograd4x4To36Code(
|
||||
const std::string xs = std::to_string(x);
|
||||
c += " int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n";
|
||||
c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" +
|
||||
xs + " < src_size.x);\n";
|
||||
xs + " < args.src_tensor.Width());\n";
|
||||
c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
|
||||
" < src_size.x);\n";
|
||||
c += " xc" + xs + " = clamp(xc" + xs + ", 0, src_size.x - 1);\n";
|
||||
c += " " + src_tensor.GetAddressWHSB("src_a_" + xs, "xc" + xs, "0",
|
||||
"DST_Z", batch_id);
|
||||
" < args.src_tensor.Width());\n";
|
||||
c += " xc" + xs + " = clamp(xc" + xs +
|
||||
", 0, args.src_tensor.Width() - 1);\n";
|
||||
c += " args.src_tensor.GetAddress(src_a_" + xs + ", xc" + xs +
|
||||
", 0, DST_Z);\n";
|
||||
if (is_image_buffer) {
|
||||
c += " src_a_" + xs + " = select(-src_size.x * src_size.y, src_a_" +
|
||||
c += " src_a_" + xs +
|
||||
" = select(-args.src_tensor.Width() * args.src_tensor.Height(), "
|
||||
"src_a_" +
|
||||
xs + ", inx" + xs + ");\n";
|
||||
}
|
||||
}
|
||||
@ -149,8 +146,8 @@ std::string GetWinograd4x4To36Code(
|
||||
c += " {\n";
|
||||
c += " int yc = tile_y + args.padding_y;\n";
|
||||
if (is_buffer || is_image_buffer) {
|
||||
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
|
||||
c += " int offset = select(0, yc * src_size.x, iny);\n";
|
||||
c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
|
||||
c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
|
||||
c += " ACCUM_FLT bt = bt_ar[0] * (ACCUM_FLT)(iny);\n";
|
||||
} else {
|
||||
c += " ACCUM_FLT bt = bt_ar[0];\n";
|
||||
@ -167,8 +164,8 @@ std::string GetWinograd4x4To36Code(
|
||||
c += " {\n";
|
||||
c += " int yc = tile_y + args.padding_y + (" + ys + ");\n";
|
||||
if (is_buffer || is_image_buffer) {
|
||||
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
|
||||
c += " int offset = select(0, yc * src_size.x, iny);\n";
|
||||
c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
|
||||
c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
|
||||
c += " ACCUM_FLT bt = bt_ar[" + ys + "] * (ACCUM_FLT)(iny);\n";
|
||||
} else {
|
||||
c += " ACCUM_FLT bt = bt_ar[" + ys + "];\n";
|
||||
@ -185,14 +182,14 @@ std::string GetWinograd4x4To36Code(
|
||||
c += " {\n";
|
||||
c += " FLT4 r0 = TO_FLT4(I0 + Bt[2] * I2 + Bt[4] * I4);\n";
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
|
||||
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
|
||||
c += " DST_Y++;\n";
|
||||
c += " }\n";
|
||||
c += " {\n";
|
||||
c += " FLT4 r0 = TO_FLT4(Bt[7] * I1 + Bt[8] * I2 + Bt[9] * I3 + Bt[10] * "
|
||||
"I4);\n";
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
|
||||
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
|
||||
c += " DST_Y++;\n";
|
||||
c += " }\n";
|
||||
c += " {\n";
|
||||
@ -200,7 +197,7 @@ std::string GetWinograd4x4To36Code(
|
||||
"* "
|
||||
"I4);\n";
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
|
||||
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
|
||||
c += " DST_Y++;\n";
|
||||
c += " }\n";
|
||||
c += " {\n";
|
||||
@ -208,7 +205,7 @@ std::string GetWinograd4x4To36Code(
|
||||
"* "
|
||||
"I4);\n";
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
|
||||
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
|
||||
c += " DST_Y++;\n";
|
||||
c += " }\n";
|
||||
c += " {\n";
|
||||
@ -216,13 +213,13 @@ std::string GetWinograd4x4To36Code(
|
||||
"* "
|
||||
"I4);\n";
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
|
||||
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
|
||||
c += " DST_Y++;\n";
|
||||
c += " }\n";
|
||||
c += " {\n";
|
||||
c += " FLT4 r0 = TO_FLT4(Bt[31] * I1 + Bt[33] * I3 + I5);\n";
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
|
||||
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
|
||||
c += " DST_Y++;\n";
|
||||
c += " }\n";
|
||||
c += "}\n";
|
||||
@ -443,16 +440,14 @@ absl::Status Winograd4x4To36::BindArguments() {
|
||||
const int tiles_y = DivideRoundUp(
|
||||
src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4);
|
||||
const int tiles_total = tiles_x * tiles_y;
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
RETURN_IF_ERROR(args_.SetInt("padding_x", -padding_.prepended.w));
|
||||
RETURN_IF_ERROR(args_.SetInt("padding_y", -padding_.prepended.h));
|
||||
RETURN_IF_ERROR(args_.SetInt("tiles_total", tiles_total));
|
||||
RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x));
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||
RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ GPUResources TensorLinearDescriptor::GetGPUResources(
|
||||
|
||||
absl::Status TensorLinearDescriptor::PerformSelector(
|
||||
const std::string& selector, const std::vector<std::string>& args,
|
||||
std::string* result) const {
|
||||
const std::vector<std::string>& template_args, std::string* result) const {
|
||||
if (selector == "Length") {
|
||||
*result = "length";
|
||||
return absl::OkStatus();
|
||||
|
@ -57,6 +57,7 @@ struct TensorLinearDescriptor : public GPUObjectDescriptor {
|
||||
|
||||
absl::Status PerformSelector(const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result) const override;
|
||||
|
||||
GPUResources GetGPUResources(AccessType access_type) const override;
|
||||
|
@ -227,7 +227,7 @@ GPUResources TensorDescriptor::GetGPUResources(AccessType access_type) const {
|
||||
|
||||
absl::Status TensorDescriptor::PerformSelector(
|
||||
const std::string& selector, const std::vector<std::string>& args,
|
||||
std::string* result) const {
|
||||
const std::vector<std::string>& template_args, std::string* result) const {
|
||||
if (selector == "Width") {
|
||||
*result = "width";
|
||||
return absl::OkStatus();
|
||||
@ -255,9 +255,11 @@ absl::Status TensorDescriptor::PerformSelector(
|
||||
*result = "";
|
||||
return absl::OkStatus();
|
||||
} else if (selector == "Read") {
|
||||
return PerformReadSelector(args, result);
|
||||
return PerformReadSelector(args, template_args, result);
|
||||
} else if (selector == "Write") {
|
||||
return PerformWriteSelector(args, result);
|
||||
} else if (selector == "GetAddress") {
|
||||
return PerformGetAddressSelector(args, result);
|
||||
} else {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"TensorDescriptor don't have selector with name - ", selector));
|
||||
@ -265,7 +267,29 @@ absl::Status TensorDescriptor::PerformSelector(
|
||||
}
|
||||
|
||||
absl::Status TensorDescriptor::PerformReadSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const {
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) const {
|
||||
DataType read_as_type = data_type;
|
||||
if (!template_args.empty()) {
|
||||
if (template_args.size() != 1) {
|
||||
return absl::NotFoundError(
|
||||
"Unrecognized Read selector template arguments.");
|
||||
} else {
|
||||
RETURN_IF_ERROR(
|
||||
GetDataTypeFromTemplateArgs(template_args[0], &read_as_type));
|
||||
}
|
||||
}
|
||||
if (args.size() == 1) { // function overload for 1D linear types.
|
||||
if (storage_type == TensorStorageType::BUFFER ||
|
||||
storage_type == TensorStorageType::IMAGE_BUFFER) {
|
||||
*result = Read(read_as_type, args[0]);
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::InvalidArgumentError(
|
||||
"Read selector with single argument can be used only with linear "
|
||||
"storage types(BUFFER or IMAGE_BUFFER)");
|
||||
}
|
||||
}
|
||||
std::string xc;
|
||||
std::string yc;
|
||||
std::string zc;
|
||||
@ -276,24 +300,9 @@ absl::Status TensorDescriptor::PerformReadSelector(
|
||||
return absl::NotFoundError("Unrecognized Read selector");
|
||||
}
|
||||
|
||||
if (layout == Layout::HWC) {
|
||||
*result = Read(GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type));
|
||||
return absl::OkStatus();
|
||||
} else if (layout == Layout::BHWC) {
|
||||
*result =
|
||||
Read(GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc, storage_type));
|
||||
return absl::OkStatus();
|
||||
} else if (layout == Layout::HWDC) {
|
||||
*result =
|
||||
Read(GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc, storage_type));
|
||||
return absl::OkStatus();
|
||||
} else if (layout == Layout::BHWDC) {
|
||||
*result = Read(
|
||||
GetGlobalAddressNoDeclarationWHDSB(xc, yc, zc, sc, bc, storage_type));
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::NotFoundError("Unsupported layout");
|
||||
}
|
||||
*result =
|
||||
Read(read_as_type, GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TensorDescriptor::PerformWriteSelector(
|
||||
@ -307,29 +316,14 @@ absl::Status TensorDescriptor::PerformWriteSelector(
|
||||
if (args.size() < 2 || !parsed) {
|
||||
return absl::NotFoundError("Unrecognized Write selector");
|
||||
}
|
||||
|
||||
if (layout == Layout::HWC) {
|
||||
*result = Write(args[0],
|
||||
GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type));
|
||||
return absl::OkStatus();
|
||||
} else if (layout == Layout::BHWC) {
|
||||
*result = Write(args[0], GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc,
|
||||
storage_type));
|
||||
return absl::OkStatus();
|
||||
} else if (layout == Layout::HWDC) {
|
||||
*result = Write(args[0], GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc,
|
||||
storage_type));
|
||||
return absl::OkStatus();
|
||||
} else if (layout == Layout::BHWDC) {
|
||||
*result = Write(args[0], GetGlobalAddressNoDeclarationWHDSB(
|
||||
xc, yc, zc, sc, bc, storage_type));
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::NotFoundError("Unsupported layout");
|
||||
}
|
||||
*result = Write(args[0], GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::string TensorDescriptor::Read(const std::string& global_address) const {
|
||||
std::string TensorDescriptor::Read(DataType read_as_type,
|
||||
const std::string& global_address) const {
|
||||
const std::string read_as =
|
||||
read_as_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
|
||||
std::string image_type;
|
||||
if (storage_type == TensorStorageType::TEXTURE_2D ||
|
||||
storage_type == TensorStorageType::SINGLE_TEXTURE_2D) {
|
||||
@ -341,16 +335,22 @@ std::string TensorDescriptor::Read(const std::string& global_address) const {
|
||||
}
|
||||
switch (storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
return absl::StrCat("buffer[", global_address, "]");
|
||||
if (read_as_type == data_type) {
|
||||
return absl::StrCat("buffer[", global_address, "]");
|
||||
} else {
|
||||
const std::string conversion = read_as_type == DataType::FLOAT16
|
||||
? "convert_half4"
|
||||
: "convert_float4";
|
||||
return absl::StrCat(conversion, "(buffer[", global_address, "])");
|
||||
}
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
case TensorStorageType::TEXTURE_3D:
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
return absl::StrCat(GetReadImageFromDataType(data_type), "(", image_type,
|
||||
", smp_none, ", global_address, ")");
|
||||
return absl::StrCat(read_as, "(", image_type, ", smp_none, ",
|
||||
global_address, ")");
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return absl::StrCat(GetReadImageFromDataType(data_type),
|
||||
"(image_buffer, ", global_address, ")");
|
||||
return absl::StrCat(read_as, "(image_buffer, ", global_address, ")");
|
||||
case TensorStorageType::UNKNOWN:
|
||||
return "";
|
||||
}
|
||||
@ -382,6 +382,85 @@ std::string TensorDescriptor::Write(const std::string& var_name,
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status TensorDescriptor::PerformGetAddressSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const {
|
||||
std::string xc;
|
||||
std::string yc;
|
||||
std::string zc;
|
||||
std::string sc;
|
||||
std::string bc;
|
||||
bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
|
||||
if (args.size() < 3 || !parsed) {
|
||||
return absl::NotFoundError("Unrecognized GetAddress selector");
|
||||
}
|
||||
|
||||
*result = DeclareAddress(args[0],
|
||||
GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::string TensorDescriptor::DeclareAddress(const std::string& var_name,
|
||||
const std::string& address) const {
|
||||
return absl::StrCat(StorageTypeToAddressType(), " ", var_name, " = ", address,
|
||||
";");
|
||||
}
|
||||
|
||||
std::string TensorDescriptor::StorageTypeToAddressType() const {
|
||||
switch (storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return "int";
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
return "int2";
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
case TensorStorageType::TEXTURE_3D:
|
||||
return "int4";
|
||||
case TensorStorageType::UNKNOWN:
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
std::string TensorDescriptor::GetGlobalAddressNoDeclaration(
|
||||
const std::string& xc, const std::string& yc, const std::string& zc,
|
||||
const std::string& sc, const std::string& bc) const {
|
||||
if (layout == Layout::HWC) {
|
||||
return GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type);
|
||||
} else if (layout == Layout::BHWC) {
|
||||
return GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc, storage_type);
|
||||
} else if (layout == Layout::HWDC) {
|
||||
return GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc, storage_type);
|
||||
} else if (layout == Layout::BHWDC) {
|
||||
return GetGlobalAddressNoDeclarationWHDSB(xc, yc, zc, sc, bc, storage_type);
|
||||
} else {
|
||||
return "Unsupported layout";
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status TensorDescriptor::GetDataTypeFromTemplateArgs(
|
||||
const std::string& template_arg, DataType* result) const {
|
||||
std::string read_type = template_arg;
|
||||
if (read_type == "FLT" || read_type == "ACCUM_FLT") {
|
||||
auto it = state_vars_.find(read_type);
|
||||
if (it == state_vars_.end()) {
|
||||
return absl::UnavailableError(absl::StrCat(
|
||||
"Read selector template argument ", read_type, " uninitialized."));
|
||||
} else {
|
||||
read_type = it->second;
|
||||
}
|
||||
}
|
||||
|
||||
if (read_type == "half") {
|
||||
*result = DataType::FLOAT16;
|
||||
} else if (read_type == "float") {
|
||||
*result = DataType::FLOAT32;
|
||||
} else {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"Unrecognized Read selector template argument - ", read_type));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
bool TensorDescriptor::HasAxis(Axis axis) const {
|
||||
if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS) {
|
||||
return true;
|
||||
|
@ -65,6 +65,7 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
||||
|
||||
absl::Status PerformSelector(const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result) const override;
|
||||
|
||||
GPUResources GetGPUResources(AccessType access_type) const override;
|
||||
@ -79,16 +80,35 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
||||
Layout::UNKNOWN; // Supported layouts is HWC, BHWC, HWDC, BHWDC
|
||||
|
||||
private:
|
||||
absl::Status PerformReadSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
absl::Status PerformReadSelector(
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) const;
|
||||
|
||||
absl::Status PerformGetAddressSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
|
||||
std::string DeclareAddress(const std::string& var_name,
|
||||
const std::string& address) const;
|
||||
|
||||
std::string StorageTypeToAddressType() const;
|
||||
|
||||
absl::Status PerformWriteSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
|
||||
std::string Read(const std::string& global_address) const;
|
||||
std::string Read(DataType read_as_type,
|
||||
const std::string& global_address) const;
|
||||
std::string Write(const std::string& var_name,
|
||||
const std::string& global_address) const;
|
||||
|
||||
absl::Status GetDataTypeFromTemplateArgs(const std::string& template_arg,
|
||||
DataType* result) const;
|
||||
|
||||
std::string GetGlobalAddressNoDeclaration(const std::string& xc,
|
||||
const std::string& yc,
|
||||
const std::string& zc,
|
||||
const std::string& sc,
|
||||
const std::string& bc) const;
|
||||
|
||||
bool ParseCoordsFromArgs(const std::vector<std::string>& args, int offset,
|
||||
std::string* xc, std::string* yc, std::string* zc,
|
||||
std::string* sc, std::string* bc) const;
|
||||
|
Loading…
x
Reference in New Issue
Block a user