diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc index 5bc32597b31..7dcb767d666 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc @@ -28,65 +28,79 @@ namespace { std::string GetMaxUnoolingKernelCode( const OperationDef& op_def, const CLDevice& device, const std::vector& linked_operations) { - TensorCodeGenerator src("src_data", "src_size", op_def.src_tensors[0]); - TensorCodeGenerator src_ind("src_data_indices", "src_size", + TensorCodeGenerator src("src_data", + {"src_size.x", "src_size.y", "src_size.z"}, + op_def.src_tensors[0]); + TensorCodeGenerator src_ind("src_data_indices", + {"src_size.x", "src_size.y", "src_size.z"}, op_def.src_tensors[1]); - TensorCodeGenerator dst("dst_data", "dst_size", op_def.dst_tensors[0]); + TensorCodeGenerator dst("dst_data", + {"dst_size.x", "dst_size.y", "dst_size.z"}, + op_def.dst_tensors[0]); const auto address_mode = GetFastestZeroMode(device); - std::string code = GetCommonDefines(op_def.precision); + std::string c = GetCommonDefines(op_def.precision); - code += "__kernel void main_function(\n"; - code += src.GetDeclaration(AccessType::READ) + ",\n"; - code += src_ind.GetDeclaration(AccessType::READ); - code += GetArgsDeclaration(linked_operations); - code += dst.GetDeclaration(AccessType::WRITE) + ",\n"; - code += " int4 src_size, \n"; - code += " int4 dst_size, \n"; - code += " int2 kernel_size, \n"; - code += " int2 padding, \n"; - code += " int2 stride \n"; - code += ") {\n"; - code += " int X = get_global_id(0);\n"; - code += " int Y = get_global_id(1);\n"; - code += " int Z = get_global_id(2);\n"; - code += - " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.w) return; \n"; - code += " int src_x = (X + padding.x) / stride.x;\n"; - code += " int src_y = (Y + padding.y) / stride.y;\n"; - code += " " + src.GetAddress("src_adr", "src_x", "src_y", "Z") + "\n"; - if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) { - code += " bool outside = src_x < 0 || src_y < 0 ||"; - code += " src_x >= src_size.x || src_y >= src_size.y;\n"; - code += " FLT4 src = (FLT4)(0.0f);\n"; - code += " int4 ind = (int4)(0);\n"; - code += " if (!outside) {\n"; - code += " src = " + src.Read("src_adr", TextureAddressMode::DONT_CARE) + - ";\n"; - code += " ind = convert_int4(" + - src_ind.Read("src_adr", TextureAddressMode::DONT_CARE) + ");\n"; - code += " }\n"; + c += "__kernel void main_function(\n"; + c += src.GetDeclaration(AccessType::READ) + ",\n"; + c += src_ind.GetDeclaration(AccessType::READ); + c += GetArgsDeclaration(linked_operations); + c += dst.GetDeclaration(AccessType::WRITE) + ",\n"; + c += " int4 src_size, \n"; + c += " int4 dst_size, \n"; + c += " int2 kernel_size, \n"; + c += " int2 padding, \n"; + c += " int2 stride \n"; + c += ") {\n"; + c += " int X = get_global_id(0);\n"; + c += " int Y = get_global_id(1);\n"; + c += " int Z = get_global_id(2);\n"; + c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n"; + if (op_def.batch_support) { + c += " int B = get_global_id(0) % dst_size.w;\n"; + c += " int X0 = get_global_id(0) / dst_size.w;\n"; + c += " int src_x0 = (X0 + padding.x) / stride.x;\n"; + c += " int src_x = src_x0 * dst_size.w + B;\n"; } else { - code += " FLT4 src = " + src.Read("src_adr", address_mode) + ";\n"; - code += " int4 ind = convert_int4(" + - src_ind.Read("src_adr", address_mode) + ");\n"; + c += " int src_x = (X + padding.x) / stride.x;\n"; } - code += " int t_x = X - (src_x * stride.x - padding.x);\n"; - code += " int t_y = Y - (src_y * stride.y - padding.y);\n"; - code += " int t_index = t_y * kernel_size.x + t_x;\n"; - code += " FLT4 result;\n"; + c += " int src_y = (Y + padding.y) / stride.y;\n"; + c += " " + src.GetAddress("src_adr", "src_x", "src_y", "Z") + "\n"; + if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) { + c += " bool outside = src_x < 0 || src_y < 0 ||"; + c += " src_x >= src_size.x || src_y >= src_size.y;\n"; + c += " FLT4 src = (FLT4)(0.0f);\n"; + c += " int4 ind = (int4)(0);\n"; + c += " if (!outside) {\n"; + c += " src = " + src.Read("src_adr", TextureAddressMode::DONT_CARE) + + ";\n"; + c += " ind = convert_int4(" + + src_ind.Read("src_adr", TextureAddressMode::DONT_CARE) + ");\n"; + c += " }\n"; + } else { + c += " FLT4 src = " + src.Read("src_adr", address_mode) + ";\n"; + c += " int4 ind = convert_int4(" + src_ind.Read("src_adr", address_mode) + + ");\n"; + } + if (op_def.batch_support) { + c += " int t_x = X0 - (src_x0 * stride.x - padding.x);\n"; + } else { + c += " int t_x = X - (src_x * stride.x - padding.x);\n"; + } + c += " int t_y = Y - (src_y * stride.y - padding.y);\n"; + c += " int t_index = t_y * kernel_size.x + t_x;\n"; + c += " FLT4 result;\n"; const std::string channels[] = {".x", ".y", ".z", ".w"}; for (int i = 0; i < 4; ++i) { const auto& s = channels[i]; - code += " result" + s + "= t_index == ind" + s + "? src" + s + ": 0.0f;\n"; + c += " result" + s + "= t_index == ind" + s + "? src" + s + ": 0.0f;\n"; } - const LinkingContext context{"result", "X", "Y", "Z"}; - code += PostProcess(linked_operations, context); - code += " " + dst.Write3D("result", "X", "Y", "Z"); - code += "}\n"; + c += PostProcess(linked_operations, {"result", "X", "Y", "Z"}); + c += " " + dst.Write3D("result", "X", "Y", "Z"); + c += "}\n"; - return code; + return c; } } // namespace @@ -131,8 +145,8 @@ Status MaxUnpooling::BindArguments() { RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth())); + RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDB())); + RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_)); RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_)); RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_)); @@ -141,7 +155,7 @@ Status MaxUnpooling::BindArguments() { } int3 MaxUnpooling::GetGridSize() const { - const int grid_x = dst_[0]->Width(); + const int grid_x = dst_[0]->Width() * dst_[0]->Batch(); const int grid_y = dst_[0]->Height(); const int grid_z = dst_[0]->Depth(); return int3(grid_x, grid_y, grid_z);