Remove intermediate relocatable code stored in __nv_relfatbin sections, if objcopy is at least version 2.26 (which added support for --update-sections).

The intermediate code is a result of separate compilation and linking, and removing it reduces TF's GPU wheel size.

PiperOrigin-RevId: 317081343
Change-Id: I603477b4499344aeec653765be78de11f392eac6
This commit is contained in:
Christian Sigg 2020-06-18 05:07:41 -07:00 committed by TensorFlower Gardener
parent 16e6f9e792
commit 0a541ad1cc
1 changed files with 88 additions and 15 deletions

View File

@ -169,35 +169,94 @@ _device_link = rule(
) )
"""Links device code and generates source code for kernel registration.""" """Links device code and generates source code for kernel registration."""
def _prune_relocatable_code_impl(ctx):
"""Clears __nv_relfatbin section containing relocatable device code."""
empty_file = ctx.actions.declare_file(ctx.attr.name + "__nv_relfatbin")
ctx.actions.write(empty_file, "")
# Parse 'objcopy --version' and update section if it's at least v2.26.
# Otherwise, simply copy the file without changing it.
# TODO(csigg): version parsing is brittle, can we do better?
command = r"""
objcopy=$1 \
section=$2 \
input=$3 \
output=$4 \
args="" \
pattern='([0-9])\.([0-9]+)'; \
if [[ $($objcopy --version) =~ $pattern ]] && { \
[ ${BASH_REMATCH[1]} -gt 2 ] || \
[ ${BASH_REMATCH[2]} -ge 26 ]; }; then \
args="--update-section __nv_relfatbin=$section"; \
fi; \
$objcopy $args $input $output
"""
cc_toolchain = find_cpp_toolchain(ctx)
outputs = []
for src in ctx.files.srcs:
out = ctx.actions.declare_file("pruned_" + src.basename, sibling = src)
ctx.actions.run_shell(
inputs = [empty_file] + ctx.files.srcs, # + ctx.files._crosstool,
outputs = [out],
arguments = [
cc_toolchain.objcopy_executable,
empty_file.path,
src.path,
out.path,
],
command = command,
)
outputs.append(out)
return DefaultInfo(files = depset(outputs))
_prune_relocatable_code = rule(
implementation = _prune_relocatable_code_impl,
attrs = {
"srcs": attr.label_list(mandatory = True, allow_files = True),
"_cc_toolchain": attr.label(
default = "@bazel_tools//tools/cpp:current_cc_toolchain",
),
# "_crosstool": attr.label_list(
# cfg = "host",
# default = ["@bazel_tools//tools/cpp:crosstool"]
# ),
},
)
def _merge_archive_impl(ctx): def _merge_archive_impl(ctx):
# Generate an mri script to the merge archives in srcs and pass it to 'ar'. # Generate an mri script to the merge archives in srcs and pass it to 'ar'.
# See https://stackoverflow.com/a/23621751. # See https://stackoverflow.com/a/23621751.
files = _pic_only(ctx.files.srcs) files = _pic_only(ctx.files.srcs)
mri_script = "create " + ctx.outputs.out.path mri_script = "create " + ctx.outputs.out.path
for f in files: for f in files:
mri_script += "\\naddlib " + f.path mri_script += r"\naddlib " + f.path
mri_script += "\\nsave\\nend" mri_script += r"\nsave\nend"
cc_toolchain = find_cpp_toolchain(ctx) cc_toolchain = find_cpp_toolchain(ctx)
ctx.actions.run_shell( ctx.actions.run_shell(
inputs = ctx.files.srcs, # + ctx.files._crosstool, inputs = ctx.files.srcs, # + ctx.files._crosstool,
outputs = [ctx.outputs.out], outputs = [ctx.outputs.out],
command = "printf \"%s\" | %s -M" % (mri_script, cc_toolchain.ar_executable), command = "echo -e \"%s\" | %s -M" % (mri_script, cc_toolchain.ar_executable),
) )
_merge_archive = rule( _merge_archive = rule(
implementation = _merge_archive_impl, implementation = _merge_archive_impl,
attrs = { attrs = {
"srcs": attr.label_list(mandatory = True, allow_files = True), "srcs": attr.label_list(mandatory = True, allow_files = True),
"_cc_toolchain": attr.label(default = "@bazel_tools//tools/cpp:current_cc_toolchain"), "_cc_toolchain": attr.label(
# "_crosstool": attr.label_list(cfg = "host", default = ["@bazel_tools//tools/cpp:crosstool"]), default = "@bazel_tools//tools/cpp:current_cc_toolchain",
),
# "_crosstool": attr.label_list(
# cfg = "host",
# default = ["@bazel_tools//tools/cpp:crosstool"]
# ),
}, },
outputs = {"out": "lib%{name}.a"}, outputs = {"out": "lib%{name}.a"},
) )
"""Merges srcs into a single archive.""" """Merges srcs into a single archive."""
def cuda_rdc_library(name, hdrs = None, copts = None, linkstatic = True, **kwargs): def cuda_rdc_library(name, hdrs = None, copts = None, linkstatic = True, **kwargs):
"""Produces a cuda_library using separate compilation and linking. r"""Produces a cuda_library using separate compilation and linking.
CUDA separate compilation and linking allows device function calls across CUDA separate compilation and linking allows device function calls across
translation units. This is different from the normal whole program translation units. This is different from the normal whole program
@ -239,17 +298,24 @@ def cuda_rdc_library(name, hdrs = None, copts = None, linkstatic = True, **kwarg
The steps marked with '*' are implemented in the _device_link rule. The steps marked with '*' are implemented in the _device_link rule.
The intermediate relocatable device code in xy.a is no longer needed at
this point and the corresponding section is replaced with an empty one using
objcopy. We do not remove the section completely because it is referenced by
relocations, and removing those as well breaks fatbin registration.
The object files in both xy.a and dlink.a reference symbols defined in the The object files in both xy.a and dlink.a reference symbols defined in the
other archive. The separate archives are a side effect of using two other archive. The separate archives are a side effect of using two
cc_library targets to implement a single compilation trajectory. We could cc_library targets to implement a single compilation trajectory. We could
fix this once bazel supports C++ sandwich. For now, we just merge the two fix this once bazel supports C++ sandwich. For now, we just merge the two
archives to avoid unresolved symbols: archives to avoid unresolved symbols:
xy.a dlink.a xy.a
\ / merge archive | objcopy --update-section __nv_relfatbin=''
xy_dlink.a dlink.a xy_pruned.a
| cc_library (or alternatively, cc_import) \ / merge archive
final target xy_merged.a
| cc_library (or alternatively, cc_import)
final target
Another complication is that cc_library produces (depending on the Another complication is that cc_library produces (depending on the
configuration) both PIC and non-PIC archives, but the distinction configuration) both PIC and non-PIC archives, but the distinction
@ -313,19 +379,26 @@ def cuda_rdc_library(name, hdrs = None, copts = None, linkstatic = True, **kwarg
linkstatic = linkstatic, linkstatic = linkstatic,
) )
# Remove intermediate relocatable device code.
pruned = name + "_pruned"
_prune_relocatable_code(
name = pruned,
srcs = [lib],
)
# Repackage the two libs into a single archive. This is required because # Repackage the two libs into a single archive. This is required because
# both libs reference symbols defined in the other one. For details, see # both libs reference symbols defined in the other one. For details, see
# https://eli.thegreenplace.net/2013/07/09/library-order-in-static-linking # https://eli.thegreenplace.net/2013/07/09/library-order-in-static-linking
archive = name + "_a" merged = name + "_merged"
_merge_archive( _merge_archive(
name = archive, name = merged,
srcs = [lib, dlink], srcs = [pruned, dlink],
) )
# Create cc target from archive. # Create cc target from archive.
native.cc_library( native.cc_library(
name = name, name = name,
srcs = [archive], srcs = [merged],
hdrs = hdrs, hdrs = hdrs,
linkstatic = linkstatic, linkstatic = linkstatic,
) )