Pooling converted to new style.

Merged 2D and 3D versions into one.

PiperOrigin-RevId: 316998142
Change-Id: I92c020476f085e6160a02282c1edafabdd72ca30
This commit is contained in:
Raman Sarokin 2020-06-17 17:21:24 -07:00 committed by TensorFlower Gardener
parent 2a6d9a1e81
commit 73cf8263c7
6 changed files with 313 additions and 458 deletions

View File

@ -25,366 +25,307 @@ namespace gpu {
namespace cl {
namespace {
std::string GetAveragePoolingKernelCode(
const OperationDef& op_def, bool stride_correction, const CLDevice& device,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor(
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
op_def.dst_tensors[0]);
std::string GetAveragePoolingKernelCode(const OperationDef& op_def,
bool stride_correction,
const CLDevice& device,
Arguments* args) {
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
src_desc->SetTextureAddressMode(GetFastestZeroMode(device));
if (op_def.IsBatchSupported()) {
src_desc->SetStateVar("BatchedWidth", "true");
}
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
auto dst_desc = absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]);
if (op_def.IsBatchSupported()) {
dst_desc->SetStateVar("BatchedWidth", "true");
}
args->AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc));
if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
args->AddInt("kernel_size_x");
args->AddInt("padding_x");
args->AddInt("stride_x");
}
if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
args->AddInt("kernel_size_y");
args->AddInt("padding_y");
args->AddInt("stride_y");
}
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
args->AddInt("kernel_size_z");
args->AddInt("padding_z");
args->AddInt("stride_z");
}
const auto address_mode = GetFastestZeroMode(device);
std::map<Axis, std::string> axis_to_src_coord = {
{Axis::WIDTH, "x_c"}, {Axis::HEIGHT, "y_c"}, {Axis::DEPTH, "d_c"},
{Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
};
std::string c = GetCommonDefines(op_def.precision);
std::map<Axis, std::string> axis_to_dst_coord = {
{Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
{Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
};
std::vector<std::string> src_coords;
std::vector<std::string> dst_coords;
for (auto axis : {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS}) {
if (op_def.dst_tensors[0].HasAxis(axis)) {
dst_coords.push_back(axis_to_dst_coord[axis]);
}
if (op_def.src_tensors[0].HasAxis(axis)) {
src_coords.push_back(axis_to_src_coord[axis]);
}
}
std::string src_coord = src_coords[0];
for (int i = 1; i < src_coords.size(); ++i) {
src_coord += ", " + src_coords[i];
}
std::string dst_coord = dst_coords[0];
for (int i = 1; i < dst_coords.size(); ++i) {
dst_coord += ", " + dst_coords[i];
}
const bool manual_clamp =
op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER ||
op_def.src_tensors[0].storage_type == TensorStorageType::IMAGE_BUFFER;
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int4 dst_size, \n";
c += " int2 kernel_size, \n";
c += " int2 padding, \n";
c += " int2 stride \n";
c += ") {\n";
c += "$0) {\n";
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " int linear_id_1 = get_global_id(1);\n";
c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n";
c += " int D = linear_id_1 % args.dst_tensor.Depth();\n";
} else {
c += " int Y = get_global_id(1);\n";
}
c += " int Z = get_global_id(2);\n";
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) { \n";
c += " return; \n";
c += " } \n";
c += " float4 r = (float4)(0.0f);\n";
c += " float window_size = 0.0;\n";
if (stride_correction) {
c += " int xs = " +
GetXStrideCorrected("X", "src_size.w", "stride.x", "padding.x") +
GetXStrideCorrected("X", "args.src_tensor.Batch()", "args.stride_x",
"args.padding_x") +
";\n";
} else {
c += " int xs = X * stride.x + padding.x;\n";
c += " int xs = X * args.stride_x + args.padding_x;\n";
}
c += " int ys = Y * stride.y + padding.y;\n";
c += " for (int ky = 0; ky < kernel_size.y; ++ky) {\n";
c += " int ys = Y * args.stride_y + args.padding_y;\n";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " int ds = D * args.stride_z + args.padding_z;\n";
c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n";
c += " int d_c = ds + kz;\n";
c += " if (d_c < 0 || d_c >= args.src_tensor.Depth()) continue;\n";
}
c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n";
c += " int y_c = ys + ky;\n";
c += " bool outside_y = y_c < 0 || y_c >= src_size.y;\n";
c += " for (int kx = 0; kx < kernel_size.x; ++kx) {\n";
c += " bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n";
c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n";
if (op_def.IsBatchSupported()) {
c += " int x_c = xs + kx * src_size.w;\n";
c += " int x_c = xs + kx * args.src_tensor.Batch();\n";
} else {
c += " int x_c = xs + kx;\n";
}
c += " bool outside = outside_y || x_c < 0 || x_c >= src_size.x;\n";
c += " bool outside = outside_y || x_c < 0 || x_c >= "
"args.src_tensor.Width();\n";
if (manual_clamp) {
c += " r += !outside ? " +
src_tensor.ReadAsFloatWHS("x_c", "y_c", "Z") + " : (float4)(0.0f);\n";
c += " r += !outside ? args.src_tensor.Read<float>(" + src_coord +
") : "
"(float4)(0.0f);\n";
} else {
c += " r += " +
src_tensor.ReadAsFloatWHS("x_c", "y_c", "Z", address_mode) + ";\n";
c += " r += args.src_tensor.Read<float>(" + src_coord + ");\n";
}
c += " window_size += !outside ? 1.0 : 0.0;\n";
c += " }\n";
c += " }\n";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " } // Depth\n";
}
// If window_size==0, window covered nothing. This situation is a sign of
// incorrectly constructed operation. NaNs are expected as output.
c += " FLT4 result = TO_FLT4(r / window_size);\n";
const LinkingContext context{"result", "X", "Y", "Z"};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHS("result", "X", "Y", "Z");
c += " args.dst_tensor.Write(result, " + dst_coord + ");\n";
c += "}\n";
return c;
}
std::string GetAveragePooling3DKernelCode(
const OperationDef& op_def, bool stride_correction, const CLDevice& device,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor(
"src_data",
WHDSPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data",
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
op_def.dst_tensors[0]);
const auto address_mode = GetFastestZeroMode(device);
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int4 dst_size, \n";
std::string GetMaxPoolingKernelCode(const OperationDef& op_def,
bool stride_correction, bool output_indices,
Arguments* args) {
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
if (op_def.IsBatchSupported()) {
c += " int batch_size, \n";
src_desc->SetStateVar("BatchedWidth", "true");
}
c += " int4 kernel_size, \n";
c += " int4 padding, \n";
c += " int4 stride \n";
c += ") {\n";
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
c += " int linear_id_z = get_global_id(2);\n";
c += " int S = linear_id_z % dst_size.w;\n";
c += " int Z = linear_id_z / dst_size.w;\n";
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
c += " float4 r = (float4)(0.0f);\n";
c += " float window_size = 0.0;\n";
if (stride_correction) {
c += " int xs = " +
GetXStrideCorrected("X", "batch_size", "stride.x", "padding.x") +
";\n";
} else {
c += " int xs = X * stride.x + padding.x;\n";
}
c += " int ys = Y * stride.y + padding.y;\n";
c += " int zs = Z * stride.z + padding.z;\n";
c += " for (int kz = 0; kz < kernel_size.z; ++kz) {\n";
c += " int z_c = zs + kz;\n";
c += " if (z_c < 0 || z_c >= src_size.z) continue;\n";
c += " for (int ky = 0; ky < kernel_size.y; ++ky) {\n";
c += " int y_c = ys + ky;\n";
c += " if (y_c < 0 || y_c >= src_size.y) continue;\n";
c += " for (int kx = 0; kx < kernel_size.x; ++kx) {\n";
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
auto dst_desc = absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]);
if (op_def.IsBatchSupported()) {
c += " int x_c = xs + kx * batch_size;\n";
} else {
c += " int x_c = xs + kx;\n";
dst_desc->SetStateVar("BatchedWidth", "true");
}
c += " if(x_c < 0 || x_c >= src_size.x) continue;\n";
c += " r += " +
src_tensor.ReadAsFloatWHDS("x_c", "y_c", "z_c", "S", address_mode) +
";\n";
c += " window_size += 1.0;\n";
c += " }\n";
c += " }\n";
c += " }\n";
// If window_size==0, window covered nothing. This situation is a sign of
// incorrectly constructed operation. NaNs are expected as output.
c += " FLT4 result = TO_FLT4(r / window_size);\n";
const LinkingContext context{"result", "X", "Y", "Z"};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHDS("result", "X", "Y", "Z", "S");
c += "}\n";
return c;
}
std::string GetMaxPoolingKernelCode(
const OperationDef& op_def, bool stride_correction,
const std::vector<ElementwiseOperation*>& linked_operations,
bool output_indices) {
TensorCodeGenerator src_tensor(
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
op_def.dst_tensors[0]);
const auto dst_ind_def =
output_indices ? op_def.dst_tensors[1] : op_def.dst_tensors[0];
TensorCodeGenerator indices_tensor(
"dst_indices", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
dst_ind_def);
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
args->AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc));
if (output_indices) {
c += indices_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
auto dst_ind_desc =
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[1]);
if (op_def.IsBatchSupported()) {
dst_ind_desc->SetStateVar("BatchedWidth", "true");
}
args->AddObjectRef("dst_indices", AccessType::WRITE,
std::move(dst_ind_desc));
}
c += " int4 src_size, \n";
c += " int4 dst_size, \n";
c += " int2 kernel_size, \n";
c += " int2 padding, \n";
c += " int2 stride \n";
c += ") {\n";
if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
args->AddInt("kernel_size_x");
args->AddInt("padding_x");
args->AddInt("stride_x");
}
if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
args->AddInt("kernel_size_y");
args->AddInt("padding_y");
args->AddInt("stride_y");
}
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
args->AddInt("kernel_size_z");
args->AddInt("padding_z");
args->AddInt("stride_z");
}
std::map<Axis, std::string> axis_to_src_coord = {
{Axis::WIDTH, "x_c"}, {Axis::HEIGHT, "y_c"}, {Axis::DEPTH, "d_c"},
{Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
};
std::map<Axis, std::string> axis_to_dst_coord = {
{Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
{Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
};
std::vector<std::string> src_coords;
std::vector<std::string> dst_coords;
for (auto axis : {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS}) {
if (op_def.dst_tensors[0].HasAxis(axis)) {
dst_coords.push_back(axis_to_dst_coord[axis]);
}
if (op_def.src_tensors[0].HasAxis(axis)) {
src_coords.push_back(axis_to_src_coord[axis]);
}
}
std::string src_coord = src_coords[0];
for (int i = 1; i < src_coords.size(); ++i) {
src_coord += ", " + src_coords[i];
}
std::string dst_coord = dst_coords[0];
for (int i = 1; i < dst_coords.size(); ++i) {
dst_coord += ", " + dst_coords[i];
}
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += "$0) {\n";
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " int linear_id_1 = get_global_id(1);\n";
c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n";
c += " int D = linear_id_1 % args.dst_tensor.Depth();\n";
} else {
c += " int Y = get_global_id(1);\n";
}
c += " int Z = get_global_id(2);\n";
c +=
" if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; \n";
c += " FLT4 maximum = (FLT4)(-10000.0f);\n";
if (output_indices) {
c += " FLT4 indexes = (FLT4)(0.0f);\n";
c += " FLT index_counter = (FLT)(0.1f);\n";
}
if (stride_correction) {
c += " int xs = " +
GetXStrideCorrected("X", "src_size.w", "stride.x", "padding.x") +
";\n";
} else {
c += " int xs = X * stride.x + padding.x;\n";
}
c += " int ys = Y * stride.y + padding.y;\n";
c += " for (int ky = 0; ky < kernel_size.y; ++ky) {\n";
c += " int y_c = ys + ky;\n";
c += " bool outside_y = y_c < 0 || y_c >= src_size.y;\n";
c += " for (int kx = 0; kx < kernel_size.x; ++kx) {\n";
if (op_def.IsBatchSupported()) {
c += " int x_c = xs + kx * src_size.w;\n";
} else {
c += " int x_c = xs + kx;\n";
}
c += " bool outside_x = x_c < 0 || x_c >= src_size.x;\n";
c += " if (!outside_x && !outside_y) {\n";
c += " FLT4 src = " + src_tensor.ReadWHS("x_c", "y_c", "Z") + ";\n";
if (output_indices) {
c += " if (src.x > maximum.x) {\n";
c += " indexes.x = index_counter;\n";
c += " maximum.x = src.x;\n";
c += " }\n";
c += " if (src.y > maximum.y) {\n";
c += " indexes.y = index_counter;\n";
c += " maximum.y = src.y;\n";
c += " }\n";
c += " if (src.z > maximum.z) {\n";
c += " indexes.z = index_counter;\n";
c += " maximum.z = src.z;\n";
c += " }\n";
c += " if (src.w > maximum.w) {\n";
c += " indexes.w = index_counter;\n";
c += " maximum.w = src.w;\n";
c += " }\n";
c += " index_counter += (FLT)(1.0f);\n";
} else {
c += " maximum = max(src, maximum);\n";
}
c += " }\n";
c += " }\n";
c += " }\n";
const LinkingContext context{"maximum", "X", "Y", "Z"};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHS("maximum", "X", "Y", "Z");
if (output_indices) {
c += " " + indices_tensor.WriteWHS("indexes", "X", "Y", "Z");
}
c += "}\n";
return c;
}
std::string GetMaxPooling3DKernelCode(
const OperationDef& op_def, bool stride_correction,
const std::vector<ElementwiseOperation*>& linked_operations,
bool output_indices) {
TensorCodeGenerator src_tensor(
"src_data",
WHDSPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data",
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
op_def.dst_tensors[0]);
const auto dst_ind_def =
output_indices ? op_def.dst_tensors[1] : op_def.dst_tensors[0];
TensorCodeGenerator indices_tensor(
"dst_indices",
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
dst_ind_def);
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
if (output_indices) {
c += indices_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
}
c += " int4 src_size, \n";
c += " int4 dst_size, \n";
if (op_def.IsBatchSupported()) {
c += " int batch_size, \n";
}
c += " int4 kernel_size, \n";
c += " int4 padding, \n";
c += " int4 stride \n";
c += ") {\n";
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
c += " int linear_id_z = get_global_id(2);\n";
c += " int S = linear_id_z % dst_size.w;\n";
c += " int Z = linear_id_z / dst_size.w;\n";
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) { \n";
c += " return; \n";
c += " } \n";
c += " FLT4 maximum = (FLT4)(-10000.0f);\n";
if (output_indices) {
c += " FLT4 indexes = (FLT4)(0.0f);\n";
}
if (stride_correction) {
c += " int xs = " +
GetXStrideCorrected("X", "batch_size", "stride.x", "padding.x") +
GetXStrideCorrected("X", "args.src_tensor.Batch()", "args.stride_x",
"args.padding_x") +
";\n";
} else {
c += " int xs = X * stride.x + padding.x;\n";
c += " int xs = X * args.stride_x + args.padding_x;\n";
}
c += " int ys = Y * stride.y + padding.y;\n";
c += " int zs = Z * stride.z + padding.z;\n";
c += " for (int ky = 0; ky < kernel_size.y; ++ky) {\n";
c += " int ys = Y * args.stride_y + args.padding_y;\n";
c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n";
c += " int y_c = ys + ky;\n";
c += " if (y_c < 0 || y_c >= src_size.y) continue;\n";
c += " for (int kx = 0; kx < kernel_size.x; ++kx) {\n";
c += " if (y_c < 0 || y_c >= args.src_tensor.Height()) continue;\n";
c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n";
if (op_def.IsBatchSupported()) {
c += " int x_c = xs + kx * batch_size;\n";
c += " int x_c = xs + kx * args.src_tensor.Batch();\n";
} else {
c += " int x_c = xs + kx;\n";
}
c += " if (x_c < 0 || x_c >= src_size.x) continue;\n";
c += " for (int kz = 0; kz < kernel_size.z; ++kz) {\n";
c += " int z_c = zs + kz;\n";
c += " if (z_c < 0 || z_c >= src_size.z) continue;\n";
c += " FLT4 src = " + src_tensor.ReadWHDS("x_c", "y_c", "z_c", "S") +
";\n";
if (output_indices) {
c += " FLT index_counter = (FLT)((ky * kernel_size.x + kx) * "
"kernel_size.z + kz) + (FLT)(0.1f);\n";
c += " if (src.x > maximum.x) {\n";
c += " indexes.x = index_counter;\n";
c += " maximum.x = src.x;\n";
c += " }\n";
c += " if (src.y > maximum.y) {\n";
c += " indexes.y = index_counter;\n";
c += " maximum.y = src.y;\n";
c += " }\n";
c += " if (src.z > maximum.z) {\n";
c += " indexes.z = index_counter;\n";
c += " maximum.z = src.z;\n";
c += " }\n";
c += " if (src.w > maximum.w) {\n";
c += " indexes.w = index_counter;\n";
c += " maximum.w = src.w;\n";
c += " }\n";
} else {
c += " maximum = max(src, maximum);\n";
c += " if (x_c < 0 || x_c >= args.src_tensor.Width()) continue;\n";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " int ds = D * args.stride_z + args.padding_z;\n";
c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n";
c += " int d_c = ds + kz;\n";
c += " if (d_c < 0 || d_c >= args.src_tensor.Depth()) continue;\n";
}
c += " FLT4 src = args.src_tensor.Read(" + src_coord + ");\n";
if (output_indices) {
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " FLT index_counter = (FLT)((ky * args.kernel_size_x + kx) * "
"args.kernel_size_z + kz) + (FLT)(0.1f);\n";
} else {
c += " FLT index_counter = (FLT)(ky * args.kernel_size_x + kx) + "
"(FLT)(0.1f);\n";
}
c += " if (src.x > maximum.x) {\n";
c += " indexes.x = index_counter;\n";
c += " maximum.x = src.x;\n";
c += " }\n";
c += " if (src.y > maximum.y) {\n";
c += " indexes.y = index_counter;\n";
c += " maximum.y = src.y;\n";
c += " }\n";
c += " if (src.z > maximum.z) {\n";
c += " indexes.z = index_counter;\n";
c += " maximum.z = src.z;\n";
c += " }\n";
c += " if (src.w > maximum.w) {\n";
c += " indexes.w = index_counter;\n";
c += " maximum.w = src.w;\n";
c += " }\n";
} else {
c += " maximum = max(src, maximum);\n";
}
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " } // Depth\n";
}
c += " };\n";
c += " }\n";
c += " }\n";
const LinkingContext context{"maximum", "X", "Y", "Z"};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHDS("maximum", "X", "Y", "Z", "S");
c += " args.dst_tensor.Write(maximum, " + dst_coord + ");\n";
if (output_indices) {
c += " " + indices_tensor.WriteWHDS("indexes", "X", "Y", "Z", "S");
c += " args.dst_indices.Write(indexes, " + dst_coord + ");\n";
}
c += "}\n";
return c;
}
} // namespace
Pooling::Pooling(const OperationDef& definition,
const Pooling2DAttributes& attr)
: GPUOperation(definition),
stride_(attr.strides.w, attr.strides.h),
padding_(-attr.padding.prepended.w, -attr.padding.prepended.h),
kernel_size_(attr.kernel.w, attr.kernel.h),
stride_(attr.strides.w, attr.strides.h, 0, 0),
padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0),
kernel_size_(attr.kernel.w, attr.kernel.h, 0, 0),
type_(attr.type),
output_indices_(attr.output_indices) {}
Pooling::Pooling(const OperationDef& definition,
const Pooling3DAttributes& attr)
: GPUOperation(definition),
stride_(attr.strides.w, attr.strides.h, attr.strides.d, 0),
padding_(-attr.padding.prepended.w, -attr.padding.prepended.h,
-attr.padding.prepended.d, 0),
kernel_size_(attr.kernel.w, attr.kernel.h, attr.kernel.d, 0),
type_(attr.type),
output_indices_(attr.output_indices) {}
@ -419,44 +360,56 @@ absl::Status Pooling::Compile(const CreationContext& creation_context) {
switch (type_) {
case PoolingType::AVERAGE:
code = GetAveragePoolingKernelCode(definition_, stride_correction,
*creation_context.device,
linked_operations_);
*creation_context.device, &args_);
break;
case PoolingType::MAX:
code = GetMaxPoolingKernelCode(definition_, stride_correction,
linked_operations_, output_indices_);
output_indices_, &args_);
break;
default:
return absl::InvalidArgumentError(
"You should create another kernel with this params");
break;
}
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
{{"dst_tensor", element_wise_code}},
&code));
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
absl::Status Pooling::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
if (output_indices_) {
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[1]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
if (definition_.dst_tensors[0].HasAxis(Axis::WIDTH)) {
RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x));
RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch()));
RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x));
}
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
RETURN_IF_ERROR(
kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y)));
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
return absl::OkStatus();
if (definition_.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y));
RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y));
RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y));
}
if (definition_.dst_tensors[0].HasAxis(Axis::DEPTH)) {
RETURN_IF_ERROR(args_.SetInt("stride_z", stride_.z));
RETURN_IF_ERROR(args_.SetInt("padding_z", padding_.z));
RETURN_IF_ERROR(args_.SetInt("kernel_size_z", kernel_size_.z));
}
if (output_indices_) {
RETURN_IF_ERROR(args_.SetObjectRef("dst_indices", dst_[1]));
}
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
return args_.Bind(kernel_.kernel());
}
int3 Pooling::GetGridSize() const {
const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
const int grid_y = dst_[0]->Height();
const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
const int grid_z = dst_[0]->Slices();
return int3(grid_x, grid_y, grid_z);
}
@ -476,107 +429,9 @@ Pooling CreatePooling(const OperationDef& definition,
return Pooling(definition, attr);
}
Pooling3D::Pooling3D(const OperationDef& definition,
const Pooling3DAttributes& attr)
: GPUOperation(definition),
stride_(attr.strides.w, attr.strides.h, attr.strides.d),
padding_(-attr.padding.prepended.w, -attr.padding.prepended.h,
-attr.padding.prepended.d),
kernel_size_(attr.kernel.w, attr.kernel.h, attr.kernel.d),
type_(attr.type),
output_indices_(attr.output_indices) {}
Pooling3D::Pooling3D(Pooling3D&& kernel)
: GPUOperation(std::move(kernel)),
stride_(kernel.stride_),
padding_(kernel.padding_),
kernel_size_(kernel.kernel_size_),
type_(kernel.type_),
output_indices_(kernel.output_indices_),
kernel_(std::move(kernel.kernel_)),
work_group_size_(kernel.work_group_size_) {}
Pooling3D& Pooling3D::operator=(Pooling3D&& kernel) {
if (this != &kernel) {
std::swap(stride_, kernel.stride_);
std::swap(padding_, kernel.padding_);
std::swap(kernel_size_, kernel.kernel_size_);
std::swap(type_, kernel.type_);
std::swap(output_indices_, kernel.output_indices_);
kernel_ = std::move(kernel.kernel_);
std::swap(work_group_size_, kernel.work_group_size_);
GPUOperation::operator=(std::move(kernel));
}
return *this;
}
absl::Status Pooling3D::Compile(const CreationContext& creation_context) {
std::string code;
const bool stride_correction =
definition_.IsBatchSupported() && stride_.x != 1;
switch (type_) {
case PoolingType::AVERAGE:
code = GetAveragePooling3DKernelCode(definition_, stride_correction,
*creation_context.device,
linked_operations_);
break;
case PoolingType::MAX:
code = GetMaxPooling3DKernelCode(definition_, stride_correction,
linked_operations_, output_indices_);
break;
default:
return absl::InvalidArgumentError(
"You should create another kernel with this params");
break;
}
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
absl::Status Pooling3D::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
if (output_indices_) {
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[1]->GetMemoryPtrForWriting()));
}
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS()));
if (definition_.IsBatchSupported()) {
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Batch()));
}
RETURN_IF_ERROR(kernel_.SetBytesAuto(
int4(kernel_size_.x, kernel_size_.y, kernel_size_.z, 1)));
RETURN_IF_ERROR(kernel_.SetBytesAuto(
int4(padding_.x * src_[0]->Batch(), padding_.y, padding_.z, 1)));
RETURN_IF_ERROR(
kernel_.SetBytesAuto(int4(stride_.x, stride_.y, stride_.z, 1)));
return absl::OkStatus();
}
int3 Pooling3D::GetGridSize() const {
const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
const int grid_y = dst_[0]->Height();
const int grid_z = dst_[0]->Slices() * dst_[0]->Depth();
return int3(grid_x, grid_y, grid_z);
}
absl::Status Pooling3D::Tune(const TuningParameters& params) {
RETURN_IF_ERROR(BindArguments());
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
}
absl::Status Pooling3D::AddToQueue(CLCommandQueue* queue) {
RETURN_IF_ERROR(BindArguments());
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
}
Pooling3D CreatePooling3D(const OperationDef& definition,
const Pooling3DAttributes& attr) {
return Pooling3D(definition, attr);
Pooling CreatePooling(const OperationDef& definition,
const Pooling3DAttributes& attr) {
return Pooling(definition, attr);
}
} // namespace cl

