Add support for unrolling to the build rules.

PiperOrigin-RevId: 316851326
Change-Id: Icef09025b154f6d88c94884a7970f93022bcd160
This commit is contained in:
Stephan Herhut 2020-06-17 02:32:08 -07:00 committed by TensorFlower Gardener
parent 70a617943c
commit 17cacd3a21

View File

@ -15,9 +15,11 @@ def _gen_kernel_image_hdr_impl(ctx):
name = ctx.attr.name name = ctx.attr.name
tile_sizes = ctx.attr.tile_size.replace("x", ",") tile_sizes = ctx.attr.tile_size.replace("x", ",")
same_shape = [] cmd_args = []
if ctx.attr.same_shape: if ctx.attr.same_shape:
same_shape.append("--same_shape=%s" % ctx.attr.same_shape) cmd_args.append("--same_shape=%s" % ctx.attr.same_shape)
if ctx.attr.unroll_factors:
cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
cubins = [] cubins = []
images = [] images = []
@ -30,7 +32,7 @@ def _gen_kernel_image_hdr_impl(ctx):
inputs = [ctx.file.mlir_op], inputs = [ctx.file.mlir_op],
outputs = [cubin], outputs = [cubin],
executable = ctx.executable._tool, executable = ctx.executable._tool,
arguments = same_shape + [ arguments = cmd_args + [
"--tile_sizes=%s" % tile_sizes, "--tile_sizes=%s" % tile_sizes,
"--arch=%s" % arch.split("_")[1], "--arch=%s" % arch.split("_")[1],
"--input=%s" % ctx.file.mlir_op.path, "--input=%s" % ctx.file.mlir_op.path,
@ -74,6 +76,7 @@ _gen_kernel_image_hdr_rule = rule(
"mlir_op": attr.label(mandatory = True, allow_single_file = True), "mlir_op": attr.label(mandatory = True, allow_single_file = True),
"tile_size": attr.string(mandatory = True), "tile_size": attr.string(mandatory = True),
"same_shape": attr.string(), "same_shape": attr.string(),
"unroll_factors": attr.string(),
"out": attr.output(mandatory = True), "out": attr.output(mandatory = True),
"symbol": attr.string(mandatory = True), "symbol": attr.string(mandatory = True),
"gpu_archs": attr.string_list(mandatory = True), "gpu_archs": attr.string_list(mandatory = True),
@ -88,7 +91,7 @@ _gen_kernel_image_hdr_rule = rule(
}, },
) )
def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None): def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None, unroll_factors = None):
"""Generates a C header with fatbin data from a Tensorflow op.""" """Generates a C header with fatbin data from a Tensorflow op."""
if cuda_gpu_architectures(): if cuda_gpu_architectures():
_gen_kernel_image_hdr_rule( _gen_kernel_image_hdr_rule(
@ -96,6 +99,7 @@ def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None
mlir_op = mlir_op, mlir_op = mlir_op,
tile_size = tile_size, tile_size = tile_size,
same_shape = same_shape, same_shape = same_shape,
unroll_factors = unroll_factors,
out = "%s.h" % name, out = "%s.h" % name,
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""), symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
gpu_archs = cuda_gpu_architectures(), gpu_archs = cuda_gpu_architectures(),
@ -131,13 +135,14 @@ def _gen_mlir_op(name, type):
out = "{name}_{type}.mlir".format(name = name, type = type), out = "{name}_{type}.mlir".format(name = name, type = type),
) )
def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None): def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unroll_factors = None):
""" Generate a library with kernels for a specific tensorflow op. """ Generate a library with kernels for a specific tensorflow op.
Args: Args:
name: The name of the tensorflow op. name: The name of the tensorflow op.
types: The types ("f16", "f32", "f64") for which a kernel should be generated. types: The types ("f16", "f32", "f64") for which a kernel should be generated.
tile_size: The tiling specification, e.g. "16x16". tile_size: The tiling specification, e.g. "16x16".
unroll_factors: The unrolling specification, e.g. "4,4"
tags: The tags which should be added to the library. tags: The tags which should be added to the library.
same_shape: The information about which shapes are the same, e.g. "0,1". same_shape: The information about which shapes are the same, e.g. "0,1".
""" """
@ -154,6 +159,7 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None):
tile_size = tile_size, tile_size = tile_size,
tags = tags, tags = tags,
same_shape = same_shape, same_shape = same_shape,
unroll_factors = unroll_factors,
) )
native.cc_library( native.cc_library(