Fix sin kernel for ROCM.

It should be lowered to an intrinsic call.
Re-enable generating a GPU kernel for sin, now that the
problem is fixed.

PiperOrigin-RevId: 342243974
Change-Id: If151fedbdde3146c01ee5f80f002e224d9e8e3b6
This commit is contained in:
Adrian Kuegel 2020-11-13 05:24:45 -08:00 committed by TensorFlower Gardener
parent 8a1bb87da5
commit 4d27598a22
2 changed files with 5 additions and 19 deletions

View File

@ -236,10 +236,10 @@ class LowerToROCDLPass
::mlir::ConversionTarget target(getContext()); ::mlir::ConversionTarget target(getContext());
target.addIllegalDialect<::mlir::gpu::GPUDialect>(); target.addIllegalDialect<::mlir::gpu::GPUDialect>();
target target.addIllegalOp<mlir::LLVM::CosOp, mlir::LLVM::ExpOp,
.addIllegalOp<mlir::LLVM::CosOp, mlir::LLVM::ExpOp, mlir::LLVM::FAbsOp, mlir::LLVM::FAbsOp, mlir::LLVM::FCeilOp,
mlir::LLVM::FCeilOp, mlir::LLVM::LogOp, mlir::LLVM::LogOp, mlir::LLVM::Log10Op,
mlir::LLVM::Log10Op, mlir::LLVM::Log2Op>(); mlir::LLVM::Log2Op, mlir::LLVM::SinOp>();
target.addIllegalOp<mlir::FuncOp>(); target.addIllegalOp<mlir::FuncOp>();
target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();

View File

@ -302,21 +302,6 @@ gen_kernel_library(
unroll_factors = "4", unroll_factors = "4",
) )
# Temporarily disabled due to failure on ROCM, fix will come soon.
# gen_kernel_library(
# name = "sin",
# generate_ranked = False,
# generate_unranked = True,
# tags = ["no_rocm"],
# tile_size = "256",
# types = [
# "f16",
# "f32",
# "f64",
# ],
# unroll_factors = "4",
# )
gen_kernel_library( gen_kernel_library(
name = "addv2", name = "addv2",
generate_ranked = False, generate_ranked = False,
@ -352,6 +337,7 @@ gen_kernel_library(
"log", "log",
"neg", "neg",
"rsqrt", "rsqrt",
"sin",
"sqrt", "sqrt",
"tanh", "tanh",
] ]