View File

@ -30,6 +30,7 @@ namespace cl {
class Pooling : public GPUOperation {
public:
Pooling(const OperationDef& definition, const Pooling2DAttributes& attr);
Pooling(const OperationDef& definition, const Pooling3DAttributes& attr);
absl::Status AddToQueue(CLCommandQueue* queue) override;
absl::Status Tune(const TuningParameters& params) override;
@ -45,9 +46,9 @@ class Pooling : public GPUOperation {
absl::Status BindArguments();
int3 GetGridSize() const;
int2 stride_;
int2 padding_;
int2 kernel_size_;
int4 stride_;
int4 padding_;
int4 kernel_size_;
PoolingType type_;
bool output_indices_;
@ -59,37 +60,8 @@ class Pooling : public GPUOperation {
Pooling CreatePooling(const OperationDef& definition,
const Pooling2DAttributes& attr);
class Pooling3D : public GPUOperation {
public:
Pooling3D(const OperationDef& definition, const Pooling3DAttributes& attr);
absl::Status AddToQueue(CLCommandQueue* queue) override;
absl::Status Tune(const TuningParameters& params) override;
absl::Status Compile(const CreationContext& creation_context) override;
// Move only
Pooling3D(Pooling3D&& kernel);
Pooling3D& operator=(Pooling3D&& kernel);
Pooling3D(const Pooling3D&) = delete;
Pooling3D& operator=(const Pooling3D&) = delete;
private:
absl::Status BindArguments();
int3 GetGridSize() const;
int3 stride_;
int3 padding_;
int3 kernel_size_;
PoolingType type_;
bool output_indices_;
CLKernel kernel_;
int3 work_group_size_ = int3(8, 4, 1);
};
Pooling3D CreatePooling3D(const OperationDef& definition,
const Pooling3DAttributes& attr);
Pooling CreatePooling(const OperationDef& definition,
const Pooling3DAttributes& attr);
} // namespace cl
} // namespace gpu

View File

@ -60,15 +60,6 @@ std::string GetImageModifier(AccessType access) {
}
}
std::string TextureAddressModeToString(TextureAddressMode address_mode) {
switch (address_mode) {
case TextureAddressMode::DONT_CARE:
return "smp_none";
case TextureAddressMode::ZERO:
return "smp_zero";
}
}
} // namespace
std::string GetCommonDefines(CalculationsPrecision precision) {

View File

@ -36,11 +36,6 @@ namespace cl {
std::string GetCommonDefines(CalculationsPrecision precision);
enum class TextureAddressMode {
DONT_CARE, // translated to CLK_ADDRESS_NONE
ZERO, // translated to CLK_ADDRESS_CLAMP
};
struct WHSPoint {
std::string w_name;
std::string h_name;

View File

@ -45,6 +45,15 @@ std::string GetWriteImageFromDataType(DataType data_type) {
} // namespace
std::string TextureAddressModeToString(TextureAddressMode address_mode) {
switch (address_mode) {
case TextureAddressMode::DONT_CARE:
return "smp_none";
case TextureAddressMode::ZERO:
return "smp_zero";
}
}
std::string ToString(TensorStorageType type) {
switch (type) {
case TensorStorageType::UNKNOWN:
@ -271,8 +280,10 @@ std::string TensorDescriptor::Read(DataType read_as_type,
case TensorStorageType::TEXTURE_3D:
case TensorStorageType::SINGLE_TEXTURE_2D:
case TensorStorageType::TEXTURE_ARRAY:
return absl::StrCat(read_as, "(", image_type, ", smp_none, ",
global_address, ")");
return absl::StrCat(
read_as, "(", image_type,
", " + TextureAddressModeToString(ModeFromState()) + ", ",
global_address, ")");
case TensorStorageType::IMAGE_BUFFER:
return absl::StrCat(read_as, "(image_buffer, ", global_address, ")");
case TensorStorageType::UNKNOWN:
@ -500,6 +511,14 @@ bool TensorDescriptor::HasAxis(Axis axis) const {
return false;
}
void TensorDescriptor::SetTextureAddressMode(TextureAddressMode mode) {
if (mode == TextureAddressMode::ZERO) {
state_vars_["TextureMode"] = "ZERO";
} else {
state_vars_["TextureMode"] = "DONT_CARE";
}
}
bool TensorDescriptor::ParseCoordsFromArgs(const std::vector<std::string>& args,
int offset, std::string* xc,
std::string* yc, std::string* zc,
@ -549,6 +568,19 @@ bool TensorDescriptor::IsBatchedWidth() const {
return it != state_vars_.end() && it->second == "true";
}
TextureAddressMode TensorDescriptor::ModeFromState() const {
auto it = state_vars_.find("TextureMode");
if (it != state_vars_.end()) {
if (it->second == "ZERO") {
return TextureAddressMode::ZERO;
} else {
return TextureAddressMode::DONT_CARE;
}
} else {
return TextureAddressMode::DONT_CARE;
}
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -27,6 +27,13 @@ namespace tflite {
namespace gpu {
namespace cl {
enum class TextureAddressMode {
DONT_CARE, // translated to CLK_ADDRESS_NONE
ZERO, // translated to CLK_ADDRESS_CLAMP
};
std::string TextureAddressModeToString(TextureAddressMode address_mode);
enum class TensorStorageType {
UNKNOWN,
BUFFER,
@ -71,6 +78,7 @@ struct TensorDescriptor : public GPUObjectDescriptor {
GPUResources GetGPUResources(AccessType access_type) const override;
bool HasAxis(Axis axis) const;
void SetTextureAddressMode(TextureAddressMode mode);
absl::Status GetLinkingContextFromWriteSelector(
const std::vector<std::string>& args, std::string* value_name,
@ -106,6 +114,8 @@ struct TensorDescriptor : public GPUObjectDescriptor {
bool IsBatchedWidth() const;
TextureAddressMode ModeFromState() const;
absl::Status GetDataTypeFromTemplateArgs(const std::string& template_arg,
DataType* result) const;