diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc index 34227f6b887..439b7d0fc15 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.cc @@ -27,28 +27,33 @@ namespace gpu { namespace cl { namespace { -std::string GetSpaceToDepthCode( - const OperationDef& op_def, - const std::vector& linked_operations) { - TensorCodeGenerator src_tensor( - "src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"}, - op_def.src_tensors[0]); - TensorCodeGenerator dst_tensor( - "dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"}, - op_def.dst_tensors[0]); +std::string GetSpaceToDepthCode(const OperationDef& op_def, Arguments* args) { + args->AddObjectRef( + "src_tensor", AccessType::READ, + absl::make_unique(op_def.src_tensors[0])); + args->AddObjectRef( + "dst_tensor", AccessType::WRITE, + absl::make_unique(op_def.dst_tensors[0])); + args->AddInt("block_size"); + std::string c = GetCommonDefines(op_def.precision); c += "__kernel void main_function(\n"; - c += src_tensor.GetDeclaration(AccessType::READ); - c += GetArgsDeclaration(linked_operations); - c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n"; - c += " int4 src_size,\n"; - c += " int4 dst_size,\n"; - c += " int src_channels,\n"; - c += " int block_size) {\n"; - c += " int X = get_global_id(0);\n"; + c += "$0) {\n"; + if (op_def.IsBatchSupported()) { + c += " int linear_id = get_global_id(0);\n"; + c += " int X = linear_id / args.dst_tensor.Batch();\n"; + c += " int B = linear_id % args.dst_tensor.Batch();\n"; + c += " args.dst_tensor.SetBatchRef(B);\n"; + c += " args.src_tensor.SetBatchRef(B);\n"; + } else { + c += " int X = get_global_id(0);\n"; + } c += " int Y = get_global_id(1);\n"; c += " int Z = get_global_id(2);\n"; - c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n"; + c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || " + "Z >= args.dst_tensor.Slices()) { \n"; + c += " return; \n"; + c += " } \n"; c += " FLT tmp[4];\n"; c += " tmp[0] = (FLT)(0.0f);\n"; c += " tmp[1] = (FLT)(0.0f);\n"; @@ -56,19 +61,17 @@ std::string GetSpaceToDepthCode( c += " tmp[3] = (FLT)(0.0f);\n"; c += " for (int i = 0; i < 4; ++i) {\n"; c += " int dst_c = 4 * Z + i;\n"; - c += " int block_id = dst_c / src_channels;\n"; - c += " int src_x = X * block_size + block_id % block_size;\n"; - c += " int src_y = Y * block_size + block_id / block_size;\n"; - c += " int src_c = dst_c % src_channels;\n"; + c += " int block_id = dst_c / args.src_tensor.Channels();\n"; + c += " int src_x = X * args.block_size + block_id % args.block_size;\n"; + c += " int src_y = Y * args.block_size + block_id / args.block_size;\n"; + c += " int src_c = dst_c % args.src_tensor.Channels();\n"; c += " int src_z = src_c / 4;\n"; - c += " FLT4 t = " + src_tensor.ReadWHS("src_x", "src_y", "src_z") + ";\n"; + c += " FLT4 t = args.src_tensor.Read(src_x, src_y, src_z);\n"; c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n"; c += " tmp[i] = t_ar[src_c % 4];\n"; c += " }\n"; c += " FLT4 result = (FLT4)(tmp[0], tmp[1], tmp[2], tmp[3]);\n"; - const LinkingContext context{"result", "X", "Y", "Z"}; - c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHS("result", "X", "Y", "Z"); + c += " args.dst_tensor.Write(result, X, Y, Z);\n"; c += "}\n"; return c; } @@ -92,21 +95,24 @@ SpaceToDepth& SpaceToDepth::operator=(SpaceToDepth&& operation) { } absl::Status SpaceToDepth::Compile(const CreationContext& creation_context) { - const auto code = GetSpaceToDepthCode(definition_, linked_operations_); + std::string code = GetSpaceToDepthCode(definition_, &args_); + std::string element_wise_code; + RETURN_IF_ERROR( + MergeOperations(linked_operations_, &args_, &element_wise_code)); + RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(), + {{"dst_tensor", element_wise_code}}, + &code)); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } absl::Status SpaceToDepth::BindArguments() { - kernel_.ResetBindingCounter(); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); - RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels())); - return kernel_.SetBytesAuto(attr_.block_size); + RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0])); + RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0])); + RETURN_IF_ERROR(args_.SetInt("block_size", attr_.block_size)); + RETURN_IF_ERROR(SetArguments(linked_operations_, &args_)); + return args_.Bind(kernel_.kernel()); } int3 SpaceToDepth::GetGridSize() const {