Add support for unrolling to the build rules.
PiperOrigin-RevId: 316851326 Change-Id: Icef09025b154f6d88c94884a7970f93022bcd160
This commit is contained in:
parent
70a617943c
commit
17cacd3a21
@ -15,9 +15,11 @@ def _gen_kernel_image_hdr_impl(ctx):
|
|||||||
|
|
||||||
name = ctx.attr.name
|
name = ctx.attr.name
|
||||||
tile_sizes = ctx.attr.tile_size.replace("x", ",")
|
tile_sizes = ctx.attr.tile_size.replace("x", ",")
|
||||||
same_shape = []
|
cmd_args = []
|
||||||
if ctx.attr.same_shape:
|
if ctx.attr.same_shape:
|
||||||
same_shape.append("--same_shape=%s" % ctx.attr.same_shape)
|
cmd_args.append("--same_shape=%s" % ctx.attr.same_shape)
|
||||||
|
if ctx.attr.unroll_factors:
|
||||||
|
cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
|
||||||
|
|
||||||
cubins = []
|
cubins = []
|
||||||
images = []
|
images = []
|
||||||
@ -30,7 +32,7 @@ def _gen_kernel_image_hdr_impl(ctx):
|
|||||||
inputs = [ctx.file.mlir_op],
|
inputs = [ctx.file.mlir_op],
|
||||||
outputs = [cubin],
|
outputs = [cubin],
|
||||||
executable = ctx.executable._tool,
|
executable = ctx.executable._tool,
|
||||||
arguments = same_shape + [
|
arguments = cmd_args + [
|
||||||
"--tile_sizes=%s" % tile_sizes,
|
"--tile_sizes=%s" % tile_sizes,
|
||||||
"--arch=%s" % arch.split("_")[1],
|
"--arch=%s" % arch.split("_")[1],
|
||||||
"--input=%s" % ctx.file.mlir_op.path,
|
"--input=%s" % ctx.file.mlir_op.path,
|
||||||
@ -74,6 +76,7 @@ _gen_kernel_image_hdr_rule = rule(
|
|||||||
"mlir_op": attr.label(mandatory = True, allow_single_file = True),
|
"mlir_op": attr.label(mandatory = True, allow_single_file = True),
|
||||||
"tile_size": attr.string(mandatory = True),
|
"tile_size": attr.string(mandatory = True),
|
||||||
"same_shape": attr.string(),
|
"same_shape": attr.string(),
|
||||||
|
"unroll_factors": attr.string(),
|
||||||
"out": attr.output(mandatory = True),
|
"out": attr.output(mandatory = True),
|
||||||
"symbol": attr.string(mandatory = True),
|
"symbol": attr.string(mandatory = True),
|
||||||
"gpu_archs": attr.string_list(mandatory = True),
|
"gpu_archs": attr.string_list(mandatory = True),
|
||||||
@ -88,7 +91,7 @@ _gen_kernel_image_hdr_rule = rule(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None):
|
def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None, unroll_factors = None):
|
||||||
"""Generates a C header with fatbin data from a Tensorflow op."""
|
"""Generates a C header with fatbin data from a Tensorflow op."""
|
||||||
if cuda_gpu_architectures():
|
if cuda_gpu_architectures():
|
||||||
_gen_kernel_image_hdr_rule(
|
_gen_kernel_image_hdr_rule(
|
||||||
@ -96,6 +99,7 @@ def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None
|
|||||||
mlir_op = mlir_op,
|
mlir_op = mlir_op,
|
||||||
tile_size = tile_size,
|
tile_size = tile_size,
|
||||||
same_shape = same_shape,
|
same_shape = same_shape,
|
||||||
|
unroll_factors = unroll_factors,
|
||||||
out = "%s.h" % name,
|
out = "%s.h" % name,
|
||||||
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
|
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
|
||||||
gpu_archs = cuda_gpu_architectures(),
|
gpu_archs = cuda_gpu_architectures(),
|
||||||
@ -131,13 +135,14 @@ def _gen_mlir_op(name, type):
|
|||||||
out = "{name}_{type}.mlir".format(name = name, type = type),
|
out = "{name}_{type}.mlir".format(name = name, type = type),
|
||||||
)
|
)
|
||||||
|
|
||||||
def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None):
|
def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unroll_factors = None):
|
||||||
""" Generate a library with kernels for a specific tensorflow op.
|
""" Generate a library with kernels for a specific tensorflow op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the tensorflow op.
|
name: The name of the tensorflow op.
|
||||||
types: The types ("f16", "f32", "f64") for which a kernel should be generated.
|
types: The types ("f16", "f32", "f64") for which a kernel should be generated.
|
||||||
tile_size: The tiling specification, e.g. "16x16".
|
tile_size: The tiling specification, e.g. "16x16".
|
||||||
|
unroll_factors: The unrolling specification, e.g. "4,4"
|
||||||
tags: The tags which should be added to the library.
|
tags: The tags which should be added to the library.
|
||||||
same_shape: The information about which shapes are the same, e.g. "0,1".
|
same_shape: The information about which shapes are the same, e.g. "0,1".
|
||||||
"""
|
"""
|
||||||
@ -154,6 +159,7 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None):
|
|||||||
tile_size = tile_size,
|
tile_size = tile_size,
|
||||||
tags = tags,
|
tags = tags,
|
||||||
same_shape = same_shape,
|
same_shape = same_shape,
|
||||||
|
unroll_factors = unroll_factors,
|
||||||
)
|
)
|
||||||
|
|
||||||
native.cc_library(
|
native.cc_library(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user