Added attributes for buffer.

Enhanced GetPtr method of Buffer object.

PiperOrigin-RevId: 318083432
Change-Id: I1fead5800bd03868296d93759598df378298ec71
This commit is contained in:
Raman Sarokin 2020-06-24 09:51:37 -07:00 committed by TensorFlower Gardener
parent fb9291d5a5
commit 05c284bcc3
7 changed files with 96 additions and 24 deletions

View File

@ -49,6 +49,7 @@ cc_library(
":tensor_type",
":util",
"//tensorflow/lite/delegates/gpu/common:access_type",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
@ -84,6 +85,7 @@ cc_library(
":gpu_object",
":opencl_wrapper",
":util",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@ -330,6 +332,7 @@ cc_library(
cc_library(
name = "gpu_object",
srcs = ["gpu_object.cc"],
hdrs = ["gpu_object.h"],
deps = [
":opencl_wrapper",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/str_split.h"
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
@ -457,21 +458,15 @@ std::string Arguments::GetListOfArgs() {
for (auto& t : buffers_) {
const std::string type_name =
t.second.data_type == DataType::FLOAT32 ? "float" : "half";
std::string memory_type;
switch (t.second.memory_type) {
case MemoryType::GLOBAL:
memory_type = "__global";
break;
case MemoryType::CONSTANT:
memory_type = "__constant";
break;
case MemoryType::LOCAL:
memory_type = "__local";
break;
std::string attributes;
for (const auto& attr : t.second.attributes) {
attributes += absl::StrCat(" __attribute__((", attr, "))");
}
AppendArgument(absl::StrCat(memory_type, " ", type_name,
t.second.element_size, "* ", t.first),
&result);
AppendArgument(
absl::StrCat(MemoryTypeToCLType(t.second.memory_type), " ",
ToCLDataType(t.second.data_type, t.second.element_size),
"* ", t.first, attributes),
&result);
}
for (auto& t : image_buffers_) {
AppendArgument(absl::StrCat(GetImageModifier(t.second.access_type),

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
#include <string>
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
@ -51,6 +54,7 @@ GPUResources BufferDescriptor::GetGPUResources(AccessType access_type) const {
desc.access_type = access_type;
desc.element_size = element_size;
desc.memory_type = memory_type;
desc.attributes = attributes;
resources.buffers.push_back({"buffer", desc});
return resources;
}
@ -61,7 +65,7 @@ absl::Status BufferDescriptor::PerformSelector(
if (selector == "Read") {
return PerformReadSelector(args, result);
} else if (selector == "GetPtr") {
return PerformGetPtrSelector(args, result);
return PerformGetPtrSelector(args, template_args, result);
} else {
return absl::NotFoundError(absl::StrCat(
"BufferDescriptor don't have selector with name - ", selector));
@ -80,13 +84,34 @@ absl::Status BufferDescriptor::PerformReadSelector(
}
absl::Status BufferDescriptor::PerformGetPtrSelector(
const std::vector<std::string>& args, std::string* result) const {
if (!args.empty()) {
return absl::NotFoundError(
absl::StrCat("BufferDescriptor GetPtr require zero arguments, but ",
args.size(), " was passed"));
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const {
if (args.size() > 1) {
return absl::NotFoundError(absl::StrCat(
"BufferDescriptor GetPtr require one or zero arguments, but ",
args.size(), " was passed"));
}
if (template_args.size() > 1) {
return absl::NotFoundError(
absl::StrCat("BufferDescriptor GetPtr require one or zero teemplate "
"arguments, but ",
template_args.size(), " was passed"));
}
std::string conversion;
if (template_args.size() == 1) {
const std::string type_name = ToCLDataType(element_type, element_size);
if (type_name != template_args[0]) {
conversion = absl::StrCat("(", MemoryTypeToCLType(memory_type), " ",
template_args[0], "*)&");
}
}
if (args.empty()) {
*result = absl::StrCat(conversion, "buffer");
} else if (conversion.empty()) {
*result = absl::StrCat("(buffer + ", args[0], ")");
} else {
*result = absl::StrCat(conversion, "buffer[", args[0], "]");
}
*result = "buffer";
return absl::OkStatus();
}

View File

@ -30,9 +30,10 @@ namespace gpu {
namespace cl {
struct BufferDescriptor : public GPUObjectDescriptor {
DataType element_type; // FLOAT32 or FLOAT16
DataType element_type;
int element_size;
MemoryType memory_type = MemoryType::GLOBAL;
std::vector<std::string> attributes;
absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args,
@ -42,8 +43,9 @@ struct BufferDescriptor : public GPUObjectDescriptor {
GPUResources GetGPUResources(AccessType access_type) const override;
absl::Status PerformReadSelector(const std::vector<std::string>& args,
std::string* result) const;
absl::Status PerformGetPtrSelector(const std::vector<std::string>& args,
std::string* result) const;
absl::Status PerformGetPtrSelector(
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const;
};
// Buffer represent linear GPU data storage with arbitrary data format.

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
namespace tflite {
namespace gpu {
namespace cl {
std::string MemoryTypeToCLType(MemoryType type) {
switch (type) {
case MemoryType::GLOBAL:
return "__global";
case MemoryType::CONSTANT:
return "__constant";
break;
case MemoryType::LOCAL:
return "__local";
}
return "";
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -56,11 +56,14 @@ struct GPUImageBufferDescriptor {
enum class MemoryType { GLOBAL, CONSTANT, LOCAL };
std::string MemoryTypeToCLType(MemoryType type);
struct GPUBufferDescriptor {
DataType data_type;
AccessType access_type;
int element_size;
MemoryType memory_type = MemoryType::GLOBAL;
std::vector<std::string> attributes;
cl_mem memory;
};

View File

@ -147,6 +147,13 @@ absl::Status TensorDescriptor::PerformSelector(
} else if (selector == "Slices") {
*result = "slices";
return absl::OkStatus();
} else if (selector == "SliceStride") {
if (IsBatchedWidth()) {
*result = "width_batched * height";
} else {
*result = "width * height";
}
return absl::OkStatus();
} else if (selector == "Channels") {
*result = "channels";
return absl::OkStatus();