Added support of different layouts for src/dst. For example src - HWC, dst - BHWC, or vice versa.

PiperOrigin-RevId: 316998239
Change-Id: I89b07923020f185c356bb0b63926bbe81be55cb5
This commit is contained in:
Raman Sarokin 2020-06-17 17:22:03 -07:00 committed by TensorFlower Gardener
parent 73cf8263c7
commit b1933d67e5
2 changed files with 10 additions and 9 deletions

View File

@ -42,11 +42,12 @@ std::string GetStridedSliceCode(const OperationDef& op_def, bool alignedx4,
args->AddInt("stride_z");
args->AddInt("stride_b");
const std::string dst_batch = op_def.IsBatchSupported() ? "B" : "";
const std::string batch_id =
op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += "$0) {\n";
if (op_def.IsBatchSupported()) {
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
@ -62,11 +63,10 @@ std::string GetStridedSliceCode(const OperationDef& op_def, bool alignedx4,
c += " } \n";
c += " int s_x = X * args.stride_x + args.offset_x;\n";
c += " int s_y = Y * args.stride_y + args.offset_y;\n";
if (op_def.IsBatchSupported()) {
c += " int s_b = B * args.stride_b + args.offset_b;\n";
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
c += " int s_b = " + batch_id + " * args.stride_b + args.offset_b;\n";
c += " args.src_tensor.SetBatchRef(s_b);\n";
}
const std::string src_batch = op_def.IsBatchSupported() ? "s_b" : "";
if (alignedx4) {
c += " int s_z = Z + args.offset_z;\n";
c += " FLT4 result = args.src_tensor.Read(s_x, s_y, s_z);\n";

View File

@ -36,11 +36,12 @@ std::string GetTransposeCode(
"dst_tensor", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
const std::string batch_id = op_def.IsBatchSupported() ? "B" : "";
const std::string batch_id =
op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += "$0) {\n";
if (op_def.IsBatchSupported()) {
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
@ -65,7 +66,7 @@ std::string GetTransposeCode(
remap[attr.perm.w] = 2;
remap[attr.perm.c] = 3;
if (attr.perm.c == 3) { // optimized reading when no channels permutation
const std::string bhw[] = {"B", "Y", "X"};
const std::string bhw[] = {batch_id, "Y", "X"};
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
c += " args.src_tensor.SetBatchRef(" + bhw[remap[0]] + ");\n";
}
@ -80,7 +81,7 @@ std::string GetTransposeCode(
c += " for (int i = 0; i < 4; ++i) {\n";
c += " int dst_channel = Z * 4 + i;\n";
c += " if (dst_channel < args.dst_tensor.Channels()) {\n";
const std::string bhwc[] = {"B", "Y", "X", "dst_channel"};
const std::string bhwc[] = {batch_id, "Y", "X", "dst_channel"};
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
c += " args.src_tensor.SetBatchRef(" + bhwc[remap[0]] + ");\n";
}