Added TensorLinearDescriptor (GPUObjectDescriptor subclass)
Implemented GPUObject interface for LinearStorage. Added selector resolve pass for arguments. Used linear storage as gpu argument in Winograd4x4To36. PiperOrigin-RevId: 313795928 Change-Id: I6f03d06fc6464cd8f5b93814ad16f23cf59b4e27
This commit is contained in:
parent
2244921925
commit
a475c198ec
|
@ -54,6 +54,23 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "arguments_test",
|
||||||
|
srcs = ["arguments_test.cc"],
|
||||||
|
linkstatic = True,
|
||||||
|
tags = tf_gpu_tests_tags() + [
|
||||||
|
"linux",
|
||||||
|
"local",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":arguments",
|
||||||
|
":gpu_object",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "buffer",
|
name = "buffer",
|
||||||
srcs = ["buffer.cc"],
|
srcs = ["buffer.cc"],
|
||||||
|
@ -354,6 +371,7 @@ cc_library(
|
||||||
hdrs = ["linear_storage.h"],
|
hdrs = ["linear_storage.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":buffer",
|
":buffer",
|
||||||
|
":gpu_object",
|
||||||
":opencl_wrapper",
|
":opencl_wrapper",
|
||||||
":tensor_type",
|
":tensor_type",
|
||||||
":texture2d",
|
":texture2d",
|
||||||
|
|
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/strings/ascii.h"
|
#include "absl/strings/ascii.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_replace.h"
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
@ -36,6 +38,55 @@ std::string GetNextWord(const std::string& code, size_t first_position) {
|
||||||
}
|
}
|
||||||
return code.substr(first_position, pos - first_position);
|
return code.substr(first_position, pos - first_position);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
|
||||||
|
char bracket) {
|
||||||
|
const std::map<char, char> brackets = {
|
||||||
|
{'(', ')'},
|
||||||
|
{'{', '}'},
|
||||||
|
{'[', ']'},
|
||||||
|
};
|
||||||
|
char b_open = bracket;
|
||||||
|
auto it = brackets.find(b_open);
|
||||||
|
if (it == brackets.end()) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
char b_close = it->second;
|
||||||
|
size_t pos = first_pos;
|
||||||
|
int opened = 1;
|
||||||
|
int closed = 0;
|
||||||
|
while (opened != closed && pos < text.size()) {
|
||||||
|
if (text[pos] == b_open) {
|
||||||
|
opened++;
|
||||||
|
} else if (text[pos] == b_close) {
|
||||||
|
closed++;
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (opened == closed) {
|
||||||
|
return pos;
|
||||||
|
} else {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
|
||||||
|
std::string* str) {
|
||||||
|
size_t position = str->find(old_word);
|
||||||
|
while (position != std::string::npos) {
|
||||||
|
char prev = position == 0 ? '.' : (*str)[position - 1];
|
||||||
|
char next = position + old_word.size() < str->size()
|
||||||
|
? (*str)[position + old_word.size()]
|
||||||
|
: '.';
|
||||||
|
if (IsWordSymbol(prev) || IsWordSymbol(next)) {
|
||||||
|
position = str->find(old_word, position + 1);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
str->replace(position, old_word.size(), new_word);
|
||||||
|
position = str->find(old_word, position + new_word.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Arguments::Arguments(Arguments&& args)
|
Arguments::Arguments(Arguments&& args)
|
||||||
|
@ -45,6 +96,7 @@ Arguments::Arguments(Arguments&& args)
|
||||||
shared_float4s_data_(std::move(args.shared_float4s_data_)),
|
shared_float4s_data_(std::move(args.shared_float4s_data_)),
|
||||||
buffers_(std::move(args.buffers_)),
|
buffers_(std::move(args.buffers_)),
|
||||||
images2d_(std::move(args.images2d_)),
|
images2d_(std::move(args.images2d_)),
|
||||||
|
object_refs_(std::move(args.object_refs_)),
|
||||||
objects_(std::move(args.objects_)) {}
|
objects_(std::move(args.objects_)) {}
|
||||||
Arguments& Arguments::operator=(Arguments&& args) {
|
Arguments& Arguments::operator=(Arguments&& args) {
|
||||||
if (this != &args) {
|
if (this != &args) {
|
||||||
|
@ -54,6 +106,7 @@ Arguments& Arguments::operator=(Arguments&& args) {
|
||||||
shared_float4s_data_ = std::move(args.shared_float4s_data_);
|
shared_float4s_data_ = std::move(args.shared_float4s_data_);
|
||||||
buffers_ = std::move(args.buffers_);
|
buffers_ = std::move(args.buffers_);
|
||||||
images2d_ = std::move(args.images2d_);
|
images2d_ = std::move(args.images2d_);
|
||||||
|
object_refs_ = std::move(args.object_refs_);
|
||||||
objects_ = std::move(args.objects_);
|
objects_ = std::move(args.objects_);
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
|
@ -74,6 +127,11 @@ void Arguments::AddImage2D(const std::string& name,
|
||||||
images2d_[name] = desc;
|
images2d_[name] = desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Arguments::AddObjectRef(const std::string& name,
|
||||||
|
GPUObjectDescriptorPtr&& descriptor_ptr) {
|
||||||
|
object_refs_[name] = {AccessType::READ, std::move(descriptor_ptr)};
|
||||||
|
}
|
||||||
|
|
||||||
void Arguments::AddObject(const std::string& name, GPUObjectPtr&& object) {
|
void Arguments::AddObject(const std::string& name, GPUObjectPtr&& object) {
|
||||||
objects_[name] = {AccessType::READ, std::move(object)};
|
objects_[name] = {AccessType::READ, std::move(object)};
|
||||||
}
|
}
|
||||||
|
@ -159,6 +217,7 @@ absl::Status Arguments::SetGPUResources(
|
||||||
|
|
||||||
absl::Status Arguments::TransformToCLCode(std::string* code) {
|
absl::Status Arguments::TransformToCLCode(std::string* code) {
|
||||||
RETURN_IF_ERROR(AddObjectArgs());
|
RETURN_IF_ERROR(AddObjectArgs());
|
||||||
|
RETURN_IF_ERROR(ResolveSelectorsPass(code));
|
||||||
ResolveArgsPass(code);
|
ResolveArgsPass(code);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -260,18 +319,17 @@ std::string Arguments::AddActiveArgument(const std::string& arg_name) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Arguments::ResolveArgsPass(std::string* code) {
|
void Arguments::ResolveArgsPass(std::string* code) {
|
||||||
constexpr char kPrefix[] = "args.";
|
|
||||||
std::string result;
|
std::string result;
|
||||||
size_t position = 0;
|
size_t position = 0;
|
||||||
size_t next_position = code->find(kPrefix);
|
size_t next_position = code->find(kArgsPrefix);
|
||||||
while (next_position != std::string::npos) {
|
while (next_position != std::string::npos) {
|
||||||
size_t arg_pos = next_position;
|
size_t arg_pos = next_position;
|
||||||
next_position += strlen(kPrefix);
|
next_position += strlen(kArgsPrefix);
|
||||||
std::string object_name = GetNextWord(*code, next_position);
|
std::string object_name = GetNextWord(*code, next_position);
|
||||||
std::string new_name = AddActiveArgument(object_name);
|
std::string new_name = AddActiveArgument(object_name);
|
||||||
code->replace(arg_pos, object_name.size() + strlen(kPrefix), new_name);
|
code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name);
|
||||||
position = arg_pos + new_name.size();
|
position = arg_pos + new_name.size();
|
||||||
next_position = code->find(kPrefix, position);
|
next_position = code->find(kArgsPrefix, position);
|
||||||
}
|
}
|
||||||
|
|
||||||
int shared_int4s_aligned_size = AlignByN(shared_int4s_data_.size(), 4);
|
int shared_int4s_aligned_size = AlignByN(shared_int4s_data_.size(), 4);
|
||||||
|
@ -280,6 +338,86 @@ void Arguments::ResolveArgsPass(std::string* code) {
|
||||||
shared_float4s_data_.resize(shared_float4s_aligned_size);
|
shared_float4s_data_.resize(shared_float4s_aligned_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Arguments::ResolveObjectNames(const std::string& object_name,
|
||||||
|
const std::vector<std::string>& member_names,
|
||||||
|
std::string* code) {
|
||||||
|
for (const auto& member_name : member_names) {
|
||||||
|
const std::string new_name = "args." + object_name + "_" + member_name;
|
||||||
|
ReplaceAllWords(member_name, new_name, code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Arguments::ResolveSelector(const std::string& object_name,
|
||||||
|
const std::string& selector,
|
||||||
|
const std::vector<std::string>& args,
|
||||||
|
std::string* result) {
|
||||||
|
const GPUObjectDescriptor* desc_ptr;
|
||||||
|
AccessType access_type;
|
||||||
|
if (auto it = object_refs_.find(object_name); it != object_refs_.end()) {
|
||||||
|
desc_ptr = it->second.descriptor.get();
|
||||||
|
access_type = it->second.access_type;
|
||||||
|
} else if (auto it = objects_.find(object_name); it != objects_.end()) {
|
||||||
|
desc_ptr = it->second.obj_ptr->GetGPUDescriptor();
|
||||||
|
access_type = it->second.access_type;
|
||||||
|
} else {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("No object with name - ", object_name));
|
||||||
|
}
|
||||||
|
RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, args, result));
|
||||||
|
auto names = desc_ptr->GetGPUResources().GetNames();
|
||||||
|
ResolveObjectNames(object_name, names, result);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Arguments::ResolveSelectorsPass(std::string* code) {
|
||||||
|
std::string result;
|
||||||
|
size_t position = 0;
|
||||||
|
size_t next_position = code->find(kArgsPrefix);
|
||||||
|
while (next_position != std::string::npos) {
|
||||||
|
size_t arg_pos = next_position;
|
||||||
|
next_position += strlen(kArgsPrefix);
|
||||||
|
std::string object_name = GetNextWord(*code, next_position);
|
||||||
|
char next = (*code)[next_position + object_name.size()];
|
||||||
|
if (next == '.') {
|
||||||
|
next_position += object_name.size() + 1;
|
||||||
|
std::string selector_name = GetNextWord(*code, next_position);
|
||||||
|
next_position += selector_name.size();
|
||||||
|
next = (*code)[next_position];
|
||||||
|
if (next != '(') {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("Expected ( after function ", selector_name, " call"));
|
||||||
|
}
|
||||||
|
next_position += 1;
|
||||||
|
size_t bracket_pos = FindEnclosingBracket(*code, next_position, '(');
|
||||||
|
if (bracket_pos == -1) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("Not found enclosing bracket for function ",
|
||||||
|
selector_name, " call"));
|
||||||
|
}
|
||||||
|
std::string str_args =
|
||||||
|
code->substr(next_position, bracket_pos - next_position - 1);
|
||||||
|
std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
|
||||||
|
std::vector<std::string> args;
|
||||||
|
args.reserve(words.size());
|
||||||
|
for (const auto& word : words) {
|
||||||
|
absl::string_view arg = absl::StripAsciiWhitespace(word);
|
||||||
|
if (!arg.empty()) {
|
||||||
|
args.push_back(std::string(arg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::string patch;
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
ResolveSelector(object_name, selector_name, args, &patch));
|
||||||
|
code->replace(arg_pos, bracket_pos - arg_pos, patch);
|
||||||
|
position = arg_pos + patch.size();
|
||||||
|
} else {
|
||||||
|
position = arg_pos + strlen(kArgsPrefix);
|
||||||
|
}
|
||||||
|
next_position = code->find(kArgsPrefix, position);
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status Arguments::AddObjectArgs() {
|
absl::Status Arguments::AddObjectArgs() {
|
||||||
for (auto& t : objects_) {
|
for (auto& t : objects_) {
|
||||||
AddGPUResources(t.first,
|
AddGPUResources(t.first,
|
||||||
|
@ -287,6 +425,9 @@ absl::Status Arguments::AddObjectArgs() {
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
SetGPUResources(t.first, t.second.obj_ptr->GetGPUResources()));
|
SetGPUResources(t.first, t.second.obj_ptr->GetGPUResources()));
|
||||||
}
|
}
|
||||||
|
for (auto& t : object_refs_) {
|
||||||
|
AddGPUResources(t.first, t.second.descriptor->GetGPUResources());
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,8 @@ class Arguments {
|
||||||
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
|
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
|
||||||
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
|
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
|
||||||
|
|
||||||
|
void AddObjectRef(const std::string& name,
|
||||||
|
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||||
void AddObject(const std::string& name, GPUObjectPtr&& object);
|
void AddObject(const std::string& name, GPUObjectPtr&& object);
|
||||||
|
|
||||||
absl::Status SetInt(const std::string& name, int value);
|
absl::Status SetInt(const std::string& name, int value);
|
||||||
|
@ -69,6 +71,18 @@ class Arguments {
|
||||||
absl::Status AddObjectArgs();
|
absl::Status AddObjectArgs();
|
||||||
|
|
||||||
void ResolveArgsPass(std::string* code);
|
void ResolveArgsPass(std::string* code);
|
||||||
|
absl::Status ResolveSelectorsPass(std::string* code);
|
||||||
|
|
||||||
|
absl::Status ResolveSelector(const std::string& object_name,
|
||||||
|
const std::string& selector,
|
||||||
|
const std::vector<std::string>& args,
|
||||||
|
std::string* result);
|
||||||
|
|
||||||
|
void ResolveObjectNames(const std::string& object_name,
|
||||||
|
const std::vector<std::string>& member_names,
|
||||||
|
std::string* code);
|
||||||
|
|
||||||
|
static constexpr char kArgsPrefix[] = "args.";
|
||||||
|
|
||||||
struct IntValue {
|
struct IntValue {
|
||||||
int value;
|
int value;
|
||||||
|
@ -99,6 +113,12 @@ class Arguments {
|
||||||
std::map<std::string, GPUBufferDescriptor> buffers_;
|
std::map<std::string, GPUBufferDescriptor> buffers_;
|
||||||
std::map<std::string, GPUImage2DDescriptor> images2d_;
|
std::map<std::string, GPUImage2DDescriptor> images2d_;
|
||||||
|
|
||||||
|
struct ObjectRefArg {
|
||||||
|
AccessType access_type;
|
||||||
|
GPUObjectDescriptorPtr descriptor;
|
||||||
|
};
|
||||||
|
std::map<std::string, ObjectRefArg> object_refs_;
|
||||||
|
|
||||||
struct ObjectArg {
|
struct ObjectArg {
|
||||||
AccessType access_type;
|
AccessType access_type;
|
||||||
GPUObjectPtr obj_ptr;
|
GPUObjectPtr obj_ptr;
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace cl {
|
||||||
|
namespace {
|
||||||
|
struct TestDescriptor : public GPUObjectDescriptor {
|
||||||
|
absl::Status PerformSelector(const std::string& selector,
|
||||||
|
const std::vector<std::string>& args,
|
||||||
|
std::string* result) const override {
|
||||||
|
if (selector == "Length") {
|
||||||
|
*result = "length";
|
||||||
|
return absl::OkStatus();
|
||||||
|
} else if (selector == "Read") {
|
||||||
|
if (args.size() != 1) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("TestDescriptor Read require one argument, but ",
|
||||||
|
args.size(), " was passed"));
|
||||||
|
}
|
||||||
|
*result = absl::StrCat("buffer[", args[0], "]");
|
||||||
|
return absl::OkStatus();
|
||||||
|
} else {
|
||||||
|
return absl::NotFoundError(absl::StrCat(
|
||||||
|
"TestDescriptor don't have selector with name - ", selector));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GPUResources GetGPUResources() const override {
|
||||||
|
GPUResources resources;
|
||||||
|
resources.ints.push_back("length");
|
||||||
|
GPUBufferDescriptor desc;
|
||||||
|
desc.data_type = DataType::FLOAT32;
|
||||||
|
desc.element_size = 4;
|
||||||
|
resources.buffers.push_back({"buffer", desc});
|
||||||
|
return resources;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(ArgumentsTest, TestSelectorResolve) {
|
||||||
|
TestDescriptor descriptor;
|
||||||
|
Arguments args;
|
||||||
|
args.AddObjectRef("object", absl::make_unique<TestDescriptor>(descriptor));
|
||||||
|
std::string sample_code = R"(
|
||||||
|
if (a < 3) {
|
||||||
|
value = args.object.Read(id);
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
const std::string expected_result = R"(
|
||||||
|
if (a < 3) {
|
||||||
|
value = object_buffer[id];
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
ASSERT_OK(args.TransformToCLCode(&sample_code));
|
||||||
|
EXPECT_EQ(sample_code, expected_result);
|
||||||
|
|
||||||
|
std::string cl_arguments = args.GetListOfArgs();
|
||||||
|
EXPECT_TRUE(cl_arguments.find("__global float4* object_buffer") !=
|
||||||
|
std::string::npos);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ArgumentsTest, TestNoSelector) {
|
||||||
|
TestDescriptor descriptor;
|
||||||
|
Arguments args;
|
||||||
|
args.AddObjectRef("object", absl::make_unique<TestDescriptor>(descriptor));
|
||||||
|
std::string sample_code = R"(
|
||||||
|
if (a < 3) {
|
||||||
|
value = args.object.Write(id);
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
EXPECT_FALSE(args.TransformToCLCode(&sample_code).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
|
@ -99,6 +99,8 @@ class GPUObjectDescriptor {
|
||||||
mutable std::map<std::string, std::string> state_vars_;
|
mutable std::map<std::string, std::string> state_vars_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using GPUObjectDescriptorPtr = std::unique_ptr<GPUObjectDescriptor>;
|
||||||
|
|
||||||
class GPUObject {
|
class GPUObject {
|
||||||
public:
|
public:
|
||||||
GPUObject() = default;
|
GPUObject() = default;
|
||||||
|
|
|
@ -1385,6 +1385,7 @@ cc_library(
|
||||||
":gpu_operation",
|
":gpu_operation",
|
||||||
":util",
|
":util",
|
||||||
":work_group_picking",
|
":work_group_picking",
|
||||||
|
"//tensorflow/lite/delegates/gpu/cl:arguments",
|
||||||
"//tensorflow/lite/delegates/gpu/cl:cl_device",
|
"//tensorflow/lite/delegates/gpu/cl:cl_device",
|
||||||
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
|
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
|
||||||
"//tensorflow/lite/delegates/gpu/cl:linear_storage",
|
"//tensorflow/lite/delegates/gpu/cl:linear_storage",
|
||||||
|
@ -1395,6 +1396,7 @@ cc_library(
|
||||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
"//tensorflow/lite/delegates/gpu/common:winograd_util",
|
"//tensorflow/lite/delegates/gpu/common:winograd_util",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
|
||||||
|
@ -34,8 +35,9 @@ namespace cl {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::string GetWinograd4x4To36Code(
|
std::string GetWinograd4x4To36Code(
|
||||||
const OperationDef& op_def, const LinearStorage& bt_arr,
|
const OperationDef& op_def,
|
||||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
const std::vector<ElementwiseOperation*>& linked_operations,
|
||||||
|
Arguments* args) {
|
||||||
TensorCodeGenerator src_tensor(
|
TensorCodeGenerator src_tensor(
|
||||||
"src_data",
|
"src_data",
|
||||||
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
||||||
|
@ -78,31 +80,31 @@ std::string GetWinograd4x4To36Code(
|
||||||
}
|
}
|
||||||
c += "};\n";
|
c += "};\n";
|
||||||
|
|
||||||
|
args->AddInt("padding_x");
|
||||||
|
args->AddInt("padding_y");
|
||||||
|
args->AddInt("tiles_total");
|
||||||
|
args->AddInt("tiles_x");
|
||||||
|
|
||||||
c += "__kernel void main_function(\n";
|
c += "__kernel void main_function(\n";
|
||||||
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
|
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||||
c += bt_arr.GetDeclaration();
|
|
||||||
c += GetArgsDeclaration(linked_operations);
|
c += GetArgsDeclaration(linked_operations);
|
||||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||||
c += " int4 src_size, \n";
|
c += " int4 src_size, \n";
|
||||||
c += " int4 dst_size, \n";
|
c += " int4 dst_size";
|
||||||
c += " int2 padding, \n";
|
c += "$0) {\n";
|
||||||
c += " int tiles_total, \n";
|
|
||||||
c += " int tiles_x \n";
|
|
||||||
c += ") {\n";
|
|
||||||
c += " int DST_X = get_global_id(0);\n";
|
c += " int DST_X = get_global_id(0);\n";
|
||||||
c += " int DST_Y = get_global_id(1);\n";
|
c += " int DST_Y = get_global_id(1);\n";
|
||||||
c += " int DST_Z = get_global_id(2);\n";
|
c += " int DST_Z = get_global_id(2);\n";
|
||||||
c += " if (DST_X >= tiles_total || DST_Y >= 6 || DST_Z >= dst_size.z) {\n";
|
c += " if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= dst_size.z) "
|
||||||
|
"{\n";
|
||||||
c += " return; \n";
|
c += " return; \n";
|
||||||
c += " }\n";
|
c += " }\n";
|
||||||
c += " int tile_x = (DST_X % tiles_x) * 4;\n";
|
c += " int tile_x = (DST_X % args.tiles_x) * 4;\n";
|
||||||
c += " int tile_y = (DST_X / tiles_x) * 4;\n";
|
c += " int tile_y = (DST_X / args.tiles_x) * 4;\n";
|
||||||
c += " ACCUM_FLT4 I0, I1, I2, I3, I4, I5;\n";
|
c += " ACCUM_FLT4 I0, I1, I2, I3, I4, I5;\n";
|
||||||
c += " ACCUM_FLT bt_ar[6];\n";
|
c += " ACCUM_FLT bt_ar[6];\n";
|
||||||
c += " ACCUM_FLT4 t0 = TO_ACCUM_TYPE(" +
|
c += " ACCUM_FLT4 t0 = TO_ACCUM_TYPE(args.bt.Read(DST_Y * 2 + 0));\n";
|
||||||
bt_arr.ReadLinearFLT4("DST_Y * 2 + 0") + ");\n";
|
c += " ACCUM_FLT4 t1 = TO_ACCUM_TYPE(args.bt.Read(DST_Y * 2 + 1));\n";
|
||||||
c += " ACCUM_FLT4 t1 = TO_ACCUM_TYPE(" +
|
|
||||||
bt_arr.ReadLinearFLT4("DST_Y * 2 + 1") + ");\n";
|
|
||||||
c += " DST_Y *= 6;\n";
|
c += " DST_Y *= 6;\n";
|
||||||
c += " bt_ar[0] = t0.x;\n";
|
c += " bt_ar[0] = t0.x;\n";
|
||||||
c += " bt_ar[1] = t0.y;\n";
|
c += " bt_ar[1] = t0.y;\n";
|
||||||
|
@ -121,15 +123,16 @@ std::string GetWinograd4x4To36Code(
|
||||||
" * m" + xs + "_x;\n";
|
" * m" + xs + "_x;\n";
|
||||||
} else {
|
} else {
|
||||||
c += " ACCUM_FLT4 " + src + " = " +
|
c += " ACCUM_FLT4 " + src + " = " +
|
||||||
src_tensor.ReadAsTypeWHSB(accum_type, "tile_x + padding.x + " + xs,
|
src_tensor.ReadAsTypeWHSB(accum_type,
|
||||||
"yc", "DST_Z", batch_id) +
|
"tile_x + args.padding_x + " + xs, "yc",
|
||||||
|
"DST_Z", batch_id) +
|
||||||
";\n";
|
";\n";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if (is_buffer || is_image_buffer) {
|
if (is_buffer || is_image_buffer) {
|
||||||
for (int x = 0; x < 6; ++x) {
|
for (int x = 0; x < 6; ++x) {
|
||||||
const std::string xs = std::to_string(x);
|
const std::string xs = std::to_string(x);
|
||||||
c += " int xc" + xs + " = tile_x + padding.x + " + xs + ";\n";
|
c += " int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n";
|
||||||
c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" +
|
c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" +
|
||||||
xs + " < src_size.x);\n";
|
xs + " < src_size.x);\n";
|
||||||
c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
|
c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
|
||||||
|
@ -144,7 +147,7 @@ std::string GetWinograd4x4To36Code(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c += " {\n";
|
c += " {\n";
|
||||||
c += " int yc = tile_y + padding.y;\n";
|
c += " int yc = tile_y + args.padding_y;\n";
|
||||||
if (is_buffer || is_image_buffer) {
|
if (is_buffer || is_image_buffer) {
|
||||||
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
|
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
|
||||||
c += " int offset = select(0, yc * src_size.x, iny);\n";
|
c += " int offset = select(0, yc * src_size.x, iny);\n";
|
||||||
|
@ -162,7 +165,7 @@ std::string GetWinograd4x4To36Code(
|
||||||
for (int y = 1; y < 6; ++y) {
|
for (int y = 1; y < 6; ++y) {
|
||||||
const std::string ys = std::to_string(y);
|
const std::string ys = std::to_string(y);
|
||||||
c += " {\n";
|
c += " {\n";
|
||||||
c += " int yc = tile_y + padding.y + (" + ys + ");\n";
|
c += " int yc = tile_y + args.padding_y + (" + ys + ");\n";
|
||||||
if (is_buffer || is_image_buffer) {
|
if (is_buffer || is_image_buffer) {
|
||||||
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
|
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
|
||||||
c += " int offset = select(0, yc * src_size.x, iny);\n";
|
c += " int offset = select(0, yc * src_size.x, iny);\n";
|
||||||
|
@ -223,7 +226,6 @@ std::string GetWinograd4x4To36Code(
|
||||||
c += " DST_Y++;\n";
|
c += " DST_Y++;\n";
|
||||||
c += " }\n";
|
c += " }\n";
|
||||||
c += "}\n";
|
c += "}\n";
|
||||||
// std::cout << c << std::endl;
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -366,15 +368,15 @@ std::string GetWinograd36To4x4Code(
|
||||||
|
|
||||||
Winograd4x4To36::Winograd4x4To36(Winograd4x4To36&& operation)
|
Winograd4x4To36::Winograd4x4To36(Winograd4x4To36&& operation)
|
||||||
: GPUOperation(std::move(operation)),
|
: GPUOperation(std::move(operation)),
|
||||||
bt_(std::move(operation.bt_)),
|
|
||||||
padding_(operation.padding_),
|
padding_(operation.padding_),
|
||||||
|
args_(std::move(operation.args_)),
|
||||||
kernel_(std::move(operation.kernel_)),
|
kernel_(std::move(operation.kernel_)),
|
||||||
work_group_size_(operation.work_group_size_) {}
|
work_group_size_(operation.work_group_size_) {}
|
||||||
|
|
||||||
Winograd4x4To36& Winograd4x4To36::operator=(Winograd4x4To36&& operation) {
|
Winograd4x4To36& Winograd4x4To36::operator=(Winograd4x4To36&& operation) {
|
||||||
if (this != &operation) {
|
if (this != &operation) {
|
||||||
bt_ = std::move(operation.bt_);
|
|
||||||
std::swap(padding_, operation.padding_);
|
std::swap(padding_, operation.padding_);
|
||||||
|
args_ = std::move(operation.args_);
|
||||||
kernel_ = std::move(operation.kernel_);
|
kernel_ = std::move(operation.kernel_);
|
||||||
std::swap(work_group_size_, operation.work_group_size_);
|
std::swap(work_group_size_, operation.work_group_size_);
|
||||||
GPUOperation::operator=(std::move(operation));
|
GPUOperation::operator=(std::move(operation));
|
||||||
|
@ -392,8 +394,10 @@ absl::Status Winograd4x4To36::Compile(const CreationContext& creation_context) {
|
||||||
options.push_back(CompilerOptions::POWERVR_FP16);
|
options.push_back(CompilerOptions::POWERVR_FP16);
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(UploadBt(creation_context.context));
|
RETURN_IF_ERROR(UploadBt(creation_context.context));
|
||||||
const auto code =
|
std::string code =
|
||||||
GetWinograd4x4To36Code(definition_, bt_, linked_operations_);
|
GetWinograd4x4To36Code(definition_, linked_operations_, &args_);
|
||||||
|
RETURN_IF_ERROR(args_.TransformToCLCode(&code));
|
||||||
|
code = absl::Substitute(code, args_.GetListOfArgs());
|
||||||
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", options, *creation_context.context,
|
code, "main_function", options, *creation_context.context,
|
||||||
*creation_context.device, &kernel_));
|
*creation_context.device, &kernel_));
|
||||||
|
@ -418,7 +422,11 @@ absl::Status Winograd4x4To36::UploadBt(CLContext* context) {
|
||||||
create_info.storage_type = LinearStorageType::TEXTURE_2D;
|
create_info.storage_type = LinearStorageType::TEXTURE_2D;
|
||||||
create_info.data_type = definition_.GetDataType();
|
create_info.data_type = definition_.GetDataType();
|
||||||
create_info.name = "bt_arr";
|
create_info.name = "bt_arr";
|
||||||
return CreateLinearStorage(create_info, bt_aligned, context, &bt_);
|
|
||||||
|
LinearStorage lt;
|
||||||
|
RETURN_IF_ERROR(CreateLinearStorage(create_info, bt_aligned, context, <));
|
||||||
|
args_.AddObject("bt", absl::make_unique<LinearStorage>(std::move(lt)));
|
||||||
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Winograd4x4To36::SelectBestWorkGroup() {
|
int3 Winograd4x4To36::SelectBestWorkGroup() {
|
||||||
|
@ -429,22 +437,22 @@ int3 Winograd4x4To36::SelectBestWorkGroup() {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Winograd4x4To36::BindArguments() {
|
absl::Status Winograd4x4To36::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(bt_.GetMemoryPtr()));
|
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
|
||||||
const int tiles_x = DivideRoundUp(
|
const int tiles_x = DivideRoundUp(
|
||||||
src_[0]->Width() + padding_.prepended.w + padding_.appended.w - 2, 4);
|
src_[0]->Width() + padding_.prepended.w + padding_.appended.w - 2, 4);
|
||||||
const int tiles_y = DivideRoundUp(
|
const int tiles_y = DivideRoundUp(
|
||||||
src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4);
|
src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4);
|
||||||
const int tiles_total = tiles_x * tiles_y;
|
const int tiles_total = tiles_x * tiles_y;
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(args_.SetInt("padding_x", -padding_.prepended.w));
|
||||||
kernel_.SetBytesAuto(int2(-padding_.prepended.w, -padding_.prepended.h)));
|
RETURN_IF_ERROR(args_.SetInt("padding_y", -padding_.prepended.h));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_total));
|
RETURN_IF_ERROR(args_.SetInt("tiles_total", tiles_total));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(tiles_x));
|
RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x));
|
||||||
|
kernel_.ResetBindingCounter();
|
||||||
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||||
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||||
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||||
|
RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter()));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WINOGRAD_H_
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WINOGRAD_H_
|
||||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WINOGRAD_H_
|
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_WINOGRAD_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
|
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
||||||
|
@ -59,9 +60,9 @@ class Winograd4x4To36 : public GPUOperation {
|
||||||
absl::Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
LinearStorage bt_;
|
|
||||||
Padding2D padding_;
|
Padding2D padding_;
|
||||||
|
|
||||||
|
Arguments args_;
|
||||||
CLKernel kernel_;
|
CLKernel kernel_;
|
||||||
int3 work_group_size_ = int3(128, 1, 1);
|
int3 work_group_size_ = int3(128, 1, 1);
|
||||||
};
|
};
|
||||||
|
|
|
@ -15,24 +15,79 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace cl {
|
namespace cl {
|
||||||
|
|
||||||
|
GPUResources TensorLinearDescriptor::GetGPUResources() const {
|
||||||
|
GPUResources resources;
|
||||||
|
resources.ints.push_back("length");
|
||||||
|
if (storage_type == LinearStorageType::BUFFER) {
|
||||||
|
GPUBufferDescriptor desc;
|
||||||
|
desc.data_type = element_type;
|
||||||
|
desc.element_size = 4;
|
||||||
|
resources.buffers.push_back({"buffer", desc});
|
||||||
|
} else {
|
||||||
|
GPUImage2DDescriptor desc;
|
||||||
|
desc.data_type = element_type;
|
||||||
|
resources.images2d.push_back({"tex2d", desc});
|
||||||
|
}
|
||||||
|
return resources;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status TensorLinearDescriptor::PerformSelector(
|
||||||
|
const std::string& selector, const std::vector<std::string>& args,
|
||||||
|
std::string* result) const {
|
||||||
|
if (selector == "Length") {
|
||||||
|
*result = "length";
|
||||||
|
return absl::OkStatus();
|
||||||
|
} else if (selector == "Read") {
|
||||||
|
return PerformReadSelector(args, result);
|
||||||
|
} else {
|
||||||
|
return absl::NotFoundError(absl::StrCat(
|
||||||
|
"TensorLinearDescriptor don't have selector with name - ", selector));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status TensorLinearDescriptor::PerformReadSelector(
|
||||||
|
const std::vector<std::string>& args, std::string* result) const {
|
||||||
|
if (args.size() != 1) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("TensorLinearDescriptor Read require one argument, but ",
|
||||||
|
args.size(), " was passed"));
|
||||||
|
}
|
||||||
|
if (storage_type == LinearStorageType::BUFFER) {
|
||||||
|
*result = absl::StrCat("buffer[", args[0], "]");
|
||||||
|
return absl::OkStatus();
|
||||||
|
} else {
|
||||||
|
const std::string read =
|
||||||
|
element_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
|
||||||
|
*result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0], ", 0))");
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LinearStorage::LinearStorage(int depth, LinearStorageType storage_type,
|
LinearStorage::LinearStorage(int depth, LinearStorageType storage_type,
|
||||||
DataType data_type)
|
DataType data_type)
|
||||||
: depth_(depth), storage_type_(storage_type), data_type_(data_type) {}
|
: depth_(depth), storage_type_(storage_type), data_type_(data_type) {
|
||||||
|
desc_.storage_type = storage_type;
|
||||||
|
desc_.element_type = data_type;
|
||||||
|
}
|
||||||
|
|
||||||
LinearStorage::LinearStorage(LinearStorage&& storage)
|
LinearStorage::LinearStorage(LinearStorage&& storage)
|
||||||
: texture_storage_(std::move(storage.texture_storage_)),
|
: GPUObject(std::move(storage)),
|
||||||
|
texture_storage_(std::move(storage.texture_storage_)),
|
||||||
buffer_storage_(std::move(storage.buffer_storage_)),
|
buffer_storage_(std::move(storage.buffer_storage_)),
|
||||||
memory_(storage.memory_),
|
memory_(storage.memory_),
|
||||||
depth_(storage.depth_),
|
depth_(storage.depth_),
|
||||||
name_(std::move(storage.name_)),
|
name_(std::move(storage.name_)),
|
||||||
storage_type_(storage.storage_type_),
|
storage_type_(storage.storage_type_),
|
||||||
data_type_(storage.data_type_) {
|
data_type_(storage.data_type_),
|
||||||
|
desc_(storage.desc_) {
|
||||||
storage.memory_ = nullptr;
|
storage.memory_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,6 +100,8 @@ LinearStorage& LinearStorage::operator=(LinearStorage&& storage) {
|
||||||
name_ = std::move(storage.name_);
|
name_ = std::move(storage.name_);
|
||||||
std::swap(storage_type_, storage.storage_type_);
|
std::swap(storage_type_, storage.storage_type_);
|
||||||
std::swap(data_type_, storage.data_type_);
|
std::swap(data_type_, storage.data_type_);
|
||||||
|
desc_ = storage.desc_;
|
||||||
|
GPUObject::operator=(std::move(storage));
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
@ -66,6 +123,19 @@ std::string LinearStorage::GetDeclaration() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GPUResourcesWithValue LinearStorage::GetGPUResources() const {
|
||||||
|
GPUResourcesWithValue resources;
|
||||||
|
resources.ints.push_back({"length", depth_});
|
||||||
|
|
||||||
|
if (storage_type_ == LinearStorageType::BUFFER) {
|
||||||
|
resources.buffers.push_back({"buffer", memory_});
|
||||||
|
} else {
|
||||||
|
resources.images2d.push_back({"tex2d", memory_});
|
||||||
|
}
|
||||||
|
|
||||||
|
return resources;
|
||||||
|
}
|
||||||
|
|
||||||
LinearStorageType DeduceLinearStorageType(
|
LinearStorageType DeduceLinearStorageType(
|
||||||
TensorStorageType tensor_storage_type) {
|
TensorStorageType tensor_storage_type) {
|
||||||
if (tensor_storage_type == TensorStorageType::BUFFER) {
|
if (tensor_storage_type == TensorStorageType::BUFFER) {
|
||||||
|
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
|
#include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
|
||||||
|
@ -36,6 +37,33 @@ namespace cl {
|
||||||
|
|
||||||
enum class LinearStorageType { BUFFER, TEXTURE_2D };
|
enum class LinearStorageType { BUFFER, TEXTURE_2D };
|
||||||
|
|
||||||
|
struct TensorLinearDescriptor : public GPUObjectDescriptor {
|
||||||
|
LinearStorageType storage_type;
|
||||||
|
DataType element_type; // FLOAT32 or FLOAT16
|
||||||
|
|
||||||
|
TensorLinearDescriptor() = default;
|
||||||
|
TensorLinearDescriptor(const TensorLinearDescriptor& desc)
|
||||||
|
: GPUObjectDescriptor(desc),
|
||||||
|
storage_type(desc.storage_type),
|
||||||
|
element_type(desc.element_type) {}
|
||||||
|
TensorLinearDescriptor& operator=(const TensorLinearDescriptor& desc) {
|
||||||
|
if (this != &desc) {
|
||||||
|
storage_type = desc.storage_type;
|
||||||
|
element_type = desc.element_type;
|
||||||
|
GPUObjectDescriptor::operator=(desc);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status PerformSelector(const std::string& selector,
|
||||||
|
const std::vector<std::string>& args,
|
||||||
|
std::string* result) const override;
|
||||||
|
|
||||||
|
GPUResources GetGPUResources() const override;
|
||||||
|
absl::Status PerformReadSelector(const std::vector<std::string>& args,
|
||||||
|
std::string* result) const;
|
||||||
|
};
|
||||||
|
|
||||||
struct LinearStorageCreateInfo {
|
struct LinearStorageCreateInfo {
|
||||||
LinearStorageType storage_type;
|
LinearStorageType storage_type;
|
||||||
DataType data_type;
|
DataType data_type;
|
||||||
|
@ -48,7 +76,7 @@ LinearStorageType DeduceLinearStorageType(
|
||||||
|
|
||||||
// Represent GPU 1D-array of FLT4(float4/half4) values
|
// Represent GPU 1D-array of FLT4(float4/half4) values
|
||||||
// Can use inside texture2d or buffer
|
// Can use inside texture2d or buffer
|
||||||
class LinearStorage {
|
class LinearStorage : public GPUObject {
|
||||||
public:
|
public:
|
||||||
LinearStorage() {}
|
LinearStorage() {}
|
||||||
|
|
||||||
|
@ -63,6 +91,11 @@ class LinearStorage {
|
||||||
std::string ReadLinearFLT4(const std::string& z_coord) const;
|
std::string ReadLinearFLT4(const std::string& z_coord) const;
|
||||||
std::string GetDeclaration() const;
|
std::string GetDeclaration() const;
|
||||||
|
|
||||||
|
const GPUObjectDescriptor* GetGPUDescriptor() const override {
|
||||||
|
return &desc_;
|
||||||
|
}
|
||||||
|
GPUResourcesWithValue GetGPUResources() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend absl::Status CreateTextureLinearStorage(int size, DataType data_type,
|
friend absl::Status CreateTextureLinearStorage(int size, DataType data_type,
|
||||||
void* data, CLContext* context,
|
void* data, CLContext* context,
|
||||||
|
@ -81,6 +114,7 @@ class LinearStorage {
|
||||||
std::string name_;
|
std::string name_;
|
||||||
LinearStorageType storage_type_;
|
LinearStorageType storage_type_;
|
||||||
DataType data_type_;
|
DataType data_type_;
|
||||||
|
TensorLinearDescriptor desc_;
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data,
|
absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data,
|
||||||
|
|
Loading…
Reference in New Issue