Only import the entry computation when importing an HLO Module into MLIR

The reachable subcomputations will be imported on demand as the module is traversed.
This avoids importing some computations that are only used in reduction or similar
MLIR region-based operations.

This does not change any functionality: the unused computation would be deleted by
SymbolDCE.

PiperOrigin-RevId: 354460669
Change-Id: I22ec7be64cf09d4a1c7e5443a343ced7014ce403
This commit is contained in:
Mehdi Amini 2021-01-28 21:07:51 -08:00 committed by TensorFlower Gardener
parent de3d52e832
commit f329ec5aec
7 changed files with 73 additions and 26 deletions

View File

@ -30,21 +30,25 @@ limitations under the License.
namespace xla {
HloModuleImporter::HloModuleImporter(mlir::ModuleOp module)
: module_(module), builder_(module.getContext()) {
HloModuleImporter::HloModuleImporter(mlir::ModuleOp module,
bool import_all_computation)
: import_all_computation_(import_all_computation),
module_(module),
builder_(module.getContext()) {
module.getContext()->loadDialect<mlir::StandardOpsDialect>();
module.getContext()->loadDialect<mlir::mhlo::MhloDialect>();
}
Status HloModuleImporter::Import(const xla::HloModule& module) {
// TODO(hinsu): Only import the entry computation here once all HLO ops with
// reference to other computation are updated to have a region instead of a
// function attribute. Currently the importer test doesn't refer to all the
// computations from the entry computation so tests may need some update.
for (const auto* computation : module.computations()) {
if (!import_all_computation_)
// Only import the entry computation, any reachable one will be imported
// unless turned into a region operation.
return HloFunctionImporter::ImportAsFunc(
*module.entry_computation(), module_, &function_map_, &builder_);
for (const auto* computation : module.computations())
TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc(
*computation, module_, &function_map_, &builder_));
}
return Status::OK();
}

View File

@ -37,7 +37,8 @@ class Shape;
// dialect. HloModuleImporter does not take ownership.
class HloModuleImporter {
public:
explicit HloModuleImporter(mlir::ModuleOp module);
explicit HloModuleImporter(mlir::ModuleOp module,
bool import_all_computation = false);
// Import the HloModule into the MLIR Module.
Status Import(const xla::HloModule& module);
@ -46,6 +47,7 @@ class HloModuleImporter {
Status Import(const xla::HloModuleProto& module);
private:
bool import_all_computation_;
mlir::ModuleOp module_;
mlir::Builder builder_;

View File

@ -22,14 +22,17 @@ limitations under the License.
namespace xla {
Status ConvertHloToMlirHlo(mlir::ModuleOp module,
xla::HloModuleProto* hlo_module_proto) {
xla::HloModuleProto* hlo_module_proto,
bool import_all_computation) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
return HloModuleImporter(module).Import(*hlo_module_proto);
return HloModuleImporter(module, import_all_computation)
.Import(*hlo_module_proto);
}
Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module) {
Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module,
bool import_all_computation) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
return HloModuleImporter(module).Import(*hlo_module);
return HloModuleImporter(module, import_all_computation).Import(*hlo_module);
}
} // namespace xla

View File

@ -28,11 +28,17 @@ class HloModule;
class HloModuleProto;
// Converts an HLO module proto to a MLIR module in HLO dialect.
// If import_all_computation is set to true, imports all computations
// irrespective if transitively called from entry computation.
Status ConvertHloToMlirHlo(mlir::ModuleOp module,
xla::HloModuleProto* hlo_module);
xla::HloModuleProto* hlo_module,
bool import_all_computations = false);
// Converts an HLO module to a MLIR module in HLO dialect.
Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module);
// If import_all_computation is set to true, imports all computations
// irrespective if transitively called from entry computation.
Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module,
bool import_all_computations = false);
} // namespace xla

View File

@ -1,4 +1,7 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FILECHECK_OPTS="" FileCheck %s
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION
// NO_DEAD_FUNCTION-NOT: @test
HloModule main

View File

