From 5e260f92acaf2e07f2e2109bda9f6c51a346a28e Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Thu, 11 Jun 2020 04:33:05 -0700 Subject: [PATCH] Add a factory method to import HLO computation as a region in MLIR This new factory method ImportAsRegion will be used from Mlir HloBuilder to handle ops with XlaComputations. Also, updated methods to take HloComputation as const reference instead of pointer. PiperOrigin-RevId: 315874568 Change-Id: I0f1001cc858ea82fdfbbc89f923dcd5ea057fd28 --- .../mlir/xla/hlo_function_importer.cc | 99 ++++++++++--------- .../compiler/mlir/xla/hlo_function_importer.h | 34 ++++--- .../compiler/mlir/xla/hlo_module_importer.cc | 10 +- .../compiler/mlir/xla/hlo_module_importer.h | 2 +- 4 files changed, 80 insertions(+), 65 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 22a0b038833..a3b6222a8af 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" +#include + #include "absl/types/optional.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -79,30 +81,35 @@ bool DotIsDefault(const HloInstruction* instruction) { } } // namespace -StatusOr HloFunctionImporter::ImportFunction( - mlir::ModuleOp module, mlir::Builder* builder, - std::unordered_map* function_map, - HloComputation* computation) { - HloFunctionImporter importer(module, builder, function_map); - return importer.ImportFunction(computation); +Status HloFunctionImporter::ImportAsFunc( + const HloComputation& computation, mlir::ModuleOp module, + std::unordered_map* function_map, + mlir::Builder* builder) { + HloFunctionImporter importer(module, function_map, builder); + return importer.ImportAsFunc(computation).status(); } -StatusOr HloFunctionImporter::ImportFunction( - HloComputation* computation) { - auto& imported = (*function_map_)[computation]; +Status HloFunctionImporter::ImportAsRegion( + const xla::HloComputation& computation, mlir::Region* region, + mlir::Builder* builder) { + HloFunctionImporter importer(region->getParentOfType(), {}, + builder); + return importer.ImportAsRegion(computation, region); +} + +StatusOr HloFunctionImporter::ImportAsFunc( + const HloComputation& computation) { + auto& imported = (*function_map_)[&computation]; if (imported) return imported; - llvm::SmallVector args, rets; - TF_RETURN_IF_ERROR( - GetMlirTypes(computation->parameter_instructions(), &args)); - TF_RETURN_IF_ERROR(GetMlirTypes({computation->root_instruction()}, &rets)); - + TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args)); + TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets)); auto func_type = mlir::FunctionType::get(args, rets, context_); string computation_name = - computation->parent()->entry_computation() == computation + computation.parent()->entry_computation() == &computation ? "main" - : SanitizeFunctionName(computation->name()); + : SanitizeFunctionName(computation.name()); // Construct the MLIR function and map arguments. llvm::ArrayRef attrs; @@ -119,31 +126,30 @@ StatusOr HloFunctionImporter::ImportFunction( return function; } -tensorflow::Status HloFunctionImporter::ImportComputation( - HloComputation* computation, mlir::Region* region) { +tensorflow::Status HloFunctionImporter::ImportAsRegion( + const HloComputation& computation, mlir::Region* region) { // TODO(hinsu): Store computation name as an attribute for round-trip. auto* block = new mlir::Block; region->push_back(block); llvm::SmallVector args; - TF_RETURN_IF_ERROR( - GetMlirTypes(computation->parameter_instructions(), &args)); + TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args)); block->addArguments(args); return ImportInstructions(computation, block); } tensorflow::Status HloFunctionImporter::ImportInstructions( - HloComputation* computation, mlir::Block* block) { + const HloComputation& computation, mlir::Block* block) { // Setup the input parameters. - const int num_parameters = computation->num_parameters(); + const int num_parameters = computation.num_parameters(); for (int i = 0; i < num_parameters; i++) { - auto hlo_parameter = computation->parameter_instruction(i); + auto hlo_parameter = computation.parameter_instruction(i); instruction_value_map_[hlo_parameter] = block->getArgument(i); } mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block); - for (auto instruction : computation->MakeInstructionPostOrder()) { + for (auto instruction : computation.MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(auto new_operation, ImportInstruction(instruction, &builder)); if (new_operation) { @@ -156,7 +162,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions( // Setup the return type (HLO only supports a single return value). TF_ASSIGN_OR_RETURN(auto result, - GetMlirValue(computation->root_instruction())); + GetMlirValue(computation.root_instruction())); // Create terminator op depending on the parent op of this region. if (llvm::isa(block->getParentOp())) { @@ -249,7 +255,7 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kCall: { TF_ASSIGN_OR_RETURN(FuncOp function, - ImportFunction(instruction->to_apply())); + ImportAsFunc(*instruction->to_apply())); mlir::Operation* new_operation = func_builder->create(loc, function, operands); return new_operation; @@ -365,8 +371,8 @@ StatusOr HloFunctionImporter::ImportInstruction( auto scatter_op = func_builder->create( loc, result_type, operands, attributes); - TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(), - &scatter_op.update_computation())); + TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(), + &scatter_op.update_computation())); return scatter_op.getOperation(); } case HloOpcode::kSelectAndScatter: { @@ -387,10 +393,10 @@ StatusOr HloFunctionImporter::ImportInstruction( auto select_scatter_op = func_builder->create( loc, result_type, operands, attributes); - TF_RETURN_IF_ERROR(ImportComputation(select_scatter->select(), - &select_scatter_op.select())); - TF_RETURN_IF_ERROR(ImportComputation(select_scatter->scatter(), - &select_scatter_op.scatter())); + TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(), + &select_scatter_op.select())); + TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(), + &select_scatter_op.scatter())); return select_scatter_op.getOperation(); } case HloOpcode::kSetDimensionSize: { @@ -414,8 +420,8 @@ StatusOr HloFunctionImporter::ImportInstruction( loc, result_type, operands, builder_->getI64IntegerAttr(sort_instruction->sort_dimension()), builder_->getBoolAttr(sort_instruction->is_stable())); - TF_RETURN_IF_ERROR(ImportComputation(sort_instruction->to_apply(), - &sort_op.comparator())); + TF_RETURN_IF_ERROR( + ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator())); return sort_op.getOperation(); } case HloOpcode::kConditional: { @@ -430,10 +436,10 @@ StatusOr HloFunctionImporter::ImportInstruction( auto op = func_builder->create(loc, rets, operands, attributes); - TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(), - &op.true_branch())); - TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(), - &op.false_branch())); + TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(), + &op.true_branch())); + TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(), + &op.false_branch())); return op.getOperation(); } @@ -448,8 +454,7 @@ StatusOr HloFunctionImporter::ImportInstruction( llvm::enumerate(instruction->branch_computations())) { auto index = index_and_computation.index(); HloComputation* computation = index_and_computation.value(); - TF_RETURN_IF_ERROR( - ImportComputation(computation, &op.branches()[index])); + TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index])); } return op.getOperation(); } @@ -468,8 +473,8 @@ StatusOr HloFunctionImporter::ImportInstruction( attributes.push_back(ConvertChannelHandle(all_reduce->channel_id())); auto all_reduce_op = func_builder->create( loc, result_type, operands, attributes); - TF_RETURN_IF_ERROR(ImportComputation(all_reduce->to_apply(), - &all_reduce_op.computation())); + TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(), + &all_reduce_op.computation())); return all_reduce_op.getOperation(); } case HloOpcode::kReduce: { @@ -481,7 +486,7 @@ StatusOr HloFunctionImporter::ImportInstruction( llvm::makeArrayRef(operands).drop_front(num_inputs), ConvertDimensions(instruction->dimensions())); TF_RETURN_IF_ERROR( - ImportComputation(instruction->to_apply(), &reduce.body())); + ImportAsRegion(*instruction->to_apply(), &reduce.body())); return reduce.getOperation(); } case HloOpcode::kReverse: { @@ -517,9 +522,9 @@ StatusOr HloFunctionImporter::ImportInstruction( auto op = func_builder->create( loc, operands[0].getType(), operands[0]); TF_RETURN_IF_ERROR( - ImportComputation(instruction->while_condition(), &op.cond())); + ImportAsRegion(*instruction->while_condition(), &op.cond())); TF_RETURN_IF_ERROR( - ImportComputation(instruction->while_body(), &op.body())); + ImportAsRegion(*instruction->while_body(), &op.body())); return op.getOperation(); } case HloOpcode::kGetTupleElement: { @@ -580,7 +585,7 @@ StatusOr HloFunctionImporter::ImportInstruction( auto reduce = func_builder->create( loc, result_type, operands, attributes); TF_RETURN_IF_ERROR( - ImportComputation(instruction->to_apply(), &reduce.body())); + ImportAsRegion(*instruction->to_apply(), &reduce.body())); return reduce.getOperation(); } case HloOpcode::kMap: { @@ -588,7 +593,7 @@ StatusOr HloFunctionImporter::ImportInstruction( loc, result_type, operands, ConvertDimensions(instruction->dimensions())); TF_RETURN_IF_ERROR( - ImportComputation(instruction->to_apply(), &op.computation())); + ImportAsRegion(*instruction->to_apply(), &op.computation())); return op.getOperation(); } case HloOpcode::kConvolution: { diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index 14b6d309e94..f0b978d5306 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -42,29 +42,39 @@ class Shape; // Helper class for importing HloComputations. class HloFunctionImporter { public: - static StatusOr ImportFunction( - mlir::ModuleOp module, mlir::Builder* builder, - std::unordered_map* function_map, - xla::HloComputation* computation); + // Imports the given computation as a function in the given module. This also + // imports any computations referred by instructions in this computation. + static Status ImportAsFunc(const xla::HloComputation& computation, + mlir::ModuleOp module, + std::unordered_map* function_map, + mlir::Builder* builder); + + // Imports the given hlo computation to the specified region. + static Status ImportAsRegion(const xla::HloComputation& computation, + mlir::Region* region, mlir::Builder* builder); private: - HloFunctionImporter( - mlir::ModuleOp module, mlir::Builder* builder, - std::unordered_map* function_map) + HloFunctionImporter(mlir::ModuleOp module, + std::unordered_map* function_map, + mlir::Builder* builder) : context_(module.getContext()), module_(module), builder_(builder), function_map_(function_map) {} - StatusOr ImportFunction(xla::HloComputation* computation); + // Imports the given computation as a new function, if it hasn't been already + // imported. + StatusOr ImportAsFunc(const xla::HloComputation& computation); // Imports the given computation in the specified region. - tensorflow::Status ImportComputation(HloComputation* computation, - mlir::Region* region); + tensorflow::Status ImportAsRegion(const HloComputation& computation, + mlir::Region* region); // Imports instructions from the given computation in the specified block. // Assumes that the block already has correct arguments populated. - tensorflow::Status ImportInstructions(HloComputation* computation, + tensorflow::Status ImportInstructions(const HloComputation& computation, mlir::Block* block); // Imports an instruction. @@ -125,7 +135,7 @@ class HloFunctionImporter { mlir::Builder* builder_; // Mapping from HloComputation to the created MLIR function. - std::unordered_map* function_map_; + std::unordered_map* function_map_; // Mapping from HloInstructions to the associative MLIR values. std::unordered_map instruction_value_map_; diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc index 906dcba0083..888abae1efd 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc @@ -33,11 +33,11 @@ namespace xla { Status HloModuleImporter::Import(const xla::HloModule& module) { // TODO(hinsu): Only import the entry computation here once all HLO ops with // reference to other computation are updated to have a region instead of a - // function attribute. - for (const auto& computation : module.computations()) { - auto result = HloFunctionImporter::ImportFunction( - module_, &builder_, &function_map_, computation); - TF_RETURN_IF_ERROR(result.status()); + // function attribute. Currently the importer test doesn't refer to all the + // computations from the entry computation so tests may need some update. + for (const auto* computation : module.computations()) { + TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc( + *computation, module_, &function_map_, &builder_)); } return Status::OK(); diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.h b/tensorflow/compiler/mlir/xla/hlo_module_importer.h index 2fd7102c5a6..b0a8bf4c0c7 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.h @@ -54,7 +54,7 @@ class HloModuleImporter { // Map for tracking which MLIR function map to which HLO Computation. This // tracks functions as they are imported and provides a quick lookup for // functions invoked by control flow related operations (e.g. while, call). - std::unordered_map function_map_; + std::unordered_map function_map_; }; } // namespace xla