diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 95b20bc6e81..d37f666f8a6 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -54,6 +54,23 @@ cc_library( ], ) +cc_test( + name = "arguments_test", + srcs = ["arguments_test.cc"], + linkstatic = True, + tags = tf_gpu_tests_tags() + [ + "linux", + "local", + ], + deps = [ + ":arguments", + ":gpu_object", + "//tensorflow/lite/delegates/gpu/common:data_type", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "buffer", srcs = ["buffer.cc"], @@ -354,6 +371,7 @@ cc_library( hdrs = ["linear_storage.h"], deps = [ ":buffer", + ":gpu_object", ":opencl_wrapper", ":tensor_type", ":texture2d", diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.cc b/tensorflow/lite/delegates/gpu/cl/arguments.cc index bdfae935f28..7b28ee215da 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.cc +++ b/tensorflow/lite/delegates/gpu/cl/arguments.cc @@ -17,6 +17,8 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" #include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { @@ -36,6 +38,55 @@ std::string GetNextWord(const std::string& code, size_t first_position) { } return code.substr(first_position, pos - first_position); } + +size_t FindEnclosingBracket(const std::string& text, size_t first_pos, + char bracket) { + const std::map brackets = { + {'(', ')'}, + {'{', '}'}, + {'[', ']'}, + }; + char b_open = bracket; + auto it = brackets.find(b_open); + if (it == brackets.end()) { + return -1; + } + char b_close = it->second; + size_t pos = first_pos; + int opened = 1; + int closed = 0; + while (opened != closed && pos < text.size()) { + if (text[pos] == b_open) { + opened++; + } else if (text[pos] == b_close) { + closed++; + } + pos++; + } + if (opened == closed) { + return pos; + } else { + return -1; + } +} + +void ReplaceAllWords(const std::string& old_word, const std::string& new_word, + std::string* str) { + size_t position = str->find(old_word); + while (position != std::string::npos) { + char prev = position == 0 ? '.' : (*str)[position - 1]; + char next = position + old_word.size() < str->size() + ? (*str)[position + old_word.size()] + : '.'; + if (IsWordSymbol(prev) || IsWordSymbol(next)) { + position = str->find(old_word, position + 1); + continue; + } + str->replace(position, old_word.size(), new_word); + position = str->find(old_word, position + new_word.size()); + } +} + } // namespace Arguments::Arguments(Arguments&& args) @@ -45,6 +96,7 @@ Arguments::Arguments(Arguments&& args) shared_float4s_data_(std::move(args.shared_float4s_data_)), buffers_(std::move(args.buffers_)), images2d_(std::move(args.images2d_)), + object_refs_(std::move(args.object_refs_)), objects_(std::move(args.objects_)) {} Arguments& Arguments::operator=(Arguments&& args) { if (this != &args) { @@ -54,6 +106,7 @@ Arguments& Arguments::operator=(Arguments&& args) { shared_float4s_data_ = std::move(args.shared_float4s_data_); buffers_ = std::move(args.buffers_); images2d_ = std::move(args.images2d_); + object_refs_ = std::move(args.object_refs_); objects_ = std::move(args.objects_); } return *this; @@ -74,6 +127,11 @@ void Arguments::AddImage2D(const std::string& name, images2d_[name] = desc; } +void Arguments::AddObjectRef(const std::string& name, + GPUObjectDescriptorPtr&& descriptor_ptr) { + object_refs_[name] = {AccessType::READ, std::move(descriptor_ptr)}; +} + void Arguments::AddObject(const std::string& name, GPUObjectPtr&& object) { objects_[name] = {AccessType::READ, std::move(object)}; } @@ -159,6 +217,7 @@ absl::Status Arguments::SetGPUResources( absl::Status Arguments::TransformToCLCode(std::string* code) { RETURN_IF_ERROR(AddObjectArgs()); + RETURN_IF_ERROR(ResolveSelectorsPass(code)); ResolveArgsPass(code); return absl::OkStatus(); } @@ -260,18 +319,17 @@ std::string Arguments::AddActiveArgument(const std::string& arg_name) { } void Arguments::ResolveArgsPass(std::string* code) { - constexpr char kPrefix[] = "args."; std::string result; size_t position = 0; - size_t next_position = code->find(kPrefix); + size_t next_position = code->find(kArgsPrefix); while (next_position != std::string::npos) { size_t arg_pos = next_position; - next_position += strlen(kPrefix); + next_position += strlen(kArgsPrefix); std::string object_name = GetNextWord(*code, next_position); std::string new_name = AddActiveArgument(object_name); - code->replace(arg_pos, object_name.size() + strlen(kPrefix), new_name); + code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name); position = arg_pos + new_name.size(); - next_position = code->find(kPrefix, position); + next_position = code->find(kArgsPrefix, position); } int shared_int4s_aligned_size = AlignByN(shared_int4s_data_.size(), 4); @@ -280,6 +338,86 @@ void Arguments::ResolveArgsPass(std::string* code) { shared_float4s_data_.resize(shared_float4s_aligned_size); } +void Arguments::ResolveObjectNames(const std::string& object_name, + const std::vector& member_names, + std::string* code) { + for (const auto& member_name : member_names) { + const std::string new_name = "args." + object_name + "_" + member_name; + ReplaceAllWords(member_name, new_name, code); + } +} + +absl::Status Arguments::ResolveSelector(const std::string& object_name, + const std::string& selector, + const std::vector& args, + std::string* result) { + const GPUObjectDescriptor* desc_ptr; + AccessType access_type; + if (auto it = object_refs_.find(object_name); it != object_refs_.end()) { + desc_ptr = it->second.descriptor.get(); + access_type = it->second.access_type; + } else if (auto it = objects_.find(object_name); it != objects_.end()) { + desc_ptr = it->second.obj_ptr->GetGPUDescriptor(); + access_type = it->second.access_type; + } else { + return absl::NotFoundError( + absl::StrCat("No object with name - ", object_name)); + } + RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, args, result)); + auto names = desc_ptr->GetGPUResources().GetNames(); + ResolveObjectNames(object_name, names, result); + return absl::OkStatus(); +} + +absl::Status Arguments::ResolveSelectorsPass(std::string* code) { + std::string result; + size_t position = 0; + size_t next_position = code->find(kArgsPrefix); + while (next_position != std::string::npos) { + size_t arg_pos = next_position; + next_position += strlen(kArgsPrefix); + std::string object_name = GetNextWord(*code, next_position); + char next = (*code)[next_position + object_name.size()]; + if (next == '.') { + next_position += object_name.size() + 1; + std::string selector_name = GetNextWord(*code, next_position); + next_position += selector_name.size(); + next = (*code)[next_position]; + if (next != '(') { + return absl::NotFoundError( + absl::StrCat("Expected ( after function ", selector_name, " call")); + } + next_position += 1; + size_t bracket_pos = FindEnclosingBracket(*code, next_position, '('); + if (bracket_pos == -1) { + return absl::NotFoundError( + absl::StrCat("Not found enclosing bracket for function ", + selector_name, " call")); + } + std::string str_args = + code->substr(next_position, bracket_pos - next_position - 1); + std::vector words = absl::StrSplit(str_args, ','); + std::vector args; + args.reserve(words.size()); + for (const auto& word : words) { + absl::string_view arg = absl::StripAsciiWhitespace(word); + if (!arg.empty()) { + args.push_back(std::string(arg)); + } + } + std::string patch; + RETURN_IF_ERROR( + ResolveSelector(object_name, selector_name, args, &patch)); + code->replace(arg_pos, bracket_pos - arg_pos, patch); + position = arg_pos + patch.size(); + } else { + position = arg_pos + strlen(kArgsPrefix); + } + next_position = code->find(kArgsPrefix, position); + } + return absl::OkStatus(); +} + absl::Status Arguments::AddObjectArgs() { for (auto& t : objects_) { AddGPUResources(t.first, @@ -287,6 +425,9 @@ absl::Status Arguments::AddObjectArgs() { RETURN_IF_ERROR( SetGPUResources(t.first, t.second.obj_ptr->GetGPUResources())); } + for (auto& t : object_refs_) { + AddGPUResources(t.first, t.second.descriptor->GetGPUResources()); + } return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.h b/tensorflow/lite/delegates/gpu/cl/arguments.h index f1059e77c93..65c114b2cf6 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.h +++ b/tensorflow/lite/delegates/gpu/cl/arguments.h @@ -40,6 +40,8 @@ class Arguments { void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc); void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc); + void AddObjectRef(const std::string& name, + GPUObjectDescriptorPtr&& descriptor_ptr); void AddObject(const std::string& name, GPUObjectPtr&& object); absl::Status SetInt(const std::string& name, int value); @@ -69,6 +71,18 @@ class Arguments { absl::Status AddObjectArgs(); void ResolveArgsPass(std::string* code); + absl::Status ResolveSelectorsPass(std::string* code); + + absl::Status ResolveSelector(const std::string& object_name, + const std::string& selector, + const std::vector& args, + std::string* result); + + void ResolveObjectNames(const std::string& object_name, + const std::vector& member_names, + std::string* code); + + static constexpr char kArgsPrefix[] = "args."; struct IntValue { int value; @@ -99,6 +113,12 @@ class Arguments { std::map buffers_; std::map images2d_; + struct ObjectRefArg { + AccessType access_type; + GPUObjectDescriptorPtr descriptor; + }; + std::map object_refs_; + struct ObjectArg { AccessType access_type; GPUObjectPtr obj_ptr; diff --git a/tensorflow/lite/delegates/gpu/cl/arguments_test.cc b/tensorflow/lite/delegates/gpu/cl/arguments_test.cc new file mode 100644 index 00000000000..1a4c9fc9c00 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/arguments_test.cc @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" + +#include + +#include +#include +#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +struct TestDescriptor : public GPUObjectDescriptor { + absl::Status PerformSelector(const std::string& selector, + const std::vector& args, + std::string* result) const override { + if (selector == "Length") { + *result = "length"; + return absl::OkStatus(); + } else if (selector == "Read") { + if (args.size() != 1) { + return absl::NotFoundError( + absl::StrCat("TestDescriptor Read require one argument, but ", + args.size(), " was passed")); + } + *result = absl::StrCat("buffer[", args[0], "]"); + return absl::OkStatus(); + } else { + return absl::NotFoundError(absl::StrCat( + "TestDescriptor don't have selector with name - ", selector)); + } + } + + GPUResources GetGPUResources() const override { + GPUResources resources; + resources.ints.push_back("length"); + GPUBufferDescriptor desc; + desc.data_type = DataType::FLOAT32; + desc.element_size = 4; + resources.buffers.push_back({"buffer", desc}); + return resources; + } +}; +} // namespace + +TEST(ArgumentsTest, TestSelectorResolve) { + TestDescriptor descriptor; + Arguments args; + args.AddObjectRef("object", absl::make_unique(descriptor)); + std::string sample_code = R"( + if (a < 3) { + value = args.object.Read(id); + } +)"; + const std::string expected_result = R"( + if (a < 3) { + value = object_buffer[id]; + } +)"; + ASSERT_OK(args.TransformToCLCode(&sample_code)); + EXPECT_EQ(sample_code, expected_result); + + std::string cl_arguments = args.GetListOfArgs(); + EXPECT_TRUE(cl_arguments.find("__global float4* object_buffer") != + std::string::npos); +} + +TEST(ArgumentsTest, TestNoSelector) { + TestDescriptor descriptor; + Arguments args; + args.AddObjectRef("object", absl::make_unique(descriptor)); + std::string sample_code = R"( + if (a < 3) { + value = args.object.Write(id); + } +)"; + EXPECT_FALSE(args.TransformToCLCode(&sample_code).ok()); +} + +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_object.h b/tensorflow/lite/delegates/gpu/cl/gpu_object.h index 5cc045c6fc7..23d1f210459 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_object.h +++ b/tensorflow/lite/delegates/gpu/cl/gpu_object.h @@ -99,6 +99,8 @@ class GPUObjectDescriptor { mutable std::map state_vars_; }; +using GPUObjectDescriptorPtr = std::unique_ptr; + class GPUObject { public: GPUObject() = default; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index b5510b3e8df..24a62e5a82f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -1385,6 +1385,7 @@ cc_library( ":gpu_operation", ":util", ":work_group_picking", + "//tensorflow/lite/delegates/gpu/cl:arguments", "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/cl:linear_storage", @@ -1395,6 +1396,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:winograd_util", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc index 6219952b9bf..e3c9306b80c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/str_format.h" +#include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" @@ -34,8 +35,9 @@ namespace cl { namespace { std::string GetWinograd4x4To36Code( - const OperationDef& op_def, const LinearStorage& bt_arr, - const std::vector& linked_operations) { + const OperationDef& op_def, + const std::vector& linked_operations, + Arguments* args) { TensorCodeGenerator src_tensor( "src_data", WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"}, @@ -78,31 +80,31 @@ std::string GetWinograd4x4To36Code( } c += "};\n"; + args->AddInt("padding_x"); + args->AddInt("padding_y"); + args->AddInt("tiles_total"); + args->AddInt("tiles_x"); + c += "__kernel void main_function(\n"; - c += src_tensor.GetDeclaration(AccessType::READ) + ",\n"; - c += bt_arr.GetDeclaration(); + c += src_tensor.GetDeclaration(AccessType::READ); c += GetArgsDeclaration(linked_operations); c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n"; c += " int4 src_size, \n"; - c += " int4 dst_size, \n"; - c += " int2 padding, \n"; - c += " int tiles_total, \n"; - c += " int tiles_x \n"; - c += ") {\n"; + c += " int4 dst_size"; + c += "$0) {\n"; c += " int DST_X = get_global_id(0);\n"; c += " int DST_Y = get_global_id(1);\n"; c += " int DST_Z = get_global_id(2);\n"; - c += " if (DST_X >= tiles_total || DST_Y >= 6 || DST_Z >= dst_size.z) {\n"; + c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= dst_size.z) " + "{\n"; c += " return; \n"; c += " }\n"; - c += " int tile_x = (DST_X % tiles_x) * 4;\n"; - c += " int tile_y = (DST_X / tiles_x) * 4;\n"; + c += " int tile_x = (DST_X % args.tiles_x) * 4;\n"; + c += " int tile_y = (DST_X / args.tiles_x) * 4;\n"; c += " ACCUM_FLT4 I0, I1, I2, I3, I4, I5;\n"; c += " ACCUM_FLT bt_ar[6];\n"; - c += " ACCUM_FLT4 t0 = TO_ACCUM_TYPE(" + - bt_arr.ReadLinearFLT4("DST_Y * 2 + 0") + ");\n"; - c += " ACCUM_FLT4 t1 = TO_ACCUM_TYPE(" + - bt_arr.ReadLinearFLT4("DST_Y * 2 + 1") + ");\n"; + c += " ACCUM_FLT4 t0 = TO_ACCUM_TYPE(args.bt.Read(DST_Y * 2 + 0));\n"; + c += " ACCUM_FLT4 t1 = TO_ACCUM_TYPE(args.bt.Read(DST_Y * 2 + 1));\n"; c += " DST_Y *= 6;\n"; c += " bt_ar[0] = t0.x;\n"; c += " bt_ar[1] = t0.y;\n"; @@ -121,15 +123,16 @@ std::string GetWinograd4x4To36Code( " * m" + xs + "_x;\n"; } else { c += " ACCUM_FLT4 " + src + " = " + - src_tensor.ReadAsTypeWHSB(accum_type, "tile_x + padding.x + " + xs, - "yc", "DST_Z", batch_id) + + src_tensor.ReadAsTypeWHSB(accum_type, + "tile_x + args.padding_x + " + xs, "yc", + "DST_Z", batch_id) + ";\n"; } }; if (is_buffer || is_image_buffer) { for (int x = 0; x < 6; ++x) { const std::string xs = std::to_string(x); - c += " int xc" + xs + " = tile_x + padding.x + " + xs + ";\n"; + c += " int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n"; c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" + xs + " < src_size.x);\n"; c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs + @@ -144,7 +147,7 @@ std::string GetWinograd4x4To36Code( } } c += " {\n"; - c += " int yc = tile_y + padding.y;\n"; + c += " int yc = tile_y + args.padding_y;\n"; if (is_buffer || is_image_buffer) { c += " bool iny = (yc >= 0 && yc < src_size.y);\n"; c += " int offset = select(0, yc * src_size.x, iny);\n"; @@ -162,7 +165,7 @@ std::string GetWinograd4x4To36Code( for (int y = 1; y < 6; ++y) { const std::string ys = std::to_string(y); c += " {\n"; - c += " int yc = tile_y + padding.y + (" + ys + ");\n"; + c += " int yc = tile_y + args.padding_y + (" + ys + ");\n"; if (is_buffer || is_image_buffer) { c += " bool iny = (yc >= 0 && yc < src_size.y);\n"; c += " int offset = select(0, yc * src_size.x, iny);\n"; @@ -223,7 +226,6 @@ std::string GetWinograd4x4To36Code( c += " DST_Y++;\n"; c += " }\n"; c += "}\n"; - // std::cout << c << std::endl; return c; } @@ -366,15 +368,15 @@ std::string GetWinograd36To4x4Code( Winograd4x4To36::Winograd4x4To36(Winograd4x4To36&& operation) : GPUOperation(std::move(operation)), - bt_(std::move(operation.bt_)), padding_(operation.padding_), + args_(std::move(operation.args_)), kernel_(std::move(operation.kernel_)), work_group_size_(operation.work_group_size_) {} Winograd4x4To36& Winograd4x4To36::operator=(Winograd4x4To36&& operation) { if (this != &operation) { - bt_ = std::move(operation.bt_); std::swap(padding_, operation.padding_); + args_ = std::move(operation.args_); kernel_ = std::move(operation.kernel_); std::swap(work_group_size_, operation.work_group_size_); GPUOperation::operator=(std::move(operation)); @@ -392,8 +394,10 @@ absl::Status Winograd4x4To36::Compile(const CreationContext& creation_context) { options.push_back(CompilerOptions::POWERVR_FP16); } RETURN_IF_ERROR(UploadBt(creation_context.context)); - const auto code = - GetWinograd4x4To36Code(definition_, bt_, linked_operations_); + std::string code = + GetWinograd4x4To36Code(definition_, linked_operations_, &args_); + RETURN_IF_ERROR(args_.TransformToCLCode(&code)); + code = absl::Substitute(code, args_.GetListOfArgs()); RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( code, "main_function", options, *creation_context.context, *creation_context.device, &kernel_)); @@ -418,7 +422,11 @@ absl::Status Winograd4x4To36::UploadBt(CLContext* context) { create_info.storage_type = LinearStorageType::TEXTURE_2D; create_info.data_type = definition_.GetDataType(); create_info.name = "bt_arr"; - return CreateLinearStorage(create_info, bt_aligned, context, &bt_); + + LinearStorage lt; + RETURN_IF_ERROR(CreateLinearStorage(create_info, bt_aligned, context, <)); + args_.AddObject("bt", absl::make_unique(std::move(lt))); + return absl::OkStatus(); } int3 Winograd4x4To36::SelectBestWorkGroup() { @@ -429,22 +437,22 @@ int3 Winograd4x4To36::SelectBestWorkGroup() { } absl::Status Winograd4x4To36::BindArguments() { - kernel_.ResetBindingCounter(); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(bt_.GetMemoryPtr())); - RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); const int tiles_x = DivideRoundUp( src_[0]->Width() + padding_.prepended.w + padding_.appended.w - 2, 4); const int tiles_y = DivideRoundUp( src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4); const int tiles_total = tiles_x * tiles_y; - RETURN_IF_ERROR( - kernel_.SetBytesAuto(int2(-padding_.prepended.w, -padding_.prepended.h))); - RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_total)); - RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_x)); + RETURN_IF_ERROR(args_.SetInt("padding_x", -padding_.prepended.w)); + RETURN_IF_ERROR(args_.SetInt("padding_y", -padding_.prepended.h)); + RETURN_IF_ERROR(args_.SetInt("tiles_total", tiles_total)); + RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x)); + kernel_.ResetBindingCounter(); + RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); + RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); + RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); + RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); + RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); + RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter())); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h index c6a88773af3..02e3c268b28 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WINOGRAD_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WINOGRAD_H_ +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" @@ -59,9 +60,9 @@ class Winograd4x4To36 : public GPUOperation { absl::Status BindArguments(); int3 GetGridSize() const; - LinearStorage bt_; Padding2D padding_; + Arguments args_; CLKernel kernel_; int3 work_group_size_ = int3(128, 1, 1); }; diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc index 4fb21d0ec6a..ecf0e087427 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc @@ -15,24 +15,79 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { namespace gpu { namespace cl { +GPUResources TensorLinearDescriptor::GetGPUResources() const { + GPUResources resources; + resources.ints.push_back("length"); + if (storage_type == LinearStorageType::BUFFER) { + GPUBufferDescriptor desc; + desc.data_type = element_type; + desc.element_size = 4; + resources.buffers.push_back({"buffer", desc}); + } else { + GPUImage2DDescriptor desc; + desc.data_type = element_type; + resources.images2d.push_back({"tex2d", desc}); + } + return resources; +} + +absl::Status TensorLinearDescriptor::PerformSelector( + const std::string& selector, const std::vector& args, + std::string* result) const { + if (selector == "Length") { + *result = "length"; + return absl::OkStatus(); + } else if (selector == "Read") { + return PerformReadSelector(args, result); + } else { + return absl::NotFoundError(absl::StrCat( + "TensorLinearDescriptor don't have selector with name - ", selector)); + } +} + +absl::Status TensorLinearDescriptor::PerformReadSelector( + const std::vector& args, std::string* result) const { + if (args.size() != 1) { + return absl::NotFoundError( + absl::StrCat("TensorLinearDescriptor Read require one argument, but ", + args.size(), " was passed")); + } + if (storage_type == LinearStorageType::BUFFER) { + *result = absl::StrCat("buffer[", args[0], "]"); + return absl::OkStatus(); + } else { + const std::string read = + element_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef"; + *result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0], ", 0))"); + return absl::OkStatus(); + } +} + LinearStorage::LinearStorage(int depth, LinearStorageType storage_type, DataType data_type) - : depth_(depth), storage_type_(storage_type), data_type_(data_type) {} + : depth_(depth), storage_type_(storage_type), data_type_(data_type) { + desc_.storage_type = storage_type; + desc_.element_type = data_type; +} LinearStorage::LinearStorage(LinearStorage&& storage) - : texture_storage_(std::move(storage.texture_storage_)), + : GPUObject(std::move(storage)), + texture_storage_(std::move(storage.texture_storage_)), buffer_storage_(std::move(storage.buffer_storage_)), memory_(storage.memory_), depth_(storage.depth_), name_(std::move(storage.name_)), storage_type_(storage.storage_type_), - data_type_(storage.data_type_) { + data_type_(storage.data_type_), + desc_(storage.desc_) { storage.memory_ = nullptr; } @@ -45,6 +100,8 @@ LinearStorage& LinearStorage::operator=(LinearStorage&& storage) { name_ = std::move(storage.name_); std::swap(storage_type_, storage.storage_type_); std::swap(data_type_, storage.data_type_); + desc_ = storage.desc_; + GPUObject::operator=(std::move(storage)); } return *this; } @@ -66,6 +123,19 @@ std::string LinearStorage::GetDeclaration() const { } } +GPUResourcesWithValue LinearStorage::GetGPUResources() const { + GPUResourcesWithValue resources; + resources.ints.push_back({"length", depth_}); + + if (storage_type_ == LinearStorageType::BUFFER) { + resources.buffers.push_back({"buffer", memory_}); + } else { + resources.images2d.push_back({"tex2d", memory_}); + } + + return resources; +} + LinearStorageType DeduceLinearStorageType( TensorStorageType tensor_storage_type) { if (tensor_storage_type == TensorStorageType::BUFFER) { diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.h b/tensorflow/lite/delegates/gpu/cl/linear_storage.h index f461b08ebec..a31094b4a47 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.h +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h" +#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/cl/texture2d.h" @@ -36,6 +37,33 @@ namespace cl { enum class LinearStorageType { BUFFER, TEXTURE_2D }; +struct TensorLinearDescriptor : public GPUObjectDescriptor { + LinearStorageType storage_type; + DataType element_type; // FLOAT32 or FLOAT16 + + TensorLinearDescriptor() = default; + TensorLinearDescriptor(const TensorLinearDescriptor& desc) + : GPUObjectDescriptor(desc), + storage_type(desc.storage_type), + element_type(desc.element_type) {} + TensorLinearDescriptor& operator=(const TensorLinearDescriptor& desc) { + if (this != &desc) { + storage_type = desc.storage_type; + element_type = desc.element_type; + GPUObjectDescriptor::operator=(desc); + } + return *this; + } + + absl::Status PerformSelector(const std::string& selector, + const std::vector& args, + std::string* result) const override; + + GPUResources GetGPUResources() const override; + absl::Status PerformReadSelector(const std::vector& args, + std::string* result) const; +}; + struct LinearStorageCreateInfo { LinearStorageType storage_type; DataType data_type; @@ -48,7 +76,7 @@ LinearStorageType DeduceLinearStorageType( // Represent GPU 1D-array of FLT4(float4/half4) values // Can use inside texture2d or buffer -class LinearStorage { +class LinearStorage : public GPUObject { public: LinearStorage() {} @@ -63,6 +91,11 @@ class LinearStorage { std::string ReadLinearFLT4(const std::string& z_coord) const; std::string GetDeclaration() const; + const GPUObjectDescriptor* GetGPUDescriptor() const override { + return &desc_; + } + GPUResourcesWithValue GetGPUResources() const override; + private: friend absl::Status CreateTextureLinearStorage(int size, DataType data_type, void* data, CLContext* context, @@ -81,6 +114,7 @@ class LinearStorage { std::string name_; LinearStorageType storage_type_; DataType data_type_; + TensorLinearDescriptor desc_; }; absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data,