Merged Reshape code generation for HWC/BHWC layouts.

Added support of different layouts for src/dst. For example src - HWC, dst - BHWC, or vice versa.

PiperOrigin-RevId: 316999310
Change-Id: I20bd9a12afba8bdcb832565f09350440349041bd
This commit is contained in:
Raman Sarokin 2020-06-17 17:29:11 -07:00 committed by TensorFlower Gardener
parent 274a0f944e
commit 70e2387ecc
2 changed files with 47 additions and 101 deletions

View File

@ -25,56 +25,6 @@ namespace gpu {
namespace cl {
namespace {
std::string GetReshapeBatchedCode(const OperationDef& op_def, Arguments* args) {
args->AddObjectRef(
"src_tensor", AccessType::READ,
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
args->AddObjectRef(
"dst_tensor", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += "$0) {\n";
c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) { \n";
c += " return; \n";
c += " } \n";
c += " FLT temps[4];\n";
c += " temps[0] = (FLT)(0.0f);\n";
c += " temps[1] = (FLT)(0.0f);\n";
c += " temps[2] = (FLT)(0.0f);\n";
c += " temps[3] = (FLT)(0.0f);\n";
c += " int base = ((B * args.dst_tensor.Height() + Y) * "
"args.dst_tensor.Width() + X) * args.dst_tensor.Channels() + Z * 4;\n";
c += " for (int i = 0; i < 4; ++i) {\n";
c += " int dst_channel = Z * 4 + i;\n";
c += " if (dst_channel < args.dst_tensor.Channels()) {;\n";
c += " int p = base + i;\n";
c += " int src_c = p % args.src_tensor.Channels();\n";
c += " p = p / args.src_tensor.Channels();\n";
c += " int src_x = p % args.src_tensor.Width();\n";
c += " p = p / args.src_tensor.Width();\n";
c += " int src_y = p % args.src_tensor.Height();\n";
c += " int src_b = p / args.src_tensor.Height();\n";
c += " int src_z = src_c / 4;\n";
c += " int src_sub_ch = src_c % 4;\n";
c += " FLT4 t = args.src_tensor.Read(src_x, src_y, src_z, src_b);\n";
c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
c += " temps[i] = t_ar[src_sub_ch];\n";
c += " }\n";
c += " }\n";
c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n";
c += " args.dst_tensor.Write(result, X, Y, Z, B);\n";
c += "}\n";
return c;
}
std::string GetReshapeCode(const OperationDef& op_def, Arguments* args) {
args->AddObjectRef(
"src_tensor", AccessType::READ,
@ -86,7 +36,14 @@ std::string GetReshapeCode(const OperationDef& op_def, Arguments* args) {
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += "$0) {\n";
c += " int X = get_global_id(0);\n";
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " args.dst_tensor.SetBatchRef(B);\n";
} else {
c += " int X = get_global_id(0);\n";
}
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
@ -98,8 +55,13 @@ std::string GetReshapeCode(const OperationDef& op_def, Arguments* args) {
c += " temps[1] = (FLT)(0.0f);\n";
c += " temps[2] = (FLT)(0.0f);\n";
c += " temps[3] = (FLT)(0.0f);\n";
c += " int base = (Y * args.dst_tensor.Width() + X) * "
"args.dst_tensor.Channels() + Z * 4;\n";
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int base = B;\n";
} else {
c += " int base = 0;\n";
}
c += " base = ((base * args.dst_tensor.Height() + Y) * "
"args.dst_tensor.Width() + X) * args.dst_tensor.Channels() + Z * 4;\n";
c += " for (int i = 0; i < 4; ++i) {\n";
c += " int dst_channel = Z * 4 + i;\n";
c += " if (dst_channel < args.dst_tensor.Channels()) {;\n";
@ -107,7 +69,12 @@ std::string GetReshapeCode(const OperationDef& op_def, Arguments* args) {
c += " int src_c = p % args.src_tensor.Channels();\n";
c += " p = p / args.src_tensor.Channels();\n";
c += " int src_x = p % args.src_tensor.Width();\n";
c += " int src_y = p / args.src_tensor.Width();\n";
c += " p = p / args.src_tensor.Width();\n";
c += " int src_y = p % args.src_tensor.Height();\n";
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
c += " int src_b = p / args.src_tensor.Height();\n";
c += " args.src_tensor.SetBatchRef(src_b);\n";
}
c += " int src_z = src_c / 4;\n";
c += " int src_sub_ch = src_c % 4;\n";
c += " FLT4 t = args.src_tensor.Read(src_x, src_y, src_z);\n";
@ -137,9 +104,7 @@ Reshape& Reshape::operator=(Reshape&& operation) {
}
absl::Status Reshape::Compile(const CreationContext& creation_context) {
std::string code = definition_.IsBatchSupported()
? GetReshapeBatchedCode(definition_, &args_)
: GetReshapeCode(definition_, &args_);
std::string code = GetReshapeCode(definition_, &args_);
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));

View File

@ -25,40 +25,6 @@ namespace gpu {
namespace cl {
namespace {
std::string GetReshapeBatchedCode(const OperationDef& op_def, Arguments* args) {
args->AddObjectRef(
"src_tensor", AccessType::READ,
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
args->AddObjectRef(
"dst_tensor", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += "$0) {\n";
c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) { \n";
c += " return; \n";
c += " } \n";
c += " int dst_bhwc4 = ((B * args.dst_tensor.Height() + Y) * "
"args.dst_tensor.Width() + X) * args.dst_tensor.Slices() + Z;\n";
c += " int src_z = dst_bhwc4 % args.src_tensor.Slices();\n";
c += " dst_bhwc4 = dst_bhwc4 / args.src_tensor.Slices();\n";
c += " int src_x = dst_bhwc4 % args.src_tensor.Width();\n";
c += " dst_bhwc4 = dst_bhwc4 / args.src_tensor.Width();\n";
c += " int src_y = dst_bhwc4 % args.src_tensor.Height();\n";
c += " int src_b = dst_bhwc4 / args.src_tensor.Height();\n";
c += " FLT4 result = args.src_tensor.Read(src_x, src_y, src_z, src_b);\n";
c += " args.dst_tensor.Write(result, X, Y, Z, B);\n";
c += "}\n";
return c;
}
std::string GetReshapeCode(const OperationDef& op_def, Arguments* args) {
args->AddObjectRef(
"src_tensor", AccessType::READ,
@ -70,19 +36,36 @@ std::string GetReshapeCode(const OperationDef& op_def, Arguments* args) {
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += "$0) {\n";
c += " int X = get_global_id(0);\n";
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " args.dst_tensor.SetBatchRef(B);\n";
} else {
c += " int X = get_global_id(0);\n";
}
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) { \n";
c += " return; \n";
c += " } \n";
c += " int dst_hwc4 = (Y * args.dst_tensor.Width() + X) * "
"args.dst_tensor.Slices() + Z;\n";
c += " int src_z = dst_hwc4 % args.src_tensor.Slices();\n";
c += " dst_hwc4 = dst_hwc4 / args.src_tensor.Slices();\n";
c += " int src_x = dst_hwc4 % args.src_tensor.Width();\n";
c += " int src_y = dst_hwc4 / args.src_tensor.Width();\n";
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int dst_bhwc4 = B;\n";
} else {
c += " int dst_bhwc4 = 0;\n";
}
c += " dst_bhwc4 = ((dst_bhwc4 * args.dst_tensor.Height() + Y) * "
"args.dst_tensor.Width() + X) * args.dst_tensor.Slices() + Z;\n";
c += " int src_z = dst_bhwc4 % args.src_tensor.Slices();\n";
c += " dst_bhwc4 = dst_bhwc4 / args.src_tensor.Slices();\n";
c += " int src_x = dst_bhwc4 % args.src_tensor.Width();\n";
c += " dst_bhwc4 = dst_bhwc4 / args.src_tensor.Width();\n";
c += " int src_y = dst_bhwc4 % args.src_tensor.Height();\n";
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
c += " int src_b = dst_bhwc4 / args.src_tensor.Height();\n";
c += " args.src_tensor.SetBatchRef(src_b);\n";
}
c += " FLT4 result = args.src_tensor.Read(src_x, src_y, src_z);\n";
c += " args.dst_tensor.Write(result, X, Y, Z);\n";
c += "}\n";
@ -105,9 +88,7 @@ Reshapex4& Reshapex4::operator=(Reshapex4&& operation) {
}
absl::Status Reshapex4::Compile(const CreationContext& creation_context) {
std::string code = definition_.IsBatchSupported()
? GetReshapeBatchedCode(definition_, &args_)
: GetReshapeCode(definition_, &args_);
std::string code = GetReshapeCode(definition_, &args_);
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));