From 7dd629d9a15d444e57a9c4a721f2e94328080e38 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Fri, 2 Oct 2020 04:18:04 -0700 Subject: [PATCH] Add applyTensorflowAndCLOptions to handle pass manager setup This handles Tensorflow wide configurations along with calling mlir's PassManager's command line parsing and setup. PiperOrigin-RevId: 335000457 Change-Id: I06d3bf1e4f8d7a79959b132e04ead0f14094cabe --- .../mlir/tensorflow/c/c_api_unified_experimental_mlir.cc | 2 +- tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc | 2 +- .../tensorflow/transforms/graph_optimization_pass.cc | 2 +- .../compiler/mlir/tensorflow/utils/compile_mlir_util.cc | 4 ++-- .../compiler/mlir/tensorflow/utils/dump_mlir_util.cc | 6 ++++++ .../compiler/mlir/tensorflow/utils/dump_mlir_util.h | 9 +++++++++ tensorflow/compiler/mlir/tools/kernel_gen/BUILD | 1 + .../compiler/mlir/tools/kernel_gen/kernel_creator.cc | 8 ++------ tensorflow/compiler/xla/service/mlir_gpu/BUILD | 1 + .../compiler/xla/service/mlir_gpu/kernel_lowering.cc | 7 ++++--- 10 files changed, 28 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index f97f59c67f8..32c51f2e2bd 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -512,7 +512,7 @@ Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) { return Status::OK(); } PassManager pm(func_.getContext()); - ::tensorflow::SetCrashReproducer(pm); + ::tensorflow::applyTensorflowAndCLOptions(pm); pm.addNestedPass(CreateFunctionalToExecutorDialectConversionPass()); pm.addPass(CreateBreakUpIslandsPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index df311712242..2e1e63eb143 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -58,7 +58,7 @@ tensorflow::Status RunTPUBridge( ModuleOp module, bool enable_logging, llvm::function_ref pipeline_builder) { PassManager bridge(module.getContext()); - ::tensorflow::SetCrashReproducer(bridge); + ::tensorflow::applyTensorflowAndCLOptions(bridge); if (enable_logging) EnableLogging(&bridge); // Populate a passmanager with the list of passes that implement the bridge. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc index 2677f5bdfac..a18d893fac7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -40,7 +40,7 @@ Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto, VLOG(1) << "Run MLIR Graph Optimization Passes"; PassManager pm(module.getContext()); - ::tensorflow::SetCrashReproducer(pm); + ::tensorflow::applyTensorflowAndCLOptions(pm); // Run island coarsening before shape inference to allow more exact shape // inference using constant folding within islands. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 824828052c0..2b9c7ed9a6f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -324,7 +324,7 @@ Status ConvertMLIRToXlaComputation( llvm::MutableArrayRef> custom_legalization_passes) { mlir::PassManager tf2xla(module_op.getContext()); - SetCrashReproducer(tf2xla); + applyTensorflowAndCLOptions(tf2xla); CreateConvertMlirToXlaHloPipeline(tf2xla, device_type, custom_legalization_passes); @@ -513,7 +513,7 @@ Status CompileGraphToXlaHlo( } mlir::PassManager pm(module_op.getContext()); - SetCrashReproducer(pm); + applyTensorflowAndCLOptions(pm); mlir::TF::StandardPipelineOptions tf_options; mlir::TF::CreateTFStandardPipeline(pm, tf_options); { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index b9d26c9849a..6c1cab435d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -227,4 +227,10 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { pm.enableCrashReproducerGeneration(path, /*genLocalReproducer=*/false); } +void applyTensorflowAndCLOptions(mlir::PassManager& pm, + llvm::StringRef dir_path) { + mlir::applyPassManagerCLOptions(pm); + SetCrashReproducer(pm, dir_path); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h index 55eaeadc43a..133285864f6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -72,6 +72,15 @@ std::string DumpRawStringToFile(llvm::StringRef name, llvm::StringRef content, // Files" by looking up the environment to infer the directory path. void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path = ""); +// This applies both the PassManagerCLOptions provided by MLIR along with any +// tensorflow specific options. +// +// Note that this function should be in a more appropriate file, but it is +// unclear what a proper file would be as no other functions would currently be +// in the file also. +void applyTensorflowAndCLOptions(mlir::PassManager& pm, + llvm::StringRef dir_path = ""); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 2c59cd880e2..619a56cb6a9 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -44,6 +44,7 @@ cc_library( "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", "//tensorflow/compiler/mlir/hlo:transform_unranked_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 24d96230c54..48696f6e8b0 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -69,9 +69,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors) { mlir::PassManager pm(module.getContext()); - applyPassManagerCLOptions(pm); - // TODO(b/169357508): renable when pipeline is serializable. - // SetCrashReproducer(pm); + applyTensorflowAndCLOptions(pm); pm.addPass(mlir::mhlo::createLegalizeTFPass(false)); if (gpu_binary_only) { @@ -178,9 +176,7 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only, llvm::StringRef gpu_binary_attr_name, int32_t architecture) { mlir::PassManager pm(module.getContext()); - applyPassManagerCLOptions(pm); - // TODO(b/169357508): renable when pipeline is serializable. - // SetCrashReproducer(pm); + applyTensorflowAndCLOptions(pm); auto& kernel_pm = pm.nest(); if (gpu_binary_only) { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 954fb1360fd..456d0c9fab5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -210,6 +210,7 @@ cc_library( "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index db03aaeb7d1..bb8a990fa6d 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/xla/service/mlir_gpu/passes.h" #include "tensorflow/compiler/xla/util.h" @@ -48,7 +49,7 @@ namespace mlir_gpu { Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { mlir::PassManager pm(module.getContext()); - applyPassManagerCLOptions(pm); + tensorflow::applyTensorflowAndCLOptions(pm); // We have to anticipate later unrolling in tiling to make sure that we get // the requested tiling after unrolling. Compute the new tiling here if @@ -181,7 +182,7 @@ class LowerToNVVMPass Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) { // We cannot verify as the signature of the kernel is rewritten. ::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false); - applyPassManagerCLOptions(pm); + tensorflow::applyTensorflowAndCLOptions(pm); // Rewrite kernel functions to LLVM IR. auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>(); @@ -250,7 +251,7 @@ class LowerToROCDLPass Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) { // We cannot verify as the signature of the kernel is rewritten. ::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false); - applyPassManagerCLOptions(pm); + tensorflow::applyTensorflowAndCLOptions(pm); auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) { return VLOG_IS_ON(1);