Added missing resource types to arguments.
Image2DArray/Image3D/ImageBuffer. PiperOrigin-RevId: 313858546 Change-Id: I5a83491728c7f6709994464186725649ad81e3c7
This commit is contained in:
parent
a5bd187cce
commit
a6a3a48679
@ -87,6 +87,13 @@ void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
|
||||
}
|
||||
}
|
||||
|
||||
void AppendArgument(const std::string& arg, std::string* args) {
|
||||
if (!args->empty()) {
|
||||
absl::StrAppend(args, ",\n ");
|
||||
}
|
||||
absl::StrAppend(args, arg);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Arguments::Arguments(Arguments&& args)
|
||||
@ -96,6 +103,9 @@ Arguments::Arguments(Arguments&& args)
|
||||
shared_float4s_data_(std::move(args.shared_float4s_data_)),
|
||||
buffers_(std::move(args.buffers_)),
|
||||
images2d_(std::move(args.images2d_)),
|
||||
image2d_arrays_(std::move(args.image2d_arrays_)),
|
||||
images3d_(std::move(args.images3d_)),
|
||||
image_buffers_(std::move(args.image_buffers_)),
|
||||
object_refs_(std::move(args.object_refs_)),
|
||||
objects_(std::move(args.objects_)) {}
|
||||
Arguments& Arguments::operator=(Arguments&& args) {
|
||||
@ -106,6 +116,9 @@ Arguments& Arguments::operator=(Arguments&& args) {
|
||||
shared_float4s_data_ = std::move(args.shared_float4s_data_);
|
||||
buffers_ = std::move(args.buffers_);
|
||||
images2d_ = std::move(args.images2d_);
|
||||
image2d_arrays_ = std::move(args.image2d_arrays_);
|
||||
images3d_ = std::move(args.images3d_);
|
||||
image_buffers_ = std::move(args.image_buffers_);
|
||||
object_refs_ = std::move(args.object_refs_);
|
||||
objects_ = std::move(args.objects_);
|
||||
}
|
||||
@ -127,6 +140,21 @@ void Arguments::AddImage2D(const std::string& name,
|
||||
images2d_[name] = desc;
|
||||
}
|
||||
|
||||
void Arguments::AddImage2DArray(const std::string& name,
|
||||
const GPUImage2DArrayDescriptor& desc) {
|
||||
image2d_arrays_[name] = desc;
|
||||
}
|
||||
|
||||
void Arguments::AddImage3D(const std::string& name,
|
||||
const GPUImage3DDescriptor& desc) {
|
||||
images3d_[name] = desc;
|
||||
}
|
||||
|
||||
void Arguments::AddImageBuffer(const std::string& name,
|
||||
const GPUImageBufferDescriptor& desc) {
|
||||
image_buffers_[name] = desc;
|
||||
}
|
||||
|
||||
void Arguments::AddObjectRef(const std::string& name,
|
||||
GPUObjectDescriptorPtr&& descriptor_ptr) {
|
||||
object_refs_[name] = {AccessType::READ, std::move(descriptor_ptr)};
|
||||
@ -150,6 +178,15 @@ void Arguments::AddGPUResources(const std::string& name,
|
||||
for (const auto& r : resources.images2d) {
|
||||
AddImage2D(absl::StrCat(name, "_", r.first), r.second);
|
||||
}
|
||||
for (const auto& r : resources.image2d_arrays) {
|
||||
AddImage2DArray(absl::StrCat(name, "_", r.first), r.second);
|
||||
}
|
||||
for (const auto& r : resources.images3d) {
|
||||
AddImage3D(absl::StrCat(name, "_", r.first), r.second);
|
||||
}
|
||||
for (const auto& r : resources.image_buffers) {
|
||||
AddImageBuffer(absl::StrCat(name, "_", r.first), r.second);
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetInt(const std::string& name, int value) {
|
||||
@ -179,12 +216,12 @@ absl::Status Arguments::SetFloat(const std::string& name, float value) {
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetImage2D(const std::string& name, cl_mem memory) {
|
||||
auto ti = images2d_.find(name);
|
||||
if (ti == images2d_.end()) {
|
||||
auto it = images2d_.find(name);
|
||||
if (it == images2d_.end()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No image2D argument with name - ", name));
|
||||
}
|
||||
ti->second.memory = memory;
|
||||
it->second.memory = memory;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
@ -198,6 +235,47 @@ absl::Status Arguments::SetBuffer(const std::string& name, cl_mem memory) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetImage2DArray(const std::string& name,
|
||||
cl_mem memory) {
|
||||
auto it = image2d_arrays_.find(name);
|
||||
if (it == image2d_arrays_.end()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No image2D array argument with name - ", name));
|
||||
}
|
||||
it->second.memory = memory;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetImage3D(const std::string& name, cl_mem memory) {
|
||||
auto it = images3d_.find(name);
|
||||
if (it == images3d_.end()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No image3D argument with name - ", name));
|
||||
}
|
||||
it->second.memory = memory;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetImageBuffer(const std::string& name, cl_mem memory) {
|
||||
auto it = image_buffers_.find(name);
|
||||
if (it == image_buffers_.end()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No image buffer argument with name - ", name));
|
||||
}
|
||||
it->second.memory = memory;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetObjectRef(const std::string& name,
|
||||
const GPUObject* object) {
|
||||
auto it = object_refs_.find(name);
|
||||
if (it == object_refs_.end()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No object ref with name - ", name));
|
||||
}
|
||||
return SetGPUResources(name, object->GetGPUResources());
|
||||
}
|
||||
|
||||
absl::Status Arguments::SetGPUResources(
|
||||
const std::string& name, const GPUResourcesWithValue& resources) {
|
||||
for (const auto& r : resources.ints) {
|
||||
@ -212,6 +290,16 @@ absl::Status Arguments::SetGPUResources(
|
||||
for (const auto& r : resources.images2d) {
|
||||
RETURN_IF_ERROR(SetImage2D(absl::StrCat(name, "_", r.first), r.second));
|
||||
}
|
||||
for (const auto& r : resources.image2d_arrays) {
|
||||
RETURN_IF_ERROR(
|
||||
SetImage2DArray(absl::StrCat(name, "_", r.first), r.second));
|
||||
}
|
||||
for (const auto& r : resources.images3d) {
|
||||
RETURN_IF_ERROR(SetImage3D(absl::StrCat(name, "_", r.first), r.second));
|
||||
}
|
||||
for (const auto& r : resources.image_buffers) {
|
||||
RETURN_IF_ERROR(SetImageBuffer(absl::StrCat(name, "_", r.first), r.second));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
@ -227,17 +315,29 @@ std::string Arguments::GetListOfArgs() {
|
||||
for (auto& t : buffers_) {
|
||||
const std::string type_name =
|
||||
t.second.data_type == DataType::FLOAT32 ? "float" : "half";
|
||||
absl::StrAppend(&result, ",\n __global ", type_name, t.second.element_size,
|
||||
"* ", t.first);
|
||||
AppendArgument(absl::StrCat("__global ", type_name, t.second.element_size,
|
||||
"* ", t.first),
|
||||
&result);
|
||||
}
|
||||
for (auto& t : image_buffers_) {
|
||||
AppendArgument(absl::StrCat("__read_only image1d_buffer_t ", t.first),
|
||||
&result);
|
||||
}
|
||||
for (auto& t : images2d_) {
|
||||
absl::StrAppend(&result, ",\n __read_only image2d_t ", t.first);
|
||||
AppendArgument(absl::StrCat("__read_only image2d_t ", t.first), &result);
|
||||
}
|
||||
for (auto& t : image2d_arrays_) {
|
||||
AppendArgument(absl::StrCat("__read_only image2d_array_t ", t.first),
|
||||
&result);
|
||||
}
|
||||
for (auto& t : images3d_) {
|
||||
AppendArgument(absl::StrCat("__read_only image3d_t ", t.first), &result);
|
||||
}
|
||||
for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
|
||||
absl::StrAppend(&result, ",\n int4 shared_int4_", i);
|
||||
AppendArgument(absl::StrCat("int4 shared_int4_", i), &result);
|
||||
}
|
||||
for (int i = 0; i < shared_float4s_data_.size() / 4; ++i) {
|
||||
absl::StrAppend(&result, ",\n float4 shared_float4_", i);
|
||||
AppendArgument(absl::StrCat("float4 shared_float4_", i), &result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -253,6 +353,16 @@ absl::Status Arguments::Bind(cl_kernel kernel, int offset) {
|
||||
}
|
||||
offset++;
|
||||
}
|
||||
for (auto& t : image_buffers_) {
|
||||
const int error_code =
|
||||
clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
|
||||
if (error_code != CL_SUCCESS) {
|
||||
return absl::UnknownError(absl::StrCat(
|
||||
"Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
|
||||
"(at index - ", offset, ")"));
|
||||
}
|
||||
offset++;
|
||||
}
|
||||
for (auto& t : images2d_) {
|
||||
const int error_code =
|
||||
clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
|
||||
@ -263,6 +373,26 @@ absl::Status Arguments::Bind(cl_kernel kernel, int offset) {
|
||||
}
|
||||
offset++;
|
||||
}
|
||||
for (auto& t : image2d_arrays_) {
|
||||
const int error_code =
|
||||
clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
|
||||
if (error_code != CL_SUCCESS) {
|
||||
return absl::UnknownError(absl::StrCat(
|
||||
"Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
|
||||
"(at index - ", offset, ")"));
|
||||
}
|
||||
offset++;
|
||||
}
|
||||
for (auto& t : images3d_) {
|
||||
const int error_code =
|
||||
clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
|
||||
if (error_code != CL_SUCCESS) {
|
||||
return absl::UnknownError(absl::StrCat(
|
||||
"Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
|
||||
"(at index - ", offset, ")"));
|
||||
}
|
||||
offset++;
|
||||
}
|
||||
for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
|
||||
const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4,
|
||||
&shared_int4s_data_[i * 4]);
|
||||
@ -342,7 +472,7 @@ void Arguments::ResolveObjectNames(const std::string& object_name,
|
||||
const std::vector<std::string>& member_names,
|
||||
std::string* code) {
|
||||
for (const auto& member_name : member_names) {
|
||||
const std::string new_name = "args." + object_name + "_" + member_name;
|
||||
const std::string new_name = kArgsPrefix + object_name + "_" + member_name;
|
||||
ReplaceAllWords(member_name, new_name, code);
|
||||
}
|
||||
}
|
||||
|
@ -39,6 +39,11 @@ class Arguments {
|
||||
void AddInt(const std::string& name, int value = 0);
|
||||
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
|
||||
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
|
||||
void AddImage2DArray(const std::string& name,
|
||||
const GPUImage2DArrayDescriptor& desc);
|
||||
void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc);
|
||||
void AddImageBuffer(const std::string& name,
|
||||
const GPUImageBufferDescriptor& desc);
|
||||
|
||||
void AddObjectRef(const std::string& name,
|
||||
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||
@ -48,6 +53,10 @@ class Arguments {
|
||||
absl::Status SetFloat(const std::string& name, float value);
|
||||
absl::Status SetImage2D(const std::string& name, cl_mem memory);
|
||||
absl::Status SetBuffer(const std::string& name, cl_mem memory);
|
||||
absl::Status SetImage2DArray(const std::string& name, cl_mem memory);
|
||||
absl::Status SetImage3D(const std::string& name, cl_mem memory);
|
||||
absl::Status SetImageBuffer(const std::string& name, cl_mem memory);
|
||||
absl::Status SetObjectRef(const std::string& name, const GPUObject* object);
|
||||
|
||||
std::string GetListOfArgs();
|
||||
|
||||
@ -112,6 +121,9 @@ class Arguments {
|
||||
|
||||
std::map<std::string, GPUBufferDescriptor> buffers_;
|
||||
std::map<std::string, GPUImage2DDescriptor> images2d_;
|
||||
std::map<std::string, GPUImage2DArrayDescriptor> image2d_arrays_;
|
||||
std::map<std::string, GPUImage3DDescriptor> images3d_;
|
||||
std::map<std::string, GPUImageBufferDescriptor> image_buffers_;
|
||||
|
||||
struct ObjectRefArg {
|
||||
AccessType access_type;
|
||||
|
@ -34,6 +34,21 @@ struct GPUImage2DDescriptor {
|
||||
cl_mem memory;
|
||||
};
|
||||
|
||||
struct GPUImage3DDescriptor {
|
||||
DataType data_type;
|
||||
cl_mem memory;
|
||||
};
|
||||
|
||||
struct GPUImage2DArrayDescriptor {
|
||||
DataType data_type;
|
||||
cl_mem memory;
|
||||
};
|
||||
|
||||
struct GPUImageBufferDescriptor {
|
||||
DataType data_type;
|
||||
cl_mem memory;
|
||||
};
|
||||
|
||||
struct GPUBufferDescriptor {
|
||||
DataType data_type;
|
||||
int element_size;
|
||||
@ -45,6 +60,9 @@ struct GPUResources {
|
||||
std::vector<std::string> floats;
|
||||
std::vector<std::pair<std::string, GPUBufferDescriptor>> buffers;
|
||||
std::vector<std::pair<std::string, GPUImage2DDescriptor>> images2d;
|
||||
std::vector<std::pair<std::string, GPUImage2DArrayDescriptor>> image2d_arrays;
|
||||
std::vector<std::pair<std::string, GPUImage3DDescriptor>> images3d;
|
||||
std::vector<std::pair<std::string, GPUImageBufferDescriptor>> image_buffers;
|
||||
|
||||
std::vector<std::string> GetNames() const {
|
||||
std::vector<std::string> names = ints;
|
||||
@ -55,6 +73,15 @@ struct GPUResources {
|
||||
for (const auto& obj : images2d) {
|
||||
names.push_back(obj.first);
|
||||
}
|
||||
for (const auto& obj : image2d_arrays) {
|
||||
names.push_back(obj.first);
|
||||
}
|
||||
for (const auto& obj : images3d) {
|
||||
names.push_back(obj.first);
|
||||
}
|
||||
for (const auto& obj : image_buffers) {
|
||||
names.push_back(obj.first);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
};
|
||||
@ -64,6 +91,9 @@ struct GPUResourcesWithValue {
|
||||
std::vector<std::pair<std::string, float>> floats;
|
||||
std::vector<std::pair<std::string, cl_mem>> buffers;
|
||||
std::vector<std::pair<std::string, cl_mem>> images2d;
|
||||
std::vector<std::pair<std::string, cl_mem>> image2d_arrays;
|
||||
std::vector<std::pair<std::string, cl_mem>> images3d;
|
||||
std::vector<std::pair<std::string, cl_mem>> image_buffers;
|
||||
};
|
||||
|
||||
class GPUObjectDescriptor {
|
||||
|
@ -55,7 +55,7 @@ std::string GetTransposeCode(
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n ";
|
||||
c += "$0) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int linear_id = get_global_id(0);\n";
|
||||
|
@ -90,7 +90,7 @@ std::string GetWinograd4x4To36Code(
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size";
|
||||
c += " int4 dst_size,\n ";
|
||||
c += "$0) {\n";
|
||||
c += " int DST_X = get_global_id(0);\n";
|
||||
c += " int DST_Y = get_global_id(1);\n";
|
||||
|
Loading…
Reference in New Issue
Block a user