Pooling converted to new style.
Merged 2D and 3D versions into one. PiperOrigin-RevId: 316998142 Change-Id: I92c020476f085e6160a02282c1edafabdd72ca30
This commit is contained in:
parent
2a6d9a1e81
commit
73cf8263c7
@ -25,366 +25,307 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GetAveragePoolingKernelCode(
|
||||
const OperationDef& op_def, bool stride_correction, const CLDevice& device,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
op_def.dst_tensors[0]);
|
||||
std::string GetAveragePoolingKernelCode(const OperationDef& op_def,
|
||||
bool stride_correction,
|
||||
const CLDevice& device,
|
||||
Arguments* args) {
|
||||
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
|
||||
src_desc->SetTextureAddressMode(GetFastestZeroMode(device));
|
||||
if (op_def.IsBatchSupported()) {
|
||||
src_desc->SetStateVar("BatchedWidth", "true");
|
||||
}
|
||||
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
|
||||
auto dst_desc = absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]);
|
||||
if (op_def.IsBatchSupported()) {
|
||||
dst_desc->SetStateVar("BatchedWidth", "true");
|
||||
}
|
||||
args->AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc));
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
|
||||
args->AddInt("kernel_size_x");
|
||||
args->AddInt("padding_x");
|
||||
args->AddInt("stride_x");
|
||||
}
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
|
||||
args->AddInt("kernel_size_y");
|
||||
args->AddInt("padding_y");
|
||||
args->AddInt("stride_y");
|
||||
}
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
args->AddInt("kernel_size_z");
|
||||
args->AddInt("padding_z");
|
||||
args->AddInt("stride_z");
|
||||
}
|
||||
|
||||
const auto address_mode = GetFastestZeroMode(device);
|
||||
std::map<Axis, std::string> axis_to_src_coord = {
|
||||
{Axis::WIDTH, "x_c"}, {Axis::HEIGHT, "y_c"}, {Axis::DEPTH, "d_c"},
|
||||
{Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
|
||||
};
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
std::map<Axis, std::string> axis_to_dst_coord = {
|
||||
{Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
|
||||
{Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
|
||||
};
|
||||
|
||||
std::vector<std::string> src_coords;
|
||||
std::vector<std::string> dst_coords;
|
||||
for (auto axis : {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS}) {
|
||||
if (op_def.dst_tensors[0].HasAxis(axis)) {
|
||||
dst_coords.push_back(axis_to_dst_coord[axis]);
|
||||
}
|
||||
if (op_def.src_tensors[0].HasAxis(axis)) {
|
||||
src_coords.push_back(axis_to_src_coord[axis]);
|
||||
}
|
||||
}
|
||||
std::string src_coord = src_coords[0];
|
||||
for (int i = 1; i < src_coords.size(); ++i) {
|
||||
src_coord += ", " + src_coords[i];
|
||||
}
|
||||
std::string dst_coord = dst_coords[0];
|
||||
for (int i = 1; i < dst_coords.size(); ++i) {
|
||||
dst_coord += ", " + dst_coords[i];
|
||||
}
|
||||
|
||||
const bool manual_clamp =
|
||||
op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER ||
|
||||
op_def.src_tensors[0].storage_type == TensorStorageType::IMAGE_BUFFER;
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size, \n";
|
||||
c += " int2 kernel_size, \n";
|
||||
c += " int2 padding, \n";
|
||||
c += " int2 stride \n";
|
||||
c += ") {\n";
|
||||
c += "$0) {\n";
|
||||
c += " int X = get_global_id(0);\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
c += " int linear_id_1 = get_global_id(1);\n";
|
||||
c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n";
|
||||
c += " int D = linear_id_1 % args.dst_tensor.Depth();\n";
|
||||
} else {
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
}
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"Z >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
c += " float4 r = (float4)(0.0f);\n";
|
||||
c += " float window_size = 0.0;\n";
|
||||
if (stride_correction) {
|
||||
c += " int xs = " +
|
||||
GetXStrideCorrected("X", "src_size.w", "stride.x", "padding.x") +
|
||||
GetXStrideCorrected("X", "args.src_tensor.Batch()", "args.stride_x",
|
||||
"args.padding_x") +
|
||||
";\n";
|
||||
} else {
|
||||
c += " int xs = X * stride.x + padding.x;\n";
|
||||
c += " int xs = X * args.stride_x + args.padding_x;\n";
|
||||
}
|
||||
c += " int ys = Y * stride.y + padding.y;\n";
|
||||
c += " for (int ky = 0; ky < kernel_size.y; ++ky) {\n";
|
||||
c += " int ys = Y * args.stride_y + args.padding_y;\n";
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
c += " int ds = D * args.stride_z + args.padding_z;\n";
|
||||
c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n";
|
||||
c += " int d_c = ds + kz;\n";
|
||||
c += " if (d_c < 0 || d_c >= args.src_tensor.Depth()) continue;\n";
|
||||
}
|
||||
c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n";
|
||||
c += " int y_c = ys + ky;\n";
|
||||
c += " bool outside_y = y_c < 0 || y_c >= src_size.y;\n";
|
||||
c += " for (int kx = 0; kx < kernel_size.x; ++kx) {\n";
|
||||
c += " bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n";
|
||||
c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int x_c = xs + kx * src_size.w;\n";
|
||||
c += " int x_c = xs + kx * args.src_tensor.Batch();\n";
|
||||
} else {
|
||||
c += " int x_c = xs + kx;\n";
|
||||
}
|
||||
c += " bool outside = outside_y || x_c < 0 || x_c >= src_size.x;\n";
|
||||
c += " bool outside = outside_y || x_c < 0 || x_c >= "
|
||||
"args.src_tensor.Width();\n";
|
||||
if (manual_clamp) {
|
||||
c += " r += !outside ? " +
|
||||
src_tensor.ReadAsFloatWHS("x_c", "y_c", "Z") + " : (float4)(0.0f);\n";
|
||||
c += " r += !outside ? args.src_tensor.Read<float>(" + src_coord +
|
||||
") : "
|
||||
"(float4)(0.0f);\n";
|
||||
} else {
|
||||
c += " r += " +
|
||||
src_tensor.ReadAsFloatWHS("x_c", "y_c", "Z", address_mode) + ";\n";
|
||||
c += " r += args.src_tensor.Read<float>(" + src_coord + ");\n";
|
||||
}
|
||||
c += " window_size += !outside ? 1.0 : 0.0;\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
c += " } // Depth\n";
|
||||
}
|
||||
// If window_size==0, window covered nothing. This situation is a sign of
|
||||
// incorrectly constructed operation. NaNs are expected as output.
|
||||
c += " FLT4 result = TO_FLT4(r / window_size);\n";
|
||||
const LinkingContext context{"result", "X", "Y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHS("result", "X", "Y", "Z");
|
||||
c += " args.dst_tensor.Write(result, " + dst_coord + ");\n";
|
||||
c += "}\n";
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
std::string GetAveragePooling3DKernelCode(
|
||||
const OperationDef& op_def, bool stride_correction, const CLDevice& device,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data",
|
||||
WHDSPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data",
|
||||
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
op_def.dst_tensors[0]);
|
||||
|
||||
const auto address_mode = GetFastestZeroMode(device);
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size, \n";
|
||||
std::string GetMaxPoolingKernelCode(const OperationDef& op_def,
|
||||
bool stride_correction, bool output_indices,
|
||||
Arguments* args) {
|
||||
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int batch_size, \n";
|
||||
src_desc->SetStateVar("BatchedWidth", "true");
|
||||
}
|
||||
c += " int4 kernel_size, \n";
|
||||
c += " int4 padding, \n";
|
||||
c += " int4 stride \n";
|
||||
c += ") {\n";
|
||||
c += " int X = get_global_id(0);\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int linear_id_z = get_global_id(2);\n";
|
||||
c += " int S = linear_id_z % dst_size.w;\n";
|
||||
c += " int Z = linear_id_z / dst_size.w;\n";
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
|
||||
c += " float4 r = (float4)(0.0f);\n";
|
||||
c += " float window_size = 0.0;\n";
|
||||
if (stride_correction) {
|
||||
c += " int xs = " +
|
||||
GetXStrideCorrected("X", "batch_size", "stride.x", "padding.x") +
|
||||
";\n";
|
||||
} else {
|
||||
c += " int xs = X * stride.x + padding.x;\n";
|
||||
}
|
||||
c += " int ys = Y * stride.y + padding.y;\n";
|
||||
c += " int zs = Z * stride.z + padding.z;\n";
|
||||
c += " for (int kz = 0; kz < kernel_size.z; ++kz) {\n";
|
||||
c += " int z_c = zs + kz;\n";
|
||||
c += " if (z_c < 0 || z_c >= src_size.z) continue;\n";
|
||||
c += " for (int ky = 0; ky < kernel_size.y; ++ky) {\n";
|
||||
c += " int y_c = ys + ky;\n";
|
||||
c += " if (y_c < 0 || y_c >= src_size.y) continue;\n";
|
||||
c += " for (int kx = 0; kx < kernel_size.x; ++kx) {\n";
|
||||
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
|
||||
auto dst_desc = absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]);
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int x_c = xs + kx * batch_size;\n";
|
||||
} else {
|
||||
c += " int x_c = xs + kx;\n";
|
||||
dst_desc->SetStateVar("BatchedWidth", "true");
|
||||
}
|
||||
c += " if(x_c < 0 || x_c >= src_size.x) continue;\n";
|
||||
c += " r += " +
|
||||
src_tensor.ReadAsFloatWHDS("x_c", "y_c", "z_c", "S", address_mode) +
|
||||
";\n";
|
||||
c += " window_size += 1.0;\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
// If window_size==0, window covered nothing. This situation is a sign of
|
||||
// incorrectly constructed operation. NaNs are expected as output.
|
||||
c += " FLT4 result = TO_FLT4(r / window_size);\n";
|
||||
const LinkingContext context{"result", "X", "Y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHDS("result", "X", "Y", "Z", "S");
|
||||
c += "}\n";
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
std::string GetMaxPoolingKernelCode(
|
||||
const OperationDef& op_def, bool stride_correction,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations,
|
||||
bool output_indices) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
op_def.dst_tensors[0]);
|
||||
const auto dst_ind_def =
|
||||
output_indices ? op_def.dst_tensors[1] : op_def.dst_tensors[0];
|
||||
TensorCodeGenerator indices_tensor(
|
||||
"dst_indices", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
dst_ind_def);
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
args->AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc));
|
||||
if (output_indices) {
|
||||
c += indices_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
auto dst_ind_desc =
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[1]);
|
||||
if (op_def.IsBatchSupported()) {
|
||||
dst_ind_desc->SetStateVar("BatchedWidth", "true");
|
||||
}
|
||||
args->AddObjectRef("dst_indices", AccessType::WRITE,
|
||||
std::move(dst_ind_desc));
|
||||
}
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size, \n";
|
||||
c += " int2 kernel_size, \n";
|
||||
c += " int2 padding, \n";
|
||||
c += " int2 stride \n";
|
||||
c += ") {\n";
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
|
||||
args->AddInt("kernel_size_x");
|
||||
args->AddInt("padding_x");
|
||||
args->AddInt("stride_x");
|
||||
}
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
|
||||
args->AddInt("kernel_size_y");
|
||||
args->AddInt("padding_y");
|
||||
args->AddInt("stride_y");
|
||||
}
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
args->AddInt("kernel_size_z");
|
||||
args->AddInt("padding_z");
|
||||
args->AddInt("stride_z");
|
||||
}
|
||||
|
||||
std::map<Axis, std::string> axis_to_src_coord = {
|
||||
{Axis::WIDTH, "x_c"}, {Axis::HEIGHT, "y_c"}, {Axis::DEPTH, "d_c"},
|
||||
{Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
|
||||
};
|
||||
|
||||
std::map<Axis, std::string> axis_to_dst_coord = {
|
||||
{Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
|
||||
{Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
|
||||
};
|
||||
|
||||
std::vector<std::string> src_coords;
|
||||
std::vector<std::string> dst_coords;
|
||||
for (auto axis : {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS}) {
|
||||
if (op_def.dst_tensors[0].HasAxis(axis)) {
|
||||
dst_coords.push_back(axis_to_dst_coord[axis]);
|
||||
}
|
||||
if (op_def.src_tensors[0].HasAxis(axis)) {
|
||||
src_coords.push_back(axis_to_src_coord[axis]);
|
||||
}
|
||||
}
|
||||
std::string src_coord = src_coords[0];
|
||||
for (int i = 1; i < src_coords.size(); ++i) {
|
||||
src_coord += ", " + src_coords[i];
|
||||
}
|
||||
std::string dst_coord = dst_coords[0];
|
||||
for (int i = 1; i < dst_coords.size(); ++i) {
|
||||
dst_coord += ", " + dst_coords[i];
|
||||
}
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += "$0) {\n";
|
||||
c += " int X = get_global_id(0);\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
c += " int linear_id_1 = get_global_id(1);\n";
|
||||
c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n";
|
||||
c += " int D = linear_id_1 % args.dst_tensor.Depth();\n";
|
||||
} else {
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
}
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c +=
|
||||
" if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; \n";
|
||||
c += " FLT4 maximum = (FLT4)(-10000.0f);\n";
|
||||
if (output_indices) {
|
||||
c += " FLT4 indexes = (FLT4)(0.0f);\n";
|
||||
c += " FLT index_counter = (FLT)(0.1f);\n";
|
||||
}
|
||||
if (stride_correction) {
|
||||
c += " int xs = " +
|
||||
GetXStrideCorrected("X", "src_size.w", "stride.x", "padding.x") +
|
||||
";\n";
|
||||
} else {
|
||||
c += " int xs = X * stride.x + padding.x;\n";
|
||||
}
|
||||
c += " int ys = Y * stride.y + padding.y;\n";
|
||||
c += " for (int ky = 0; ky < kernel_size.y; ++ky) {\n";
|
||||
c += " int y_c = ys + ky;\n";
|
||||
c += " bool outside_y = y_c < 0 || y_c >= src_size.y;\n";
|
||||
c += " for (int kx = 0; kx < kernel_size.x; ++kx) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int x_c = xs + kx * src_size.w;\n";
|
||||
} else {
|
||||
c += " int x_c = xs + kx;\n";
|
||||
}
|
||||
c += " bool outside_x = x_c < 0 || x_c >= src_size.x;\n";
|
||||
c += " if (!outside_x && !outside_y) {\n";
|
||||
c += " FLT4 src = " + src_tensor.ReadWHS("x_c", "y_c", "Z") + ";\n";
|
||||
if (output_indices) {
|
||||
c += " if (src.x > maximum.x) {\n";
|
||||
c += " indexes.x = index_counter;\n";
|
||||
c += " maximum.x = src.x;\n";
|
||||
c += " }\n";
|
||||
c += " if (src.y > maximum.y) {\n";
|
||||
c += " indexes.y = index_counter;\n";
|
||||
c += " maximum.y = src.y;\n";
|
||||
c += " }\n";
|
||||
c += " if (src.z > maximum.z) {\n";
|
||||
c += " indexes.z = index_counter;\n";
|
||||
c += " maximum.z = src.z;\n";
|
||||
c += " }\n";
|
||||
c += " if (src.w > maximum.w) {\n";
|
||||
c += " indexes.w = index_counter;\n";
|
||||
c += " maximum.w = src.w;\n";
|
||||
c += " }\n";
|
||||
c += " index_counter += (FLT)(1.0f);\n";
|
||||
} else {
|
||||
c += " maximum = max(src, maximum);\n";
|
||||
}
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
const LinkingContext context{"maximum", "X", "Y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHS("maximum", "X", "Y", "Z");
|
||||
if (output_indices) {
|
||||
c += " " + indices_tensor.WriteWHS("indexes", "X", "Y", "Z");
|
||||
}
|
||||
c += "}\n";
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
std::string GetMaxPooling3DKernelCode(
|
||||
const OperationDef& op_def, bool stride_correction,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations,
|
||||
bool output_indices) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data",
|
||||
WHDSPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data",
|
||||
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
op_def.dst_tensors[0]);
|
||||
const auto dst_ind_def =
|
||||
output_indices ? op_def.dst_tensors[1] : op_def.dst_tensors[0];
|
||||
TensorCodeGenerator indices_tensor(
|
||||
"dst_indices",
|
||||
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
dst_ind_def);
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
if (output_indices) {
|
||||
c += indices_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
}
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size, \n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int batch_size, \n";
|
||||
}
|
||||
c += " int4 kernel_size, \n";
|
||||
c += " int4 padding, \n";
|
||||
c += " int4 stride \n";
|
||||
c += ") {\n";
|
||||
c += " int X = get_global_id(0);\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int linear_id_z = get_global_id(2);\n";
|
||||
c += " int S = linear_id_z % dst_size.w;\n";
|
||||
c += " int Z = linear_id_z / dst_size.w;\n";
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"Z >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
c += " FLT4 maximum = (FLT4)(-10000.0f);\n";
|
||||
if (output_indices) {
|
||||
c += " FLT4 indexes = (FLT4)(0.0f);\n";
|
||||
}
|
||||
if (stride_correction) {
|
||||
c += " int xs = " +
|
||||
GetXStrideCorrected("X", "batch_size", "stride.x", "padding.x") +
|
||||
GetXStrideCorrected("X", "args.src_tensor.Batch()", "args.stride_x",
|
||||
"args.padding_x") +
|
||||
";\n";
|
||||
} else {
|
||||
c += " int xs = X * stride.x + padding.x;\n";
|
||||
c += " int xs = X * args.stride_x + args.padding_x;\n";
|
||||
}
|
||||
c += " int ys = Y * stride.y + padding.y;\n";
|
||||
c += " int zs = Z * stride.z + padding.z;\n";
|
||||
c += " for (int ky = 0; ky < kernel_size.y; ++ky) {\n";
|
||||
c += " int ys = Y * args.stride_y + args.padding_y;\n";
|
||||
c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n";
|
||||
c += " int y_c = ys + ky;\n";
|
||||
c += " if (y_c < 0 || y_c >= src_size.y) continue;\n";
|
||||
c += " for (int kx = 0; kx < kernel_size.x; ++kx) {\n";
|
||||
c += " if (y_c < 0 || y_c >= args.src_tensor.Height()) continue;\n";
|
||||
c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int x_c = xs + kx * batch_size;\n";
|
||||
c += " int x_c = xs + kx * args.src_tensor.Batch();\n";
|
||||
} else {
|
||||
c += " int x_c = xs + kx;\n";
|
||||
}
|
||||
c += " if (x_c < 0 || x_c >= src_size.x) continue;\n";
|
||||
c += " for (int kz = 0; kz < kernel_size.z; ++kz) {\n";
|
||||
c += " int z_c = zs + kz;\n";
|
||||
c += " if (z_c < 0 || z_c >= src_size.z) continue;\n";
|
||||
c += " FLT4 src = " + src_tensor.ReadWHDS("x_c", "y_c", "z_c", "S") +
|
||||
";\n";
|
||||
if (output_indices) {
|
||||
c += " FLT index_counter = (FLT)((ky * kernel_size.x + kx) * "
|
||||
"kernel_size.z + kz) + (FLT)(0.1f);\n";
|
||||
c += " if (src.x > maximum.x) {\n";
|
||||
c += " indexes.x = index_counter;\n";
|
||||
c += " maximum.x = src.x;\n";
|
||||
c += " }\n";
|
||||
c += " if (src.y > maximum.y) {\n";
|
||||
c += " indexes.y = index_counter;\n";
|
||||
c += " maximum.y = src.y;\n";
|
||||
c += " }\n";
|
||||
c += " if (src.z > maximum.z) {\n";
|
||||
c += " indexes.z = index_counter;\n";
|
||||
c += " maximum.z = src.z;\n";
|
||||
c += " }\n";
|
||||
c += " if (src.w > maximum.w) {\n";
|
||||
c += " indexes.w = index_counter;\n";
|
||||
c += " maximum.w = src.w;\n";
|
||||
c += " }\n";
|
||||
} else {
|
||||
c += " maximum = max(src, maximum);\n";
|
||||
c += " if (x_c < 0 || x_c >= args.src_tensor.Width()) continue;\n";
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
c += " int ds = D * args.stride_z + args.padding_z;\n";
|
||||
c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n";
|
||||
c += " int d_c = ds + kz;\n";
|
||||
c += " if (d_c < 0 || d_c >= args.src_tensor.Depth()) continue;\n";
|
||||
}
|
||||
c += " FLT4 src = args.src_tensor.Read(" + src_coord + ");\n";
|
||||
if (output_indices) {
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
c += " FLT index_counter = (FLT)((ky * args.kernel_size_x + kx) * "
|
||||
"args.kernel_size_z + kz) + (FLT)(0.1f);\n";
|
||||
} else {
|
||||
c += " FLT index_counter = (FLT)(ky * args.kernel_size_x + kx) + "
|
||||
"(FLT)(0.1f);\n";
|
||||
}
|
||||
c += " if (src.x > maximum.x) {\n";
|
||||
c += " indexes.x = index_counter;\n";
|
||||
c += " maximum.x = src.x;\n";
|
||||
c += " }\n";
|
||||
c += " if (src.y > maximum.y) {\n";
|
||||
c += " indexes.y = index_counter;\n";
|
||||
c += " maximum.y = src.y;\n";
|
||||
c += " }\n";
|
||||
c += " if (src.z > maximum.z) {\n";
|
||||
c += " indexes.z = index_counter;\n";
|
||||
c += " maximum.z = src.z;\n";
|
||||
c += " }\n";
|
||||
c += " if (src.w > maximum.w) {\n";
|
||||
c += " indexes.w = index_counter;\n";
|
||||
c += " maximum.w = src.w;\n";
|
||||
c += " }\n";
|
||||
} else {
|
||||
c += " maximum = max(src, maximum);\n";
|
||||
}
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
c += " } // Depth\n";
|
||||
}
|
||||
c += " };\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
const LinkingContext context{"maximum", "X", "Y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHDS("maximum", "X", "Y", "Z", "S");
|
||||
c += " args.dst_tensor.Write(maximum, " + dst_coord + ");\n";
|
||||
if (output_indices) {
|
||||
c += " " + indices_tensor.WriteWHDS("indexes", "X", "Y", "Z", "S");
|
||||
c += " args.dst_indices.Write(indexes, " + dst_coord + ");\n";
|
||||
}
|
||||
c += "}\n";
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Pooling::Pooling(const OperationDef& definition,
|
||||
const Pooling2DAttributes& attr)
|
||||
: GPUOperation(definition),
|
||||
stride_(attr.strides.w, attr.strides.h),
|
||||
padding_(-attr.padding.prepended.w, -attr.padding.prepended.h),
|
||||
kernel_size_(attr.kernel.w, attr.kernel.h),
|
||||
stride_(attr.strides.w, attr.strides.h, 0, 0),
|
||||
padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0),
|
||||
kernel_size_(attr.kernel.w, attr.kernel.h, 0, 0),
|
||||
type_(attr.type),
|
||||
output_indices_(attr.output_indices) {}
|
||||
|
||||
Pooling::Pooling(const OperationDef& definition,
|
||||
const Pooling3DAttributes& attr)
|
||||
: GPUOperation(definition),
|
||||
stride_(attr.strides.w, attr.strides.h, attr.strides.d, 0),
|
||||
padding_(-attr.padding.prepended.w, -attr.padding.prepended.h,
|
||||
-attr.padding.prepended.d, 0),
|
||||
kernel_size_(attr.kernel.w, attr.kernel.h, attr.kernel.d, 0),
|
||||
type_(attr.type),
|
||||
output_indices_(attr.output_indices) {}
|
||||
|
||||
@ -419,44 +360,56 @@ absl::Status Pooling::Compile(const CreationContext& creation_context) {
|
||||
switch (type_) {
|
||||
case PoolingType::AVERAGE:
|
||||
code = GetAveragePoolingKernelCode(definition_, stride_correction,
|
||||
*creation_context.device,
|
||||
linked_operations_);
|
||||
*creation_context.device, &args_);
|
||||
break;
|
||||
case PoolingType::MAX:
|
||||
code = GetMaxPoolingKernelCode(definition_, stride_correction,
|
||||
linked_operations_, output_indices_);
|
||||
output_indices_, &args_);
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
"You should create another kernel with this params");
|
||||
break;
|
||||
}
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
|
||||
{{"dst_tensor", element_wise_code}},
|
||||
&code));
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
}
|
||||
|
||||
absl::Status Pooling::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
if (output_indices_) {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[1]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
if (definition_.dst_tensors[0].HasAxis(Axis::WIDTH)) {
|
||||
RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x));
|
||||
RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch()));
|
||||
RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x));
|
||||
}
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
|
||||
RETURN_IF_ERROR(
|
||||
kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y)));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
||||
|
||||
return absl::OkStatus();
|
||||
if (definition_.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
|
||||
RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y));
|
||||
RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y));
|
||||
RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y));
|
||||
}
|
||||
if (definition_.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
RETURN_IF_ERROR(args_.SetInt("stride_z", stride_.z));
|
||||
RETURN_IF_ERROR(args_.SetInt("padding_z", padding_.z));
|
||||
RETURN_IF_ERROR(args_.SetInt("kernel_size_z", kernel_size_.z));
|
||||
}
|
||||
if (output_indices_) {
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_indices", dst_[1]));
|
||||
}
|
||||
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
|
||||
return args_.Bind(kernel_.kernel());
|
||||
}
|
||||
|
||||
int3 Pooling::GetGridSize() const {
|
||||
const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
|
||||
const int grid_y = dst_[0]->Height();
|
||||
const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
|
||||
const int grid_z = dst_[0]->Slices();
|
||||
return int3(grid_x, grid_y, grid_z);
|
||||
}
|
||||
@ -476,107 +429,9 @@ Pooling CreatePooling(const OperationDef& definition,
|
||||
return Pooling(definition, attr);
|
||||
}
|
||||
|
||||
Pooling3D::Pooling3D(const OperationDef& definition,
|
||||
const Pooling3DAttributes& attr)
|
||||
: GPUOperation(definition),
|
||||
stride_(attr.strides.w, attr.strides.h, attr.strides.d),
|
||||
padding_(-attr.padding.prepended.w, -attr.padding.prepended.h,
|
||||
-attr.padding.prepended.d),
|
||||
kernel_size_(attr.kernel.w, attr.kernel.h, attr.kernel.d),
|
||||
type_(attr.type),
|
||||
output_indices_(attr.output_indices) {}
|
||||
|
||||
Pooling3D::Pooling3D(Pooling3D&& kernel)
|
||||
: GPUOperation(std::move(kernel)),
|
||||
stride_(kernel.stride_),
|
||||
padding_(kernel.padding_),
|
||||
kernel_size_(kernel.kernel_size_),
|
||||
type_(kernel.type_),
|
||||
output_indices_(kernel.output_indices_),
|
||||
kernel_(std::move(kernel.kernel_)),
|
||||
work_group_size_(kernel.work_group_size_) {}
|
||||
|
||||
Pooling3D& Pooling3D::operator=(Pooling3D&& kernel) {
|
||||
if (this != &kernel) {
|
||||
std::swap(stride_, kernel.stride_);
|
||||
std::swap(padding_, kernel.padding_);
|
||||
std::swap(kernel_size_, kernel.kernel_size_);
|
||||
std::swap(type_, kernel.type_);
|
||||
std::swap(output_indices_, kernel.output_indices_);
|
||||
kernel_ = std::move(kernel.kernel_);
|
||||
std::swap(work_group_size_, kernel.work_group_size_);
|
||||
GPUOperation::operator=(std::move(kernel));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
absl::Status Pooling3D::Compile(const CreationContext& creation_context) {
|
||||
std::string code;
|
||||
const bool stride_correction =
|
||||
definition_.IsBatchSupported() && stride_.x != 1;
|
||||
switch (type_) {
|
||||
case PoolingType::AVERAGE:
|
||||
code = GetAveragePooling3DKernelCode(definition_, stride_correction,
|
||||
*creation_context.device,
|
||||
linked_operations_);
|
||||
break;
|
||||
case PoolingType::MAX:
|
||||
code = GetMaxPooling3DKernelCode(definition_, stride_correction,
|
||||
linked_operations_, output_indices_);
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
"You should create another kernel with this params");
|
||||
break;
|
||||
}
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
}
|
||||
|
||||
absl::Status Pooling3D::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
if (output_indices_) {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[1]->GetMemoryPtrForWriting()));
|
||||
}
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS()));
|
||||
if (definition_.IsBatchSupported()) {
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Batch()));
|
||||
}
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(
|
||||
int4(kernel_size_.x, kernel_size_.y, kernel_size_.z, 1)));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(
|
||||
int4(padding_.x * src_[0]->Batch(), padding_.y, padding_.z, 1)));
|
||||
RETURN_IF_ERROR(
|
||||
kernel_.SetBytesAuto(int4(stride_.x, stride_.y, stride_.z, 1)));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
int3 Pooling3D::GetGridSize() const {
|
||||
const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
|
||||
const int grid_y = dst_[0]->Height();
|
||||
const int grid_z = dst_[0]->Slices() * dst_[0]->Depth();
|
||||
return int3(grid_x, grid_y, grid_z);
|
||||
}
|
||||
|
||||
absl::Status Pooling3D::Tune(const TuningParameters& params) {
|
||||
RETURN_IF_ERROR(BindArguments());
|
||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||
}
|
||||
|
||||
absl::Status Pooling3D::AddToQueue(CLCommandQueue* queue) {
|
||||
RETURN_IF_ERROR(BindArguments());
|
||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||
}
|
||||
|
||||
Pooling3D CreatePooling3D(const OperationDef& definition,
|
||||
const Pooling3DAttributes& attr) {
|
||||
return Pooling3D(definition, attr);
|
||||
Pooling CreatePooling(const OperationDef& definition,
|
||||
const Pooling3DAttributes& attr) {
|
||||
return Pooling(definition, attr);
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
|
@ -30,6 +30,7 @@ namespace cl {
|
||||
class Pooling : public GPUOperation {
|
||||
public:
|
||||
Pooling(const OperationDef& definition, const Pooling2DAttributes& attr);
|
||||
Pooling(const OperationDef& definition, const Pooling3DAttributes& attr);
|
||||
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||
absl::Status Tune(const TuningParameters& params) override;
|
||||
|
||||
@ -45,9 +46,9 @@ class Pooling : public GPUOperation {
|
||||
absl::Status BindArguments();
|
||||
int3 GetGridSize() const;
|
||||
|
||||
int2 stride_;
|
||||
int2 padding_;
|
||||
int2 kernel_size_;
|
||||
int4 stride_;
|
||||
int4 padding_;
|
||||
int4 kernel_size_;
|
||||
|
||||
PoolingType type_;
|
||||
bool output_indices_;
|
||||
@ -59,37 +60,8 @@ class Pooling : public GPUOperation {
|
||||
Pooling CreatePooling(const OperationDef& definition,
|
||||
const Pooling2DAttributes& attr);
|
||||
|
||||
class Pooling3D : public GPUOperation {
|
||||
public:
|
||||
Pooling3D(const OperationDef& definition, const Pooling3DAttributes& attr);
|
||||
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||
absl::Status Tune(const TuningParameters& params) override;
|
||||
|
||||
absl::Status Compile(const CreationContext& creation_context) override;
|
||||
|
||||
// Move only
|
||||
Pooling3D(Pooling3D&& kernel);
|
||||
Pooling3D& operator=(Pooling3D&& kernel);
|
||||
Pooling3D(const Pooling3D&) = delete;
|
||||
Pooling3D& operator=(const Pooling3D&) = delete;
|
||||
|
||||
private:
|
||||
absl::Status BindArguments();
|
||||
int3 GetGridSize() const;
|
||||
|
||||
int3 stride_;
|
||||
int3 padding_;
|
||||
int3 kernel_size_;
|
||||
|
||||
PoolingType type_;
|
||||
bool output_indices_;
|
||||
|
||||
CLKernel kernel_;
|
||||
int3 work_group_size_ = int3(8, 4, 1);
|
||||
};
|
||||
|
||||
Pooling3D CreatePooling3D(const OperationDef& definition,
|
||||
const Pooling3DAttributes& attr);
|
||||
Pooling CreatePooling(const OperationDef& definition,
|
||||
const Pooling3DAttributes& attr);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -60,15 +60,6 @@ std::string GetImageModifier(AccessType access) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string TextureAddressModeToString(TextureAddressMode address_mode) {
|
||||
switch (address_mode) {
|
||||
case TextureAddressMode::DONT_CARE:
|
||||
return "smp_none";
|
||||
case TextureAddressMode::ZERO:
|
||||
return "smp_zero";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string GetCommonDefines(CalculationsPrecision precision) {
|
||||
|
@ -36,11 +36,6 @@ namespace cl {
|
||||
|
||||
std::string GetCommonDefines(CalculationsPrecision precision);
|
||||
|
||||
enum class TextureAddressMode {
|
||||
DONT_CARE, // translated to CLK_ADDRESS_NONE
|
||||
ZERO, // translated to CLK_ADDRESS_CLAMP
|
||||
};
|
||||
|
||||
struct WHSPoint {
|
||||
std::string w_name;
|
||||
std::string h_name;
|
||||
|
@ -45,6 +45,15 @@ std::string GetWriteImageFromDataType(DataType data_type) {
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string TextureAddressModeToString(TextureAddressMode address_mode) {
|
||||
switch (address_mode) {
|
||||
case TextureAddressMode::DONT_CARE:
|
||||
return "smp_none";
|
||||
case TextureAddressMode::ZERO:
|
||||
return "smp_zero";
|
||||
}
|
||||
}
|
||||
|
||||
std::string ToString(TensorStorageType type) {
|
||||
switch (type) {
|
||||
case TensorStorageType::UNKNOWN:
|
||||
@ -271,8 +280,10 @@ std::string TensorDescriptor::Read(DataType read_as_type,
|
||||
case TensorStorageType::TEXTURE_3D:
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
return absl::StrCat(read_as, "(", image_type, ", smp_none, ",
|
||||
global_address, ")");
|
||||
return absl::StrCat(
|
||||
read_as, "(", image_type,
|
||||
", " + TextureAddressModeToString(ModeFromState()) + ", ",
|
||||
global_address, ")");
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return absl::StrCat(read_as, "(image_buffer, ", global_address, ")");
|
||||
case TensorStorageType::UNKNOWN:
|
||||
@ -500,6 +511,14 @@ bool TensorDescriptor::HasAxis(Axis axis) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
void TensorDescriptor::SetTextureAddressMode(TextureAddressMode mode) {
|
||||
if (mode == TextureAddressMode::ZERO) {
|
||||
state_vars_["TextureMode"] = "ZERO";
|
||||
} else {
|
||||
state_vars_["TextureMode"] = "DONT_CARE";
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorDescriptor::ParseCoordsFromArgs(const std::vector<std::string>& args,
|
||||
int offset, std::string* xc,
|
||||
std::string* yc, std::string* zc,
|
||||
@ -549,6 +568,19 @@ bool TensorDescriptor::IsBatchedWidth() const {
|
||||
return it != state_vars_.end() && it->second == "true";
|
||||
}
|
||||
|
||||
TextureAddressMode TensorDescriptor::ModeFromState() const {
|
||||
auto it = state_vars_.find("TextureMode");
|
||||
if (it != state_vars_.end()) {
|
||||
if (it->second == "ZERO") {
|
||||
return TextureAddressMode::ZERO;
|
||||
} else {
|
||||
return TextureAddressMode::DONT_CARE;
|
||||
}
|
||||
} else {
|
||||
return TextureAddressMode::DONT_CARE;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -27,6 +27,13 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
enum class TextureAddressMode {
|
||||
DONT_CARE, // translated to CLK_ADDRESS_NONE
|
||||
ZERO, // translated to CLK_ADDRESS_CLAMP
|
||||
};
|
||||
|
||||
std::string TextureAddressModeToString(TextureAddressMode address_mode);
|
||||
|
||||
enum class TensorStorageType {
|
||||
UNKNOWN,
|
||||
BUFFER,
|
||||
@ -71,6 +78,7 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
||||
GPUResources GetGPUResources(AccessType access_type) const override;
|
||||
|
||||
bool HasAxis(Axis axis) const;
|
||||
void SetTextureAddressMode(TextureAddressMode mode);
|
||||
|
||||
absl::Status GetLinkingContextFromWriteSelector(
|
||||
const std::vector<std::string>& args, std::string* value_name,
|
||||
@ -106,6 +114,8 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
||||
|
||||
bool IsBatchedWidth() const;
|
||||
|
||||
TextureAddressMode ModeFromState() const;
|
||||
|
||||
absl::Status GetDataTypeFromTemplateArgs(const std::string& template_arg,
|
||||
DataType* result) const;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user