Add applyTensorflowAndCLOptions to handle pass manager setup
This handles Tensorflow wide configurations along with calling mlir's PassManager's command line parsing and setup. PiperOrigin-RevId: 335000457 Change-Id: I06d3bf1e4f8d7a79959b132e04ead0f14094cabe
This commit is contained in:
parent
fed53e76a7
commit
7dd629d9a1
@ -512,7 +512,7 @@ Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) {
|
||||
return Status::OK();
|
||||
}
|
||||
PassManager pm(func_.getContext());
|
||||
::tensorflow::SetCrashReproducer(pm);
|
||||
::tensorflow::applyTensorflowAndCLOptions(pm);
|
||||
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
|
||||
pm.addPass(CreateBreakUpIslandsPass());
|
||||
|
||||
|
@ -58,7 +58,7 @@ tensorflow::Status RunTPUBridge(
|
||||
ModuleOp module, bool enable_logging,
|
||||
llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
|
||||
PassManager bridge(module.getContext());
|
||||
::tensorflow::SetCrashReproducer(bridge);
|
||||
::tensorflow::applyTensorflowAndCLOptions(bridge);
|
||||
if (enable_logging) EnableLogging(&bridge);
|
||||
|
||||
// Populate a passmanager with the list of passes that implement the bridge.
|
||||
|
@ -40,7 +40,7 @@ Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto,
|
||||
|
||||
VLOG(1) << "Run MLIR Graph Optimization Passes";
|
||||
PassManager pm(module.getContext());
|
||||
::tensorflow::SetCrashReproducer(pm);
|
||||
::tensorflow::applyTensorflowAndCLOptions(pm);
|
||||
|
||||
// Run island coarsening before shape inference to allow more exact shape
|
||||
// inference using constant folding within islands.
|
||||
|
@ -324,7 +324,7 @@ Status ConvertMLIRToXlaComputation(
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
mlir::PassManager tf2xla(module_op.getContext());
|
||||
SetCrashReproducer(tf2xla);
|
||||
applyTensorflowAndCLOptions(tf2xla);
|
||||
CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
|
||||
custom_legalization_passes);
|
||||
|
||||
@ -513,7 +513,7 @@ Status CompileGraphToXlaHlo(
|
||||
}
|
||||
|
||||
mlir::PassManager pm(module_op.getContext());
|
||||
SetCrashReproducer(pm);
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
mlir::TF::StandardPipelineOptions tf_options;
|
||||
mlir::TF::CreateTFStandardPipeline(pm, tf_options);
|
||||
{
|
||||
|
@ -227,4 +227,10 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) {
|
||||
pm.enableCrashReproducerGeneration(path, /*genLocalReproducer=*/false);
|
||||
}
|
||||
|
||||
void applyTensorflowAndCLOptions(mlir::PassManager& pm,
|
||||
llvm::StringRef dir_path) {
|
||||
mlir::applyPassManagerCLOptions(pm);
|
||||
SetCrashReproducer(pm, dir_path);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -72,6 +72,15 @@ std::string DumpRawStringToFile(llvm::StringRef name, llvm::StringRef content,
|
||||
// Files" by looking up the environment to infer the directory path.
|
||||
void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path = "");
|
||||
|
||||
// This applies both the PassManagerCLOptions provided by MLIR along with any
|
||||
// tensorflow specific options.
|
||||
//
|
||||
// Note that this function should be in a more appropriate file, but it is
|
||||
// unclear what a proper file would be as no other functions would currently be
|
||||
// in the file also.
|
||||
void applyTensorflowAndCLOptions(mlir::PassManager& pm,
|
||||
llvm::StringRef dir_path = "");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_
|
||||
|
@ -44,6 +44,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/hlo:transform_unranked_hlo", # buildcleaner: keep
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
|
@ -69,9 +69,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyPassManagerCLOptions(pm);
|
||||
// TODO(b/169357508): renable when pipeline is serializable.
|
||||
// SetCrashReproducer(pm);
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
|
||||
pm.addPass(mlir::mhlo::createLegalizeTFPass(false));
|
||||
if (gpu_binary_only) {
|
||||
@ -178,9 +176,7 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
llvm::StringRef gpu_binary_attr_name,
|
||||
int32_t architecture) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyPassManagerCLOptions(pm);
|
||||
// TODO(b/169357508): renable when pipeline is serializable.
|
||||
// SetCrashReproducer(pm);
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
|
||||
auto& kernel_pm = pm.nest<mlir::gpu::GPUModuleOp>();
|
||||
if (gpu_binary_only) {
|
||||
|
@ -210,6 +210,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
@ -48,7 +49,7 @@ namespace mlir_gpu {
|
||||
|
||||
Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyPassManagerCLOptions(pm);
|
||||
tensorflow::applyTensorflowAndCLOptions(pm);
|
||||
|
||||
// We have to anticipate later unrolling in tiling to make sure that we get
|
||||
// the requested tiling after unrolling. Compute the new tiling here if
|
||||
@ -181,7 +182,7 @@ class LowerToNVVMPass
|
||||
Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) {
|
||||
// We cannot verify as the signature of the kernel is rewritten.
|
||||
::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false);
|
||||
applyPassManagerCLOptions(pm);
|
||||
tensorflow::applyTensorflowAndCLOptions(pm);
|
||||
|
||||
// Rewrite kernel functions to LLVM IR.
|
||||
auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
|
||||
@ -250,7 +251,7 @@ class LowerToROCDLPass
|
||||
Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) {
|
||||
// We cannot verify as the signature of the kernel is rewritten.
|
||||
::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false);
|
||||
applyPassManagerCLOptions(pm);
|
||||
tensorflow::applyTensorflowAndCLOptions(pm);
|
||||
|
||||
auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) {
|
||||
return VLOG_IS_ON(1);
|
||||
|
Loading…
x
Reference in New Issue
Block a user