Share the call to tf_to_gpu_binary between ROCM and CUDA.
PiperOrigin-RevId: 329891456 Change-Id: Ie0a1de2ac138fc35916a3dba20e6a8f5ed33fae2
This commit is contained in:
parent
8f80e6e013
commit
73a37969bf
@ -23,10 +23,10 @@ def _lookup_file(filegroup, path):
|
||||
|
||||
GpuBinaryInfo = provider(
|
||||
"GPU binaries in either cubin format or hsaco format",
|
||||
fields = ["cubins", "hsacos"],
|
||||
fields = ["gpu_bins"],
|
||||
)
|
||||
|
||||
def _gen_kernel_cubin_impl_cuda(ctx):
|
||||
def _gen_kernel_gpu_bin_impl(ctx):
|
||||
name = ctx.attr.name
|
||||
tile_sizes = ctx.attr.tile_size.replace("x", ",")
|
||||
cmd_args = []
|
||||
@ -35,56 +35,29 @@ def _gen_kernel_cubin_impl_cuda(ctx):
|
||||
if ctx.attr.unroll_factors:
|
||||
cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
|
||||
|
||||
cubins = []
|
||||
gpu_bins = []
|
||||
for arch in ctx.attr.gpu_archs:
|
||||
# TODO(b/152737872): 'compute_' should generate both SASS and PTX.
|
||||
arch = arch.replace("compute_", "sm_")
|
||||
filename = "%s.%s.cubin" % (name, arch)
|
||||
cubin = ctx.actions.declare_file(filename)
|
||||
filename = "%s.%s.bin" % (name, arch)
|
||||
gpu_bin = ctx.actions.declare_file(filename)
|
||||
ctx.actions.run(
|
||||
inputs = [ctx.file.mlir_op, ctx.file._tfso],
|
||||
outputs = [cubin],
|
||||
outputs = [gpu_bin],
|
||||
executable = ctx.executable._tool,
|
||||
arguments = cmd_args + [
|
||||
"--tile_sizes=%s" % tile_sizes,
|
||||
"--arch=%s" % arch.split("_")[1],
|
||||
# For ROCM, remove the "gfx" prefix. For CUDA, remove the "sm_" prefix.
|
||||
"--arch=%s" % arch[3:],
|
||||
"--input=%s" % ctx.file.mlir_op.path,
|
||||
"--output=%s" % cubin.path,
|
||||
"--output=%s" % gpu_bin.path,
|
||||
],
|
||||
mnemonic = "compile",
|
||||
)
|
||||
cubins.append(cubin)
|
||||
return [GpuBinaryInfo(cubins = cubins)]
|
||||
gpu_bins.append(gpu_bin)
|
||||
return [GpuBinaryInfo(gpu_bins = gpu_bins)]
|
||||
|
||||
def _gen_kernel_cubin_impl_rocm(ctx):
|
||||
name = ctx.attr.name
|
||||
tile_sizes = ctx.attr.tile_size.replace("x", ",")
|
||||
cmd_args = []
|
||||
if ctx.attr.same_shape:
|
||||
cmd_args.append("--same_shape=%s" % ctx.attr.same_shape)
|
||||
if ctx.attr.unroll_factors:
|
||||
cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
|
||||
|
||||
hsacos = []
|
||||
for arch in ctx.attr.gpu_archs:
|
||||
filename = "%s.%s.hsaco" % (name, arch)
|
||||
hsaco = ctx.actions.declare_file(filename)
|
||||
ctx.actions.run(
|
||||
inputs = [ctx.file.mlir_op, ctx.file._tfso],
|
||||
outputs = [hsaco],
|
||||
executable = ctx.executable._tool,
|
||||
arguments = cmd_args + [
|
||||
"--tile_sizes=%s" % tile_sizes,
|
||||
"--arch=%s" % arch[3:], # DDD in "gfxDDD"
|
||||
"--input=%s" % ctx.file.mlir_op.path,
|
||||
"--output=%s" % hsaco.path,
|
||||
],
|
||||
mnemonic = "compile",
|
||||
)
|
||||
hsacos.append(hsaco)
|
||||
return [GpuBinaryInfo(hsacos = hsacos)]
|
||||
|
||||
_gen_kernel_cubin_rule = rule(
|
||||
_gen_kernel_gpu_bin_rule = rule(
|
||||
attrs = {
|
||||
"mlir_op": attr.label(mandatory = True, allow_single_file = True),
|
||||
"tile_size": attr.string(mandatory = True),
|
||||
@ -103,12 +76,12 @@ _gen_kernel_cubin_rule = rule(
|
||||
),
|
||||
},
|
||||
output_to_genfiles = True,
|
||||
implementation = _gen_kernel_cubin_impl_rocm if rocm_is_configured() else _gen_kernel_cubin_impl_cuda,
|
||||
implementation = _gen_kernel_gpu_bin_impl,
|
||||
)
|
||||
|
||||
def _gen_kernel_image_hdr_impl_cuda(ctx):
|
||||
images = []
|
||||
for cubin in ctx.attr.input[GpuBinaryInfo].cubins:
|
||||
for cubin in ctx.attr.input[GpuBinaryInfo].gpu_bins:
|
||||
arch = cubin.path.split(".")[-2]
|
||||
images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
|
||||
|
||||
@ -116,7 +89,7 @@ def _gen_kernel_image_hdr_impl_cuda(ctx):
|
||||
fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name)
|
||||
ctx.actions.run(
|
||||
outputs = [fatbin],
|
||||
inputs = ctx.attr.input[GpuBinaryInfo].cubins,
|
||||
inputs = ctx.attr.input[GpuBinaryInfo].gpu_bins,
|
||||
executable = _lookup_file(ctx.attr._gpu_root, "bin/fatbinary"),
|
||||
arguments = [
|
||||
"--64",
|
||||
@ -146,7 +119,7 @@ def _gen_kernel_image_hdr_impl_rocm(ctx):
|
||||
hsaco_files.append("/dev/null")
|
||||
hsaco_targets.append("host-x86_64-unknown-linux")
|
||||
|
||||
hsacos = ctx.attr.input[GpuBinaryInfo].hsacos
|
||||
hsacos = ctx.attr.input[GpuBinaryInfo].gpu_bins
|
||||
for hsaco in hsacos:
|
||||
gfx_arch = hsaco.path.split(".")[-2]
|
||||
hsaco_files.append(hsaco.path)
|
||||
@ -196,23 +169,22 @@ _gen_kernel_image_hdr_rule = rule(
|
||||
},
|
||||
)
|
||||
|
||||
def _gen_kernel_image_hdr(name, mlir_op, tile_size, same_shape = None, unroll_factors = None):
|
||||
def _gen_kernel_image_hdr(name, mlir_op, gpu_archs, tile_size, same_shape = None, unroll_factors = None):
|
||||
"""Generates a C header with fatbin data from a Tensorflow op."""
|
||||
if cuda_gpu_architectures() or rocm_gpu_architectures():
|
||||
_gen_kernel_cubin_rule(
|
||||
name = name + "_cubin",
|
||||
mlir_op = mlir_op,
|
||||
tile_size = tile_size,
|
||||
same_shape = same_shape,
|
||||
unroll_factors = unroll_factors,
|
||||
gpu_archs = rocm_gpu_architectures() if rocm_is_configured() else cuda_gpu_architectures(),
|
||||
)
|
||||
_gen_kernel_image_hdr_rule(
|
||||
name = name,
|
||||
input = ":" + name + "_cubin",
|
||||
out = "%s.h" % name,
|
||||
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
|
||||
)
|
||||
_gen_kernel_gpu_bin_rule(
|
||||
name = name + "_cubin",
|
||||
mlir_op = mlir_op,
|
||||
tile_size = tile_size,
|
||||
same_shape = same_shape,
|
||||
unroll_factors = unroll_factors,
|
||||
gpu_archs = gpu_archs,
|
||||
)
|
||||
_gen_kernel_image_hdr_rule(
|
||||
name = name,
|
||||
input = ":" + name + "_cubin",
|
||||
out = "%s.h" % name,
|
||||
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
|
||||
)
|
||||
|
||||
def _gen_mlir_op_impl(ctx):
|
||||
ctx.actions.run_shell(
|
||||
@ -264,6 +236,7 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr
|
||||
_gen_kernel_image_hdr(
|
||||
name = "{name}_{type}_kernel".format(name = name, type = type),
|
||||
mlir_op = "{name}_{type}.mlir".format(name = name, type = type),
|
||||
gpu_archs = rocm_gpu_architectures() if rocm_is_configured() else cuda_gpu_architectures(),
|
||||
tile_size = tile_size,
|
||||
same_shape = same_shape,
|
||||
unroll_factors = unroll_factors,
|
||||
|
Loading…
x
Reference in New Issue
Block a user