Improve the genrules for cubin kernel headers.
Add an additional build macro to build a library of headers. Also change the output of bin2c to use type char instead of type int. PiperOrigin-RevId: 314321754 Change-Id: I81ee3c7c962e807d28bb0f580eea8032f2a390ee
This commit is contained in:
parent
5c3cdff00b
commit
1b6ad07ad3
@ -1,5 +1,10 @@
|
|||||||
# Generates headers containing cubin for CUDA kernels.
|
# Generates headers containing cubin for CUDA kernels.
|
||||||
load("//tensorflow/core/kernels/cubin_headers:build_defs.bzl", "gen_kernel_image_hdr")
|
load("//tensorflow/core/kernels/cubin_headers:build_defs.bzl", "gen_kernel_library")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//tensorflow/core/kernels:__subpackages__"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
bias_add_kernel = """
|
bias_add_kernel = """
|
||||||
func @bias_add(%arg0: tensor<?x?xf99>,
|
func @bias_add(%arg0: tensor<?x?xf99>,
|
||||||
@ -10,19 +15,17 @@ func @bias_add(%arg0: tensor<?x?xf99>,
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
[
|
gen_kernel_library(
|
||||||
gen_kernel_image_hdr(
|
name = "bias_add",
|
||||||
name = "bias_add_{type}_kernel".format(type = type),
|
op = bias_add_kernel,
|
||||||
op = bias_add_kernel.replace("f99", type).replace("DT_TYPE", dtype),
|
same_shape = "0,2",
|
||||||
same_shape = "0,2",
|
tile_size = "16x16",
|
||||||
tile_size = "16x16",
|
types = [
|
||||||
)
|
"f16",
|
||||||
for (type, dtype) in [
|
"f32",
|
||||||
("f16", "DT_HALF"),
|
"f64",
|
||||||
("f32", "DT_FLOAT"),
|
],
|
||||||
("f64", "DT_DOUBLE"),
|
)
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
relu_kernel = """
|
relu_kernel = """
|
||||||
func @relu(%arg0: tensor<?xf99>) -> tensor<?xf99> {
|
func @relu(%arg0: tensor<?xf99>) -> tensor<?xf99> {
|
||||||
@ -32,19 +35,17 @@ func @relu(%arg0: tensor<?xf99>) -> tensor<?xf99> {
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
[
|
gen_kernel_library(
|
||||||
gen_kernel_image_hdr(
|
name = "relu",
|
||||||
name = "relu_{type}_kernel".format(type = type),
|
op = relu_kernel,
|
||||||
op = relu_kernel.replace("f99", type).replace("DT_TYPE", dtype),
|
same_shape = "0,1",
|
||||||
same_shape = "0,1",
|
tile_size = "256",
|
||||||
tile_size = "256",
|
types = [
|
||||||
)
|
"f16",
|
||||||
for (type, dtype) in [
|
"f32",
|
||||||
("f16", "DT_HALF"),
|
"f64",
|
||||||
("f32", "DT_FLOAT"),
|
],
|
||||||
("f64", "DT_DOUBLE"),
|
)
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
tanh_kernel = """
|
tanh_kernel = """
|
||||||
func @tanh(%arg0: tensor<?xf99>) -> tensor<?xf99> {
|
func @tanh(%arg0: tensor<?xf99>) -> tensor<?xf99> {
|
||||||
@ -54,14 +55,12 @@ func @tanh(%arg0: tensor<?xf99>) -> tensor<?xf99> {
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
[
|
gen_kernel_library(
|
||||||
gen_kernel_image_hdr(
|
name = "tanh",
|
||||||
name = "tanh_{type}_kernel".format(type = type),
|
op = tanh_kernel,
|
||||||
op = tanh_kernel.replace("f99", type).replace("DT_TYPE", dtype),
|
tile_size = "256",
|
||||||
tile_size = "256",
|
types = [
|
||||||
)
|
"f32",
|
||||||
for (type, dtype) in [
|
"f64",
|
||||||
("f32", "DT_FLOAT"),
|
],
|
||||||
("f64", "DT_DOUBLE"),
|
)
|
||||||
]
|
|
||||||
]
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Generates cubin headers for TF dialect ops."""
|
"""Generates cubin headers for TF dialect ops."""
|
||||||
|
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_gpu_architectures")
|
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_gpu_architectures", "if_cuda")
|
||||||
|
|
||||||
def _lookup_file(filegroup, path):
|
def _lookup_file(filegroup, path):
|
||||||
"""Extracts file at (relative) path in filegroup."""
|
"""Extracts file at (relative) path in filegroup."""
|
||||||
@ -61,12 +61,12 @@ def _gen_kernel_image_hdr_impl(ctx):
|
|||||||
outputs = [ctx.outputs.out],
|
outputs = [ctx.outputs.out],
|
||||||
inputs = [fatbin],
|
inputs = [fatbin],
|
||||||
tools = [bin2c],
|
tools = [bin2c],
|
||||||
command = "%s --static --const --type=int --name=%s %s 1> %s" %
|
command = "%s --static --const --type=char --name=%s %s 1> %s" %
|
||||||
(bin2c.path, ctx.attr.symbol, fatbin.path, ctx.outputs.out.path),
|
(bin2c.path, ctx.attr.symbol, fatbin.path, ctx.outputs.out.path),
|
||||||
mnemonic = "bin2c",
|
mnemonic = "bin2c",
|
||||||
)
|
)
|
||||||
|
|
||||||
_gen_kernel_image_hdr = rule(
|
_gen_kernel_image_hdr_rule = rule(
|
||||||
implementation = _gen_kernel_image_hdr_impl,
|
implementation = _gen_kernel_image_hdr_impl,
|
||||||
output_to_genfiles = True,
|
output_to_genfiles = True,
|
||||||
attrs = {
|
attrs = {
|
||||||
@ -87,10 +87,10 @@ _gen_kernel_image_hdr = rule(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def gen_kernel_image_hdr(name, op, tile_size, tags = [], same_shape = None):
|
def _gen_kernel_image_hdr(name, op, tile_size, tags = [], same_shape = None):
|
||||||
"""Generates a C header with fatbin data from a Tensorflow op."""
|
"""Generates a C header with fatbin data from a Tensorflow op."""
|
||||||
if cuda_gpu_architectures():
|
if cuda_gpu_architectures():
|
||||||
_gen_kernel_image_hdr(
|
_gen_kernel_image_hdr_rule(
|
||||||
name = name,
|
name = name,
|
||||||
op = op,
|
op = op,
|
||||||
tile_size = tile_size,
|
tile_size = tile_size,
|
||||||
@ -100,3 +100,25 @@ def gen_kernel_image_hdr(name, op, tile_size, tags = [], same_shape = None):
|
|||||||
gpu_archs = cuda_gpu_architectures(),
|
gpu_archs = cuda_gpu_architectures(),
|
||||||
tags = tags,
|
tags = tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def gen_kernel_library(name, op, types, tile_size, tags = [], same_shape = None):
|
||||||
|
if cuda_gpu_architectures():
|
||||||
|
type_to_dtype = {
|
||||||
|
"f16": "DT_HALF",
|
||||||
|
"f32": "DT_FLOAT",
|
||||||
|
"f64": "DT_DOUBLE",
|
||||||
|
}
|
||||||
|
for type in types:
|
||||||
|
_gen_kernel_image_hdr(
|
||||||
|
name = "{name}_{type}_kernel".format(name = name, type = type),
|
||||||
|
op = op.replace("f99", type).replace("DT_TYPE", type_to_dtype[type]),
|
||||||
|
tile_size = tile_size,
|
||||||
|
tags = tags,
|
||||||
|
same_shape = same_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
native.cc_library(
|
||||||
|
name = name + "_kernels",
|
||||||
|
hdrs = if_cuda(if_true = [":{name}_{type}_kernel".format(name = name, type = type) for type in types]),
|
||||||
|
tags = tags,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user