From b93bd76a9f5025ec42b6b9a2ca4a26562b49c405 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 19 May 2020 05:15:31 -0700 Subject: [PATCH] Generate a cubin header for tanh. So far, only generate it for f32 and f64, f16 doesn't work yet. PiperOrigin-RevId: 312258425 Change-Id: I73c7a58d8fa2ebf02729fe1f7317aabb746fa8b0 --- .../mlir/tools/kernel_gen/cubin_creator.cc | 10 ++++++++-- tensorflow/core/kernels/cubin_headers/BUILD | 20 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc index b1c4b1beae1..f47485d0214 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -231,8 +231,14 @@ StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors, /*collapseParallelLoops=*/false)); TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); - TF_RETURN_IF_ERROR( - PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape)); + // TODO(b/156985522): Figure out why we get a segfault when generating Tanh + // with 'same_shape' containing {0, 1}. We would also get the crash if we + // unconditionally call PropagateStaticShapeKnowledgeToKernel while + // 'same_shape' is empty. + if (!same_shape.empty()) { + TF_RETURN_IF_ERROR( + PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape)); + } mlir::OwningModuleRef kernel_module = xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie(); diff --git a/tensorflow/core/kernels/cubin_headers/BUILD b/tensorflow/core/kernels/cubin_headers/BUILD index bb7995dd221..509ac008355 100644 --- a/tensorflow/core/kernels/cubin_headers/BUILD +++ b/tensorflow/core/kernels/cubin_headers/BUILD @@ -45,3 +45,23 @@ func @relu(%arg0: tensor) -> tensor { ("f64", "DT_DOUBLE"), ] ] + +tanh_kernel = """ +func @tanh(%arg0: tensor) -> tensor { + %0 = "tf.Tanh"(%arg0) { T = "tfdtype$DT_TYPE" } + : (tensor) -> tensor + return %0 : tensor +} +""" + +[ + gen_kernel_image_hdr( + name = "tanh_{type}_kernel".format(type = type), + op = tanh_kernel.replace("f99", type).replace("DT_TYPE", dtype), + tile_size = "256", + ) + for (type, dtype) in [ + ("f32", "DT_FLOAT"), + ("f64", "DT_DOUBLE"), + ] +]