From f329ec5aec49b77bf67a3c7331e4a6417dd27284 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 28 Jan 2021 21:07:51 -0800 Subject: [PATCH] Only import the entry computation when importing an HLO Module into MLIR The reachable subcomputations will be imported on demand as the module is traversed. This avoids importing some computations that are only used in reduction or similar MLIR region-based operations. This does not change any functionality: the unused computation would be deleted by SymbolDCE. PiperOrigin-RevId: 354460669 Change-Id: I22ec7be64cf09d4a1c7e5443a343ced7014ce403 --- .../compiler/mlir/xla/hlo_module_importer.cc | 20 ++++++---- .../compiler/mlir/xla/hlo_module_importer.h | 4 +- .../compiler/mlir/xla/hlo_to_mlir_hlo.cc | 11 ++++-- .../compiler/mlir/xla/hlo_to_mlir_hlo.h | 10 ++++- .../mlir/xla/tests/translate/import.hlotxt | 5 ++- .../compiler/mlir/xla/xla_mlir_translate.cc | 37 +++++++++++++++---- .../compiler/mlir/xla/xla_mlir_translate.h | 12 ++++-- 7 files changed, 73 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc index b554f38b148..1ad16db7db9 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc @@ -30,21 +30,25 @@ limitations under the License. namespace xla { -HloModuleImporter::HloModuleImporter(mlir::ModuleOp module) - : module_(module), builder_(module.getContext()) { +HloModuleImporter::HloModuleImporter(mlir::ModuleOp module, + bool import_all_computation) + : import_all_computation_(import_all_computation), + module_(module), + builder_(module.getContext()) { module.getContext()->loadDialect(); module.getContext()->loadDialect(); } Status HloModuleImporter::Import(const xla::HloModule& module) { - // TODO(hinsu): Only import the entry computation here once all HLO ops with - // reference to other computation are updated to have a region instead of a - // function attribute. Currently the importer test doesn't refer to all the - // computations from the entry computation so tests may need some update. - for (const auto* computation : module.computations()) { + if (!import_all_computation_) + // Only import the entry computation, any reachable one will be imported + // unless turned into a region operation. + return HloFunctionImporter::ImportAsFunc( + *module.entry_computation(), module_, &function_map_, &builder_); + + for (const auto* computation : module.computations()) TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc( *computation, module_, &function_map_, &builder_)); - } return Status::OK(); } diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.h b/tensorflow/compiler/mlir/xla/hlo_module_importer.h index 64cccbe46f6..b4a7113ea15 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.h @@ -37,7 +37,8 @@ class Shape; // dialect. HloModuleImporter does not take ownership. class HloModuleImporter { public: - explicit HloModuleImporter(mlir::ModuleOp module); + explicit HloModuleImporter(mlir::ModuleOp module, + bool import_all_computation = false); // Import the HloModule into the MLIR Module. Status Import(const xla::HloModule& module); @@ -46,6 +47,7 @@ class HloModuleImporter { Status Import(const xla::HloModuleProto& module); private: + bool import_all_computation_; mlir::ModuleOp module_; mlir::Builder builder_; diff --git a/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.cc b/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.cc index d9ffa166289..34286878351 100644 --- a/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.cc +++ b/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.cc @@ -22,14 +22,17 @@ limitations under the License. namespace xla { Status ConvertHloToMlirHlo(mlir::ModuleOp module, - xla::HloModuleProto* hlo_module_proto) { + xla::HloModuleProto* hlo_module_proto, + bool import_all_computation) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - return HloModuleImporter(module).Import(*hlo_module_proto); + return HloModuleImporter(module, import_all_computation) + .Import(*hlo_module_proto); } -Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module) { +Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module, + bool import_all_computation) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - return HloModuleImporter(module).Import(*hlo_module); + return HloModuleImporter(module, import_all_computation).Import(*hlo_module); } } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h b/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h index e613ce72b23..e4021b5c6e2 100644 --- a/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h +++ b/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h @@ -28,11 +28,17 @@ class HloModule; class HloModuleProto; // Converts an HLO module proto to a MLIR module in HLO dialect. +// If import_all_computation is set to true, imports all computations +// irrespective if transitively called from entry computation. Status ConvertHloToMlirHlo(mlir::ModuleOp module, - xla::HloModuleProto* hlo_module); + xla::HloModuleProto* hlo_module, + bool import_all_computations = false); // Converts an HLO module to a MLIR module in HLO dialect. -Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module); +// If import_all_computation is set to true, imports all computations +// irrespective if transitively called from entry computation. +Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module, + bool import_all_computations = false); } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index a9744f0884e..377a1f5db11 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -1,4 +1,7 @@ -// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FILECHECK_OPTS="" FileCheck %s +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION + +// NO_DEAD_FUNCTION-NOT: @test HloModule main diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index fb4164b4ced..05c1cec1304 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -58,7 +58,8 @@ bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) { } // namespace mlir::OwningModuleRef HloToMlirHloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context) { + llvm::StringRef input, mlir::MLIRContext* context, + bool import_all_computations) { HloProto hlo_proto; string content(input.data(), input.size()); if (!LoadHloProto(content, &hlo_proto)) { @@ -68,8 +69,8 @@ mlir::OwningModuleRef HloToMlirHloTranslateFunction( mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); - auto status = - ConvertHloToMlirHlo(module.get(), hlo_proto.mutable_hlo_module()); + auto status = ConvertHloToMlirHlo( + module.get(), hlo_proto.mutable_hlo_module(), import_all_computations); if (!status.ok()) { LOG(ERROR) << "Hlo module import failed: " << status; return nullptr; @@ -79,7 +80,8 @@ mlir::OwningModuleRef HloToMlirHloTranslateFunction( } mlir::OwningModuleRef HloTextToMlirHloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context) { + llvm::StringRef input, mlir::MLIRContext* context, + bool import_all_computations) { HloProto hlo_proto; string content(input.data(), input.size()); @@ -92,7 +94,8 @@ mlir::OwningModuleRef HloTextToMlirHloTranslateFunction( auto hlo_module = std::move(hlo_module_error.ValueOrDie()); mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); - auto status = ConvertHloToMlirHlo(*module, hlo_module.get()); + auto status = + ConvertHloToMlirHlo(*module, hlo_module.get(), import_all_computations); if (!status.ok()) { LOG(ERROR) << "HLO Module import failed: " << status; return nullptr; @@ -233,6 +236,26 @@ static mlir::LogicalResult MlirHloToHloTextViaBuilderTranslateFunction( } // namespace xla +//----------------------------------------------------------------------------// +// Hooks for tf-mlir-translate +//----------------------------------------------------------------------------/ + +static llvm::cl::opt import_all_computations( + "hlo-import-all-computations", + llvm::cl::desc("Enable importing unreachable computations.")); + +static mlir::OwningModuleRef HloToMlirHloTranslate(llvm::StringRef input, + mlir::MLIRContext* context) { + return xla::HloToMlirHloTranslateFunction(input, context, + import_all_computations); +} + +static mlir::OwningModuleRef HloTextToMlirHloTranslate( + llvm::StringRef input, mlir::MLIRContext* context) { + return xla::HloTextToMlirHloTranslateFunction(input, context, + import_all_computations); +} + static void RegisterInputDialects(mlir::DialectRegistry& registry) { registry.insert(); @@ -255,10 +278,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextViaBuilderTranslate( xla::MlirHloToHloTextViaBuilderTranslateFunction, RegisterInputDialects); static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate( - "hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction); + "hlo-to-mlir-hlo", HloToMlirHloTranslate); static mlir::TranslateToMLIRRegistration HloTextToHloMlirTranslate( - "hlo-text-to-mlir-hlo", xla::HloTextToMlirHloTranslateFunction); + "hlo-text-to-mlir-hlo", HloTextToMlirHloTranslate); // MHLO doesn't support explicit layouts, while XLA service does. // TODO(timshen): remove it once MHLO supports explicit layouts. diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.h b/tensorflow/compiler/mlir/xla/xla_mlir_translate.h index e086cacfb4b..253a93228f1 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.h +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.h @@ -33,14 +33,20 @@ namespace xla { // Converts a HloModuleProto stored in the file with the given `input_filename` // into a MLIR module. Creates MLIR entities into the given MLIR `context`. -mlir::OwningModuleRef HloToMlirHloTranslateFunction(llvm::StringRef input, - mlir::MLIRContext* context); +// If import_all_computation is set to true, imports all computations +// irrespective if transitively called from entry computation. +mlir::OwningModuleRef HloToMlirHloTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context, + bool import_all_computations = false); // Converts a HloModule stored in text form for a file with the given // `input_filename` into a MLIR module. Creates MLIR entities into the given // MLIR `context`. +// If import_all_computation is set to true, imports all computations +// irrespective if transitively called from entry computation. mlir::OwningModuleRef HloTextToMlirHloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context); + llvm::StringRef input, mlir::MLIRContext* context, + bool import_all_computations = false); } // namespace xla