Add option in kernel_lowering to use the tanh approximation.

PiperOrigin-RevId: 317077338
Change-Id: I5cc7bfb84d6defba05439186377f90e422247d68
This commit is contained in:
Stephan Herhut 2020-06-18 04:33:29 -07:00 committed by TensorFlower Gardener
parent dd49e65c5b
commit de5620b74c
3 changed files with 7 additions and 0 deletions

View File

@ -167,6 +167,7 @@ cc_library(
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration", "//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_tanh_to_approximation",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg", "//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",

View File

@ -505,6 +505,11 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) {
// Some basic cleanup. // Some basic cleanup.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Approximate of requested.
if (options.use_approximations) {
pm.addNestedPass<::mlir::FuncOp>(
::mlir::xla::createLegalizeTanhToApproximationPass());
}
// Move scalar operations into the launch to ensure smaller signatures. // Move scalar operations into the launch to ensure smaller signatures.
pm.addPass(absl::make_unique<MoveScalarComputationsIntoGpuLaunch>()); pm.addPass(absl::make_unique<MoveScalarComputationsIntoGpuLaunch>());
// Take launches to launches with kernels. // Take launches to launches with kernels.

View File

@ -28,6 +28,7 @@ struct LowerLHLOToGPUOptions {
llvm::ArrayRef<unsigned> unroll_factors = {}; llvm::ArrayRef<unsigned> unroll_factors = {};
bool collapse_parallel_loops = true; bool collapse_parallel_loops = true;
bool rewrite_signature = true; bool rewrite_signature = true;
bool use_approximations = false;
}; };
Status LowerLHLOToGPU(mlir::ModuleOp module, Status LowerLHLOToGPU(mlir::ModuleOp module,