Share the call to tf_to_gpu_binary between ROCM and CUDA.

PiperOrigin-RevId: 329891456
Change-Id: Ie0a1de2ac138fc35916a3dba20e6a8f5ed33fae2
This commit is contained in:
Adrian Kuegel 2020-09-03 04:15:38 -07:00 committed by TensorFlower Gardener
parent 8f80e6e013
commit 73a37969bf

View File

@ -23,10 +23,10 @@ def _lookup_file(filegroup, path):
GpuBinaryInfo = provider( GpuBinaryInfo = provider(
"GPU binaries in either cubin format or hsaco format", "GPU binaries in either cubin format or hsaco format",
fields = ["cubins", "hsacos"], fields = ["gpu_bins"],
) )
def _gen_kernel_cubin_impl_cuda(ctx): def _gen_kernel_gpu_bin_impl(ctx):
name = ctx.attr.name name = ctx.attr.name
tile_sizes = ctx.attr.tile_size.replace("x", ",") tile_sizes = ctx.attr.tile_size.replace("x", ",")
cmd_args = [] cmd_args = []
@ -35,56 +35,29 @@ def _gen_kernel_cubin_impl_cuda(ctx):
if ctx.attr.unroll_factors: if ctx.attr.unroll_factors:
cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors) cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
cubins = [] gpu_bins = []
for arch in ctx.attr.gpu_archs: for arch in ctx.attr.gpu_archs:
# TODO(b/152737872): 'compute_' should generate both SASS and PTX. # TODO(b/152737872): 'compute_' should generate both SASS and PTX.
arch = arch.replace("compute_", "sm_") arch = arch.replace("compute_", "sm_")
filename = "%s.%s.cubin" % (name, arch) filename = "%s.%s.bin" % (name, arch)
cubin = ctx.actions.declare_file(filename) gpu_bin = ctx.actions.declare_file(filename)
ctx.actions.run( ctx.actions.run(
inputs = [ctx.file.mlir_op, ctx.file._tfso], inputs = [ctx.file.mlir_op, ctx.file._tfso],
outputs = [cubin], outputs = [gpu_bin],
executable = ctx.executable._tool, executable = ctx.executable._tool,
arguments = cmd_args + [ arguments = cmd_args + [
"--tile_sizes=%s" % tile_sizes, "--tile_sizes=%s" % tile_sizes,
"--arch=%s" % arch.split("_")[1], # For ROCM, remove the "gfx" prefix. For CUDA, remove the "sm_" prefix.
"--arch=%s" % arch[3:],
"--input=%s" % ctx.file.mlir_op.path, "--input=%s" % ctx.file.mlir_op.path,
"--output=%s" % cubin.path, "--output=%s" % gpu_bin.path,
], ],
mnemonic = "compile", mnemonic = "compile",
) )
cubins.append(cubin) gpu_bins.append(gpu_bin)
return [GpuBinaryInfo(cubins = cubins)] return [GpuBinaryInfo(gpu_bins = gpu_bins)]
def _gen_kernel_cubin_impl_rocm(ctx): _gen_kernel_gpu_bin_rule = rule(
name = ctx.attr.name
tile_sizes = ctx.attr.tile_size.replace("x", ",")
cmd_args = []
if ctx.attr.same_shape:
cmd_args.append("--same_shape=%s" % ctx.attr.same_shape)
if ctx.attr.unroll_factors:
cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
hsacos = []
for arch in ctx.attr.gpu_archs:
filename = "%s.%s.hsaco" % (name, arch)
hsaco = ctx.actions.declare_file(filename)
ctx.actions.run(
inputs = [ctx.file.mlir_op, ctx.file._tfso],
outputs = [hsaco],
executable = ctx.executable._tool,
arguments = cmd_args + [
"--tile_sizes=%s" % tile_sizes,
"--arch=%s" % arch[3:], # DDD in "gfxDDD"
"--input=%s" % ctx.file.mlir_op.path,
"--output=%s" % hsaco.path,
],
mnemonic = "compile",
)
hsacos.append(hsaco)
return [GpuBinaryInfo(hsacos = hsacos)]
_gen_kernel_cubin_rule = rule(
attrs = { attrs = {
"mlir_op": attr.label(mandatory = True, allow_single_file = True), "mlir_op": attr.label(mandatory = True, allow_single_file = True),
"tile_size": attr.string(mandatory = True), "tile_size": attr.string(mandatory = True),
@ -103,12 +76,12 @@ _gen_kernel_cubin_rule = rule(
), ),
}, },
output_to_genfiles = True, output_to_genfiles = True,
implementation = _gen_kernel_cubin_impl_rocm if rocm_is_configured() else _gen_kernel_cubin_impl_cuda, implementation = _gen_kernel_gpu_bin_impl,
) )
def _gen_kernel_image_hdr_impl_cuda(ctx): def _gen_kernel_image_hdr_impl_cuda(ctx):
images = [] images = []
for cubin in ctx.attr.input[GpuBinaryInfo].cubins: for cubin in ctx.attr.input[GpuBinaryInfo].gpu_bins:
arch = cubin.path.split(".")[-2] arch = cubin.path.split(".")[-2]
images.append("--image=profile=%s,file=%s" % (arch, cubin.path)) images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
@ -116,7 +89,7 @@ def _gen_kernel_image_hdr_impl_cuda(ctx):
fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name) fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name)
ctx.actions.run( ctx.actions.run(
outputs = [fatbin], outputs = [fatbin],
inputs = ctx.attr.input[GpuBinaryInfo].cubins, inputs = ctx.attr.input[GpuBinaryInfo].gpu_bins,
executable = _lookup_file(ctx.attr._gpu_root, "bin/fatbinary"), executable = _lookup_file(ctx.attr._gpu_root, "bin/fatbinary"),
arguments = [ arguments = [
"--64", "--64",
@ -146,7 +119,7 @@ def _gen_kernel_image_hdr_impl_rocm(ctx):
hsaco_files.append("/dev/null") hsaco_files.append("/dev/null")
hsaco_targets.append("host-x86_64-unknown-linux") hsaco_targets.append("host-x86_64-unknown-linux")
hsacos = ctx.attr.input[GpuBinaryInfo].hsacos hsacos = ctx.attr.input[GpuBinaryInfo].gpu_bins
for hsaco in hsacos: for hsaco in hsacos:
gfx_arch = hsaco.path.split(".")[-2] gfx_arch = hsaco.path.split(".")[-2]
hsaco_files.append(hsaco.path) hsaco_files.append(hsaco.path)
@ -196,23 +169,22 @@ _gen_kernel_image_hdr_rule = rule(
}, },
) )
def _gen_kernel_image_hdr(name, mlir_op, tile_size, same_shape = None, unroll_factors = None): def _gen_kernel_image_hdr(name, mlir_op, gpu_archs, tile_size, same_shape = None, unroll_factors = None):
"""Generates a C header with fatbin data from a Tensorflow op.""" """Generates a C header with fatbin data from a Tensorflow op."""
if cuda_gpu_architectures() or rocm_gpu_architectures(): _gen_kernel_gpu_bin_rule(
_gen_kernel_cubin_rule( name = name + "_cubin",
name = name + "_cubin", mlir_op = mlir_op,
mlir_op = mlir_op, tile_size = tile_size,
tile_size = tile_size, same_shape = same_shape,
same_shape = same_shape, unroll_factors = unroll_factors,
unroll_factors = unroll_factors, gpu_archs = gpu_archs,
gpu_archs = rocm_gpu_architectures() if rocm_is_configured() else cuda_gpu_architectures(), )
) _gen_kernel_image_hdr_rule(
_gen_kernel_image_hdr_rule( name = name,
name = name, input = ":" + name + "_cubin",
input = ":" + name + "_cubin", out = "%s.h" % name,
out = "%s.h" % name, symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""), )
)
def _gen_mlir_op_impl(ctx): def _gen_mlir_op_impl(ctx):
ctx.actions.run_shell( ctx.actions.run_shell(
@ -264,6 +236,7 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr
_gen_kernel_image_hdr( _gen_kernel_image_hdr(
name = "{name}_{type}_kernel".format(name = name, type = type), name = "{name}_{type}_kernel".format(name = name, type = type),
mlir_op = "{name}_{type}.mlir".format(name = name, type = type), mlir_op = "{name}_{type}.mlir".format(name = name, type = type),
gpu_archs = rocm_gpu_architectures() if rocm_is_configured() else cuda_gpu_architectures(),
tile_size = tile_size, tile_size = tile_size,
same_shape = same_shape, same_shape = same_shape,
unroll_factors = unroll_factors, unroll_factors = unroll_factors,