Added batch support to ConcatXY.
PiperOrigin-RevId: 272306244
This commit is contained in:
parent
cb98da75f2
commit
5cbe5aef70
@ -28,22 +28,31 @@ namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GetConcatKernelCode(
|
||||
const OperationDef& definition, int tensors_count,
|
||||
const OperationDef& op_def, int tensors_count,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
std::vector<std::shared_ptr<TensorCodeGenerator>> srcs(tensors_count);
|
||||
std::vector<TensorCodeGenerator> srcs(tensors_count);
|
||||
for (int i = 0; i < tensors_count; ++i) {
|
||||
const std::string tensor_name = "src_data_" + std::to_string(i);
|
||||
const std::string uniform_name = "src_size_" + std::to_string(i);
|
||||
srcs[i] = std::shared_ptr<TensorCodeGenerator>(new TensorCodeGenerator(
|
||||
tensor_name, uniform_name, definition.src_tensors[i]));
|
||||
srcs[i] =
|
||||
TensorCodeGenerator(tensor_name, uniform_name, op_def.src_tensors[i]);
|
||||
}
|
||||
TensorCodeGenerator dst("dst_data", "dst_size", definition.dst_tensors[0]);
|
||||
TensorCodeGenerator dst("dst_data", "dst_size", op_def.dst_tensors[0]);
|
||||
|
||||
std::string c = GetCommonDefines(definition.precision);
|
||||
auto read_src = [&](const TensorCodeGenerator& tensor, const std::string& x,
|
||||
const std::string& y, const std::string& z) {
|
||||
if (op_def.batch_support) {
|
||||
return tensor.Read4D(x, y, z, "B");
|
||||
} else {
|
||||
return tensor.Read3D(x, y, z, TextureAddressMode::DONT_CARE);
|
||||
}
|
||||
};
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
|
||||
c += "__kernel void main_function(\n";
|
||||
for (const auto& src : srcs) {
|
||||
c += src->GetDeclaration(AccessType::READ) + ",\n";
|
||||
c += src.GetDeclaration(AccessType::READ) + ",\n";
|
||||
}
|
||||
c += dst.GetDeclaration(AccessType::WRITE);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
@ -55,23 +64,35 @@ std::string GetConcatKernelCode(
|
||||
const std::string uniform_name = "dst_offset_" + std::to_string(i);
|
||||
c += " int2 " + uniform_name + ",\n";
|
||||
}
|
||||
if (op_def.batch_support) {
|
||||
c += " int BATCH_SIZE, \n";
|
||||
}
|
||||
c += " int4 dst_size \n";
|
||||
c += ") {\n";
|
||||
c += " int X = get_global_id(0);\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (Z >= dst_size.w) return;\n";
|
||||
if (op_def.batch_support) {
|
||||
c += " int B = get_global_id(2) / dst_size.w;\n";
|
||||
c += " int Z = get_global_id(2) - B * dst_size.w;\n";
|
||||
c += " if (Z >= dst_size.w || B >= BATCH_SIZE) return;\n";
|
||||
} else {
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (Z >= dst_size.w) return;\n";
|
||||
}
|
||||
for (int i = 0; i < tensors_count; ++i) {
|
||||
const std::string offset_name = "dst_offset_" + std::to_string(i);
|
||||
const std::string size_name = "src_size_" + std::to_string(i);
|
||||
c += " if (X < " + size_name + ".x && Y < " + size_name + ".y) { \n";
|
||||
c += " FLT4 result = " +
|
||||
srcs[i]->Read3D("X", "Y", "Z", TextureAddressMode::DONT_CARE) + ";\n";
|
||||
c += " FLT4 result = " + read_src(srcs[i], "X", "Y", "Z") + ";\n";
|
||||
c += " int dst_x = X + " + offset_name + ".x;\n";
|
||||
c += " int dst_y = Y + " + offset_name + ".y;\n";
|
||||
const LinkingContext context{"result", "dst_x", "dst_y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst.Write3D("result", "dst_x", "dst_y", "Z");
|
||||
if (op_def.batch_support) {
|
||||
c += " " + dst.Write4D("result", "dst_x", "dst_y", "Z", "B");
|
||||
} else {
|
||||
c += " " + dst.Write3D("result", "dst_x", "dst_y", "Z");
|
||||
}
|
||||
c += " } \n";
|
||||
}
|
||||
c += "}\n";
|
||||
@ -127,6 +148,9 @@ Status ConcatXY::BindArguments() {
|
||||
x_offset += attr_.axis == Axis::WIDTH ? src_[i]->Width() : 0;
|
||||
y_offset += attr_.axis == Axis::HEIGHT ? src_[i]->Height() : 0;
|
||||
}
|
||||
if (definition_.batch_support) {
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch()));
|
||||
}
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
|
||||
return OkStatus();
|
||||
}
|
||||
@ -141,7 +165,7 @@ int3 ConcatXY::GetGridSize() const {
|
||||
|
||||
const int grid_x = max_src_width;
|
||||
const int grid_y = max_src_height;
|
||||
const int grid_z = dst_[0]->Depth();
|
||||
const int grid_z = dst_[0]->Depth() * dst_[0]->Batch();
|
||||
|
||||
return int3(grid_x, grid_y, grid_z);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user