From 929398ef01d2057b59be1484da0c8bc91a58dcaf Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Wed, 3 Jun 2020 12:02:54 -0700 Subject: [PATCH] Added syntax for template arguments for gpu objects selectors. Now we can write args.src_tensor.Read(...). Demonstrated in Winograd kernel. PiperOrigin-RevId: 314576804 Change-Id: I27cb6e7d251b05d489c9c1dbbe4df8640df05651 --- tensorflow/lite/delegates/gpu/cl/arguments.cc | 68 ++++--- tensorflow/lite/delegates/gpu/cl/arguments.h | 1 + tensorflow/lite/delegates/gpu/cl/gpu_object.h | 7 +- .../lite/delegates/gpu/cl/kernels/winograd.cc | 91 +++++---- .../lite/delegates/gpu/cl/linear_storage.cc | 2 +- .../lite/delegates/gpu/cl/linear_storage.h | 1 + .../lite/delegates/gpu/cl/tensor_type.cc | 173 +++++++++++++----- .../lite/delegates/gpu/cl/tensor_type.h | 26 ++- 8 files changed, 242 insertions(+), 127 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.cc b/tensorflow/lite/delegates/gpu/cl/arguments.cc index d5c0f73469b..0419b158efe 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.cc +++ b/tensorflow/lite/delegates/gpu/cl/arguments.cc @@ -45,6 +45,7 @@ size_t FindEnclosingBracket(const std::string& text, size_t first_pos, {'(', ')'}, {'{', '}'}, {'[', ']'}, + {'<', '>'}, }; char b_open = bracket; auto it = brackets.find(b_open); @@ -70,6 +71,28 @@ size_t FindEnclosingBracket(const std::string& text, size_t first_pos, } } +absl::Status ParseArgsInsideBrackets(const std::string& text, + size_t open_bracket_pos, + size_t* close_bracket_pos, + std::vector* args) { + *close_bracket_pos = + FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]); + if (*close_bracket_pos == -1) { + return absl::NotFoundError("Not found enclosing bracket"); + } + std::string str_args = text.substr(open_bracket_pos + 1, + *close_bracket_pos - open_bracket_pos - 2); + std::vector words = absl::StrSplit(str_args, ','); + args->reserve(words.size()); + for (const auto& word : words) { + absl::string_view arg = absl::StripAsciiWhitespace(word); + if (!arg.empty()) { + args->push_back(std::string(arg)); + } + } + return absl::OkStatus(); +} + void ReplaceAllWords(const std::string& old_word, const std::string& new_word, std::string* str) { size_t position = str->find(old_word); @@ -534,10 +557,10 @@ void Arguments::ResolveObjectNames(const std::string& object_name, } } -absl::Status Arguments::ResolveSelector(const std::string& object_name, - const std::string& selector, - const std::vector& args, - std::string* result) { +absl::Status Arguments::ResolveSelector( + const std::string& object_name, const std::string& selector, + const std::vector& args, + const std::vector& template_args, std::string* result) { const GPUObjectDescriptor* desc_ptr; AccessType access_type; if (auto it = object_refs_.find(object_name); it != object_refs_.end()) { @@ -550,7 +573,8 @@ absl::Status Arguments::ResolveSelector(const std::string& object_name, return absl::NotFoundError( absl::StrCat("No object with name - ", object_name)); } - RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, args, result)); + RETURN_IF_ERROR( + desc_ptr->PerformSelector(selector, args, template_args, result)); auto names = desc_ptr->GetGPUResources(access_type).GetNames(); ResolveObjectNames(object_name, names, result); return absl::OkStatus(); @@ -570,32 +594,26 @@ absl::Status Arguments::ResolveSelectorsPass(std::string* code) { std::string selector_name = GetNextWord(*code, next_position); next_position += selector_name.size(); next = (*code)[next_position]; + std::vector template_args; + if (next == '<') { + size_t close_bracket_pos; + RETURN_IF_ERROR(ParseArgsInsideBrackets( + *code, next_position, &close_bracket_pos, &template_args)); + next_position = close_bracket_pos; + next = (*code)[next_position]; + } if (next != '(') { return absl::NotFoundError( absl::StrCat("Expected ( after function ", selector_name, " call")); } - next_position += 1; - size_t bracket_pos = FindEnclosingBracket(*code, next_position, '('); - if (bracket_pos == -1) { - return absl::NotFoundError( - absl::StrCat("Not found enclosing bracket for function ", - selector_name, " call")); - } - std::string str_args = - code->substr(next_position, bracket_pos - next_position - 1); - std::vector words = absl::StrSplit(str_args, ','); std::vector args; - args.reserve(words.size()); - for (const auto& word : words) { - absl::string_view arg = absl::StripAsciiWhitespace(word); - if (!arg.empty()) { - args.push_back(std::string(arg)); - } - } + size_t close_bracket_pos; + RETURN_IF_ERROR(ParseArgsInsideBrackets(*code, next_position, + &close_bracket_pos, &args)); std::string patch; - RETURN_IF_ERROR( - ResolveSelector(object_name, selector_name, args, &patch)); - code->replace(arg_pos, bracket_pos - arg_pos, patch); + RETURN_IF_ERROR(ResolveSelector(object_name, selector_name, args, + template_args, &patch)); + code->replace(arg_pos, close_bracket_pos - arg_pos, patch); position = arg_pos + patch.size(); } else { position = arg_pos + strlen(kArgsPrefix); diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.h b/tensorflow/lite/delegates/gpu/cl/arguments.h index 17ef4353de8..453ffcb56b2 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.h +++ b/tensorflow/lite/delegates/gpu/cl/arguments.h @@ -88,6 +88,7 @@ class Arguments { absl::Status ResolveSelector(const std::string& object_name, const std::string& selector, const std::vector& args, + const std::vector& template_args, std::string* result); void ResolveObjectNames(const std::string& object_name, diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_object.h b/tensorflow/lite/delegates/gpu/cl/gpu_object.h index b936c1b01ee..fec8999e2bc 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_object.h +++ b/tensorflow/lite/delegates/gpu/cl/gpu_object.h @@ -123,9 +123,10 @@ class GPUObjectDescriptor { return ""; } - virtual absl::Status PerformSelector(const std::string& selector, - const std::vector& args, - std::string* result) const { + virtual absl::Status PerformSelector( + const std::string& selector, const std::vector& args, + const std::vector& template_args, + std::string* result) const { *result = ""; return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc index eeb95ebaff7..66687c40c6a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc @@ -38,16 +38,6 @@ std::string GetWinograd4x4To36Code( const OperationDef& op_def, const std::vector& linked_operations, Arguments* args) { - TensorCodeGenerator src_tensor( - "src_data", - WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"}, - op_def.src_tensors[0]); - TensorCodeGenerator dst_tensor( - "dst_data", - WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"}, - op_def.dst_tensors[0]); - - const std::string batch_id = op_def.IsBatchSupported() ? "batch_id" : ""; std::string c = GetCommonDefines(op_def.precision); const auto src_tensor_type = op_def.src_tensors[0].storage_type; @@ -80,23 +70,30 @@ std::string GetWinograd4x4To36Code( } c += "};\n"; + std::string cl_type = accum_type == DataType::FLOAT16 ? "half" : "float"; + auto src_desc = absl::make_unique(op_def.src_tensors[0]); + src_desc->SetStateVar("ACCUM_FLT", cl_type); + args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc)); + args->AddObjectRef( + "dst_tensor", AccessType::WRITE, + absl::make_unique(op_def.dst_tensors[0])); args->AddInt("padding_x"); args->AddInt("padding_y"); args->AddInt("tiles_total"); args->AddInt("tiles_x"); + std::string linked_args = GetArgsDeclaration(linked_operations); + if (linked_args[0] == ',') { + linked_args[0] = ' '; + } c += "__kernel void main_function(\n"; - c += src_tensor.GetDeclaration(AccessType::READ); - c += GetArgsDeclaration(linked_operations); - c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n"; - c += " int4 src_size, \n"; - c += " int4 dst_size,\n "; + c += linked_args; c += "$0) {\n"; c += " int DST_X = get_global_id(0);\n"; c += " int DST_Y = get_global_id(1);\n"; c += " int DST_Z = get_global_id(2);\n"; - c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= dst_size.z) " - "{\n"; + c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= " + "args.dst_tensor.Slices()) {\n"; c += " return; \n"; c += " }\n"; c += " int tile_x = (DST_X % args.tiles_x) * 4;\n"; @@ -114,19 +111,16 @@ std::string GetWinograd4x4To36Code( c += " bt_ar[5] = t1.y;\n"; auto read_src = [&](const std::string& src, const std::string& xs) { if (is_image_buffer) { - c += " ACCUM_FLT4 " + src + " = " + - src_tensor.ReadAsType(accum_type, "src_a_" + xs + " + offset") + - ";\n"; + c += " ACCUM_FLT4 " + src + + " = args.src_tensor.Read(src_a_" + xs + " + offset);\n"; } else if (is_buffer) { - c += " ACCUM_FLT4 " + src + " = " + - src_tensor.ReadAsType(accum_type, "src_a_" + xs + " + offset") + - " * m" + xs + "_x;\n"; + c += " ACCUM_FLT4 " + src + + " = args.src_tensor.Read(src_a_" + xs + " + offset) * m" + + xs + "_x;\n"; } else { - c += " ACCUM_FLT4 " + src + " = " + - src_tensor.ReadAsTypeWHSB(accum_type, - "tile_x + args.padding_x + " + xs, "yc", - "DST_Z", batch_id) + - ";\n"; + c += " ACCUM_FLT4 " + src + + " = args.src_tensor.Read(tile_x + args.padding_x + " + + xs + ", yc, DST_Z);\n"; } }; if (is_buffer || is_image_buffer) { @@ -134,14 +128,17 @@ std::string GetWinograd4x4To36Code( const std::string xs = std::to_string(x); c += " int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n"; c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" + - xs + " < src_size.x);\n"; + xs + " < args.src_tensor.Width());\n"; c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs + - " < src_size.x);\n"; - c += " xc" + xs + " = clamp(xc" + xs + ", 0, src_size.x - 1);\n"; - c += " " + src_tensor.GetAddressWHSB("src_a_" + xs, "xc" + xs, "0", - "DST_Z", batch_id); + " < args.src_tensor.Width());\n"; + c += " xc" + xs + " = clamp(xc" + xs + + ", 0, args.src_tensor.Width() - 1);\n"; + c += " args.src_tensor.GetAddress(src_a_" + xs + ", xc" + xs + + ", 0, DST_Z);\n"; if (is_image_buffer) { - c += " src_a_" + xs + " = select(-src_size.x * src_size.y, src_a_" + + c += " src_a_" + xs + + " = select(-args.src_tensor.Width() * args.src_tensor.Height(), " + "src_a_" + xs + ", inx" + xs + ");\n"; } } @@ -149,8 +146,8 @@ std::string GetWinograd4x4To36Code( c += " {\n"; c += " int yc = tile_y + args.padding_y;\n"; if (is_buffer || is_image_buffer) { - c += " bool iny = (yc >= 0 && yc < src_size.y);\n"; - c += " int offset = select(0, yc * src_size.x, iny);\n"; + c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n"; + c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n"; c += " ACCUM_FLT bt = bt_ar[0] * (ACCUM_FLT)(iny);\n"; } else { c += " ACCUM_FLT bt = bt_ar[0];\n"; @@ -167,8 +164,8 @@ std::string GetWinograd4x4To36Code( c += " {\n"; c += " int yc = tile_y + args.padding_y + (" + ys + ");\n"; if (is_buffer || is_image_buffer) { - c += " bool iny = (yc >= 0 && yc < src_size.y);\n"; - c += " int offset = select(0, yc * src_size.x, iny);\n"; + c += " bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n"; + c += " int offset = select(0, yc * args.src_tensor.Width(), iny);\n"; c += " ACCUM_FLT bt = bt_ar[" + ys + "] * (ACCUM_FLT)(iny);\n"; } else { c += " ACCUM_FLT bt = bt_ar[" + ys + "];\n"; @@ -185,14 +182,14 @@ std::string GetWinograd4x4To36Code( c += " {\n"; c += " FLT4 r0 = TO_FLT4(I0 + Bt[2] * I2 + Bt[4] * I4);\n"; c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); + c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n"; c += " DST_Y++;\n"; c += " }\n"; c += " {\n"; c += " FLT4 r0 = TO_FLT4(Bt[7] * I1 + Bt[8] * I2 + Bt[9] * I3 + Bt[10] * " "I4);\n"; c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); + c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n"; c += " DST_Y++;\n"; c += " }\n"; c += " {\n"; @@ -200,7 +197,7 @@ std::string GetWinograd4x4To36Code( "* " "I4);\n"; c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); + c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n"; c += " DST_Y++;\n"; c += " }\n"; c += " {\n"; @@ -208,7 +205,7 @@ std::string GetWinograd4x4To36Code( "* " "I4);\n"; c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); + c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n"; c += " DST_Y++;\n"; c += " }\n"; c += " {\n"; @@ -216,13 +213,13 @@ std::string GetWinograd4x4To36Code( "* " "I4);\n"; c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); + c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n"; c += " DST_Y++;\n"; c += " }\n"; c += " {\n"; c += " FLT4 r0 = TO_FLT4(Bt[31] * I1 + Bt[33] * I3 + I5);\n"; c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id); + c += " args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n"; c += " DST_Y++;\n"; c += " }\n"; c += "}\n"; @@ -443,16 +440,14 @@ absl::Status Winograd4x4To36::BindArguments() { const int tiles_y = DivideRoundUp( src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4); const int tiles_total = tiles_x * tiles_y; + RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0])); + RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0])); RETURN_IF_ERROR(args_.SetInt("padding_x", -padding_.prepended.w)); RETURN_IF_ERROR(args_.SetInt("padding_y", -padding_.prepended.h)); RETURN_IF_ERROR(args_.SetInt("tiles_total", tiles_total)); RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x)); kernel_.ResetBindingCounter(); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter())); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc index 7edf83f57ff..84d91b9136e 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc @@ -44,7 +44,7 @@ GPUResources TensorLinearDescriptor::GetGPUResources( absl::Status TensorLinearDescriptor::PerformSelector( const std::string& selector, const std::vector& args, - std::string* result) const { + const std::vector& template_args, std::string* result) const { if (selector == "Length") { *result = "length"; return absl::OkStatus(); diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.h b/tensorflow/lite/delegates/gpu/cl/linear_storage.h index 83c41e2c833..474c5652db2 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.h +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.h @@ -57,6 +57,7 @@ struct TensorLinearDescriptor : public GPUObjectDescriptor { absl::Status PerformSelector(const std::string& selector, const std::vector& args, + const std::vector& template_args, std::string* result) const override; GPUResources GetGPUResources(AccessType access_type) const override; diff --git a/tensorflow/lite/delegates/gpu/cl/tensor_type.cc b/tensorflow/lite/delegates/gpu/cl/tensor_type.cc index 11e1ca2ca07..0421e304afc 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor_type.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor_type.cc @@ -227,7 +227,7 @@ GPUResources TensorDescriptor::GetGPUResources(AccessType access_type) const { absl::Status TensorDescriptor::PerformSelector( const std::string& selector, const std::vector& args, - std::string* result) const { + const std::vector& template_args, std::string* result) const { if (selector == "Width") { *result = "width"; return absl::OkStatus(); @@ -255,9 +255,11 @@ absl::Status TensorDescriptor::PerformSelector( *result = ""; return absl::OkStatus(); } else if (selector == "Read") { - return PerformReadSelector(args, result); + return PerformReadSelector(args, template_args, result); } else if (selector == "Write") { return PerformWriteSelector(args, result); + } else if (selector == "GetAddress") { + return PerformGetAddressSelector(args, result); } else { return absl::NotFoundError(absl::StrCat( "TensorDescriptor don't have selector with name - ", selector)); @@ -265,7 +267,29 @@ absl::Status TensorDescriptor::PerformSelector( } absl::Status TensorDescriptor::PerformReadSelector( - const std::vector& args, std::string* result) const { + const std::vector& args, + const std::vector& template_args, std::string* result) const { + DataType read_as_type = data_type; + if (!template_args.empty()) { + if (template_args.size() != 1) { + return absl::NotFoundError( + "Unrecognized Read selector template arguments."); + } else { + RETURN_IF_ERROR( + GetDataTypeFromTemplateArgs(template_args[0], &read_as_type)); + } + } + if (args.size() == 1) { // function overload for 1D linear types. + if (storage_type == TensorStorageType::BUFFER || + storage_type == TensorStorageType::IMAGE_BUFFER) { + *result = Read(read_as_type, args[0]); + return absl::OkStatus(); + } else { + return absl::InvalidArgumentError( + "Read selector with single argument can be used only with linear " + "storage types(BUFFER or IMAGE_BUFFER)"); + } + } std::string xc; std::string yc; std::string zc; @@ -276,24 +300,9 @@ absl::Status TensorDescriptor::PerformReadSelector( return absl::NotFoundError("Unrecognized Read selector"); } - if (layout == Layout::HWC) { - *result = Read(GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type)); - return absl::OkStatus(); - } else if (layout == Layout::BHWC) { - *result = - Read(GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc, storage_type)); - return absl::OkStatus(); - } else if (layout == Layout::HWDC) { - *result = - Read(GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc, storage_type)); - return absl::OkStatus(); - } else if (layout == Layout::BHWDC) { - *result = Read( - GetGlobalAddressNoDeclarationWHDSB(xc, yc, zc, sc, bc, storage_type)); - return absl::OkStatus(); - } else { - return absl::NotFoundError("Unsupported layout"); - } + *result = + Read(read_as_type, GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc)); + return absl::OkStatus(); } absl::Status TensorDescriptor::PerformWriteSelector( @@ -307,29 +316,14 @@ absl::Status TensorDescriptor::PerformWriteSelector( if (args.size() < 2 || !parsed) { return absl::NotFoundError("Unrecognized Write selector"); } - - if (layout == Layout::HWC) { - *result = Write(args[0], - GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type)); - return absl::OkStatus(); - } else if (layout == Layout::BHWC) { - *result = Write(args[0], GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc, - storage_type)); - return absl::OkStatus(); - } else if (layout == Layout::HWDC) { - *result = Write(args[0], GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc, - storage_type)); - return absl::OkStatus(); - } else if (layout == Layout::BHWDC) { - *result = Write(args[0], GetGlobalAddressNoDeclarationWHDSB( - xc, yc, zc, sc, bc, storage_type)); - return absl::OkStatus(); - } else { - return absl::NotFoundError("Unsupported layout"); - } + *result = Write(args[0], GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc)); + return absl::OkStatus(); } -std::string TensorDescriptor::Read(const std::string& global_address) const { +std::string TensorDescriptor::Read(DataType read_as_type, + const std::string& global_address) const { + const std::string read_as = + read_as_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef"; std::string image_type; if (storage_type == TensorStorageType::TEXTURE_2D || storage_type == TensorStorageType::SINGLE_TEXTURE_2D) { @@ -341,16 +335,22 @@ std::string TensorDescriptor::Read(const std::string& global_address) const { } switch (storage_type) { case TensorStorageType::BUFFER: - return absl::StrCat("buffer[", global_address, "]"); + if (read_as_type == data_type) { + return absl::StrCat("buffer[", global_address, "]"); + } else { + const std::string conversion = read_as_type == DataType::FLOAT16 + ? "convert_half4" + : "convert_float4"; + return absl::StrCat(conversion, "(buffer[", global_address, "])"); + } case TensorStorageType::TEXTURE_2D: case TensorStorageType::TEXTURE_3D: case TensorStorageType::SINGLE_TEXTURE_2D: case TensorStorageType::TEXTURE_ARRAY: - return absl::StrCat(GetReadImageFromDataType(data_type), "(", image_type, - ", smp_none, ", global_address, ")"); + return absl::StrCat(read_as, "(", image_type, ", smp_none, ", + global_address, ")"); case TensorStorageType::IMAGE_BUFFER: - return absl::StrCat(GetReadImageFromDataType(data_type), - "(image_buffer, ", global_address, ")"); + return absl::StrCat(read_as, "(image_buffer, ", global_address, ")"); case TensorStorageType::UNKNOWN: return ""; } @@ -382,6 +382,85 @@ std::string TensorDescriptor::Write(const std::string& var_name, } } +absl::Status TensorDescriptor::PerformGetAddressSelector( + const std::vector& args, std::string* result) const { + std::string xc; + std::string yc; + std::string zc; + std::string sc; + std::string bc; + bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc); + if (args.size() < 3 || !parsed) { + return absl::NotFoundError("Unrecognized GetAddress selector"); + } + + *result = DeclareAddress(args[0], + GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc)); + return absl::OkStatus(); +} + +std::string TensorDescriptor::DeclareAddress(const std::string& var_name, + const std::string& address) const { + return absl::StrCat(StorageTypeToAddressType(), " ", var_name, " = ", address, + ";"); +} + +std::string TensorDescriptor::StorageTypeToAddressType() const { + switch (storage_type) { + case TensorStorageType::BUFFER: + case TensorStorageType::IMAGE_BUFFER: + return "int"; + case TensorStorageType::TEXTURE_2D: + case TensorStorageType::SINGLE_TEXTURE_2D: + return "int2"; + case TensorStorageType::TEXTURE_ARRAY: + case TensorStorageType::TEXTURE_3D: + return "int4"; + case TensorStorageType::UNKNOWN: + return ""; + } +} + +std::string TensorDescriptor::GetGlobalAddressNoDeclaration( + const std::string& xc, const std::string& yc, const std::string& zc, + const std::string& sc, const std::string& bc) const { + if (layout == Layout::HWC) { + return GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type); + } else if (layout == Layout::BHWC) { + return GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc, storage_type); + } else if (layout == Layout::HWDC) { + return GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc, storage_type); + } else if (layout == Layout::BHWDC) { + return GetGlobalAddressNoDeclarationWHDSB(xc, yc, zc, sc, bc, storage_type); + } else { + return "Unsupported layout"; + } +} + +absl::Status TensorDescriptor::GetDataTypeFromTemplateArgs( + const std::string& template_arg, DataType* result) const { + std::string read_type = template_arg; + if (read_type == "FLT" || read_type == "ACCUM_FLT") { + auto it = state_vars_.find(read_type); + if (it == state_vars_.end()) { + return absl::UnavailableError(absl::StrCat( + "Read selector template argument ", read_type, " uninitialized.")); + } else { + read_type = it->second; + } + } + + if (read_type == "half") { + *result = DataType::FLOAT16; + } else if (read_type == "float") { + *result = DataType::FLOAT32; + } else { + return absl::NotFoundError(absl::StrCat( + "Unrecognized Read selector template argument - ", read_type)); + } + return absl::OkStatus(); +} + bool TensorDescriptor::HasAxis(Axis axis) const { if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS) { return true; diff --git a/tensorflow/lite/delegates/gpu/cl/tensor_type.h b/tensorflow/lite/delegates/gpu/cl/tensor_type.h index 7d5ff888a85..71fa2d94880 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor_type.h +++ b/tensorflow/lite/delegates/gpu/cl/tensor_type.h @@ -65,6 +65,7 @@ struct TensorDescriptor : public GPUObjectDescriptor { absl::Status PerformSelector(const std::string& selector, const std::vector& args, + const std::vector& template_args, std::string* result) const override; GPUResources GetGPUResources(AccessType access_type) const override; @@ -79,16 +80,35 @@ struct TensorDescriptor : public GPUObjectDescriptor { Layout::UNKNOWN; // Supported layouts is HWC, BHWC, HWDC, BHWDC private: - absl::Status PerformReadSelector(const std::vector& args, - std::string* result) const; + absl::Status PerformReadSelector( + const std::vector& args, + const std::vector& template_args, std::string* result) const; + + absl::Status PerformGetAddressSelector(const std::vector& args, + std::string* result) const; + + std::string DeclareAddress(const std::string& var_name, + const std::string& address) const; + + std::string StorageTypeToAddressType() const; absl::Status PerformWriteSelector(const std::vector& args, std::string* result) const; - std::string Read(const std::string& global_address) const; + std::string Read(DataType read_as_type, + const std::string& global_address) const; std::string Write(const std::string& var_name, const std::string& global_address) const; + absl::Status GetDataTypeFromTemplateArgs(const std::string& template_arg, + DataType* result) const; + + std::string GetGlobalAddressNoDeclaration(const std::string& xc, + const std::string& yc, + const std::string& zc, + const std::string& sc, + const std::string& bc) const; + bool ParseCoordsFromArgs(const std::vector& args, int offset, std::string* xc, std::string* yc, std::string* zc, std::string* sc, std::string* bc) const;