Added syntax for template arguments for gpu objects selectors.

Now we can write args.src_tensor.Read<float>(...).
Demonstrated in Winograd kernel.

PiperOrigin-RevId: 314576804
Change-Id: I27cb6e7d251b05d489c9c1dbbe4df8640df05651
This commit is contained in:
Raman Sarokin 2020-06-03 12:02:54 -07:00 committed by TensorFlower Gardener
parent f805f48644
commit 929398ef01
8 changed files with 242 additions and 127 deletions

View File

@ -45,6 +45,7 @@ size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
{'(', ')'},
{'{', '}'},
{'[', ']'},
{'<', '>'},
};
char b_open = bracket;
auto it = brackets.find(b_open);
@ -70,6 +71,28 @@ size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
}
}
absl::Status ParseArgsInsideBrackets(const std::string& text,
size_t open_bracket_pos,
size_t* close_bracket_pos,
std::vector<std::string>* args) {
*close_bracket_pos =
FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]);
if (*close_bracket_pos == -1) {
return absl::NotFoundError("Not found enclosing bracket");
}
std::string str_args = text.substr(open_bracket_pos + 1,
*close_bracket_pos - open_bracket_pos - 2);
std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
args->reserve(words.size());
for (const auto& word : words) {
absl::string_view arg = absl::StripAsciiWhitespace(word);
if (!arg.empty()) {
args->push_back(std::string(arg));
}
}
return absl::OkStatus();
}
void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
std::string* str) {
size_t position = str->find(old_word);
@ -534,10 +557,10 @@ void Arguments::ResolveObjectNames(const std::string& object_name,
}
}
absl::Status Arguments::ResolveSelector(const std::string& object_name,
const std::string& selector,
const std::vector<std::string>& args,
std::string* result) {
absl::Status Arguments::ResolveSelector(
const std::string& object_name, const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) {
const GPUObjectDescriptor* desc_ptr;
AccessType access_type;
if (auto it = object_refs_.find(object_name); it != object_refs_.end()) {
@ -550,7 +573,8 @@ absl::Status Arguments::ResolveSelector(const std::string& object_name,
return absl::NotFoundError(
absl::StrCat("No object with name - ", object_name));
}
RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, args, result));
RETURN_IF_ERROR(
desc_ptr->PerformSelector(selector, args, template_args, result));
auto names = desc_ptr->GetGPUResources(access_type).GetNames();
ResolveObjectNames(object_name, names, result);
return absl::OkStatus();
@ -570,32 +594,26 @@ absl::Status Arguments::ResolveSelectorsPass(std::string* code) {
std::string selector_name = GetNextWord(*code, next_position);
next_position += selector_name.size();
next = (*code)[next_position];
std::vector<std::string> template_args;
if (next == '<') {
size_t close_bracket_pos;
RETURN_IF_ERROR(ParseArgsInsideBrackets(
*code, next_position, &close_bracket_pos, &template_args));
next_position = close_bracket_pos;
next = (*code)[next_position];
}
if (next != '(') {
return absl::NotFoundError(
absl::StrCat("Expected ( after function ", selector_name, " call"));
}
next_position += 1;
size_t bracket_pos = FindEnclosingBracket(*code, next_position, '(');
if (bracket_pos == -1) {
return absl::NotFoundError(
absl::StrCat("Not found enclosing bracket for function ",
selector_name, " call"));
}
std::string str_args =
code->substr(next_position, bracket_pos - next_position - 1);
std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
std::vector<std::string> args;
args.reserve(words.size());
for (const auto& word : words) {
absl::string_view arg = absl::StripAsciiWhitespace(word);
if (!arg.empty()) {
args.push_back(std::string(arg));
}
}
size_t close_bracket_pos;
RETURN_IF_ERROR(ParseArgsInsideBrackets(*code, next_position,
&close_bracket_pos, &args));
std::string patch;
RETURN_IF_ERROR(
ResolveSelector(object_name, selector_name, args, &patch));
code->replace(arg_pos, bracket_pos - arg_pos, patch);
RETURN_IF_ERROR(ResolveSelector(object_name, selector_name, args,
template_args, &patch));
code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
position = arg_pos + patch.size();
} else {
position = arg_pos + strlen(kArgsPrefix);

View File

@ -88,6 +88,7 @@ class Arguments {
absl::Status ResolveSelector(const std::string& object_name,
const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result);
void ResolveObjectNames(const std::string& object_name,

View File

@ -123,9 +123,10 @@ class GPUObjectDescriptor {
return "";
}
virtual absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args,
std::string* result) const {
virtual absl::Status PerformSelector(
const std::string& selector, const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const {
*result = "";
return absl::OkStatus();
}

View File

@ -38,16 +38,6 @@ std::string GetWinograd4x4To36Code(
const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations,
Arguments* args) {
TensorCodeGenerator src_tensor(
"src_data",
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data",
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
op_def.dst_tensors[0]);
const std::string batch_id = op_def.IsBatchSupported() ? "batch_id" : "";
std::string c = GetCommonDefines(op_def.precision);
const auto src_tensor_type = op_def.src_tensors[0].storage_type;
@ -80,23 +70,30 @@ std::string GetWinograd4x4To36Code(
}
c += "};\n";
std::string cl_type = accum_type == DataType::FLOAT16 ? "half" : "float";
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
src_desc->SetStateVar("ACCUM_FLT", cl_type);
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
args->AddObjectRef(
"dst_tensor", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
args->AddInt("padding_x");
args->AddInt("padding_y");
args->AddInt("tiles_total");
args->AddInt("tiles_x");
std::string linked_args = GetArgsDeclaration(linked_operations);
if (linked_args[0] == ',') {
linked_args[0] = ' ';
}
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int4 dst_size,\n ";
c += linked_args;
c += "$0) {\n";
c += " int DST_X = get_global_id(0);\n";
c += " int DST_Y = get_global_id(1);\n";
c += " int DST_Z = get_global_id(2);\n";
c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= dst_size.z) "
"{\n";
c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= "
"args.dst_tensor.Slices()) {\n";
c += " return; \n";
c += " }\n";
c += " int tile_x = (DST_X % args.tiles_x) * 4;\n";
@ -114,19 +111,16 @@ std::string GetWinograd4x4To36Code(
c += " bt_ar[5] = t1.y;\n";
auto read_src = [&](const std::string& src, const std::string& xs) {
if (is_image_buffer) {
c += " ACCUM_FLT4 " + src + " = " +
src_tensor.ReadAsType(accum_type, "src_a_" + xs + " + offset") +
";\n";
c += " ACCUM_FLT4 " + src +
" = args.src_tensor.Read<ACCUM_FLT>(src_a_" + xs + " + offset);\n";
} else if (is_buffer) {
c += " ACCUM_FLT4 " + src + " = " +
src_tensor.ReadAsType(accum_type, "src_a_" + xs + " + offset") +
" * m" + xs + "_x;\n";
c += " ACCUM_FLT4 " + src +
" = args.src_tensor.Read<ACCUM_FLT>(src_a_" + xs + " + offset) * m" +
xs + "_x;\n";
} else {
c += " ACCUM_FLT4 " + src + " = " +
src_tensor.ReadAsTypeWHSB(accum_type,
"tile_x + args.padding_x + " + xs, "yc",
"DST_Z", batch_id) +
";\n";
c += " ACCUM_FLT4 " + src +
" = args.src_tensor.Read<ACCUM_FLT>(tile_x + args.padding_x + " +
xs + ", yc, DST_Z);\n";
}
};
if (is_buffer || is_image_buffer) {
@ -134,14 +128,17 @@ std::string GetWinograd4x4To36Code(
const std::string xs = std::to_string(x);
c += " int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n";
c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" +
xs + " < src_size.x);\n";
xs + " < args.src_tensor.Width());\n";
c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
" < src_size.x);\n";
c += " xc" + xs + " = clamp(xc" + xs + ", 0, src_size.x - 1);\n";
c += " " + src_tensor.GetAddressWHSB("src_a_" + xs, "xc" + xs, "0",
"DST_Z", batch_id);
" < args.src_tensor.Width());\n";
c += " xc" + xs + " = clamp(xc" + xs +
", 0, args.src_tensor.Width() - 1);\n";
c += " args.src_tensor.GetAddress(src_a_" + xs + ", xc" + xs +
", 0, DST_Z);\n";
if (is_image_buffer) {
c += " src_a_" + xs + " = select(-src_size.x * src_size.y, src_a_" +
c += " src_a_" + xs +
" = select(-args.src_tensor.Width() * args.src_tensor.Height(), "
"src_a_" +
xs + ", inx" + xs + ");\n";
}
}
@ -149,8 +146,8 @@ std::string GetWinograd4x4To36Code(
c += " {\n";
c += " int yc = tile_y + args.padding_y;\n";
if (is_buffer || is_image_buffer) {
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
c += " int offset = select(0, yc * src_size.x, iny);\n";
c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
c += " ACCUM_FLT bt = bt_ar[0] * (ACCUM_FLT)(iny);\n";
} else {
c += " ACCUM_FLT bt = bt_ar[0];\n";
@ -167,8 +164,8 @@ std::string GetWinograd4x4To36Code(
c += " {\n";
c += " int yc = tile_y + args.padding_y + (" + ys + ");\n";
if (is_buffer || is_image_buffer) {
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
c += " int offset = select(0, yc * src_size.x, iny);\n";
c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
c += " ACCUM_FLT bt = bt_ar[" + ys + "] * (ACCUM_FLT)(iny);\n";
} else {
c += " ACCUM_FLT bt = bt_ar[" + ys + "];\n";
@ -185,14 +182,14 @@ std::string GetWinograd4x4To36Code(
c += " {\n";
c += " FLT4 r0 = TO_FLT4(I0 + Bt[2] * I2 + Bt[4] * I4);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
c += " FLT4 r0 = TO_FLT4(Bt[7] * I1 + Bt[8] * I2 + Bt[9] * I3 + Bt[10] * "
"I4);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
@ -200,7 +197,7 @@ std::string GetWinograd4x4To36Code(
"* "
"I4);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
@ -208,7 +205,7 @@ std::string GetWinograd4x4To36Code(
"* "
"I4);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
@ -216,13 +213,13 @@ std::string GetWinograd4x4To36Code(
"* "
"I4);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
c += " FLT4 r0 = TO_FLT4(Bt[31] * I1 + Bt[33] * I3 + I5);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
c += " DST_Y++;\n";
c += " }\n";
c += "}\n";
@ -443,16 +440,14 @@ absl::Status Winograd4x4To36::BindArguments() {
const int tiles_y = DivideRoundUp(
src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4);
const int tiles_total = tiles_x * tiles_y;
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
RETURN_IF_ERROR(args_.SetInt("padding_x", -padding_.prepended.w));
RETURN_IF_ERROR(args_.SetInt("padding_y", -padding_.prepended.h));
RETURN_IF_ERROR(args_.SetInt("tiles_total", tiles_total));
RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x));
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter()));
return absl::OkStatus();
}

View File

@ -44,7 +44,7 @@ GPUResources TensorLinearDescriptor::GetGPUResources(
absl::Status TensorLinearDescriptor::PerformSelector(
const std::string& selector, const std::vector<std::string>& args,
std::string* result) const {
const std::vector<std::string>& template_args, std::string* result) const {
if (selector == "Length") {
*result = "length";
return absl::OkStatus();

View File

@ -57,6 +57,7 @@ struct TensorLinearDescriptor : public GPUObjectDescriptor {
absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const override;
GPUResources GetGPUResources(AccessType access_type) const override;

View File

@ -227,7 +227,7 @@ GPUResources TensorDescriptor::GetGPUResources(AccessType access_type) const {
absl::Status TensorDescriptor::PerformSelector(
const std::string& selector, const std::vector<std::string>& args,
std::string* result) const {
const std::vector<std::string>& template_args, std::string* result) const {
if (selector == "Width") {
*result = "width";
return absl::OkStatus();
@ -255,9 +255,11 @@ absl::Status TensorDescriptor::PerformSelector(
*result = "";
return absl::OkStatus();
} else if (selector == "Read") {
return PerformReadSelector(args, result);
return PerformReadSelector(args, template_args, result);
} else if (selector == "Write") {
return PerformWriteSelector(args, result);
} else if (selector == "GetAddress") {
return PerformGetAddressSelector(args, result);
} else {
return absl::NotFoundError(absl::StrCat(
"TensorDescriptor don't have selector with name - ", selector));
@ -265,7 +267,29 @@ absl::Status TensorDescriptor::PerformSelector(
}
absl::Status TensorDescriptor::PerformReadSelector(
const std::vector<std::string>& args, std::string* result) const {
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const {
DataType read_as_type = data_type;
if (!template_args.empty()) {
if (template_args.size() != 1) {
return absl::NotFoundError(
"Unrecognized Read selector template arguments.");
} else {
RETURN_IF_ERROR(
GetDataTypeFromTemplateArgs(template_args[0], &read_as_type));
}
}
if (args.size() == 1) { // function overload for 1D linear types.
if (storage_type == TensorStorageType::BUFFER ||
storage_type == TensorStorageType::IMAGE_BUFFER) {
*result = Read(read_as_type, args[0]);
return absl::OkStatus();
} else {
return absl::InvalidArgumentError(
"Read selector with single argument can be used only with linear "
"storage types(BUFFER or IMAGE_BUFFER)");
}
}
std::string xc;
std::string yc;
std::string zc;
@ -276,24 +300,9 @@ absl::Status TensorDescriptor::PerformReadSelector(
return absl::NotFoundError("Unrecognized Read selector");
}
if (layout == Layout::HWC) {
*result = Read(GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type));
return absl::OkStatus();
} else if (layout == Layout::BHWC) {
*result =
Read(GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc, storage_type));
return absl::OkStatus();
} else if (layout == Layout::HWDC) {
*result =
Read(GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc, storage_type));
return absl::OkStatus();
} else if (layout == Layout::BHWDC) {
*result = Read(
GetGlobalAddressNoDeclarationWHDSB(xc, yc, zc, sc, bc, storage_type));
return absl::OkStatus();
} else {
return absl::NotFoundError("Unsupported layout");
}
*result =
Read(read_as_type, GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
return absl::OkStatus();
}
absl::Status TensorDescriptor::PerformWriteSelector(
@ -307,29 +316,14 @@ absl::Status TensorDescriptor::PerformWriteSelector(
if (args.size() < 2 || !parsed) {
return absl::NotFoundError("Unrecognized Write selector");
}
if (layout == Layout::HWC) {
*result = Write(args[0],
GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type));
return absl::OkStatus();
} else if (layout == Layout::BHWC) {
*result = Write(args[0], GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc,
storage_type));
return absl::OkStatus();
} else if (layout == Layout::HWDC) {
*result = Write(args[0], GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc,
storage_type));
return absl::OkStatus();
} else if (layout == Layout::BHWDC) {
*result = Write(args[0], GetGlobalAddressNoDeclarationWHDSB(
xc, yc, zc, sc, bc, storage_type));
return absl::OkStatus();
} else {
return absl::NotFoundError("Unsupported layout");
}
*result = Write(args[0], GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
return absl::OkStatus();
}
std::string TensorDescriptor::Read(const std::string& global_address) const {
std::string TensorDescriptor::Read(DataType read_as_type,
const std::string& global_address) const {
const std::string read_as =
read_as_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
std::string image_type;
if (storage_type == TensorStorageType::TEXTURE_2D ||
storage_type == TensorStorageType::SINGLE_TEXTURE_2D) {
@ -341,16 +335,22 @@ std::string TensorDescriptor::Read(const std::string& global_address) const {
}
switch (storage_type) {
case TensorStorageType::BUFFER:
return absl::StrCat("buffer[", global_address, "]");
if (read_as_type == data_type) {
return absl::StrCat("buffer[", global_address, "]");
} else {
const std::string conversion = read_as_type == DataType::FLOAT16
? "convert_half4"
: "convert_float4";
return absl::StrCat(conversion, "(buffer[", global_address, "])");
}
case TensorStorageType::TEXTURE_2D:
case TensorStorageType::TEXTURE_3D:
case TensorStorageType::SINGLE_TEXTURE_2D:
case TensorStorageType::TEXTURE_ARRAY:
return absl::StrCat(GetReadImageFromDataType(data_type), "(", image_type,
", smp_none, ", global_address, ")");
return absl::StrCat(read_as, "(", image_type, ", smp_none, ",
global_address, ")");
case TensorStorageType::IMAGE_BUFFER:
return absl::StrCat(GetReadImageFromDataType(data_type),
"(image_buffer, ", global_address, ")");
return absl::StrCat(read_as, "(image_buffer, ", global_address, ")");
case TensorStorageType::UNKNOWN:
return "";
}
@ -382,6 +382,85 @@ std::string TensorDescriptor::Write(const std::string& var_name,
}
}
absl::Status TensorDescriptor::PerformGetAddressSelector(
const std::vector<std::string>& args, std::string* result) const {
std::string xc;
std::string yc;
std::string zc;
std::string sc;
std::string bc;
bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
if (args.size() < 3 || !parsed) {
return absl::NotFoundError("Unrecognized GetAddress selector");
}
*result = DeclareAddress(args[0],
GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
return absl::OkStatus();
}
std::string TensorDescriptor::DeclareAddress(const std::string& var_name,
const std::string& address) const {
return absl::StrCat(StorageTypeToAddressType(), " ", var_name, " = ", address,
";");
}
std::string TensorDescriptor::StorageTypeToAddressType() const {
switch (storage_type) {
case TensorStorageType::BUFFER:
case TensorStorageType::IMAGE_BUFFER:
return "int";
case TensorStorageType::TEXTURE_2D:
case TensorStorageType::SINGLE_TEXTURE_2D:
return "int2";
case TensorStorageType::TEXTURE_ARRAY:
case TensorStorageType::TEXTURE_3D:
return "int4";
case TensorStorageType::UNKNOWN:
return "";
}
}
std::string TensorDescriptor::GetGlobalAddressNoDeclaration(
const std::string& xc, const std::string& yc, const std::string& zc,
const std::string& sc, const std::string& bc) const {
if (layout == Layout::HWC) {
return GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type);
} else if (layout == Layout::BHWC) {
return GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc, storage_type);
} else if (layout == Layout::HWDC) {
return GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc, storage_type);
} else if (layout == Layout::BHWDC) {
return GetGlobalAddressNoDeclarationWHDSB(xc, yc, zc, sc, bc, storage_type);
} else {
return "Unsupported layout";
}
}
absl::Status TensorDescriptor::GetDataTypeFromTemplateArgs(
const std::string& template_arg, DataType* result) const {
std::string read_type = template_arg;
if (read_type == "FLT" || read_type == "ACCUM_FLT") {
auto it = state_vars_.find(read_type);
if (it == state_vars_.end()) {
return absl::UnavailableError(absl::StrCat(
"Read selector template argument ", read_type, " uninitialized."));
} else {
read_type = it->second;
}
}
if (read_type == "half") {
*result = DataType::FLOAT16;
} else if (read_type == "float") {
*result = DataType::FLOAT32;
} else {
return absl::NotFoundError(absl::StrCat(
"Unrecognized Read selector template argument - ", read_type));
}
return absl::OkStatus();
}
bool TensorDescriptor::HasAxis(Axis axis) const {
if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS) {
return true;

View File

@ -65,6 +65,7 @@ struct TensorDescriptor : public GPUObjectDescriptor {
absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const override;
GPUResources GetGPUResources(AccessType access_type) const override;
@ -79,16 +80,35 @@ struct TensorDescriptor : public GPUObjectDescriptor {
Layout::UNKNOWN; // Supported layouts is HWC, BHWC, HWDC, BHWDC
private:
absl::Status PerformReadSelector(const std::vector<std::string>& args,
std::string* result) const;
absl::Status PerformReadSelector(
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const;
absl::Status PerformGetAddressSelector(const std::vector<std::string>& args,
std::string* result) const;
std::string DeclareAddress(const std::string& var_name,
const std::string& address) const;
std::string StorageTypeToAddressType() const;
absl::Status PerformWriteSelector(const std::vector<std::string>& args,
std::string* result) const;
std::string Read(const std::string& global_address) const;
std::string Read(DataType read_as_type,
const std::string& global_address) const;
std::string Write(const std::string& var_name,
const std::string& global_address) const;
absl::Status GetDataTypeFromTemplateArgs(const std::string& template_arg,
DataType* result) const;
std::string GetGlobalAddressNoDeclaration(const std::string& xc,
const std::string& yc,
const std::string& zc,
const std::string& sc,
const std::string& bc) const;
bool ParseCoordsFromArgs(const std::vector<std::string>& args, int offset,
std::string* xc, std::string* yc, std::string* zc,
std::string* sc, std::string* bc) const;