Reduced memory consumption for BUFFER and IMAGE_BUFFER storage types.
PiperOrigin-RevId: 271245834
This commit is contained in:
parent
47c3266831
commit
430e8770de
tensorflow/lite/delegates/gpu/cl
@ -301,6 +301,7 @@ cc_library(
|
||||
srcs = ["inference_context.cc"],
|
||||
hdrs = ["inference_context.h"],
|
||||
deps = [
|
||||
":buffer",
|
||||
":cl_command_queue",
|
||||
":cl_device",
|
||||
":environment",
|
||||
@ -318,6 +319,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:add_bias",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
|
||||
],
|
||||
|
@ -43,6 +43,9 @@ class Buffer {
|
||||
|
||||
~Buffer();
|
||||
|
||||
// for profiling and memory statistics
|
||||
uint64_t GetMemorySizeInBytes() const { return size_; }
|
||||
|
||||
cl_mem GetMemoryPtr() const { return buffer_; }
|
||||
|
||||
// Writes data to a buffer. Data should point to a region that
|
||||
@ -58,7 +61,7 @@ class Buffer {
|
||||
void Release();
|
||||
|
||||
cl_mem buffer_ = nullptr;
|
||||
int size_;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context,
|
||||
|
@ -17,12 +17,14 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/model_hints.h"
|
||||
@ -37,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
@ -175,6 +178,13 @@ void GetTensorDescriptors(
|
||||
}
|
||||
}
|
||||
|
||||
// returns true if actual memory for this storage type will be allocated with
|
||||
// clCreateBuffer.
|
||||
bool IsBufferBased(const TensorStorageType& type) {
|
||||
return type == TensorStorageType::BUFFER ||
|
||||
type == TensorStorageType::IMAGE_BUFFER;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CLNode::CLNode(CLNode&& node)
|
||||
@ -373,16 +383,95 @@ void InferenceContext::Merge() {
|
||||
}
|
||||
}
|
||||
|
||||
Status InferenceContext::AllocateMemory(
|
||||
const GraphFloat32& graph, const CLDevice& device, CLContext* context,
|
||||
const std::unordered_map<ValueId, TensorDescriptor>& tensor_descriptors) {
|
||||
std::map<ValueId, int2> usages;
|
||||
void InferenceContext::GetUsages(
|
||||
const std::function<bool(const TensorDescriptor&)>& functor,
|
||||
const std::unordered_map<ValueId, TensorDescriptor>& tensor_descriptors,
|
||||
std::map<ValueId, int2>* usages) const {
|
||||
for (int op_index = 0; op_index < nodes_.size(); ++op_index) {
|
||||
auto tensors = GetCLNodeTensors(nodes_[op_index]);
|
||||
for (auto& tensor : tensors) {
|
||||
AddUsage(tensor.first, op_index, &usages);
|
||||
if (functor(tensor.second)) {
|
||||
AddUsage(tensor.first, op_index, usages);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& out_id : output_ids_) {
|
||||
const auto& desc = tensor_descriptors.find(out_id)->second;
|
||||
if (functor(desc)) {
|
||||
AddUsage(out_id, nodes_.size(), usages);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status InferenceContext::AllocateMemory(
|
||||
const GraphFloat32& graph, const CLDevice& device, CLContext* context,
|
||||
const std::unordered_map<ValueId, TensorDescriptor>& tensor_descriptors) {
|
||||
RETURN_IF_ERROR(
|
||||
AllocateMemoryForBuffers(graph, device, context, tensor_descriptors));
|
||||
RETURN_IF_ERROR(AllocateMemoryForStrongShapes(graph, device, context,
|
||||
tensor_descriptors));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status InferenceContext::AllocateMemoryForBuffers(
|
||||
const GraphFloat32& graph, const CLDevice& device, CLContext* context,
|
||||
const std::unordered_map<ValueId, TensorDescriptor>& tensor_descriptors) {
|
||||
std::map<ValueId, int2> buffer_usages;
|
||||
GetUsages(
|
||||
[](const TensorDescriptor& t) { return IsBufferBased(t.storage_type); },
|
||||
tensor_descriptors, &buffer_usages);
|
||||
|
||||
std::vector<TensorUsageRecord<size_t>> buffer_usage_records;
|
||||
for (auto& usage : buffer_usages) {
|
||||
const auto& shape = graph.GetValue(usage.first)->tensor.shape;
|
||||
const auto& descriptor = tensor_descriptors.find(usage.first)->second;
|
||||
const size_t element_size =
|
||||
descriptor.data_type == DataType::FLOAT32 ? 4 : 2;
|
||||
const size_t buffer_size =
|
||||
shape.w * shape.h * AlignByN(shape.c, 4) * element_size;
|
||||
graph_ids_to_shared_buffer_tensors_[usage.first] =
|
||||
buffer_usage_records.size();
|
||||
buffer_usage_records.push_back({buffer_size,
|
||||
static_cast<TaskId>(usage.second.x),
|
||||
static_cast<TaskId>(usage.second.y)});
|
||||
}
|
||||
|
||||
ObjectsAssignment<size_t> buffer_assignment;
|
||||
RETURN_IF_ERROR(AssignObjectsToTensors(
|
||||
buffer_usage_records, MemoryStrategy::GREEDY_BEST, &buffer_assignment));
|
||||
|
||||
shared_buffers_.resize(buffer_assignment.object_sizes.size());
|
||||
for (int i = 0; i < buffer_assignment.object_sizes.size(); ++i) {
|
||||
RETURN_IF_ERROR(CreateReadWriteBuffer(buffer_assignment.object_sizes[i],
|
||||
context, &shared_buffers_[i]));
|
||||
}
|
||||
|
||||
std::vector<bool> created_tensors(buffer_usage_records.size(), false);
|
||||
shared_buffer_tensors_.resize(buffer_usage_records.size());
|
||||
for (auto& node : nodes_) {
|
||||
auto tensors = GetCLNodeTensors(node);
|
||||
for (auto& t : tensors) {
|
||||
if (!IsBufferBased(t.second.storage_type)) continue;
|
||||
const int tensor_index = graph_ids_to_shared_buffer_tensors_[t.first];
|
||||
if (created_tensors[tensor_index]) continue;
|
||||
const auto& shape = graph.GetValue(t.first)->tensor.shape;
|
||||
const int buffer_index = buffer_assignment.object_ids[tensor_index];
|
||||
RETURN_IF_ERROR(CreateSharedTensor(
|
||||
*context, device, shared_buffers_[buffer_index].GetMemoryPtr(), shape,
|
||||
t.second, &shared_buffer_tensors_[tensor_index]));
|
||||
created_tensors[tensor_index] = true;
|
||||
}
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status InferenceContext::AllocateMemoryForStrongShapes(
|
||||
const GraphFloat32& graph, const CLDevice& device, CLContext* context,
|
||||
const std::unordered_map<ValueId, TensorDescriptor>& tensor_descriptors) {
|
||||
std::map<ValueId, int2> usages;
|
||||
GetUsages(
|
||||
[](const TensorDescriptor& t) { return !IsBufferBased(t.storage_type); },
|
||||
tensor_descriptors, &usages);
|
||||
|
||||
struct TensorDesc {
|
||||
BHWC shape;
|
||||
@ -410,28 +499,17 @@ Status InferenceContext::AllocateMemory(
|
||||
RETURN_IF_ERROR(AssignObjectsToTensors(
|
||||
usage_records, MemoryStrategy::EQUALITY, &assignment));
|
||||
|
||||
for (auto& node : nodes_) {
|
||||
for (auto& id : node.inputs) {
|
||||
ValueId new_id = assignment.object_ids[remap_from_graph_ids[id]];
|
||||
remap_from_graph_ids_to_shared_[id] = new_id;
|
||||
id = new_id;
|
||||
}
|
||||
for (auto& id : node.outputs) {
|
||||
ValueId new_id = assignment.object_ids[remap_from_graph_ids[id]];
|
||||
remap_from_graph_ids_to_shared_[id] = new_id;
|
||||
id = new_id;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& node : nodes_) {
|
||||
auto tensors = GetCLNodeTensors(node);
|
||||
for (auto& tensor : tensors) {
|
||||
const auto& it = tensors_.find(tensor.first);
|
||||
if (it == tensors_.end()) {
|
||||
const auto& desc = assignment.object_sizes[tensor.first];
|
||||
Tensor* t = &tensors_[tensor.first];
|
||||
RETURN_IF_ERROR(
|
||||
CreateTensor(*context, device, desc.shape, tensor.second, t));
|
||||
for (auto& t : tensors) {
|
||||
if (IsBufferBased(t.second.storage_type)) continue;
|
||||
const auto& shape = graph.GetValue(t.first)->tensor.shape;
|
||||
const auto id = assignment.object_ids[remap_from_graph_ids[t.first]];
|
||||
graph_ids_to_strong_shape_tensors_[t.first] = id;
|
||||
const auto& it = strong_shape_tensors_.find(id);
|
||||
if (it == strong_shape_tensors_.end()) {
|
||||
RETURN_IF_ERROR(CreateTensor(*context, device, shape, t.second,
|
||||
&strong_shape_tensors_[id]));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -442,23 +520,17 @@ void InferenceContext::BindMemoryToOperations() {
|
||||
for (auto& node : nodes_) {
|
||||
const auto& first_range = node.ranges[0];
|
||||
for (int k = first_range.x; k < first_range.y; ++k) {
|
||||
auto id = node.inputs[k];
|
||||
const auto& it = tensors_.find(id);
|
||||
node.operations[0]->SetSrc(&it->second, k - first_range.x);
|
||||
node.operations[0]->SetSrc(GetTensor(node.inputs[k]), k - first_range.x);
|
||||
}
|
||||
for (int i = 1; i < node.ranges.size(); ++i) {
|
||||
const auto& range = node.ranges[i];
|
||||
for (int k = range.x; k < range.y; ++k) {
|
||||
auto id = node.inputs[k];
|
||||
const auto& it = tensors_.find(id);
|
||||
node.operations[i]->SetSrc(&it->second, k - range.x + 1);
|
||||
node.operations[i]->SetSrc(GetTensor(node.inputs[k]), k - range.x + 1);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < node.outputs.size(); ++i) {
|
||||
auto id = node.outputs[i];
|
||||
const auto& it = tensors_.find(id);
|
||||
node.operations[0]->SetDst(&it->second, i);
|
||||
node.operations[0]->SetDst(GetTensor(node.outputs[i]), i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -505,8 +577,26 @@ Status InferenceContext::Profile(ProfilingCommandQueue* queue,
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
uint64_t InferenceContext::GetSizeOfMemoryAllocatedForIntermediateTensors()
|
||||
const {
|
||||
uint64_t total_memory = 0;
|
||||
for (const auto& t : strong_shape_tensors_) {
|
||||
total_memory += t.second.GetMemorySizeInBytes();
|
||||
}
|
||||
for (const auto& b : shared_buffers_) {
|
||||
total_memory += b.GetMemorySizeInBytes();
|
||||
}
|
||||
|
||||
return total_memory;
|
||||
}
|
||||
|
||||
Tensor* InferenceContext::GetTensor(ValueId id) {
|
||||
return &tensors_[remap_from_graph_ids_to_shared_[id]];
|
||||
if (graph_ids_to_shared_buffer_tensors_.find(id) !=
|
||||
graph_ids_to_shared_buffer_tensors_.end()) {
|
||||
return &shared_buffer_tensors_[graph_ids_to_shared_buffer_tensors_[id]];
|
||||
} else {
|
||||
return &strong_shape_tensors_[graph_ids_to_strong_shape_tensors_[id]];
|
||||
}
|
||||
}
|
||||
|
||||
Status InferenceContext::SetInputTensor(ValueId id, const TensorFloat32& tensor,
|
||||
|
@ -17,11 +17,13 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_INFERENCE_CONTEXT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/environment.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||
@ -74,6 +76,8 @@ class InferenceContext {
|
||||
|
||||
Status AddToQueue(CLCommandQueue* queue);
|
||||
Status Profile(ProfilingCommandQueue* queue, ProfilingInfo* result);
|
||||
// for profiling and memory statistics
|
||||
uint64_t GetSizeOfMemoryAllocatedForIntermediateTensors() const;
|
||||
|
||||
Status SetInputTensor(ValueId id, const TensorFloat32& tensor,
|
||||
CLCommandQueue* queue);
|
||||
@ -96,6 +100,21 @@ class InferenceContext {
|
||||
Status AllocateMemory(
|
||||
const GraphFloat32& graph, const CLDevice& device, CLContext* context,
|
||||
const std::unordered_map<ValueId, TensorDescriptor>& tensor_descriptors);
|
||||
|
||||
Status AllocateMemoryForBuffers(
|
||||
const GraphFloat32& graph, const CLDevice& device, CLContext* context,
|
||||
const std::unordered_map<ValueId, TensorDescriptor>& tensor_descriptors);
|
||||
|
||||
Status AllocateMemoryForStrongShapes(
|
||||
const GraphFloat32& graph, const CLDevice& device, CLContext* context,
|
||||
const std::unordered_map<ValueId, TensorDescriptor>& tensor_descriptors);
|
||||
|
||||
// utility function
|
||||
void GetUsages(
|
||||
const std::function<bool(const TensorDescriptor&)>& functor,
|
||||
const std::unordered_map<ValueId, TensorDescriptor>& tensor_descriptors,
|
||||
std::map<ValueId, int2>* usages) const;
|
||||
|
||||
void BindMemoryToOperations();
|
||||
Status Compile(const CreationContext& creation_context);
|
||||
Status Tune(const TuningParameters& tuning_parameters);
|
||||
@ -118,8 +137,14 @@ class InferenceContext {
|
||||
// Memory is allocated only once, in ConvertOperations, and is not modified
|
||||
// anywhere.
|
||||
std::vector<CLNode> nodes_;
|
||||
std::map<ValueId, Tensor> tensors_;
|
||||
std::map<ValueId, ValueId> remap_from_graph_ids_to_shared_;
|
||||
|
||||
std::vector<Buffer> shared_buffers_;
|
||||
std::vector<Tensor>
|
||||
shared_buffer_tensors_; // use references to memory from shared_buffers_
|
||||
std::map<ValueId, int> graph_ids_to_shared_buffer_tensors_;
|
||||
|
||||
std::map<ValueId, Tensor> strong_shape_tensors_;
|
||||
std::map<ValueId, ValueId> graph_ids_to_strong_shape_tensors_;
|
||||
|
||||
std::vector<ValueId> input_ids_;
|
||||
std::vector<ValueId> output_ids_;
|
||||
|
@ -51,21 +51,24 @@ Status CreateImageBufferFromBuffer(const CLContext& context, cl_mem memory,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Tensor::Tensor(cl_mem memory, int width, int height, int channels,
|
||||
enum DataType data_type, TensorStorageType storage_type)
|
||||
Tensor::Tensor(cl_mem memory, bool memory_owner, int width, int height,
|
||||
int channels, enum DataType data_type,
|
||||
TensorStorageType storage_type)
|
||||
: memory_(memory),
|
||||
image_buffer_memory_(nullptr),
|
||||
memory_owner_(memory_owner),
|
||||
width_(width),
|
||||
height_(height),
|
||||
channels_(channels),
|
||||
data_type_(data_type),
|
||||
storage_type_(storage_type) {}
|
||||
|
||||
Tensor::Tensor(cl_mem memory, cl_mem image_buffer_memory, int width, int height,
|
||||
int channels, enum DataType data_type,
|
||||
Tensor::Tensor(cl_mem memory, bool memory_owner, cl_mem image_buffer_memory,
|
||||
int width, int height, int channels, enum DataType data_type,
|
||||
TensorStorageType storage_type)
|
||||
: memory_(memory),
|
||||
image_buffer_memory_(image_buffer_memory),
|
||||
memory_owner_(memory_owner),
|
||||
width_(width),
|
||||
height_(height),
|
||||
channels_(channels),
|
||||
@ -75,6 +78,7 @@ Tensor::Tensor(cl_mem memory, cl_mem image_buffer_memory, int width, int height,
|
||||
Tensor::Tensor(Tensor&& tensor)
|
||||
: memory_(tensor.memory_),
|
||||
image_buffer_memory_(tensor.image_buffer_memory_),
|
||||
memory_owner_(tensor.memory_owner_),
|
||||
width_(tensor.width_),
|
||||
height_(tensor.height_),
|
||||
channels_(tensor.channels_),
|
||||
@ -88,6 +92,7 @@ Tensor& Tensor::operator=(Tensor&& tensor) {
|
||||
Release();
|
||||
std::swap(memory_, tensor.memory_);
|
||||
std::swap(image_buffer_memory_, tensor.image_buffer_memory_);
|
||||
std::swap(memory_owner_, tensor.memory_owner_);
|
||||
std::swap(width_, tensor.width_);
|
||||
std::swap(height_, tensor.height_);
|
||||
std::swap(channels_, tensor.channels_);
|
||||
@ -102,7 +107,7 @@ void Tensor::Release() {
|
||||
clReleaseMemObject(image_buffer_memory_);
|
||||
memory_ = nullptr;
|
||||
}
|
||||
if (memory_) {
|
||||
if (memory_owner_ && memory_) {
|
||||
clReleaseMemObject(memory_);
|
||||
memory_ = nullptr;
|
||||
}
|
||||
@ -307,10 +312,10 @@ Status CreateTensor(const CLContext& context, const CLDevice& device, int width,
|
||||
RETURN_IF_ERROR(CreateImageBufferFromBuffer(
|
||||
context, memory.memory(), data_type,
|
||||
width * height * IntegralDivideRoundUp(channels, 4), &image_memory));
|
||||
*result = Tensor(memory.Release(), image_memory, width, height, channels,
|
||||
data_type, storage_type);
|
||||
*result = Tensor(memory.Release(), true, image_memory, width, height,
|
||||
channels, data_type, storage_type);
|
||||
} else {
|
||||
*result = Tensor(memory.Release(), width, height, channels, data_type,
|
||||
*result = Tensor(memory.Release(), true, width, height, channels, data_type,
|
||||
storage_type);
|
||||
}
|
||||
return OkStatus();
|
||||
@ -331,10 +336,30 @@ Status CreateTensor(const CLContext& context, const CLDevice& device,
|
||||
RETURN_IF_ERROR(CreateImageBufferFromBuffer(
|
||||
context, memory.memory(), descriptor.data_type,
|
||||
shape.w * shape.h * IntegralDivideRoundUp(shape.c, 4), &image_memory));
|
||||
*result = Tensor(memory.Release(), image_memory, shape.w, shape.h, shape.c,
|
||||
*result = Tensor(memory.Release(), true, image_memory, shape.w, shape.h,
|
||||
shape.c, descriptor.data_type, descriptor.storage_type);
|
||||
} else {
|
||||
*result = Tensor(memory.Release(), true, shape.w, shape.h, shape.c,
|
||||
descriptor.data_type, descriptor.storage_type);
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status CreateSharedTensor(const CLContext& context, const CLDevice& device,
|
||||
cl_mem memory, const BHWC& shape,
|
||||
const TensorDescriptor& descriptor, Tensor* result) {
|
||||
if (shape.b != 1) {
|
||||
return UnimplementedError("Batch is not supported.");
|
||||
}
|
||||
if (descriptor.storage_type == TensorStorageType::IMAGE_BUFFER) {
|
||||
cl_mem image_memory;
|
||||
RETURN_IF_ERROR(CreateImageBufferFromBuffer(
|
||||
context, memory, descriptor.data_type,
|
||||
shape.w * shape.h * IntegralDivideRoundUp(shape.c, 4), &image_memory));
|
||||
*result = Tensor(memory, false, image_memory, shape.w, shape.h, shape.c,
|
||||
descriptor.data_type, descriptor.storage_type);
|
||||
} else {
|
||||
*result = Tensor(memory.Release(), shape.w, shape.h, shape.c,
|
||||
*result = Tensor(memory, false, shape.w, shape.h, shape.c,
|
||||
descriptor.data_type, descriptor.storage_type);
|
||||
}
|
||||
return OkStatus();
|
||||
@ -530,25 +555,15 @@ template void Tensor::DataToBHWC<float>(absl::Span<const float> src,
|
||||
template void Tensor::DataToBHWC<half>(absl::Span<const half> src,
|
||||
absl::Span<float> dst) const;
|
||||
|
||||
TensorBHWC::TensorBHWC(TensorBHWC&& tensor)
|
||||
: Tensor(std::move(tensor)), owner_(tensor.owner_) {}
|
||||
TensorBHWC::TensorBHWC(TensorBHWC&& tensor) : Tensor(std::move(tensor)) {}
|
||||
|
||||
TensorBHWC& TensorBHWC::operator=(TensorBHWC&& tensor) {
|
||||
if (this != &tensor) {
|
||||
ReleaseBHWC();
|
||||
owner_ = tensor.owner_;
|
||||
Tensor::operator=(std::move(tensor));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void TensorBHWC::ReleaseBHWC() {
|
||||
// Base class is handling deletion if we are not owners
|
||||
if (!owner_ && memory_) {
|
||||
memory_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Status CreateTensorBHWC(const CLContext& context, const HWC& shape,
|
||||
DataType data_type, void* data, Tensor* result) {
|
||||
const size_t data_size = shape.w * shape.h * shape.c * SizeOf(data_type);
|
||||
@ -563,7 +578,7 @@ Status CreateTensorBHWC(const CLContext& context, const HWC& shape,
|
||||
CLErrorCodeToString(error_code)));
|
||||
}
|
||||
|
||||
*result = TensorBHWC(memory, shape.w, shape.h, shape.c, data_type,
|
||||
*result = TensorBHWC(memory, true, shape.w, shape.h, shape.c, data_type,
|
||||
TensorStorageType::BUFFER);
|
||||
return OkStatus();
|
||||
}
|
||||
@ -580,9 +595,8 @@ Status CreateTensorBHWCFromOpenGlObject(const CLContext& context,
|
||||
absl::StrCat("Unable to create CL buffer from GL buffer.",
|
||||
CLErrorCodeToString(error_code)));
|
||||
}
|
||||
*tensor = TensorBHWC(cl_buffer, shape.w, shape.h, shape.c, DataType::FLOAT32,
|
||||
TensorStorageType::BUFFER);
|
||||
tensor->owner_ = false;
|
||||
*tensor = TensorBHWC(cl_buffer, false, shape.w, shape.h, shape.c,
|
||||
DataType::FLOAT32, TensorStorageType::BUFFER);
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
|
@ -38,11 +38,13 @@ namespace cl {
|
||||
|
||||
class Tensor {
|
||||
public:
|
||||
Tensor() : memory_(nullptr), image_buffer_memory_(nullptr) {}
|
||||
Tensor(cl_mem memory, int width, int height, int channels, DataType data_type,
|
||||
Tensor()
|
||||
: memory_(nullptr), image_buffer_memory_(nullptr), memory_owner_(true) {}
|
||||
Tensor(cl_mem memory, bool memory_owner, int width, int height, int channels,
|
||||
DataType data_type, TensorStorageType storage_type);
|
||||
Tensor(cl_mem memory, bool memory_owner, cl_mem image_buffer_memory,
|
||||
int width, int height, int channels, DataType data_type,
|
||||
TensorStorageType storage_type);
|
||||
Tensor(cl_mem memory, cl_mem image_buffer_memory, int width, int height,
|
||||
int channels, DataType data_type, TensorStorageType storage_type);
|
||||
|
||||
// Move only
|
||||
Tensor(Tensor&& tensor);
|
||||
@ -107,6 +109,7 @@ class Tensor {
|
||||
|
||||
cl_mem memory_;
|
||||
cl_mem image_buffer_memory_; // for TensorStorageType::IMAGE_BUFFER only
|
||||
bool memory_owner_;
|
||||
int width_;
|
||||
int height_;
|
||||
int channels_;
|
||||
@ -117,9 +120,11 @@ class Tensor {
|
||||
class TensorBHWC : public Tensor {
|
||||
public:
|
||||
TensorBHWC() = default;
|
||||
TensorBHWC(cl_mem memory, int width, int height, int channels,
|
||||
enum DataType data_type, TensorStorageType storage_type)
|
||||
: Tensor(memory, width, height, channels, data_type, storage_type) {}
|
||||
TensorBHWC(cl_mem memory, bool memory_owner, int width, int height,
|
||||
int channels, enum DataType data_type,
|
||||
TensorStorageType storage_type)
|
||||
: Tensor(memory, memory_owner, width, height, channels, data_type,
|
||||
storage_type) {}
|
||||
|
||||
// Move only
|
||||
TensorBHWC(TensorBHWC&& tensor);
|
||||
@ -143,19 +148,12 @@ class TensorBHWC : public Tensor {
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
~TensorBHWC() override { ReleaseBHWC(); }
|
||||
|
||||
private:
|
||||
friend Status CreateTensorBHWCFromOpenGlObject(const CLContext& context,
|
||||
cl_int ssbo_id,
|
||||
const HWC& shape,
|
||||
bool is_readonly,
|
||||
TensorBHWC* tensor);
|
||||
|
||||
void ReleaseBHWC();
|
||||
|
||||
// When object created from GL object it isn't owner
|
||||
bool owner_ = true;
|
||||
};
|
||||
|
||||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
@ -177,6 +175,10 @@ Status CreateTensor(const CLContext& context, const CLDevice& device,
|
||||
const BHWC& shape, const TensorDescriptor& descriptor,
|
||||
Tensor* result);
|
||||
|
||||
Status CreateSharedTensor(const CLContext& context, const CLDevice& device,
|
||||
cl_mem memory, const BHWC& shape,
|
||||
const TensorDescriptor& descriptor, Tensor* result);
|
||||
|
||||
Status CreateTensorBHWC(const CLContext& context, const HWC& shape,
|
||||
DataType data_type, void* data, Tensor* result);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user