diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc b/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc index eab4c32a3f3..ca50b57db71 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/padding.cc @@ -28,54 +28,57 @@ namespace { std::string GetPaddingCode( const OperationDef& op_def, const std::vector& linked_operations) { - TensorCodeGenerator src_tensor("src_data", "src_size", op_def.src_tensors[0]); - TensorCodeGenerator dst_tensor("dst_data", "dst_size", op_def.dst_tensors[0]); + TensorCodeGenerator src_tensor("src_data", + {"src_size.x", "src_size.y", "src_size.z"}, + op_def.src_tensors[0]); + TensorCodeGenerator dst_tensor("dst_data", + {"dst_size.x", "dst_size.y", "dst_size.z"}, + op_def.dst_tensors[0]); - std::string code = GetCommonDefines(op_def.precision); + std::string c = GetCommonDefines(op_def.precision); const std::string channels[] = {".x", ".y", ".z", ".w"}; - code += "__kernel void main_function(\n"; - code += src_tensor.GetDeclaration(AccessType::READ); - code += GetArgsDeclaration(linked_operations); - code += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n"; - code += " int4 src_size, \n"; - code += " int4 dst_size, \n"; - code += " int4 prepended \n"; - code += ") {\n"; - code += " int X = get_global_id(0);\n"; - code += " int Y = get_global_id(1);\n"; - code += " int Z = get_global_id(2);\n"; - code += - " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.w) return; \n"; - code += " FLT4 result = (FLT4)(0.0);\n"; - code += " int s_x = X - prepended.x;\n"; - code += " int s_y = Y - prepended.y;\n"; - code += " bool inside_x = s_x >= 0 && s_x < src_size.x;\n"; - code += " bool inside_y = s_y >= 0 && s_y < src_size.y;\n"; - code += " if (inside_x && inside_y) {\n"; - code += " int start_channel = Z * 4;\n"; + c += "__kernel void main_function(\n"; + c += src_tensor.GetDeclaration(AccessType::READ); + c += GetArgsDeclaration(linked_operations); + c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n"; + c += " int4 src_size, \n"; + c += " int src_channels, \n"; + c += " int4 dst_size, \n"; + c += " int4 prepended \n"; + c += ") {\n"; + c += " int X = get_global_id(0);\n"; + c += " int Y = get_global_id(1);\n"; + c += " int Z = get_global_id(2);\n"; + c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n"; + c += " FLT4 result = (FLT4)(0.0);\n"; + c += " int s_x = X - prepended.x;\n"; + c += " int s_y = Y - prepended.y;\n"; + c += " bool inside_x = s_x >= 0 && s_x < src_size.x;\n"; + c += " bool inside_y = s_y >= 0 && s_y < src_size.y;\n"; + c += " if (inside_x && inside_y) {\n"; + c += " int start_channel = Z * 4;\n"; for (int i = 0; i < 4; ++i) { const auto& s = channels[i]; - code += " {\n"; - code += " int channel = start_channel + " + std::to_string(i) + ";\n"; - code += " int s_z = channel - prepended.z;\n"; - code += " if (s_z >= 0 && s_z < src_size.z) {\n"; - code += " FLT4 t = " + - src_tensor.Read3D("s_x", "s_y", "s_z / 4", - TextureAddressMode::DONT_CARE) + - ";\n"; - code += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n"; - code += " result" + s + " = t_ar[s_z % 4];\n"; - code += " }\n"; - code += " }\n"; + c += " {\n"; + c += " int channel = start_channel + " + std::to_string(i) + ";\n"; + c += " int s_z = channel - prepended.z;\n"; + c += " if (s_z >= 0 && s_z < src_channels) {\n"; + c += " FLT4 t = " + + src_tensor.Read3D("s_x", "s_y", "s_z / 4", + TextureAddressMode::DONT_CARE) + + ";\n"; + c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n"; + c += " result" + s + " = t_ar[s_z % 4];\n"; + c += " }\n"; + c += " }\n"; } - code += " }\n"; - const LinkingContext context{"result", "X", "Y", "Z"}; - code += PostProcess(linked_operations, context); - code += " " + dst_tensor.Write3D("result", "X", "Y", "Z"); - code += "}\n"; + c += " }\n"; + c += PostProcess(linked_operations, {"result", "X", "Y", "Z"}); + c += " " + dst_tensor.Write3D("result", "X", "Y", "Z"); + c += "}\n"; - return code; + return c; } } // namespace @@ -119,14 +122,17 @@ Status Padding::BindArguments() { RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(prepended_)); + RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDB())); + RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels())); + RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB())); + RETURN_IF_ERROR( + kernel_.SetBytesAuto(int4(prepended_.x * src_[0]->Batch(), prepended_.y, + prepended_.z, prepended_.w))); return OkStatus(); } int3 Padding::GetGridSize() const { - const int grid_x = dst_[0]->Width(); + const int grid_x = dst_[0]->Width() * dst_[0]->Batch(); const int grid_y = dst_[0]->Height(); const int grid_z = dst_[0]->Depth(); return int3(grid_x, grid_y, grid_z);