Added attributes for buffer.

Enhanced GetPtr method of Buffer object.

PiperOrigin-RevId: 318083432
Change-Id: I1fead5800bd03868296d93759598df378298ec71
This commit is contained in:
Raman Sarokin 2020-06-24 09:51:37 -07:00 committed by TensorFlower Gardener
parent fb9291d5a5
commit 05c284bcc3
7 changed files with 96 additions and 24 deletions

View File

@ -49,6 +49,7 @@ cc_library(
":tensor_type", ":tensor_type",
":util", ":util",
"//tensorflow/lite/delegates/gpu/common:access_type", "//tensorflow/lite/delegates/gpu/common:access_type",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common:util",
@ -84,6 +85,7 @@ cc_library(
":gpu_object", ":gpu_object",
":opencl_wrapper", ":opencl_wrapper",
":util", ":util",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
@ -330,6 +332,7 @@ cc_library(
cc_library( cc_library(
name = "gpu_object", name = "gpu_object",
srcs = ["gpu_object.cc"],
hdrs = ["gpu_object.h"], hdrs = ["gpu_object.h"],
deps = [ deps = [
":opencl_wrapper", ":opencl_wrapper",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite { namespace tflite {
@ -457,21 +458,15 @@ std::string Arguments::GetListOfArgs() {
for (auto& t : buffers_) { for (auto& t : buffers_) {
const std::string type_name = const std::string type_name =
t.second.data_type == DataType::FLOAT32 ? "float" : "half"; t.second.data_type == DataType::FLOAT32 ? "float" : "half";
std::string memory_type; std::string attributes;
switch (t.second.memory_type) { for (const auto& attr : t.second.attributes) {
case MemoryType::GLOBAL: attributes += absl::StrCat(" __attribute__((", attr, "))");
memory_type = "__global";
break;
case MemoryType::CONSTANT:
memory_type = "__constant";
break;
case MemoryType::LOCAL:
memory_type = "__local";
break;
} }
AppendArgument(absl::StrCat(memory_type, " ", type_name, AppendArgument(
t.second.element_size, "* ", t.first), absl::StrCat(MemoryTypeToCLType(t.second.memory_type), " ",
&result); ToCLDataType(t.second.data_type, t.second.element_size),
"* ", t.first, attributes),
&result);
} }
for (auto& t : image_buffers_) { for (auto& t : image_buffers_) {
AppendArgument(absl::StrCat(GetImageModifier(t.second.access_type), AppendArgument(absl::StrCat(GetImageModifier(t.second.access_type),

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
#include <string>
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite { namespace tflite {
@ -51,6 +54,7 @@ GPUResources BufferDescriptor::GetGPUResources(AccessType access_type) const {
desc.access_type = access_type; desc.access_type = access_type;
desc.element_size = element_size; desc.element_size = element_size;
desc.memory_type = memory_type; desc.memory_type = memory_type;
desc.attributes = attributes;
resources.buffers.push_back({"buffer", desc}); resources.buffers.push_back({"buffer", desc});
return resources; return resources;
} }
@ -61,7 +65,7 @@ absl::Status BufferDescriptor::PerformSelector(
if (selector == "Read") { if (selector == "Read") {
return PerformReadSelector(args, result); return PerformReadSelector(args, result);
} else if (selector == "GetPtr") { } else if (selector == "GetPtr") {
return PerformGetPtrSelector(args, result); return PerformGetPtrSelector(args, template_args, result);
} else { } else {
return absl::NotFoundError(absl::StrCat( return absl::NotFoundError(absl::StrCat(
"BufferDescriptor don't have selector with name - ", selector)); "BufferDescriptor don't have selector with name - ", selector));
@ -80,13 +84,34 @@ absl::Status BufferDescriptor::PerformReadSelector(
} }
absl::Status BufferDescriptor::PerformGetPtrSelector( absl::Status BufferDescriptor::PerformGetPtrSelector(
const std::vector<std::string>& args, std::string* result) const { const std::vector<std::string>& args,
if (!args.empty()) { const std::vector<std::string>& template_args, std::string* result) const {
return absl::NotFoundError( if (args.size() > 1) {
absl::StrCat("BufferDescriptor GetPtr require zero arguments, but ", return absl::NotFoundError(absl::StrCat(
args.size(), " was passed")); "BufferDescriptor GetPtr require one or zero arguments, but ",
args.size(), " was passed"));
}
if (template_args.size() > 1) {
return absl::NotFoundError(
absl::StrCat("BufferDescriptor GetPtr require one or zero teemplate "
"arguments, but ",
template_args.size(), " was passed"));
}
std::string conversion;
if (template_args.size() == 1) {
const std::string type_name = ToCLDataType(element_type, element_size);
if (type_name != template_args[0]) {
conversion = absl::StrCat("(", MemoryTypeToCLType(memory_type), " ",
template_args[0], "*)&");
}
}
if (args.empty()) {
*result = absl::StrCat(conversion, "buffer");
} else if (conversion.empty()) {
*result = absl::StrCat("(buffer + ", args[0], ")");
} else {
*result = absl::StrCat(conversion, "buffer[", args[0], "]");
} }
*result = "buffer";
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -30,9 +30,10 @@ namespace gpu {
namespace cl { namespace cl {
struct BufferDescriptor : public GPUObjectDescriptor { struct BufferDescriptor : public GPUObjectDescriptor {
DataType element_type; // FLOAT32 or FLOAT16 DataType element_type;
int element_size; int element_size;
MemoryType memory_type = MemoryType::GLOBAL; MemoryType memory_type = MemoryType::GLOBAL;
std::vector<std::string> attributes;
absl::Status PerformSelector(const std::string& selector, absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args, const std::vector<std::string>& args,
@ -42,8 +43,9 @@ struct BufferDescriptor : public GPUObjectDescriptor {
GPUResources GetGPUResources(AccessType access_type) const override; GPUResources GetGPUResources(AccessType access_type) const override;
absl::Status PerformReadSelector(const std::vector<std::string>& args, absl::Status PerformReadSelector(const std::vector<std::string>& args,
std::string* result) const; std::string* result) const;
absl::Status PerformGetPtrSelector(const std::vector<std::string>& args, absl::Status PerformGetPtrSelector(
std::string* result) const; const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const;
}; };
// Buffer represent linear GPU data storage with arbitrary data format. // Buffer represent linear GPU data storage with arbitrary data format.

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
namespace tflite {
namespace gpu {
namespace cl {
std::string MemoryTypeToCLType(MemoryType type) {
switch (type) {
case MemoryType::GLOBAL:
return "__global";
case MemoryType::CONSTANT:
return "__constant";
break;
case MemoryType::LOCAL:
return "__local";
}
return "";
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -56,11 +56,14 @@ struct GPUImageBufferDescriptor {
enum class MemoryType { GLOBAL, CONSTANT, LOCAL }; enum class MemoryType { GLOBAL, CONSTANT, LOCAL };
std::string MemoryTypeToCLType(MemoryType type);
struct GPUBufferDescriptor { struct GPUBufferDescriptor {
DataType data_type; DataType data_type;
AccessType access_type; AccessType access_type;
int element_size; int element_size;
MemoryType memory_type = MemoryType::GLOBAL; MemoryType memory_type = MemoryType::GLOBAL;
std::vector<std::string> attributes;
cl_mem memory; cl_mem memory;
}; };

View File

@ -147,6 +147,13 @@ absl::Status TensorDescriptor::PerformSelector(
} else if (selector == "Slices") { } else if (selector == "Slices") {
*result = "slices"; *result = "slices";
return absl::OkStatus(); return absl::OkStatus();
} else if (selector == "SliceStride") {
if (IsBatchedWidth()) {
*result = "width_batched * height";
} else {
*result = "width * height";
}
return absl::OkStatus();
} else if (selector == "Channels") { } else if (selector == "Channels") {
*result = "channels"; *result = "channels";
return absl::OkStatus(); return absl::OkStatus();