Added support of different layouts for src/dst. For example src - HWC, dst - BHWC, or vice versa.
PiperOrigin-RevId: 316998239 Change-Id: I89b07923020f185c356bb0b63926bbe81be55cb5
This commit is contained in:
parent
73cf8263c7
commit
b1933d67e5
@ -42,11 +42,12 @@ std::string GetStridedSliceCode(const OperationDef& op_def, bool alignedx4,
|
||||
args->AddInt("stride_z");
|
||||
args->AddInt("stride_b");
|
||||
|
||||
const std::string dst_batch = op_def.IsBatchSupported() ? "B" : "";
|
||||
const std::string batch_id =
|
||||
op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += "$0) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " int linear_id = get_global_id(0);\n";
|
||||
c += " int X = linear_id / args.dst_tensor.Batch();\n";
|
||||
c += " int B = linear_id % args.dst_tensor.Batch();\n";
|
||||
@ -62,11 +63,10 @@ std::string GetStridedSliceCode(const OperationDef& op_def, bool alignedx4,
|
||||
c += " } \n";
|
||||
c += " int s_x = X * args.stride_x + args.offset_x;\n";
|
||||
c += " int s_y = Y * args.stride_y + args.offset_y;\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int s_b = B * args.stride_b + args.offset_b;\n";
|
||||
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " int s_b = " + batch_id + " * args.stride_b + args.offset_b;\n";
|
||||
c += " args.src_tensor.SetBatchRef(s_b);\n";
|
||||
}
|
||||
const std::string src_batch = op_def.IsBatchSupported() ? "s_b" : "";
|
||||
if (alignedx4) {
|
||||
c += " int s_z = Z + args.offset_z;\n";
|
||||
c += " FLT4 result = args.src_tensor.Read(s_x, s_y, s_z);\n";
|
||||
|
@ -36,11 +36,12 @@ std::string GetTransposeCode(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
|
||||
const std::string batch_id = op_def.IsBatchSupported() ? "B" : "";
|
||||
const std::string batch_id =
|
||||
op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += "$0) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " int linear_id = get_global_id(0);\n";
|
||||
c += " int X = linear_id / args.dst_tensor.Batch();\n";
|
||||
c += " int B = linear_id % args.dst_tensor.Batch();\n";
|
||||
@ -65,7 +66,7 @@ std::string GetTransposeCode(
|
||||
remap[attr.perm.w] = 2;
|
||||
remap[attr.perm.c] = 3;
|
||||
if (attr.perm.c == 3) { // optimized reading when no channels permutation
|
||||
const std::string bhw[] = {"B", "Y", "X"};
|
||||
const std::string bhw[] = {batch_id, "Y", "X"};
|
||||
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " args.src_tensor.SetBatchRef(" + bhw[remap[0]] + ");\n";
|
||||
}
|
||||
@ -80,7 +81,7 @@ std::string GetTransposeCode(
|
||||
c += " for (int i = 0; i < 4; ++i) {\n";
|
||||
c += " int dst_channel = Z * 4 + i;\n";
|
||||
c += " if (dst_channel < args.dst_tensor.Channels()) {\n";
|
||||
const std::string bhwc[] = {"B", "Y", "X", "dst_channel"};
|
||||
const std::string bhwc[] = {batch_id, "Y", "X", "dst_channel"};
|
||||
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " args.src_tensor.SetBatchRef(" + bhwc[remap[0]] + ");\n";
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user