From 0a541ad1cc89f1eead6a47dc1676bd86dc810937 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 18 Jun 2020 05:07:41 -0700 Subject: [PATCH] Remove intermediate relocatable code stored in __nv_relfatbin sections, if objcopy is at least version 2.26 (which added support for --update-sections). The intermediate code is a result of separate compilation and linking, and removing it reduces TF's GPU wheel size. PiperOrigin-RevId: 317081343 Change-Id: I603477b4499344aeec653765be78de11f392eac6 --- third_party/nccl/build_defs.bzl.tpl | 103 ++++++++++++++++++++++++---- 1 file changed, 88 insertions(+), 15 deletions(-) diff --git a/third_party/nccl/build_defs.bzl.tpl b/third_party/nccl/build_defs.bzl.tpl index 9268af7c890..b520f71d0f1 100644 --- a/third_party/nccl/build_defs.bzl.tpl +++ b/third_party/nccl/build_defs.bzl.tpl @@ -169,35 +169,94 @@ _device_link = rule( ) """Links device code and generates source code for kernel registration.""" +def _prune_relocatable_code_impl(ctx): + """Clears __nv_relfatbin section containing relocatable device code.""" + empty_file = ctx.actions.declare_file(ctx.attr.name + "__nv_relfatbin") + ctx.actions.write(empty_file, "") + + # Parse 'objcopy --version' and update section if it's at least v2.26. + # Otherwise, simply copy the file without changing it. + # TODO(csigg): version parsing is brittle, can we do better? + command = r""" + objcopy=$1 \ + section=$2 \ + input=$3 \ + output=$4 \ + args="" \ + pattern='([0-9])\.([0-9]+)'; \ + if [[ $($objcopy --version) =~ $pattern ]] && { \ + [ ${BASH_REMATCH[1]} -gt 2 ] || \ + [ ${BASH_REMATCH[2]} -ge 26 ]; }; then \ + args="--update-section __nv_relfatbin=$section"; \ + fi; \ + $objcopy $args $input $output + """ + cc_toolchain = find_cpp_toolchain(ctx) + outputs = [] + for src in ctx.files.srcs: + out = ctx.actions.declare_file("pruned_" + src.basename, sibling = src) + ctx.actions.run_shell( + inputs = [empty_file] + ctx.files.srcs, # + ctx.files._crosstool, + outputs = [out], + arguments = [ + cc_toolchain.objcopy_executable, + empty_file.path, + src.path, + out.path, + ], + command = command, + ) + outputs.append(out) + return DefaultInfo(files = depset(outputs)) + +_prune_relocatable_code = rule( + implementation = _prune_relocatable_code_impl, + attrs = { + "srcs": attr.label_list(mandatory = True, allow_files = True), + "_cc_toolchain": attr.label( + default = "@bazel_tools//tools/cpp:current_cc_toolchain", + ), + # "_crosstool": attr.label_list( + # cfg = "host", + # default = ["@bazel_tools//tools/cpp:crosstool"] + # ), + }, +) + def _merge_archive_impl(ctx): # Generate an mri script to the merge archives in srcs and pass it to 'ar'. # See https://stackoverflow.com/a/23621751. files = _pic_only(ctx.files.srcs) mri_script = "create " + ctx.outputs.out.path for f in files: - mri_script += "\\naddlib " + f.path - mri_script += "\\nsave\\nend" + mri_script += r"\naddlib " + f.path + mri_script += r"\nsave\nend" cc_toolchain = find_cpp_toolchain(ctx) ctx.actions.run_shell( inputs = ctx.files.srcs, # + ctx.files._crosstool, outputs = [ctx.outputs.out], - command = "printf \"%s\" | %s -M" % (mri_script, cc_toolchain.ar_executable), + command = "echo -e \"%s\" | %s -M" % (mri_script, cc_toolchain.ar_executable), ) _merge_archive = rule( implementation = _merge_archive_impl, attrs = { "srcs": attr.label_list(mandatory = True, allow_files = True), - "_cc_toolchain": attr.label(default = "@bazel_tools//tools/cpp:current_cc_toolchain"), - # "_crosstool": attr.label_list(cfg = "host", default = ["@bazel_tools//tools/cpp:crosstool"]), + "_cc_toolchain": attr.label( + default = "@bazel_tools//tools/cpp:current_cc_toolchain", + ), + # "_crosstool": attr.label_list( + # cfg = "host", + # default = ["@bazel_tools//tools/cpp:crosstool"] + # ), }, outputs = {"out": "lib%{name}.a"}, ) """Merges srcs into a single archive.""" def cuda_rdc_library(name, hdrs = None, copts = None, linkstatic = True, **kwargs): - """Produces a cuda_library using separate compilation and linking. + r"""Produces a cuda_library using separate compilation and linking. CUDA separate compilation and linking allows device function calls across translation units. This is different from the normal whole program @@ -239,17 +298,24 @@ def cuda_rdc_library(name, hdrs = None, copts = None, linkstatic = True, **kwarg The steps marked with '*' are implemented in the _device_link rule. + The intermediate relocatable device code in xy.a is no longer needed at + this point and the corresponding section is replaced with an empty one using + objcopy. We do not remove the section completely because it is referenced by + relocations, and removing those as well breaks fatbin registration. + The object files in both xy.a and dlink.a reference symbols defined in the other archive. The separate archives are a side effect of using two cc_library targets to implement a single compilation trajectory. We could fix this once bazel supports C++ sandwich. For now, we just merge the two archives to avoid unresolved symbols: - xy.a dlink.a - \ / merge archive - xy_dlink.a - | cc_library (or alternatively, cc_import) - final target + xy.a + | objcopy --update-section __nv_relfatbin='' + dlink.a xy_pruned.a + \ / merge archive + xy_merged.a + | cc_library (or alternatively, cc_import) + final target Another complication is that cc_library produces (depending on the configuration) both PIC and non-PIC archives, but the distinction @@ -313,19 +379,26 @@ def cuda_rdc_library(name, hdrs = None, copts = None, linkstatic = True, **kwarg linkstatic = linkstatic, ) + # Remove intermediate relocatable device code. + pruned = name + "_pruned" + _prune_relocatable_code( + name = pruned, + srcs = [lib], + ) + # Repackage the two libs into a single archive. This is required because # both libs reference symbols defined in the other one. For details, see # https://eli.thegreenplace.net/2013/07/09/library-order-in-static-linking - archive = name + "_a" + merged = name + "_merged" _merge_archive( - name = archive, - srcs = [lib, dlink], + name = merged, + srcs = [pruned, dlink], ) # Create cc target from archive. native.cc_library( name = name, - srcs = [archive], + srcs = [merged], hdrs = hdrs, linkstatic = linkstatic, )