diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc index 2cf65f24447..d0c4e432f3a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc @@ -42,11 +42,12 @@ std::string GetStridedSliceCode(const OperationDef& op_def, bool alignedx4, args->AddInt("stride_z"); args->AddInt("stride_b"); - const std::string dst_batch = op_def.IsBatchSupported() ? "B" : ""; + const std::string batch_id = + op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0"; std::string c = GetCommonDefines(op_def.precision); c += "__kernel void main_function(\n"; c += "$0) {\n"; - if (op_def.IsBatchSupported()) { + if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { c += " int linear_id = get_global_id(0);\n"; c += " int X = linear_id / args.dst_tensor.Batch();\n"; c += " int B = linear_id % args.dst_tensor.Batch();\n"; @@ -62,11 +63,10 @@ std::string GetStridedSliceCode(const OperationDef& op_def, bool alignedx4, c += " } \n"; c += " int s_x = X * args.stride_x + args.offset_x;\n"; c += " int s_y = Y * args.stride_y + args.offset_y;\n"; - if (op_def.IsBatchSupported()) { - c += " int s_b = B * args.stride_b + args.offset_b;\n"; + if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) { + c += " int s_b = " + batch_id + " * args.stride_b + args.offset_b;\n"; c += " args.src_tensor.SetBatchRef(s_b);\n"; } - const std::string src_batch = op_def.IsBatchSupported() ? "s_b" : ""; if (alignedx4) { c += " int s_z = Z + args.offset_z;\n"; c += " FLT4 result = args.src_tensor.Read(s_x, s_y, s_z);\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc index e12c44566b7..cacfd52542d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc @@ -36,11 +36,12 @@ std::string GetTransposeCode( "dst_tensor", AccessType::WRITE, absl::make_unique(op_def.dst_tensors[0])); - const std::string batch_id = op_def.IsBatchSupported() ? "B" : ""; + const std::string batch_id = + op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0"; std::string c = GetCommonDefines(op_def.precision); c += "__kernel void main_function(\n"; c += "$0) {\n"; - if (op_def.IsBatchSupported()) { + if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { c += " int linear_id = get_global_id(0);\n"; c += " int X = linear_id / args.dst_tensor.Batch();\n"; c += " int B = linear_id % args.dst_tensor.Batch();\n"; @@ -65,7 +66,7 @@ std::string GetTransposeCode( remap[attr.perm.w] = 2; remap[attr.perm.c] = 3; if (attr.perm.c == 3) { // optimized reading when no channels permutation - const std::string bhw[] = {"B", "Y", "X"}; + const std::string bhw[] = {batch_id, "Y", "X"}; if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) { c += " args.src_tensor.SetBatchRef(" + bhw[remap[0]] + ");\n"; } @@ -80,7 +81,7 @@ std::string GetTransposeCode( c += " for (int i = 0; i < 4; ++i) {\n"; c += " int dst_channel = Z * 4 + i;\n"; c += " if (dst_channel < args.dst_tensor.Channels()) {\n"; - const std::string bhwc[] = {"B", "Y", "X", "dst_channel"}; + const std::string bhwc[] = {batch_id, "Y", "X", "dst_channel"}; if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) { c += " args.src_tensor.SetBatchRef(" + bhwc[remap[0]] + ");\n"; }