diff --git a/tensorflow/core/kernels/cubin_headers/build_defs.bzl b/tensorflow/core/kernels/cubin_headers/build_defs.bzl index f0f4a944e74..5880cbe8add 100644 --- a/tensorflow/core/kernels/cubin_headers/build_defs.bzl +++ b/tensorflow/core/kernels/cubin_headers/build_defs.bzl @@ -15,9 +15,11 @@ def _gen_kernel_image_hdr_impl(ctx): name = ctx.attr.name tile_sizes = ctx.attr.tile_size.replace("x", ",") - same_shape = [] + cmd_args = [] if ctx.attr.same_shape: - same_shape.append("--same_shape=%s" % ctx.attr.same_shape) + cmd_args.append("--same_shape=%s" % ctx.attr.same_shape) + if ctx.attr.unroll_factors: + cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors) cubins = [] images = [] @@ -30,7 +32,7 @@ def _gen_kernel_image_hdr_impl(ctx): inputs = [ctx.file.mlir_op], outputs = [cubin], executable = ctx.executable._tool, - arguments = same_shape + [ + arguments = cmd_args + [ "--tile_sizes=%s" % tile_sizes, "--arch=%s" % arch.split("_")[1], "--input=%s" % ctx.file.mlir_op.path, @@ -74,6 +76,7 @@ _gen_kernel_image_hdr_rule = rule( "mlir_op": attr.label(mandatory = True, allow_single_file = True), "tile_size": attr.string(mandatory = True), "same_shape": attr.string(), + "unroll_factors": attr.string(), "out": attr.output(mandatory = True), "symbol": attr.string(mandatory = True), "gpu_archs": attr.string_list(mandatory = True), @@ -88,7 +91,7 @@ _gen_kernel_image_hdr_rule = rule( }, ) -def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None): +def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None, unroll_factors = None): """Generates a C header with fatbin data from a Tensorflow op.""" if cuda_gpu_architectures(): _gen_kernel_image_hdr_rule( @@ -96,6 +99,7 @@ def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None mlir_op = mlir_op, tile_size = tile_size, same_shape = same_shape, + unroll_factors = unroll_factors, out = "%s.h" % name, symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""), gpu_archs = cuda_gpu_architectures(), @@ -131,13 +135,14 @@ def _gen_mlir_op(name, type): out = "{name}_{type}.mlir".format(name = name, type = type), ) -def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None): +def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unroll_factors = None): """ Generate a library with kernels for a specific tensorflow op. Args: name: The name of the tensorflow op. types: The types ("f16", "f32", "f64") for which a kernel should be generated. tile_size: The tiling specification, e.g. "16x16". + unroll_factors: The unrolling specification, e.g. "4,4" tags: The tags which should be added to the library. same_shape: The information about which shapes are the same, e.g. "0,1". """ @@ -154,6 +159,7 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None): tile_size = tile_size, tags = tags, same_shape = same_shape, + unroll_factors = unroll_factors, ) native.cc_library(