Using IMAGE_BUFFER as buffer for writing.
PiperOrigin-RevId: 269903377
This commit is contained in:
parent
6274f037d4
commit
a6d8ab779d
@ -111,7 +111,7 @@ Status ConcatXY::BindArguments() {
|
||||
for (int i = 0; i < tensors_count_; ++i) {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr()));
|
||||
}
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
int max_src_width = 0;
|
||||
int max_src_height = 0;
|
||||
|
@ -189,7 +189,7 @@ Status ConcatZ::BindArguments() {
|
||||
for (int i = 0; i < channels_.size(); ++i) {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr()));
|
||||
}
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
for (int i = 0; i < channels_.size(); ++i) {
|
||||
int4 size(src_[i]->Width(), src_[i]->Height(), channels_[i],
|
||||
|
@ -217,7 +217,7 @@ Status ConvBuffer::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
int4 src_size = int4(src_[0]->Width(), src_[0]->Height(),
|
||||
src_[0]->Width() * src_[0]->Height(), src_[0]->Depth());
|
||||
int4 dst_size = int4(dst_[0]->Width(), dst_[0]->Height(),
|
||||
|
@ -259,7 +259,7 @@ Status ConvBuffer1x1::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(kernel, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
int4 src_size = int4(src_[0]->Width(), src_[0]->Height(),
|
||||
GetGridWidth(src_[0]->Width()) * src_[0]->Height(),
|
||||
src_[0]->Depth());
|
||||
|
@ -234,7 +234,7 @@ Status ConvConstants::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
|
@ -80,7 +80,7 @@ Status ConvPowerVR::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
if (!conv_params_.x_kernel_is_1 || !conv_params_.y_kernel_is_1) {
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_padding_));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_dilation_));
|
||||
|
@ -318,7 +318,7 @@ Status ConvTexture::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_3_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
if (!(kernel_size_.x == 1 && kernel_size_.y == 1)) {
|
||||
|
@ -230,7 +230,7 @@ Status ConvolutionTransposed::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
|
||||
|
@ -211,7 +211,7 @@ Status ConvolutionTransposed3x3Thin::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
return OkStatus();
|
||||
|
@ -203,7 +203,7 @@ Status ConvolutionTransposedThin::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_buf_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(bias_value_));
|
||||
|
@ -222,7 +222,7 @@ Status DepthWiseConvolution::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
|
||||
|
@ -312,7 +312,7 @@ Status DepthWiseConv3x3::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
|
||||
return OkStatus();
|
||||
|
@ -158,7 +158,7 @@ Status FullyConnectedTexture::AddToQueue(CLCommandQueue* queue) {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_depth_x4));
|
||||
|
@ -143,7 +143,7 @@ Status ElementwiseOperation::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArguments(&kernel_));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
return OkStatus();
|
||||
|
@ -131,7 +131,7 @@ Status MaxUnpooling::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
|
||||
|
@ -118,7 +118,7 @@ Status Padding::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(prepended_));
|
||||
|
@ -228,9 +228,9 @@ Status Pooling::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
if (output_indices_) {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[1]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[1]->GetMemoryPtrForWriting()));
|
||||
}
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
|
@ -102,7 +102,7 @@ Status Reshape::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
const int2 plane_size = int2(src_[0]->Width() * src_[0]->Channels(),
|
||||
|
@ -87,7 +87,7 @@ Status Reshapex4::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
const int2 plane_size = int2(src_[0]->Width() * src_[0]->Depth(),
|
||||
|
@ -101,7 +101,7 @@ Status Softmax::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(
|
||||
kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels())));
|
||||
|
@ -121,7 +121,7 @@ Status Softmax1x1::AddToQueue(CLCommandQueue* queue) {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
const int depth = src_[0]->Depth();
|
||||
RETURN_IF_ERROR(
|
||||
|
@ -153,7 +153,7 @@ Status StridedSlice::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
int3 offset = GetOffset(attributes_, src_[0]->Width(), src_[0]->Height(),
|
||||
src_[0]->Channels());
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(int4(offset.x, offset.y, offset.z, 1)));
|
||||
|
@ -105,7 +105,7 @@ Status Upsample::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
float2 scale_factor =
|
||||
|
@ -228,14 +228,10 @@ std::string WriteGlobalFLT4(TensorStorageType storage_type, DataType data_type,
|
||||
const std::string& y, const std::string& z) {
|
||||
switch (storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return absl::StrCat(tensor_name, "[((", z, ") * ", size_name, ".y + (", y,
|
||||
")) * ", size_name, ".x + (", x, ")] = ", var_name,
|
||||
";\n");
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return absl::StrCat(GetWriteImageFromDataType(data_type), "(",
|
||||
tensor_name, ", ((", z, ") * ", size_name, ".y + (",
|
||||
y, ")) * ", size_name, ".x + (", x, "), ", var_name,
|
||||
");\n");
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
return absl::StrCat(GetWriteImageFromDataType(data_type), "(",
|
||||
tensor_name, ", (int2)((", x, "), (", y, ") * ",
|
||||
@ -259,12 +255,12 @@ std::string WriteGlobalFLT4(TensorStorageType storage_type, DataType data_type,
|
||||
const std::string& global_address) {
|
||||
switch (storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return absl::StrCat(tensor_name, "[", global_address, "] = ", var_name,
|
||||
";\n");
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return absl::StrCat(GetWriteImageFromDataType(data_type), "(",
|
||||
tensor_name, ", ", global_address, ", ", var_name,
|
||||
");\n");
|
||||
@ -308,7 +304,11 @@ std::string GetTensorDeclaration(TensorStorageType storage_type,
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
return GetImageModifier(access) + " image2d_array_t";
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return GetImageModifier(access) + " image1d_buffer_t";
|
||||
if (access == AccessType::WRITE) {
|
||||
return absl::StrCat("__global ", GetDataType4(data_type), "*");
|
||||
} else {
|
||||
return GetImageModifier(access) + " image1d_buffer_t";
|
||||
}
|
||||
case TensorStorageType::UNKNOWN:
|
||||
return "";
|
||||
}
|
||||
|
@ -161,6 +161,8 @@ cl_mem Tensor::GetMemoryPtr() const {
|
||||
: memory_;
|
||||
}
|
||||
|
||||
cl_mem Tensor::GetMemoryPtrForWriting() const { return memory_; }
|
||||
|
||||
Status Tensor::WriteDataBHWC(absl::Span<const float> in,
|
||||
CLCommandQueue* queue) {
|
||||
if (in.size() != channels_ * width_ * height_) {
|
||||
|
@ -67,6 +67,10 @@ class Tensor {
|
||||
}
|
||||
cl_mem GetMemoryPtr() const;
|
||||
|
||||
// This function returns buffer memory ptr for IMAGE_BUFFER instead of image
|
||||
// memory ptr.
|
||||
cl_mem GetMemoryPtrForWriting() const;
|
||||
|
||||
Status WriteDataBHWC(absl::Span<const float> in, CLCommandQueue* queue);
|
||||
|
||||
Status ReadDataBHWC(absl::Span<float> out, CLCommandQueue* queue) const;
|
||||
|
Loading…
Reference in New Issue
Block a user