Generate a cubin header for tanh.
So far, only generate it for f32 and f64, f16 doesn't work yet. PiperOrigin-RevId: 312258425 Change-Id: I73c7a58d8fa2ebf02729fe1f7317aabb746fa8b0
This commit is contained in:
parent
8121e42ca4
commit
b93bd76a9f
@ -231,8 +231,14 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
|||||||
xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors,
|
xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors,
|
||||||
/*collapseParallelLoops=*/false));
|
/*collapseParallelLoops=*/false));
|
||||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||||
TF_RETURN_IF_ERROR(
|
// TODO(b/156985522): Figure out why we get a segfault when generating Tanh
|
||||||
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
|
// with 'same_shape' containing {0, 1}. We would also get the crash if we
|
||||||
|
// unconditionally call PropagateStaticShapeKnowledgeToKernel while
|
||||||
|
// 'same_shape' is empty.
|
||||||
|
if (!same_shape.empty()) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
|
||||||
|
}
|
||||||
|
|
||||||
mlir::OwningModuleRef kernel_module =
|
mlir::OwningModuleRef kernel_module =
|
||||||
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();
|
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();
|
||||||
|
@ -45,3 +45,23 @@ func @relu(%arg0: tensor<?xf99>) -> tensor<?xf99> {
|
|||||||
("f64", "DT_DOUBLE"),
|
("f64", "DT_DOUBLE"),
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
tanh_kernel = """
|
||||||
|
func @tanh(%arg0: tensor<?xf99>) -> tensor<?xf99> {
|
||||||
|
%0 = "tf.Tanh"(%arg0) { T = "tfdtype$DT_TYPE" }
|
||||||
|
: (tensor<?xf99>) -> tensor<?xf99>
|
||||||
|
return %0 : tensor<?xf99>
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
[
|
||||||
|
gen_kernel_image_hdr(
|
||||||
|
name = "tanh_{type}_kernel".format(type = type),
|
||||||
|
op = tanh_kernel.replace("f99", type).replace("DT_TYPE", dtype),
|
||||||
|
tile_size = "256",
|
||||||
|
)
|
||||||
|
for (type, dtype) in [
|
||||||
|
("f32", "DT_FLOAT"),
|
||||||
|
("f64", "DT_DOUBLE"),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user