Generate a cubin header for tanh.

So far, only generate it for f32 and f64, f16 doesn't work yet.

PiperOrigin-RevId: 312258425
Change-Id: I73c7a58d8fa2ebf02729fe1f7317aabb746fa8b0
This commit is contained in:
Adrian Kuegel 2020-05-19 05:15:31 -07:00 committed by TensorFlower Gardener
parent 8121e42ca4
commit b93bd76a9f
2 changed files with 28 additions and 2 deletions

View File

@ -231,8 +231,14 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors,
/*collapseParallelLoops=*/false));
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
TF_RETURN_IF_ERROR(
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
// TODO(b/156985522): Figure out why we get a segfault when generating Tanh
// with 'same_shape' containing {0, 1}. We would also get the crash if we
// unconditionally call PropagateStaticShapeKnowledgeToKernel while
// 'same_shape' is empty.
if (!same_shape.empty()) {
TF_RETURN_IF_ERROR(
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
}
mlir::OwningModuleRef kernel_module =
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();

View File

@ -45,3 +45,23 @@ func @relu(%arg0: tensor<?xf99>) -> tensor<?xf99> {
("f64", "DT_DOUBLE"),
]
]
tanh_kernel = """
func @tanh(%arg0: tensor<?xf99>) -> tensor<?xf99> {
%0 = "tf.Tanh"(%arg0) { T = "tfdtype$DT_TYPE" }
: (tensor<?xf99>) -> tensor<?xf99>
return %0 : tensor<?xf99>
}
"""
[
gen_kernel_image_hdr(
name = "tanh_{type}_kernel".format(type = type),
op = tanh_kernel.replace("f99", type).replace("DT_TYPE", dtype),
tile_size = "256",
)
for (type, dtype) in [
("f32", "DT_FLOAT"),
("f64", "DT_DOUBLE"),
]
]