Added TensorLinearDescriptor (GPUObjectDescriptor subclass)

Implemented GPUObject interface for LinearStorage.
Added selector resolve pass for arguments.
Used linear storage as gpu argument in Winograd4x4To36.

PiperOrigin-RevId: 313795928
Change-Id: I6f03d06fc6464cd8f5b93814ad16f23cf59b4e27
This commit is contained in:
Raman Sarokin 2020-05-29 09:35:46 -07:00 committed by TensorFlower Gardener
parent 2244921925
commit a475c198ec
10 changed files with 440 additions and 48 deletions

View File

@ -54,6 +54,23 @@ cc_library(
],
)
cc_test(
name = "arguments_test",
srcs = ["arguments_test.cc"],
linkstatic = True,
tags = tf_gpu_tests_tags() + [
"linux",
"local",
],
deps = [
":arguments",
":gpu_object",
"//tensorflow/lite/delegates/gpu/common:data_type",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "buffer",
srcs = ["buffer.cc"],
@ -354,6 +371,7 @@ cc_library(
hdrs = ["linear_storage.h"],
deps = [
":buffer",
":gpu_object",
":opencl_wrapper",
":tensor_type",
":texture2d",

View File

@ -17,6 +17,8 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
@ -36,6 +38,55 @@ std::string GetNextWord(const std::string& code, size_t first_position) {
}
return code.substr(first_position, pos - first_position);
}
size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
char bracket) {
const std::map<char, char> brackets = {
{'(', ')'},
{'{', '}'},
{'[', ']'},
};
char b_open = bracket;
auto it = brackets.find(b_open);
if (it == brackets.end()) {
return -1;
}
char b_close = it->second;
size_t pos = first_pos;
int opened = 1;
int closed = 0;
while (opened != closed && pos < text.size()) {
if (text[pos] == b_open) {
opened++;
} else if (text[pos] == b_close) {
closed++;
}
pos++;
}
if (opened == closed) {
return pos;
} else {
return -1;
}
}
void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
std::string* str) {
size_t position = str->find(old_word);
while (position != std::string::npos) {
char prev = position == 0 ? '.' : (*str)[position - 1];
char next = position + old_word.size() < str->size()
? (*str)[position + old_word.size()]
: '.';
if (IsWordSymbol(prev) || IsWordSymbol(next)) {
position = str->find(old_word, position + 1);
continue;
}
str->replace(position, old_word.size(), new_word);
position = str->find(old_word, position + new_word.size());
}
}
} // namespace
Arguments::Arguments(Arguments&& args)
@ -45,6 +96,7 @@ Arguments::Arguments(Arguments&& args)
shared_float4s_data_(std::move(args.shared_float4s_data_)),
buffers_(std::move(args.buffers_)),
images2d_(std::move(args.images2d_)),
object_refs_(std::move(args.object_refs_)),
objects_(std::move(args.objects_)) {}
Arguments& Arguments::operator=(Arguments&& args) {
if (this != &args) {
@ -54,6 +106,7 @@ Arguments& Arguments::operator=(Arguments&& args) {
shared_float4s_data_ = std::move(args.shared_float4s_data_);
buffers_ = std::move(args.buffers_);
images2d_ = std::move(args.images2d_);
object_refs_ = std::move(args.object_refs_);
objects_ = std::move(args.objects_);
}
return *this;
@ -74,6 +127,11 @@ void Arguments::AddImage2D(const std::string& name,
images2d_[name] = desc;
}
void Arguments::AddObjectRef(const std::string& name,
GPUObjectDescriptorPtr&& descriptor_ptr) {
object_refs_[name] = {AccessType::READ, std::move(descriptor_ptr)};
}
void Arguments::AddObject(const std::string& name, GPUObjectPtr&& object) {
objects_[name] = {AccessType::READ, std::move(object)};
}
@ -159,6 +217,7 @@ absl::Status Arguments::SetGPUResources(
absl::Status Arguments::TransformToCLCode(std::string* code) {
RETURN_IF_ERROR(AddObjectArgs());
RETURN_IF_ERROR(ResolveSelectorsPass(code));
ResolveArgsPass(code);
return absl::OkStatus();
}
@ -260,18 +319,17 @@ std::string Arguments::AddActiveArgument(const std::string& arg_name) {
}
void Arguments::ResolveArgsPass(std::string* code) {
constexpr char kPrefix[] = "args.";
std::string result;
size_t position = 0;
size_t next_position = code->find(kPrefix);
size_t next_position = code->find(kArgsPrefix);
while (next_position != std::string::npos) {
size_t arg_pos = next_position;
next_position += strlen(kPrefix);
next_position += strlen(kArgsPrefix);
std::string object_name = GetNextWord(*code, next_position);
std::string new_name = AddActiveArgument(object_name);
code->replace(arg_pos, object_name.size() + strlen(kPrefix), new_name);
code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name);
position = arg_pos + new_name.size();
next_position = code->find(kPrefix, position);
next_position = code->find(kArgsPrefix, position);
}
int shared_int4s_aligned_size = AlignByN(shared_int4s_data_.size(), 4);
@ -280,6 +338,86 @@ void Arguments::ResolveArgsPass(std::string* code) {
shared_float4s_data_.resize(shared_float4s_aligned_size);
}
void Arguments::ResolveObjectNames(const std::string& object_name,
const std::vector<std::string>& member_names,
std::string* code) {
for (const auto& member_name : member_names) {
const std::string new_name = "args." + object_name + "_" + member_name;
ReplaceAllWords(member_name, new_name, code);
}
}
absl::Status Arguments::ResolveSelector(const std::string& object_name,
const std::string& selector,
const std::vector<std::string>& args,
std::string* result) {
const GPUObjectDescriptor* desc_ptr;
AccessType access_type;
if (auto it = object_refs_.find(object_name); it != object_refs_.end()) {
desc_ptr = it->second.descriptor.get();
access_type = it->second.access_type;
} else if (auto it = objects_.find(object_name); it != objects_.end()) {
desc_ptr = it->second.obj_ptr->GetGPUDescriptor();
access_type = it->second.access_type;
} else {
return absl::NotFoundError(
absl::StrCat("No object with name - ", object_name));
}
RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, args, result));
auto names = desc_ptr->GetGPUResources().GetNames();
ResolveObjectNames(object_name, names, result);
return absl::OkStatus();
}
absl::Status Arguments::ResolveSelectorsPass(std::string* code) {
std::string result;
size_t position = 0;
size_t next_position = code->find(kArgsPrefix);
while (next_position != std::string::npos) {
size_t arg_pos = next_position;
next_position += strlen(kArgsPrefix);
std::string object_name = GetNextWord(*code, next_position);
char next = (*code)[next_position + object_name.size()];
if (next == '.') {
next_position += object_name.size() + 1;
std::string selector_name = GetNextWord(*code, next_position);
next_position += selector_name.size();
next = (*code)[next_position];
if (next != '(') {
return absl::NotFoundError(
absl::StrCat("Expected ( after function ", selector_name, " call"));
}
next_position += 1;
size_t bracket_pos = FindEnclosingBracket(*code, next_position, '(');
if (bracket_pos == -1) {
return absl::NotFoundError(
absl::StrCat("Not found enclosing bracket for function ",
selector_name, " call"));
}
std::string str_args =
code->substr(next_position, bracket_pos - next_position - 1);
std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
std::vector<std::string> args;
args.reserve(words.size());
for (const auto& word : words) {
absl::string_view arg = absl::StripAsciiWhitespace(word);
if (!arg.empty()) {
args.push_back(std::string(arg));
}
}
std::string patch;
RETURN_IF_ERROR(
ResolveSelector(object_name, selector_name, args, &patch));
code->replace(arg_pos, bracket_pos - arg_pos, patch);
position = arg_pos + patch.size();
} else {
position = arg_pos + strlen(kArgsPrefix);
}
next_position = code->find(kArgsPrefix, position);
}
return absl::OkStatus();
}
absl::Status Arguments::AddObjectArgs() {
for (auto& t : objects_) {
AddGPUResources(t.first,
@ -287,6 +425,9 @@ absl::Status Arguments::AddObjectArgs() {
RETURN_IF_ERROR(
SetGPUResources(t.first, t.second.obj_ptr->GetGPUResources()));
}
for (auto& t : object_refs_) {
AddGPUResources(t.first, t.second.descriptor->GetGPUResources());
}
return absl::OkStatus();
}

View File

@ -40,6 +40,8 @@ class Arguments {
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
void AddObjectRef(const std::string& name,
GPUObjectDescriptorPtr&& descriptor_ptr);
void AddObject(const std::string& name, GPUObjectPtr&& object);
absl::Status SetInt(const std::string& name, int value);
@ -69,6 +71,18 @@ class Arguments {
absl::Status AddObjectArgs();
void ResolveArgsPass(std::string* code);
absl::Status ResolveSelectorsPass(std::string* code);
absl::Status ResolveSelector(const std::string& object_name,
const std::string& selector,
const std::vector<std::string>& args,
std::string* result);
void ResolveObjectNames(const std::string& object_name,
const std::vector<std::string>& member_names,
std::string* code);
static constexpr char kArgsPrefix[] = "args.";
struct IntValue {
int value;
@ -99,6 +113,12 @@ class Arguments {
std::map<std::string, GPUBufferDescriptor> buffers_;
std::map<std::string, GPUImage2DDescriptor> images2d_;
struct ObjectRefArg {
AccessType access_type;
GPUObjectDescriptorPtr descriptor;
};
std::map<std::string, ObjectRefArg> object_refs_;
struct ObjectArg {
AccessType access_type;
GPUObjectPtr obj_ptr;

View File

@ -0,0 +1,96 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
#include <string>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
namespace tflite {
namespace gpu {
namespace cl {
namespace {
struct TestDescriptor : public GPUObjectDescriptor {
absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args,
std::string* result) const override {
if (selector == "Length") {
*result = "length";
return absl::OkStatus();
} else if (selector == "Read") {
if (args.size() != 1) {
return absl::NotFoundError(
absl::StrCat("TestDescriptor Read require one argument, but ",
args.size(), " was passed"));
}
*result = absl::StrCat("buffer[", args[0], "]");
return absl::OkStatus();
} else {
return absl::NotFoundError(absl::StrCat(
"TestDescriptor don't have selector with name - ", selector));
}
}
GPUResources GetGPUResources() const override {
GPUResources resources;
resources.ints.push_back("length");
GPUBufferDescriptor desc;
desc.data_type = DataType::FLOAT32;
desc.element_size = 4;
resources.buffers.push_back({"buffer", desc});
return resources;
}
};
} // namespace
TEST(ArgumentsTest, TestSelectorResolve) {
TestDescriptor descriptor;
Arguments args;
args.AddObjectRef("object", absl::make_unique<TestDescriptor>(descriptor));
std::string sample_code = R"(
if (a < 3) {
value = args.object.Read(id);
}
)";
const std::string expected_result = R"(
if (a < 3) {
value = object_buffer[id];
}
)";
ASSERT_OK(args.TransformToCLCode(&sample_code));
EXPECT_EQ(sample_code, expected_result);
std::string cl_arguments = args.GetListOfArgs();
EXPECT_TRUE(cl_arguments.find("__global float4* object_buffer") !=
std::string::npos);
}
TEST(ArgumentsTest, TestNoSelector) {
TestDescriptor descriptor;
Arguments args;
args.AddObjectRef("object", absl::make_unique<TestDescriptor>(descriptor));
std::string sample_code = R"(
if (a < 3) {
value = args.object.Write(id);
}
)";
EXPECT_FALSE(args.TransformToCLCode(&sample_code).ok());
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -99,6 +99,8 @@ class GPUObjectDescriptor {
mutable std::map<std::string, std::string> state_vars_;
};
using GPUObjectDescriptorPtr = std::unique_ptr<GPUObjectDescriptor>;
class GPUObject {
public:
GPUObject() = default;

View File

@ -1385,6 +1385,7 @@ cc_library(
":gpu_operation",
":util",
":work_group_picking",
"//tensorflow/lite/delegates/gpu/cl:arguments",
"//tensorflow/lite/delegates/gpu/cl:cl_device",
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
"//tensorflow/lite/delegates/gpu/cl:linear_storage",
@ -1395,6 +1396,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:winograd_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/str_format.h"
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
@ -34,8 +35,9 @@ namespace cl {
namespace {
std::string GetWinograd4x4To36Code(
const OperationDef& op_def, const LinearStorage& bt_arr,
const std::vector<ElementwiseOperation*>& linked_operations) {
const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations,
Arguments* args) {
TensorCodeGenerator src_tensor(
"src_data",
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
@ -78,31 +80,31 @@ std::string GetWinograd4x4To36Code(
}
c += "};\n";
args->AddInt("padding_x");
args->AddInt("padding_y");
args->AddInt("tiles_total");
args->AddInt("tiles_x");
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
c += bt_arr.GetDeclaration();
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int4 dst_size, \n";
c += " int2 padding, \n";
c += " int tiles_total, \n";
c += " int tiles_x \n";
c += ") {\n";
c += " int4 dst_size";
c += "$0) {\n";
c += " int DST_X = get_global_id(0);\n";
c += " int DST_Y = get_global_id(1);\n";
c += " int DST_Z = get_global_id(2);\n";
c += " if (DST_X >= tiles_total || DST_Y >= 6 || DST_Z >= dst_size.z) {\n";
c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= dst_size.z) "
"{\n";
c += " return; \n";
c += " }\n";
c += " int tile_x = (DST_X % tiles_x) * 4;\n";
c += " int tile_y = (DST_X / tiles_x) * 4;\n";
c += " int tile_x = (DST_X % args.tiles_x) * 4;\n";
c += " int tile_y = (DST_X / args.tiles_x) * 4;\n";
c += " ACCUM_FLT4 I0, I1, I2, I3, I4, I5;\n";
c += " ACCUM_FLT bt_ar[6];\n";
c += " ACCUM_FLT4 t0 = TO_ACCUM_TYPE(" +
bt_arr.ReadLinearFLT4("DST_Y * 2 + 0") + ");\n";
c += " ACCUM_FLT4 t1 = TO_ACCUM_TYPE(" +
bt_arr.ReadLinearFLT4("DST_Y * 2 + 1") + ");\n";
c += " ACCUM_FLT4 t0 = TO_ACCUM_TYPE(args.bt.Read(DST_Y * 2 + 0));\n";
c += " ACCUM_FLT4 t1 = TO_ACCUM_TYPE(args.bt.Read(DST_Y * 2 + 1));\n";
c += " DST_Y *= 6;\n";
c += " bt_ar[0] = t0.x;\n";
c += " bt_ar[1] = t0.y;\n";
@ -121,15 +123,16 @@ std::string GetWinograd4x4To36Code(
" * m" + xs + "_x;\n";
} else {
c += " ACCUM_FLT4 " + src + " = " +
src_tensor.ReadAsTypeWHSB(accum_type, "tile_x + padding.x + " + xs,
"yc", "DST_Z", batch_id) +
src_tensor.ReadAsTypeWHSB(accum_type,
"tile_x + args.padding_x + " + xs, "yc",
"DST_Z", batch_id) +
";\n";
}
};
if (is_buffer || is_image_buffer) {
for (int x = 0; x < 6; ++x) {
const std::string xs = std::to_string(x);
c += " int xc" + xs + " = tile_x + padding.x + " + xs + ";\n";
c += " int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n";
c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" +
xs + " < src_size.x);\n";
c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
@ -144,7 +147,7 @@ std::string GetWinograd4x4To36Code(
}
}
c += " {\n";
c += " int yc = tile_y + padding.y;\n";
c += " int yc = tile_y + args.padding_y;\n";
if (is_buffer || is_image_buffer) {
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
c += " int offset = select(0, yc * src_size.x, iny);\n";
@ -162,7 +165,7 @@ std::string GetWinograd4x4To36Code(
for (int y = 1; y < 6; ++y) {
const std::string ys = std::to_string(y);
c += " {\n";
c += " int yc = tile_y + padding.y + (" + ys + ");\n";
c += " int yc = tile_y + args.padding_y + (" + ys + ");\n";
if (is_buffer || is_image_buffer) {
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
c += " int offset = select(0, yc * src_size.x, iny);\n";
@ -223,7 +226,6 @@ std::string GetWinograd4x4To36Code(
c += " DST_Y++;\n";
c += " }\n";
c += "}\n";
// std::cout << c << std::endl;
return c;
}
@ -366,15 +368,15 @@ std::string GetWinograd36To4x4Code(
Winograd4x4To36::Winograd4x4To36(Winograd4x4To36&& operation)
: GPUOperation(std::move(operation)),
bt_(std::move(operation.bt_)),
padding_(operation.padding_),
args_(std::move(operation.args_)),
kernel_(std::move(operation.kernel_)),
work_group_size_(operation.work_group_size_) {}
Winograd4x4To36& Winograd4x4To36::operator=(Winograd4x4To36&& operation) {
if (this != &operation) {
bt_ = std::move(operation.bt_);
std::swap(padding_, operation.padding_);
args_ = std::move(operation.args_);
kernel_ = std::move(operation.kernel_);
std::swap(work_group_size_, operation.work_group_size_);
GPUOperation::operator=(std::move(operation));
@ -392,8 +394,10 @@ absl::Status Winograd4x4To36::Compile(const CreationContext& creation_context) {
options.push_back(CompilerOptions::POWERVR_FP16);
}
RETURN_IF_ERROR(UploadBt(creation_context.context));
const auto code =
GetWinograd4x4To36Code(definition_, bt_, linked_operations_);
std::string code =
GetWinograd4x4To36Code(definition_, linked_operations_, &args_);
RETURN_IF_ERROR(args_.TransformToCLCode(&code));
code = absl::Substitute(code, args_.GetListOfArgs());
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
code, "main_function", options, *creation_context.context,
*creation_context.device, &kernel_));
@ -418,7 +422,11 @@ absl::Status Winograd4x4To36::UploadBt(CLContext* context) {
create_info.storage_type = LinearStorageType::TEXTURE_2D;
create_info.data_type = definition_.GetDataType();
create_info.name = "bt_arr";
return CreateLinearStorage(create_info, bt_aligned, context, &bt_);
LinearStorage lt;
RETURN_IF_ERROR(CreateLinearStorage(create_info, bt_aligned, context, &lt));
args_.AddObject("bt", absl::make_unique<LinearStorage>(std::move(lt)));
return absl::OkStatus();
}
int3 Winograd4x4To36::SelectBestWorkGroup() {
@ -429,22 +437,22 @@ int3 Winograd4x4To36::SelectBestWorkGroup() {
}
absl::Status Winograd4x4To36::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(bt_.GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
const int tiles_x = DivideRoundUp(
src_[0]->Width() + padding_.prepended.w + padding_.appended.w - 2, 4);
const int tiles_y = DivideRoundUp(
src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4);
const int tiles_total = tiles_x * tiles_y;
RETURN_IF_ERROR(
kernel_.SetBytesAuto(int2(-padding_.prepended.w, -padding_.prepended.h)));
RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_total));
RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_x));
RETURN_IF_ERROR(args_.SetInt("padding_x", -padding_.prepended.w));
RETURN_IF_ERROR(args_.SetInt("padding_y", -padding_.prepended.h));
RETURN_IF_ERROR(args_.SetInt("tiles_total", tiles_total));
RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x));
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter()));
return absl::OkStatus();
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WINOGRAD_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WINOGRAD_H_
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
@ -59,9 +60,9 @@ class Winograd4x4To36 : public GPUOperation {
absl::Status BindArguments();
int3 GetGridSize() const;
LinearStorage bt_;
Padding2D padding_;
Arguments args_;
CLKernel kernel_;
int3 work_group_size_ = int3(128, 1, 1);
};

View File

@ -15,24 +15,79 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
namespace cl {
GPUResources TensorLinearDescriptor::GetGPUResources() const {
GPUResources resources;
resources.ints.push_back("length");
if (storage_type == LinearStorageType::BUFFER) {
GPUBufferDescriptor desc;
desc.data_type = element_type;
desc.element_size = 4;
resources.buffers.push_back({"buffer", desc});
} else {
GPUImage2DDescriptor desc;
desc.data_type = element_type;
resources.images2d.push_back({"tex2d", desc});
}
return resources;
}
absl::Status TensorLinearDescriptor::PerformSelector(
const std::string& selector, const std::vector<std::string>& args,
std::string* result) const {
if (selector == "Length") {
*result = "length";
return absl::OkStatus();
} else if (selector == "Read") {
return PerformReadSelector(args, result);
} else {
return absl::NotFoundError(absl::StrCat(
"TensorLinearDescriptor don't have selector with name - ", selector));
}
}
absl::Status TensorLinearDescriptor::PerformReadSelector(
const std::vector<std::string>& args, std::string* result) const {
if (args.size() != 1) {
return absl::NotFoundError(
absl::StrCat("TensorLinearDescriptor Read require one argument, but ",
args.size(), " was passed"));
}
if (storage_type == LinearStorageType::BUFFER) {
*result = absl::StrCat("buffer[", args[0], "]");
return absl::OkStatus();
} else {
const std::string read =
element_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
*result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0], ", 0))");
return absl::OkStatus();
}
}
LinearStorage::LinearStorage(int depth, LinearStorageType storage_type,
DataType data_type)
: depth_(depth), storage_type_(storage_type), data_type_(data_type) {}
: depth_(depth), storage_type_(storage_type), data_type_(data_type) {
desc_.storage_type = storage_type;
desc_.element_type = data_type;
}
LinearStorage::LinearStorage(LinearStorage&& storage)
: texture_storage_(std::move(storage.texture_storage_)),
: GPUObject(std::move(storage)),
texture_storage_(std::move(storage.texture_storage_)),
buffer_storage_(std::move(storage.buffer_storage_)),
memory_(storage.memory_),
depth_(storage.depth_),
name_(std::move(storage.name_)),
storage_type_(storage.storage_type_),
data_type_(storage.data_type_) {
data_type_(storage.data_type_),
desc_(storage.desc_) {
storage.memory_ = nullptr;
}
@ -45,6 +100,8 @@ LinearStorage& LinearStorage::operator=(LinearStorage&& storage) {
name_ = std::move(storage.name_);
std::swap(storage_type_, storage.storage_type_);
std::swap(data_type_, storage.data_type_);
desc_ = storage.desc_;
GPUObject::operator=(std::move(storage));
}
return *this;
}
@ -66,6 +123,19 @@ std::string LinearStorage::GetDeclaration() const {
}
}
GPUResourcesWithValue LinearStorage::GetGPUResources() const {
GPUResourcesWithValue resources;
resources.ints.push_back({"length", depth_});
if (storage_type_ == LinearStorageType::BUFFER) {
resources.buffers.push_back({"buffer", memory_});
} else {
resources.images2d.push_back({"tex2d", memory_});
}
return resources;
}
LinearStorageType DeduceLinearStorageType(
TensorStorageType tensor_storage_type) {
if (tensor_storage_type == TensorStorageType::BUFFER) {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
#include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
@ -36,6 +37,33 @@ namespace cl {
enum class LinearStorageType { BUFFER, TEXTURE_2D };
struct TensorLinearDescriptor : public GPUObjectDescriptor {
LinearStorageType storage_type;
DataType element_type; // FLOAT32 or FLOAT16
TensorLinearDescriptor() = default;
TensorLinearDescriptor(const TensorLinearDescriptor& desc)
: GPUObjectDescriptor(desc),
storage_type(desc.storage_type),
element_type(desc.element_type) {}
TensorLinearDescriptor& operator=(const TensorLinearDescriptor& desc) {
if (this != &desc) {
storage_type = desc.storage_type;
element_type = desc.element_type;
GPUObjectDescriptor::operator=(desc);
}
return *this;
}
absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args,
std::string* result) const override;
GPUResources GetGPUResources() const override;
absl::Status PerformReadSelector(const std::vector<std::string>& args,
std::string* result) const;
};
struct LinearStorageCreateInfo {
LinearStorageType storage_type;
DataType data_type;
@ -48,7 +76,7 @@ LinearStorageType DeduceLinearStorageType(
// Represent GPU 1D-array of FLT4(float4/half4) values
// Can use inside texture2d or buffer
class LinearStorage {
class LinearStorage : public GPUObject {
public:
LinearStorage() {}
@ -63,6 +91,11 @@ class LinearStorage {
std::string ReadLinearFLT4(const std::string& z_coord) const;
std::string GetDeclaration() const;
const GPUObjectDescriptor* GetGPUDescriptor() const override {
return &desc_;
}
GPUResourcesWithValue GetGPUResources() const override;
private:
friend absl::Status CreateTextureLinearStorage(int size, DataType data_type,
void* data, CLContext* context,
@ -81,6 +114,7 @@ class LinearStorage {
std::string name_;
LinearStorageType storage_type_;
DataType data_type_;
TensorLinearDescriptor desc_;
};
absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data,