Added attributes for buffer.
Enhanced GetPtr method of Buffer object. PiperOrigin-RevId: 318083432 Change-Id: I1fead5800bd03868296d93759598df378298ec71
This commit is contained in:
parent
fb9291d5a5
commit
05c284bcc3
@ -49,6 +49,7 @@ cc_library(
|
||||
":tensor_type",
|
||||
":util",
|
||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
@ -84,6 +85,7 @@ cc_library(
|
||||
":gpu_object",
|
||||
":opencl_wrapper",
|
||||
":util",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -330,6 +332,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "gpu_object",
|
||||
srcs = ["gpu_object.cc"],
|
||||
hdrs = ["gpu_object.h"],
|
||||
deps = [
|
||||
":opencl_wrapper",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -457,21 +458,15 @@ std::string Arguments::GetListOfArgs() {
|
||||
for (auto& t : buffers_) {
|
||||
const std::string type_name =
|
||||
t.second.data_type == DataType::FLOAT32 ? "float" : "half";
|
||||
std::string memory_type;
|
||||
switch (t.second.memory_type) {
|
||||
case MemoryType::GLOBAL:
|
||||
memory_type = "__global";
|
||||
break;
|
||||
case MemoryType::CONSTANT:
|
||||
memory_type = "__constant";
|
||||
break;
|
||||
case MemoryType::LOCAL:
|
||||
memory_type = "__local";
|
||||
break;
|
||||
std::string attributes;
|
||||
for (const auto& attr : t.second.attributes) {
|
||||
attributes += absl::StrCat(" __attribute__((", attr, "))");
|
||||
}
|
||||
AppendArgument(absl::StrCat(memory_type, " ", type_name,
|
||||
t.second.element_size, "* ", t.first),
|
||||
&result);
|
||||
AppendArgument(
|
||||
absl::StrCat(MemoryTypeToCLType(t.second.memory_type), " ",
|
||||
ToCLDataType(t.second.data_type, t.second.element_size),
|
||||
"* ", t.first, attributes),
|
||||
&result);
|
||||
}
|
||||
for (auto& t : image_buffers_) {
|
||||
AppendArgument(absl::StrCat(GetImageModifier(t.second.access_type),
|
||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -51,6 +54,7 @@ GPUResources BufferDescriptor::GetGPUResources(AccessType access_type) const {
|
||||
desc.access_type = access_type;
|
||||
desc.element_size = element_size;
|
||||
desc.memory_type = memory_type;
|
||||
desc.attributes = attributes;
|
||||
resources.buffers.push_back({"buffer", desc});
|
||||
return resources;
|
||||
}
|
||||
@ -61,7 +65,7 @@ absl::Status BufferDescriptor::PerformSelector(
|
||||
if (selector == "Read") {
|
||||
return PerformReadSelector(args, result);
|
||||
} else if (selector == "GetPtr") {
|
||||
return PerformGetPtrSelector(args, result);
|
||||
return PerformGetPtrSelector(args, template_args, result);
|
||||
} else {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"BufferDescriptor don't have selector with name - ", selector));
|
||||
@ -80,13 +84,34 @@ absl::Status BufferDescriptor::PerformReadSelector(
|
||||
}
|
||||
|
||||
absl::Status BufferDescriptor::PerformGetPtrSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const {
|
||||
if (!args.empty()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("BufferDescriptor GetPtr require zero arguments, but ",
|
||||
args.size(), " was passed"));
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) const {
|
||||
if (args.size() > 1) {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"BufferDescriptor GetPtr require one or zero arguments, but ",
|
||||
args.size(), " was passed"));
|
||||
}
|
||||
if (template_args.size() > 1) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("BufferDescriptor GetPtr require one or zero teemplate "
|
||||
"arguments, but ",
|
||||
template_args.size(), " was passed"));
|
||||
}
|
||||
std::string conversion;
|
||||
if (template_args.size() == 1) {
|
||||
const std::string type_name = ToCLDataType(element_type, element_size);
|
||||
if (type_name != template_args[0]) {
|
||||
conversion = absl::StrCat("(", MemoryTypeToCLType(memory_type), " ",
|
||||
template_args[0], "*)&");
|
||||
}
|
||||
}
|
||||
if (args.empty()) {
|
||||
*result = absl::StrCat(conversion, "buffer");
|
||||
} else if (conversion.empty()) {
|
||||
*result = absl::StrCat("(buffer + ", args[0], ")");
|
||||
} else {
|
||||
*result = absl::StrCat(conversion, "buffer[", args[0], "]");
|
||||
}
|
||||
*result = "buffer";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -30,9 +30,10 @@ namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
struct BufferDescriptor : public GPUObjectDescriptor {
|
||||
DataType element_type; // FLOAT32 or FLOAT16
|
||||
DataType element_type;
|
||||
int element_size;
|
||||
MemoryType memory_type = MemoryType::GLOBAL;
|
||||
std::vector<std::string> attributes;
|
||||
|
||||
absl::Status PerformSelector(const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
@ -42,8 +43,9 @@ struct BufferDescriptor : public GPUObjectDescriptor {
|
||||
GPUResources GetGPUResources(AccessType access_type) const override;
|
||||
absl::Status PerformReadSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
absl::Status PerformGetPtrSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
absl::Status PerformGetPtrSelector(
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) const;
|
||||
};
|
||||
|
||||
// Buffer represent linear GPU data storage with arbitrary data format.
|
||||
|
37
tensorflow/lite/delegates/gpu/cl/gpu_object.cc
Normal file
37
tensorflow/lite/delegates/gpu/cl/gpu_object.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
std::string MemoryTypeToCLType(MemoryType type) {
|
||||
switch (type) {
|
||||
case MemoryType::GLOBAL:
|
||||
return "__global";
|
||||
case MemoryType::CONSTANT:
|
||||
return "__constant";
|
||||
break;
|
||||
case MemoryType::LOCAL:
|
||||
return "__local";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -56,11 +56,14 @@ struct GPUImageBufferDescriptor {
|
||||
|
||||
enum class MemoryType { GLOBAL, CONSTANT, LOCAL };
|
||||
|
||||
std::string MemoryTypeToCLType(MemoryType type);
|
||||
|
||||
struct GPUBufferDescriptor {
|
||||
DataType data_type;
|
||||
AccessType access_type;
|
||||
int element_size;
|
||||
MemoryType memory_type = MemoryType::GLOBAL;
|
||||
std::vector<std::string> attributes;
|
||||
cl_mem memory;
|
||||
};
|
||||
|
||||
|
@ -147,6 +147,13 @@ absl::Status TensorDescriptor::PerformSelector(
|
||||
} else if (selector == "Slices") {
|
||||
*result = "slices";
|
||||
return absl::OkStatus();
|
||||
} else if (selector == "SliceStride") {
|
||||
if (IsBatchedWidth()) {
|
||||
*result = "width_batched * height";
|
||||
} else {
|
||||
*result = "width * height";
|
||||
}
|
||||
return absl::OkStatus();
|
||||
} else if (selector == "Channels") {
|
||||
*result = "channels";
|
||||
return absl::OkStatus();
|
||||
|
Loading…
Reference in New Issue
Block a user