@ -58,7 +58,8 @@ bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) {
} // namespace
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
llvm::StringRef input, mlir::MLIRContext* context) {
llvm::StringRef input, mlir::MLIRContext* context,
bool import_all_computations) {
HloProto hlo_proto;
string content(input.data(), input.size());
if (!LoadHloProto(content, &hlo_proto)) {
@ -68,8 +69,8 @@ mlir::OwningModuleRef HloToMlirHloTranslateFunction(
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
auto status =
ConvertHloToMlirHlo(module.get(), hlo_proto.mutable_hlo_module());
auto status = ConvertHloToMlirHlo(
module.get(), hlo_proto.mutable_hlo_module(), import_all_computations);
if (!status.ok()) {
LOG(ERROR) << "Hlo module import failed: " << status;
return nullptr;
@ -79,7 +80,8 @@ mlir::OwningModuleRef HloToMlirHloTranslateFunction(
}
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
llvm::StringRef input, mlir::MLIRContext* context) {
llvm::StringRef input, mlir::MLIRContext* context,
bool import_all_computations) {
HloProto hlo_proto;
string content(input.data(), input.size());
@ -92,7 +94,8 @@ mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
auto hlo_module = std::move(hlo_module_error.ValueOrDie());
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
auto status = ConvertHloToMlirHlo(*module, hlo_module.get());
auto status =
ConvertHloToMlirHlo(*module, hlo_module.get(), import_all_computations);
if (!status.ok()) {
LOG(ERROR) << "HLO Module import failed: " << status;
return nullptr;
@ -233,6 +236,26 @@ static mlir::LogicalResult MlirHloToHloTextViaBuilderTranslateFunction(
} // namespace xla
//----------------------------------------------------------------------------//
// Hooks for tf-mlir-translate
//----------------------------------------------------------------------------/
static llvm::cl::opt<bool> import_all_computations(
"hlo-import-all-computations",
llvm::cl::desc("Enable importing unreachable computations."));
static mlir::OwningModuleRef HloToMlirHloTranslate(llvm::StringRef input,
mlir::MLIRContext* context) {
return xla::HloToMlirHloTranslateFunction(input, context,
import_all_computations);
}
static mlir::OwningModuleRef HloTextToMlirHloTranslate(
llvm::StringRef input, mlir::MLIRContext* context) {
return xla::HloTextToMlirHloTranslateFunction(input, context,
import_all_computations);
}
static void RegisterInputDialects(mlir::DialectRegistry& registry) {
registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
mlir::tensor::TensorDialect>();
@ -255,10 +278,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextViaBuilderTranslate(
xla::MlirHloToHloTextViaBuilderTranslateFunction, RegisterInputDialects);
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
"hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);
"hlo-to-mlir-hlo", HloToMlirHloTranslate);
static mlir::TranslateToMLIRRegistration HloTextToHloMlirTranslate(
"hlo-text-to-mlir-hlo", xla::HloTextToMlirHloTranslateFunction);
"hlo-text-to-mlir-hlo", HloTextToMlirHloTranslate);
// MHLO doesn't support explicit layouts, while XLA service does.
// TODO(timshen): remove it once MHLO supports explicit layouts.

View File

@ -33,14 +33,20 @@ namespace xla {
// Converts a HloModuleProto stored in the file with the given `input_filename`
// into a MLIR module. Creates MLIR entities into the given MLIR `context`.
mlir::OwningModuleRef HloToMlirHloTranslateFunction(llvm::StringRef input,
mlir::MLIRContext* context);
// If import_all_computation is set to true, imports all computations
// irrespective if transitively called from entry computation.
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
llvm::StringRef input, mlir::MLIRContext* context,
bool import_all_computations = false);
// Converts a HloModule stored in text form for a file with the given
// `input_filename` into a MLIR module. Creates MLIR entities into the given
// MLIR `context`.
// If import_all_computation is set to true, imports all computations
// irrespective if transitively called from entry computation.
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
llvm::StringRef input, mlir::MLIRContext* context);
llvm::StringRef input, mlir::MLIRContext* context,
bool import_all_computations = false);
} // namespace xla