Added syntax for template arguments for gpu objects selectors.

Now we can write args.src_tensor.Read<float>(...).
Demonstrated in Winograd kernel.

PiperOrigin-RevId: 314576804
Change-Id: I27cb6e7d251b05d489c9c1dbbe4df8640df05651
This commit is contained in:
Raman Sarokin 2020-06-03 12:02:54 -07:00 committed by TensorFlower Gardener
parent f805f48644
commit 929398ef01
8 changed files with 242 additions and 127 deletions

View File

@ -45,6 +45,7 @@ size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
{'(', ')'}, {'(', ')'},
{'{', '}'}, {'{', '}'},
{'[', ']'}, {'[', ']'},
{'<', '>'},
}; };
char b_open = bracket; char b_open = bracket;
auto it = brackets.find(b_open); auto it = brackets.find(b_open);
@ -70,6 +71,28 @@ size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
} }
} }
absl::Status ParseArgsInsideBrackets(const std::string& text,
size_t open_bracket_pos,
size_t* close_bracket_pos,
std::vector<std::string>* args) {
*close_bracket_pos =
FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]);
if (*close_bracket_pos == -1) {
return absl::NotFoundError("Not found enclosing bracket");
}
std::string str_args = text.substr(open_bracket_pos + 1,
*close_bracket_pos - open_bracket_pos - 2);
std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
args->reserve(words.size());
for (const auto& word : words) {
absl::string_view arg = absl::StripAsciiWhitespace(word);
if (!arg.empty()) {
args->push_back(std::string(arg));
}
}
return absl::OkStatus();
}
void ReplaceAllWords(const std::string& old_word, const std::string& new_word, void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
std::string* str) { std::string* str) {
size_t position = str->find(old_word); size_t position = str->find(old_word);
@ -534,10 +557,10 @@ void Arguments::ResolveObjectNames(const std::string& object_name,
} }
} }
absl::Status Arguments::ResolveSelector(const std::string& object_name, absl::Status Arguments::ResolveSelector(
const std::string& selector, const std::string& object_name, const std::string& selector,
const std::vector<std::string>& args, const std::vector<std::string>& args,
std::string* result) { const std::vector<std::string>& template_args, std::string* result) {
const GPUObjectDescriptor* desc_ptr; const GPUObjectDescriptor* desc_ptr;
AccessType access_type; AccessType access_type;
if (auto it = object_refs_.find(object_name); it != object_refs_.end()) { if (auto it = object_refs_.find(object_name); it != object_refs_.end()) {
@ -550,7 +573,8 @@ absl::Status Arguments::ResolveSelector(const std::string& object_name,
return absl::NotFoundError( return absl::NotFoundError(
absl::StrCat("No object with name - ", object_name)); absl::StrCat("No object with name - ", object_name));
} }
RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, args, result)); RETURN_IF_ERROR(
desc_ptr->PerformSelector(selector, args, template_args, result));
auto names = desc_ptr->GetGPUResources(access_type).GetNames(); auto names = desc_ptr->GetGPUResources(access_type).GetNames();
ResolveObjectNames(object_name, names, result); ResolveObjectNames(object_name, names, result);
return absl::OkStatus(); return absl::OkStatus();
@ -570,32 +594,26 @@ absl::Status Arguments::ResolveSelectorsPass(std::string* code) {
std::string selector_name = GetNextWord(*code, next_position); std::string selector_name = GetNextWord(*code, next_position);
next_position += selector_name.size(); next_position += selector_name.size();
next = (*code)[next_position]; next = (*code)[next_position];
std::vector<std::string> template_args;
if (next == '<') {
size_t close_bracket_pos;
RETURN_IF_ERROR(ParseArgsInsideBrackets(
*code, next_position, &close_bracket_pos, &template_args));
next_position = close_bracket_pos;
next = (*code)[next_position];
}
if (next != '(') { if (next != '(') {
return absl::NotFoundError( return absl::NotFoundError(
absl::StrCat("Expected ( after function ", selector_name, " call")); absl::StrCat("Expected ( after function ", selector_name, " call"));
} }
next_position += 1;
size_t bracket_pos = FindEnclosingBracket(*code, next_position, '(');
if (bracket_pos == -1) {
return absl::NotFoundError(
absl::StrCat("Not found enclosing bracket for function ",
selector_name, " call"));
}
std::string str_args =
code->substr(next_position, bracket_pos - next_position - 1);
std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
std::vector<std::string> args; std::vector<std::string> args;
args.reserve(words.size()); size_t close_bracket_pos;
for (const auto& word : words) { RETURN_IF_ERROR(ParseArgsInsideBrackets(*code, next_position,
absl::string_view arg = absl::StripAsciiWhitespace(word); &close_bracket_pos, &args));
if (!arg.empty()) {
args.push_back(std::string(arg));
}
}
std::string patch; std::string patch;
RETURN_IF_ERROR( RETURN_IF_ERROR(ResolveSelector(object_name, selector_name, args,
ResolveSelector(object_name, selector_name, args, &patch)); template_args, &patch));
code->replace(arg_pos, bracket_pos - arg_pos, patch); code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
position = arg_pos + patch.size(); position = arg_pos + patch.size();
} else { } else {
position = arg_pos + strlen(kArgsPrefix); position = arg_pos + strlen(kArgsPrefix);

View File

@ -88,6 +88,7 @@ class Arguments {
absl::Status ResolveSelector(const std::string& object_name, absl::Status ResolveSelector(const std::string& object_name,
const std::string& selector, const std::string& selector,
const std::vector<std::string>& args, const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result); std::string* result);
void ResolveObjectNames(const std::string& object_name, void ResolveObjectNames(const std::string& object_name,

View File

@ -123,9 +123,10 @@ class GPUObjectDescriptor {
return ""; return "";
} }
virtual absl::Status PerformSelector(const std::string& selector, virtual absl::Status PerformSelector(
const std::vector<std::string>& args, const std::string& selector, const std::vector<std::string>& args,
std::string* result) const { const std::vector<std::string>& template_args,
std::string* result) const {
*result = ""; *result = "";
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -38,16 +38,6 @@ std::string GetWinograd4x4To36Code(
const OperationDef& op_def, const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations, const std::vector<ElementwiseOperation*>& linked_operations,
Arguments* args) { Arguments* args) {
TensorCodeGenerator src_tensor(
"src_data",
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data",
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
op_def.dst_tensors[0]);
const std::string batch_id = op_def.IsBatchSupported() ? "batch_id" : "";
std::string c = GetCommonDefines(op_def.precision); std::string c = GetCommonDefines(op_def.precision);
const auto src_tensor_type = op_def.src_tensors[0].storage_type; const auto src_tensor_type = op_def.src_tensors[0].storage_type;
@ -80,23 +70,30 @@ std::string GetWinograd4x4To36Code(
} }
c += "};\n"; c += "};\n";
std::string cl_type = accum_type == DataType::FLOAT16 ? "half" : "float";
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
src_desc->SetStateVar("ACCUM_FLT", cl_type);
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
args->AddObjectRef(
"dst_tensor", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
args->AddInt("padding_x"); args->AddInt("padding_x");
args->AddInt("padding_y"); args->AddInt("padding_y");
args->AddInt("tiles_total"); args->AddInt("tiles_total");
args->AddInt("tiles_x"); args->AddInt("tiles_x");
std::string linked_args = GetArgsDeclaration(linked_operations);
if (linked_args[0] == ',') {
linked_args[0] = ' ';
}
c += "__kernel void main_function(\n"; c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ); c += linked_args;
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int4 dst_size,\n ";
c += "$0) {\n"; c += "$0) {\n";
c += " int DST_X = get_global_id(0);\n"; c += " int DST_X = get_global_id(0);\n";
c += " int DST_Y = get_global_id(1);\n"; c += " int DST_Y = get_global_id(1);\n";
c += " int DST_Z = get_global_id(2);\n"; c += " int DST_Z = get_global_id(2);\n";
c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= dst_size.z) " c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= "
"{\n"; "args.dst_tensor.Slices()) {\n";
c += " return; \n"; c += " return; \n";
c += " }\n"; c += " }\n";
c += " int tile_x = (DST_X % args.tiles_x) * 4;\n"; c += " int tile_x = (DST_X % args.tiles_x) * 4;\n";
@ -114,19 +111,16 @@ std::string GetWinograd4x4To36Code(
c += " bt_ar[5] = t1.y;\n"; c += " bt_ar[5] = t1.y;\n";
auto read_src = [&](const std::string& src, const std::string& xs) { auto read_src = [&](const std::string& src, const std::string& xs) {
if (is_image_buffer) { if (is_image_buffer) {
c += " ACCUM_FLT4 " + src + " = " + c += " ACCUM_FLT4 " + src +
src_tensor.ReadAsType(accum_type, "src_a_" + xs + " + offset") + " = args.src_tensor.Read<ACCUM_FLT>(src_a_" + xs + " + offset);\n";
";\n";
} else if (is_buffer) { } else if (is_buffer) {
c += " ACCUM_FLT4 " + src + " = " + c += " ACCUM_FLT4 " + src +
src_tensor.ReadAsType(accum_type, "src_a_" + xs + " + offset") + " = args.src_tensor.Read<ACCUM_FLT>(src_a_" + xs + " + offset) * m" +
" * m" + xs + "_x;\n"; xs + "_x;\n";
} else { } else {
c += " ACCUM_FLT4 " + src + " = " + c += " ACCUM_FLT4 " + src +
src_tensor.ReadAsTypeWHSB(accum_type, " = args.src_tensor.Read<ACCUM_FLT>(tile_x + args.padding_x + " +
"tile_x + args.padding_x + " + xs, "yc", xs + ", yc, DST_Z);\n";
"DST_Z", batch_id) +
";\n";
} }
}; };
if (is_buffer || is_image_buffer) { if (is_buffer || is_image_buffer) {
@ -134,14 +128,17 @@ std::string GetWinograd4x4To36Code(
const std::string xs = std::to_string(x); const std::string xs = std::to_string(x);
c += " int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n"; c += " int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n";
c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" + c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" +
xs + " < src_size.x);\n"; xs + " < args.src_tensor.Width());\n";
c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs + c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
" < src_size.x);\n"; " < args.src_tensor.Width());\n";
c += " xc" + xs + " = clamp(xc" + xs + ", 0, src_size.x - 1);\n"; c += " xc" + xs + " = clamp(xc" + xs +
c += " " + src_tensor.GetAddressWHSB("src_a_" + xs, "xc" + xs, "0", ", 0, args.src_tensor.Width() - 1);\n";
"DST_Z", batch_id); c += " args.src_tensor.GetAddress(src_a_" + xs + ", xc" + xs +
", 0, DST_Z);\n";
if (is_image_buffer) { if (is_image_buffer) {
c += " src_a_" + xs + " = select(-src_size.x * src_size.y, src_a_" + c += " src_a_" + xs +
" = select(-args.src_tensor.Width() * args.src_tensor.Height(), "
"src_a_" +
xs + ", inx" + xs + ");\n"; xs + ", inx" + xs + ");\n";
} }
} }
@ -149,8 +146,8 @@ std::string GetWinograd4x4To36Code(
c += " {\n"; c += " {\n";
c += " int yc = tile_y + args.padding_y;\n"; c += " int yc = tile_y + args.padding_y;\n";
if (is_buffer || is_image_buffer) { if (is_buffer || is_image_buffer) {
c += " bool iny = (yc >= 0 && yc < src_size.y);\n"; c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
c += " int offset = select(0, yc * src_size.x, iny);\n"; c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
c += " ACCUM_FLT bt = bt_ar[0] * (ACCUM_FLT)(iny);\n"; c += " ACCUM_FLT bt = bt_ar[0] * (ACCUM_FLT)(iny);\n";
} else { } else {
c += " ACCUM_FLT bt = bt_ar[0];\n"; c += " ACCUM_FLT bt = bt_ar[0];\n";
@ -167,8 +164,8 @@ std::string GetWinograd4x4To36Code(
c += " {\n"; c += " {\n";
c += " int yc = tile_y + args.padding_y + (" + ys + ");\n"; c += " int yc = tile_y + args.padding_y + (" + ys + ");\n";
if (is_buffer || is_image_buffer) { if (is_buffer || is_image_buffer) {
c += " bool iny = (yc >= 0 && yc < src_size.y);\n"; c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
c += " int offset = select(0, yc * src_size.x, iny);\n"; c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
c += " ACCUM_FLT bt = bt_ar[" + ys + "] * (ACCUM_FLT)(iny);\n"; c += " ACCUM_FLT bt = bt_ar[" + ys + "] * (ACCUM_FLT)(iny);\n";
} else { } else {
c += " ACCUM_FLT bt = bt_ar[" + ys + "];\n"; c += " ACCUM_FLT bt = bt_ar[" + ys + "];\n";
@ -185,14 +182,14 @@ std::string GetWinograd4x4To36Code(
c += " {\n"; c += " {\n";
c += " FLT4 r0 = TO_FLT4(I0 + Bt[2] * I2 + Bt[4] * I4);\n"; c += " FLT4 r0 = TO_FLT4(I0 + Bt[2] * I2 + Bt[4] * I4);\n";
c += PostProcess(linked_operations, context); c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n"; c += " DST_Y++;\n";
c += " }\n"; c += " }\n";
c += " {\n"; c += " {\n";
c += " FLT4 r0 = TO_FLT4(Bt[7] * I1 + Bt[8] * I2 + Bt[9] * I3 + Bt[10] * " c += " FLT4 r0 = TO_FLT4(Bt[7] * I1 + Bt[8] * I2 + Bt[9] * I3 + Bt[10] * "
"I4);\n"; "I4);\n";
c += PostProcess(linked_operations, context); c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n"; c += " DST_Y++;\n";
c += " }\n"; c += " }\n";
c += " {\n"; c += " {\n";
@ -200,7 +197,7 @@ std::string GetWinograd4x4To36Code(
"* " "* "
"I4);\n"; "I4);\n";
c += PostProcess(linked_operations, context); c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n"; c += " DST_Y++;\n";
c += " }\n"; c += " }\n";
c += " {\n"; c += " {\n";
@ -208,7 +205,7 @@ std::string GetWinograd4x4To36Code(
"* " "* "
"I4);\n"; "I4);\n";
c += PostProcess(linked_operations, context); c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n"; c += " DST_Y++;\n";
c += " }\n"; c += " }\n";
c += " {\n"; c += " {\n";
@ -216,13 +213,13 @@ std::string GetWinograd4x4To36Code(
"* " "* "
"I4);\n"; "I4);\n";
c += PostProcess(linked_operations, context); c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n"; c += " DST_Y++;\n";
c += " }\n"; c += " }\n";
c += " {\n"; c += " {\n";
c += " FLT4 r0 = TO_FLT4(Bt[31] * I1 + Bt[33] * I3 + I5);\n"; c += " FLT4 r0 = TO_FLT4(Bt[31] * I1 + Bt[33] * I3 + I5);\n";
c += PostProcess(linked_operations, context); c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n"; c += " DST_Y++;\n";
c += " }\n"; c += " }\n";
c += "}\n"; c += "}\n";
@ -443,16 +440,14 @@ absl::Status Winograd4x4To36::BindArguments() {
const int tiles_y = DivideRoundUp( const int tiles_y = DivideRoundUp(
src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4); src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4);
const int tiles_total = tiles_x * tiles_y; const int tiles_total = tiles_x * tiles_y;
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
RETURN_IF_ERROR(args_.SetInt("padding_x", -padding_.prepended.w)); RETURN_IF_ERROR(args_.SetInt("padding_x", -padding_.prepended.w));
RETURN_IF_ERROR(args_.SetInt("padding_y", -padding_.prepended.h)); RETURN_IF_ERROR(args_.SetInt("padding_y", -padding_.prepended.h));
RETURN_IF_ERROR(args_.SetInt("tiles_total", tiles_total)); RETURN_IF_ERROR(args_.SetInt("tiles_total", tiles_total));
RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x)); RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x));
kernel_.ResetBindingCounter(); kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter())); RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter()));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -44,7 +44,7 @@ GPUResources TensorLinearDescriptor::GetGPUResources(
absl::Status TensorLinearDescriptor::PerformSelector( absl::Status TensorLinearDescriptor::PerformSelector(
const std::string& selector, const std::vector<std::string>& args, const std::string& selector, const std::vector<std::string>& args,
std::string* result) const { const std::vector<std::string>& template_args, std::string* result) const {
if (selector == "Length") { if (selector == "Length") {
*result = "length"; *result = "length";
return absl::OkStatus(); return absl::OkStatus();

View File

@ -57,6 +57,7 @@ struct TensorLinearDescriptor : public GPUObjectDescriptor {
absl::Status PerformSelector(const std::string& selector, absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args, const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const override; std::string* result) const override;
GPUResources GetGPUResources(AccessType access_type) const override; GPUResources GetGPUResources(AccessType access_type) const override;

View File

@ -227,7 +227,7 @@ GPUResources TensorDescriptor::GetGPUResources(AccessType access_type) const {
absl::Status TensorDescriptor::PerformSelector( absl::Status TensorDescriptor::PerformSelector(
const std::string& selector, const std::vector<std::string>& args, const std::string& selector, const std::vector<std::string>& args,
std::string* result) const { const std::vector<std::string>& template_args, std::string* result) const {
if (selector == "Width") { if (selector == "Width") {
*result = "width"; *result = "width";
return absl::OkStatus(); return absl::OkStatus();
@ -255,9 +255,11 @@ absl::Status TensorDescriptor::PerformSelector(
*result = ""; *result = "";
return absl::OkStatus(); return absl::OkStatus();
} else if (selector == "Read") { } else if (selector == "Read") {
return PerformReadSelector(args, result); return PerformReadSelector(args, template_args, result);
} else if (selector == "Write") { } else if (selector == "Write") {
return PerformWriteSelector(args, result); return PerformWriteSelector(args, result);
} else if (selector == "GetAddress") {
return PerformGetAddressSelector(args, result);
} else { } else {
return absl::NotFoundError(absl::StrCat( return absl::NotFoundError(absl::StrCat(
"TensorDescriptor don't have selector with name - ", selector)); "TensorDescriptor don't have selector with name - ", selector));
@ -265,7 +267,29 @@ absl::Status TensorDescriptor::PerformSelector(
} }
absl::Status TensorDescriptor::PerformReadSelector( absl::Status TensorDescriptor::PerformReadSelector(
const std::vector<std::string>& args, std::string* result) const { const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const {
DataType read_as_type = data_type;
if (!template_args.empty()) {
if (template_args.size() != 1) {
return absl::NotFoundError(
"Unrecognized Read selector template arguments.");
} else {
RETURN_IF_ERROR(
GetDataTypeFromTemplateArgs(template_args[0], &read_as_type));
}
}
if (args.size() == 1) { // function overload for 1D linear types.
if (storage_type == TensorStorageType::BUFFER ||
storage_type == TensorStorageType::IMAGE_BUFFER) {
*result = Read(read_as_type, args[0]);
return absl::OkStatus();
} else {
return absl::InvalidArgumentError(
"Read selector with single argument can be used only with linear "
"storage types(BUFFER or IMAGE_BUFFER)");
}
}
std::string xc; std::string xc;
std::string yc; std::string yc;
std::string zc; std::string zc;
@ -276,24 +300,9 @@ absl::Status TensorDescriptor::PerformReadSelector(
return absl::NotFoundError("Unrecognized Read selector"); return absl::NotFoundError("Unrecognized Read selector");
} }
if (layout == Layout::HWC) { *result =
*result = Read(GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type)); Read(read_as_type, GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
return absl::OkStatus(); return absl::OkStatus();
} else if (layout == Layout::BHWC) {
*result =
Read(GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc, storage_type));
return absl::OkStatus();
} else if (layout == Layout::HWDC) {
*result =
Read(GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc, storage_type));
return absl::OkStatus();
} else if (layout == Layout::BHWDC) {
*result = Read(
GetGlobalAddressNoDeclarationWHDSB(xc, yc, zc, sc, bc, storage_type));
return absl::OkStatus();
} else {
return absl::NotFoundError("Unsupported layout");
}
} }
absl::Status TensorDescriptor::PerformWriteSelector( absl::Status TensorDescriptor::PerformWriteSelector(
@ -307,29 +316,14 @@ absl::Status TensorDescriptor::PerformWriteSelector(
if (args.size() < 2 || !parsed) { if (args.size() < 2 || !parsed) {
return absl::NotFoundError("Unrecognized Write selector"); return absl::NotFoundError("Unrecognized Write selector");
} }
*result = Write(args[0], GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
if (layout == Layout::HWC) { return absl::OkStatus();
*result = Write(args[0],
GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type));
return absl::OkStatus();
} else if (layout == Layout::BHWC) {
*result = Write(args[0], GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc,
storage_type));
return absl::OkStatus();
} else if (layout == Layout::HWDC) {
*result = Write(args[0], GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc,
storage_type));
return absl::OkStatus();
} else if (layout == Layout::BHWDC) {
*result = Write(args[0], GetGlobalAddressNoDeclarationWHDSB(
xc, yc, zc, sc, bc, storage_type));
return absl::OkStatus();
} else {
return absl::NotFoundError("Unsupported layout");
}
} }
std::string TensorDescriptor::Read(const std::string& global_address) const { std::string TensorDescriptor::Read(DataType read_as_type,
const std::string& global_address) const {
const std::string read_as =
read_as_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
std::string image_type; std::string image_type;
if (storage_type == TensorStorageType::TEXTURE_2D || if (storage_type == TensorStorageType::TEXTURE_2D ||
storage_type == TensorStorageType::SINGLE_TEXTURE_2D) { storage_type == TensorStorageType::SINGLE_TEXTURE_2D) {
@ -341,16 +335,22 @@ std::string TensorDescriptor::Read(const std::string& global_address) const {
} }
switch (storage_type) { switch (storage_type) {
case TensorStorageType::BUFFER: case TensorStorageType::BUFFER:
return absl::StrCat("buffer[", global_address, "]"); if (read_as_type == data_type) {
return absl::StrCat("buffer[", global_address, "]");
} else {
const std::string conversion = read_as_type == DataType::FLOAT16
? "convert_half4"
: "convert_float4";
return absl::StrCat(conversion, "(buffer[", global_address, "])");
}
case TensorStorageType::TEXTURE_2D: case TensorStorageType::TEXTURE_2D:
case TensorStorageType::TEXTURE_3D: case TensorStorageType::TEXTURE_3D:
case TensorStorageType::SINGLE_TEXTURE_2D: case TensorStorageType::SINGLE_TEXTURE_2D:
case TensorStorageType::TEXTURE_ARRAY: case TensorStorageType::TEXTURE_ARRAY:
return absl::StrCat(GetReadImageFromDataType(data_type), "(", image_type, return absl::StrCat(read_as, "(", image_type, ", smp_none, ",
", smp_none, ", global_address, ")"); global_address, ")");
case TensorStorageType::IMAGE_BUFFER: case TensorStorageType::IMAGE_BUFFER:
return absl::StrCat(GetReadImageFromDataType(data_type), return absl::StrCat(read_as, "(image_buffer, ", global_address, ")");
"(image_buffer, ", global_address, ")");
case TensorStorageType::UNKNOWN: case TensorStorageType::UNKNOWN:
return ""; return "";
} }
@ -382,6 +382,85 @@ std::string TensorDescriptor::Write(const std::string& var_name,
} }
} }
absl::Status TensorDescriptor::PerformGetAddressSelector(
const std::vector<std::string>& args, std::string* result) const {
std::string xc;
std::string yc;
std::string zc;
std::string sc;
std::string bc;
bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
if (args.size() < 3 || !parsed) {
return absl::NotFoundError("Unrecognized GetAddress selector");
}
*result = DeclareAddress(args[0],
GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
return absl::OkStatus();
}
std::string TensorDescriptor::DeclareAddress(const std::string& var_name,
const std::string& address) const {
return absl::StrCat(StorageTypeToAddressType(), " ", var_name, " = ", address,
";");
}
std::string TensorDescriptor::StorageTypeToAddressType() const {
switch (storage_type) {
case TensorStorageType::BUFFER:
case TensorStorageType::IMAGE_BUFFER:
return "int";
case TensorStorageType::TEXTURE_2D:
case TensorStorageType::SINGLE_TEXTURE_2D:
return "int2";
case TensorStorageType::TEXTURE_ARRAY:
case TensorStorageType::TEXTURE_3D:
return "int4";
case TensorStorageType::UNKNOWN:
return "";
}
}
std::string TensorDescriptor::GetGlobalAddressNoDeclaration(
const std::string& xc, const std::string& yc, const std::string& zc,
const std::string& sc, const std::string& bc) const {
if (layout == Layout::HWC) {
return GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type);
} else if (layout == Layout::BHWC) {
return GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc, storage_type);
} else if (layout == Layout::HWDC) {
return GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc, storage_type);
} else if (layout == Layout::BHWDC) {
return GetGlobalAddressNoDeclarationWHDSB(xc, yc, zc, sc, bc, storage_type);
} else {
return "Unsupported layout";
}
}
absl::Status TensorDescriptor::GetDataTypeFromTemplateArgs(
const std::string& template_arg, DataType* result) const {
std::string read_type = template_arg;
if (read_type == "FLT" || read_type == "ACCUM_FLT") {
auto it = state_vars_.find(read_type);
if (it == state_vars_.end()) {
return absl::UnavailableError(absl::StrCat(
"Read selector template argument ", read_type, " uninitialized."));
} else {
read_type = it->second;
}
}
if (read_type == "half") {
*result = DataType::FLOAT16;
} else if (read_type == "float") {
*result = DataType::FLOAT32;
} else {
return absl::NotFoundError(absl::StrCat(
"Unrecognized Read selector template argument - ", read_type));
}
return absl::OkStatus();
}
bool TensorDescriptor::HasAxis(Axis axis) const { bool TensorDescriptor::HasAxis(Axis axis) const {
if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS) { if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS) {
return true; return true;

View File

@ -65,6 +65,7 @@ struct TensorDescriptor : public GPUObjectDescriptor {
absl::Status PerformSelector(const std::string& selector, absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args, const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const override; std::string* result) const override;
GPUResources GetGPUResources(AccessType access_type) const override; GPUResources GetGPUResources(AccessType access_type) const override;
@ -79,16 +80,35 @@ struct TensorDescriptor : public GPUObjectDescriptor {
Layout::UNKNOWN; // Supported layouts is HWC, BHWC, HWDC, BHWDC Layout::UNKNOWN; // Supported layouts is HWC, BHWC, HWDC, BHWDC
private: private:
absl::Status PerformReadSelector(const std::vector<std::string>& args, absl::Status PerformReadSelector(
std::string* result) const; const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const;
absl::Status PerformGetAddressSelector(const std::vector<std::string>& args,
std::string* result) const;
std::string DeclareAddress(const std::string& var_name,
const std::string& address) const;
std::string StorageTypeToAddressType() const;
absl::Status PerformWriteSelector(const std::vector<std::string>& args, absl::Status PerformWriteSelector(const std::vector<std::string>& args,
std::string* result) const; std::string* result) const;
std::string Read(const std::string& global_address) const; std::string Read(DataType read_as_type,
const std::string& global_address) const;
std::string Write(const std::string& var_name, std::string Write(const std::string& var_name,
const std::string& global_address) const; const std::string& global_address) const;
absl::Status GetDataTypeFromTemplateArgs(const std::string& template_arg,
DataType* result) const;
std::string GetGlobalAddressNoDeclaration(const std::string& xc,
const std::string& yc,
const std::string& zc,
const std::string& sc,
const std::string& bc) const;
bool ParseCoordsFromArgs(const std::vector<std::string>& args, int offset, bool ParseCoordsFromArgs(const std::vector<std::string>& args, int offset,
std::string* xc, std::string* yc, std::string* zc, std::string* xc, std::string* yc, std::string* zc,
std::string* sc, std::string* bc) const; std::string* sc, std::string* bc) const;