Generate a cubin header for tanh.
So far, only generate it for f32 and f64, f16 doesn't work yet. PiperOrigin-RevId: 312258425 Change-Id: I73c7a58d8fa2ebf02729fe1f7317aabb746fa8b0
This commit is contained in:
parent
8121e42ca4
commit
b93bd76a9f
@ -231,8 +231,14 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors,
|
||||
/*collapseParallelLoops=*/false));
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
|
||||
// TODO(b/156985522): Figure out why we get a segfault when generating Tanh
|
||||
// with 'same_shape' containing {0, 1}. We would also get the crash if we
|
||||
// unconditionally call PropagateStaticShapeKnowledgeToKernel while
|
||||
// 'same_shape' is empty.
|
||||
if (!same_shape.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef kernel_module =
|
||||
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();
|
||||
|
@ -45,3 +45,23 @@ func @relu(%arg0: tensor<?xf99>) -> tensor<?xf99> {
|
||||
("f64", "DT_DOUBLE"),
|
||||
]
|
||||
]
|
||||
|
||||
tanh_kernel = """
|
||||
func @tanh(%arg0: tensor<?xf99>) -> tensor<?xf99> {
|
||||
%0 = "tf.Tanh"(%arg0) { T = "tfdtype$DT_TYPE" }
|
||||
: (tensor<?xf99>) -> tensor<?xf99>
|
||||
return %0 : tensor<?xf99>
|
||||
}
|
||||
"""
|
||||
|
||||
[
|
||||
gen_kernel_image_hdr(
|
||||
name = "tanh_{type}_kernel".format(type = type),
|
||||
op = tanh_kernel.replace("f99", type).replace("DT_TYPE", dtype),
|
||||
tile_size = "256",
|
||||
)
|
||||
for (type, dtype) in [
|
||||
("f32", "DT_FLOAT"),
|
||||
("f64", "DT_DOUBLE"),
|
||||
]
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user