Added attributes for buffer.
Enhanced GetPtr method of Buffer object. PiperOrigin-RevId: 318083432 Change-Id: I1fead5800bd03868296d93759598df378298ec71
This commit is contained in:
parent
fb9291d5a5
commit
05c284bcc3
tensorflow/lite/delegates/gpu/cl
@ -49,6 +49,7 @@ cc_library(
|
|||||||
":tensor_type",
|
":tensor_type",
|
||||||
":util",
|
":util",
|
||||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
"//tensorflow/lite/delegates/gpu/common:types",
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
"//tensorflow/lite/delegates/gpu/common:util",
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
@ -84,6 +85,7 @@ cc_library(
|
|||||||
":gpu_object",
|
":gpu_object",
|
||||||
":opencl_wrapper",
|
":opencl_wrapper",
|
||||||
":util",
|
":util",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
@ -330,6 +332,7 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gpu_object",
|
name = "gpu_object",
|
||||||
|
srcs = ["gpu_object.cc"],
|
||||||
hdrs = ["gpu_object.h"],
|
hdrs = ["gpu_object.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":opencl_wrapper",
|
":opencl_wrapper",
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -457,21 +458,15 @@ std::string Arguments::GetListOfArgs() {
|
|||||||
for (auto& t : buffers_) {
|
for (auto& t : buffers_) {
|
||||||
const std::string type_name =
|
const std::string type_name =
|
||||||
t.second.data_type == DataType::FLOAT32 ? "float" : "half";
|
t.second.data_type == DataType::FLOAT32 ? "float" : "half";
|
||||||
std::string memory_type;
|
std::string attributes;
|
||||||
switch (t.second.memory_type) {
|
for (const auto& attr : t.second.attributes) {
|
||||||
case MemoryType::GLOBAL:
|
attributes += absl::StrCat(" __attribute__((", attr, "))");
|
||||||
memory_type = "__global";
|
|
||||||
break;
|
|
||||||
case MemoryType::CONSTANT:
|
|
||||||
memory_type = "__constant";
|
|
||||||
break;
|
|
||||||
case MemoryType::LOCAL:
|
|
||||||
memory_type = "__local";
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
AppendArgument(absl::StrCat(memory_type, " ", type_name,
|
AppendArgument(
|
||||||
t.second.element_size, "* ", t.first),
|
absl::StrCat(MemoryTypeToCLType(t.second.memory_type), " ",
|
||||||
&result);
|
ToCLDataType(t.second.data_type, t.second.element_size),
|
||||||
|
"* ", t.first, attributes),
|
||||||
|
&result);
|
||||||
}
|
}
|
||||||
for (auto& t : image_buffers_) {
|
for (auto& t : image_buffers_) {
|
||||||
AppendArgument(absl::StrCat(GetImageModifier(t.second.access_type),
|
AppendArgument(absl::StrCat(GetImageModifier(t.second.access_type),
|
||||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -51,6 +54,7 @@ GPUResources BufferDescriptor::GetGPUResources(AccessType access_type) const {
|
|||||||
desc.access_type = access_type;
|
desc.access_type = access_type;
|
||||||
desc.element_size = element_size;
|
desc.element_size = element_size;
|
||||||
desc.memory_type = memory_type;
|
desc.memory_type = memory_type;
|
||||||
|
desc.attributes = attributes;
|
||||||
resources.buffers.push_back({"buffer", desc});
|
resources.buffers.push_back({"buffer", desc});
|
||||||
return resources;
|
return resources;
|
||||||
}
|
}
|
||||||
@ -61,7 +65,7 @@ absl::Status BufferDescriptor::PerformSelector(
|
|||||||
if (selector == "Read") {
|
if (selector == "Read") {
|
||||||
return PerformReadSelector(args, result);
|
return PerformReadSelector(args, result);
|
||||||
} else if (selector == "GetPtr") {
|
} else if (selector == "GetPtr") {
|
||||||
return PerformGetPtrSelector(args, result);
|
return PerformGetPtrSelector(args, template_args, result);
|
||||||
} else {
|
} else {
|
||||||
return absl::NotFoundError(absl::StrCat(
|
return absl::NotFoundError(absl::StrCat(
|
||||||
"BufferDescriptor don't have selector with name - ", selector));
|
"BufferDescriptor don't have selector with name - ", selector));
|
||||||
@ -80,13 +84,34 @@ absl::Status BufferDescriptor::PerformReadSelector(
|
|||||||
}
|
}
|
||||||
|
|
||||||
absl::Status BufferDescriptor::PerformGetPtrSelector(
|
absl::Status BufferDescriptor::PerformGetPtrSelector(
|
||||||
const std::vector<std::string>& args, std::string* result) const {
|
const std::vector<std::string>& args,
|
||||||
if (!args.empty()) {
|
const std::vector<std::string>& template_args, std::string* result) const {
|
||||||
return absl::NotFoundError(
|
if (args.size() > 1) {
|
||||||
absl::StrCat("BufferDescriptor GetPtr require zero arguments, but ",
|
return absl::NotFoundError(absl::StrCat(
|
||||||
args.size(), " was passed"));
|
"BufferDescriptor GetPtr require one or zero arguments, but ",
|
||||||
|
args.size(), " was passed"));
|
||||||
|
}
|
||||||
|
if (template_args.size() > 1) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("BufferDescriptor GetPtr require one or zero teemplate "
|
||||||
|
"arguments, but ",
|
||||||
|
template_args.size(), " was passed"));
|
||||||
|
}
|
||||||
|
std::string conversion;
|
||||||
|
if (template_args.size() == 1) {
|
||||||
|
const std::string type_name = ToCLDataType(element_type, element_size);
|
||||||
|
if (type_name != template_args[0]) {
|
||||||
|
conversion = absl::StrCat("(", MemoryTypeToCLType(memory_type), " ",
|
||||||
|
template_args[0], "*)&");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (args.empty()) {
|
||||||
|
*result = absl::StrCat(conversion, "buffer");
|
||||||
|
} else if (conversion.empty()) {
|
||||||
|
*result = absl::StrCat("(buffer + ", args[0], ")");
|
||||||
|
} else {
|
||||||
|
*result = absl::StrCat(conversion, "buffer[", args[0], "]");
|
||||||
}
|
}
|
||||||
*result = "buffer";
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30,9 +30,10 @@ namespace gpu {
|
|||||||
namespace cl {
|
namespace cl {
|
||||||
|
|
||||||
struct BufferDescriptor : public GPUObjectDescriptor {
|
struct BufferDescriptor : public GPUObjectDescriptor {
|
||||||
DataType element_type; // FLOAT32 or FLOAT16
|
DataType element_type;
|
||||||
int element_size;
|
int element_size;
|
||||||
MemoryType memory_type = MemoryType::GLOBAL;
|
MemoryType memory_type = MemoryType::GLOBAL;
|
||||||
|
std::vector<std::string> attributes;
|
||||||
|
|
||||||
absl::Status PerformSelector(const std::string& selector,
|
absl::Status PerformSelector(const std::string& selector,
|
||||||
const std::vector<std::string>& args,
|
const std::vector<std::string>& args,
|
||||||
@ -42,8 +43,9 @@ struct BufferDescriptor : public GPUObjectDescriptor {
|
|||||||
GPUResources GetGPUResources(AccessType access_type) const override;
|
GPUResources GetGPUResources(AccessType access_type) const override;
|
||||||
absl::Status PerformReadSelector(const std::vector<std::string>& args,
|
absl::Status PerformReadSelector(const std::vector<std::string>& args,
|
||||||
std::string* result) const;
|
std::string* result) const;
|
||||||
absl::Status PerformGetPtrSelector(const std::vector<std::string>& args,
|
absl::Status PerformGetPtrSelector(
|
||||||
std::string* result) const;
|
const std::vector<std::string>& args,
|
||||||
|
const std::vector<std::string>& template_args, std::string* result) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Buffer represent linear GPU data storage with arbitrary data format.
|
// Buffer represent linear GPU data storage with arbitrary data format.
|
||||||
|
37
tensorflow/lite/delegates/gpu/cl/gpu_object.cc
Normal file
37
tensorflow/lite/delegates/gpu/cl/gpu_object.cc
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace cl {
|
||||||
|
|
||||||
|
std::string MemoryTypeToCLType(MemoryType type) {
|
||||||
|
switch (type) {
|
||||||
|
case MemoryType::GLOBAL:
|
||||||
|
return "__global";
|
||||||
|
case MemoryType::CONSTANT:
|
||||||
|
return "__constant";
|
||||||
|
break;
|
||||||
|
case MemoryType::LOCAL:
|
||||||
|
return "__local";
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -56,11 +56,14 @@ struct GPUImageBufferDescriptor {
|
|||||||
|
|
||||||
enum class MemoryType { GLOBAL, CONSTANT, LOCAL };
|
enum class MemoryType { GLOBAL, CONSTANT, LOCAL };
|
||||||
|
|
||||||
|
std::string MemoryTypeToCLType(MemoryType type);
|
||||||
|
|
||||||
struct GPUBufferDescriptor {
|
struct GPUBufferDescriptor {
|
||||||
DataType data_type;
|
DataType data_type;
|
||||||
AccessType access_type;
|
AccessType access_type;
|
||||||
int element_size;
|
int element_size;
|
||||||
MemoryType memory_type = MemoryType::GLOBAL;
|
MemoryType memory_type = MemoryType::GLOBAL;
|
||||||
|
std::vector<std::string> attributes;
|
||||||
cl_mem memory;
|
cl_mem memory;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -147,6 +147,13 @@ absl::Status TensorDescriptor::PerformSelector(
|
|||||||
} else if (selector == "Slices") {
|
} else if (selector == "Slices") {
|
||||||
*result = "slices";
|
*result = "slices";
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
} else if (selector == "SliceStride") {
|
||||||
|
if (IsBatchedWidth()) {
|
||||||
|
*result = "width_batched * height";
|
||||||
|
} else {
|
||||||
|
*result = "width * height";
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
} else if (selector == "Channels") {
|
} else if (selector == "Channels") {
|
||||||
*result = "channels";
|
*result = "channels";
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
Loading…
Reference in New Issue
Block a user