Add a factory method to import HLO computation as a region in MLIR
This new factory method ImportAsRegion will be used from Mlir HloBuilder to handle ops with XlaComputations. Also, updated methods to take HloComputation as const reference instead of pointer. PiperOrigin-RevId: 315874568 Change-Id: I0f1001cc858ea82fdfbbc89f923dcd5ea057fd28
This commit is contained in:
parent
bc38810e99
commit
5e260f92ac
tensorflow/compiler/mlir/xla
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
@ -79,30 +81,35 @@ bool DotIsDefault(const HloInstruction* instruction) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
|
||||
mlir::ModuleOp module, mlir::Builder* builder,
|
||||
std::unordered_map<HloComputation*, FuncOp>* function_map,
|
||||
HloComputation* computation) {
|
||||
HloFunctionImporter importer(module, builder, function_map);
|
||||
return importer.ImportFunction(computation);
|
||||
Status HloFunctionImporter::ImportAsFunc(
|
||||
const HloComputation& computation, mlir::ModuleOp module,
|
||||
std::unordered_map<const HloComputation*, FuncOp>* function_map,
|
||||
mlir::Builder* builder) {
|
||||
HloFunctionImporter importer(module, function_map, builder);
|
||||
return importer.ImportAsFunc(computation).status();
|
||||
}
|
||||
|
||||
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
|
||||
HloComputation* computation) {
|
||||
auto& imported = (*function_map_)[computation];
|
||||
Status HloFunctionImporter::ImportAsRegion(
|
||||
const xla::HloComputation& computation, mlir::Region* region,
|
||||
mlir::Builder* builder) {
|
||||
HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
|
||||
builder);
|
||||
return importer.ImportAsRegion(computation, region);
|
||||
}
|
||||
|
||||
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
|
||||
const HloComputation& computation) {
|
||||
auto& imported = (*function_map_)[&computation];
|
||||
if (imported) return imported;
|
||||
|
||||
llvm::SmallVector<Type, 4> args, rets;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetMlirTypes(computation->parameter_instructions(), &args));
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes({computation->root_instruction()}, &rets));
|
||||
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
|
||||
auto func_type = mlir::FunctionType::get(args, rets, context_);
|
||||
|
||||
string computation_name =
|
||||
computation->parent()->entry_computation() == computation
|
||||
computation.parent()->entry_computation() == &computation
|
||||
? "main"
|
||||
: SanitizeFunctionName(computation->name());
|
||||
: SanitizeFunctionName(computation.name());
|
||||
|
||||
// Construct the MLIR function and map arguments.
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs;
|
||||
@ -119,31 +126,30 @@ StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
|
||||
return function;
|
||||
}
|
||||
|
||||
tensorflow::Status HloFunctionImporter::ImportComputation(
|
||||
HloComputation* computation, mlir::Region* region) {
|
||||
tensorflow::Status HloFunctionImporter::ImportAsRegion(
|
||||
const HloComputation& computation, mlir::Region* region) {
|
||||
// TODO(hinsu): Store computation name as an attribute for round-trip.
|
||||
auto* block = new mlir::Block;
|
||||
region->push_back(block);
|
||||
|
||||
llvm::SmallVector<Type, 4> args;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetMlirTypes(computation->parameter_instructions(), &args));
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
|
||||
block->addArguments(args);
|
||||
|
||||
return ImportInstructions(computation, block);
|
||||
}
|
||||
|
||||
tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
HloComputation* computation, mlir::Block* block) {
|
||||
const HloComputation& computation, mlir::Block* block) {
|
||||
// Setup the input parameters.
|
||||
const int num_parameters = computation->num_parameters();
|
||||
const int num_parameters = computation.num_parameters();
|
||||
for (int i = 0; i < num_parameters; i++) {
|
||||
auto hlo_parameter = computation->parameter_instruction(i);
|
||||
auto hlo_parameter = computation.parameter_instruction(i);
|
||||
instruction_value_map_[hlo_parameter] = block->getArgument(i);
|
||||
}
|
||||
|
||||
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
for (auto instruction : computation.MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(auto new_operation,
|
||||
ImportInstruction(instruction, &builder));
|
||||
if (new_operation) {
|
||||
@ -156,7 +162,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
|
||||
// Setup the return type (HLO only supports a single return value).
|
||||
TF_ASSIGN_OR_RETURN(auto result,
|
||||
GetMlirValue(computation->root_instruction()));
|
||||
GetMlirValue(computation.root_instruction()));
|
||||
|
||||
// Create terminator op depending on the parent op of this region.
|
||||
if (llvm::isa<FuncOp>(block->getParentOp())) {
|
||||
@ -249,7 +255,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
case HloOpcode::kCall: {
|
||||
TF_ASSIGN_OR_RETURN(FuncOp function,
|
||||
ImportFunction(instruction->to_apply()));
|
||||
ImportAsFunc(*instruction->to_apply()));
|
||||
mlir::Operation* new_operation =
|
||||
func_builder->create<mlir::CallOp>(loc, function, operands);
|
||||
return new_operation;
|
||||
@ -365,8 +371,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
|
||||
auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(),
|
||||
&scatter_op.update_computation()));
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
|
||||
&scatter_op.update_computation()));
|
||||
return scatter_op.getOperation();
|
||||
}
|
||||
case HloOpcode::kSelectAndScatter: {
|
||||
@ -387,10 +393,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
auto select_scatter_op =
|
||||
func_builder->create<mlir::xla_hlo::SelectAndScatterOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(select_scatter->select(),
|
||||
&select_scatter_op.select()));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(select_scatter->scatter(),
|
||||
&select_scatter_op.scatter()));
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
|
||||
&select_scatter_op.select()));
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
|
||||
&select_scatter_op.scatter()));
|
||||
return select_scatter_op.getOperation();
|
||||
}
|
||||
case HloOpcode::kSetDimensionSize: {
|
||||
@ -414,8 +420,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
loc, result_type, operands,
|
||||
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
|
||||
builder_->getBoolAttr(sort_instruction->is_stable()));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(sort_instruction->to_apply(),
|
||||
&sort_op.comparator()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
|
||||
return sort_op.getOperation();
|
||||
}
|
||||
case HloOpcode::kConditional: {
|
||||
@ -430,10 +436,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
|
||||
auto op = func_builder->create<mlir::xla_hlo::IfOp>(loc, rets, operands,
|
||||
attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(),
|
||||
&op.true_branch()));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(),
|
||||
&op.false_branch()));
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
|
||||
&op.true_branch()));
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
|
||||
&op.false_branch()));
|
||||
return op.getOperation();
|
||||
}
|
||||
|
||||
@ -448,8 +454,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
llvm::enumerate(instruction->branch_computations())) {
|
||||
auto index = index_and_computation.index();
|
||||
HloComputation* computation = index_and_computation.value();
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(computation, &op.branches()[index]));
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index]));
|
||||
}
|
||||
return op.getOperation();
|
||||
}
|
||||
@ -468,8 +473,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
|
||||
auto all_reduce_op = func_builder->create<mlir::xla_hlo::AllReduceOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(all_reduce->to_apply(),
|
||||
&all_reduce_op.computation()));
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
|
||||
&all_reduce_op.computation()));
|
||||
return all_reduce_op.getOperation();
|
||||
}
|
||||
case HloOpcode::kReduce: {
|
||||
@ -481,7 +486,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
llvm::makeArrayRef(operands).drop_front(num_inputs),
|
||||
ConvertDimensions(instruction->dimensions()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->to_apply(), &reduce.body()));
|
||||
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
|
||||
return reduce.getOperation();
|
||||
}
|
||||
case HloOpcode::kReverse: {
|
||||
@ -517,9 +522,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
auto op = func_builder->create<mlir::xla_hlo::WhileOp>(
|
||||
loc, operands[0].getType(), operands[0]);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->while_condition(), &op.cond()));
|
||||
ImportAsRegion(*instruction->while_condition(), &op.cond()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->while_body(), &op.body()));
|
||||
ImportAsRegion(*instruction->while_body(), &op.body()));
|
||||
return op.getOperation();
|
||||
}
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
@ -580,7 +585,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
auto reduce = func_builder->create<mlir::xla_hlo::ReduceWindowOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->to_apply(), &reduce.body()));
|
||||
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
|
||||
return reduce.getOperation();
|
||||
}
|
||||
case HloOpcode::kMap: {
|
||||
@ -588,7 +593,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
loc, result_type, operands,
|
||||
ConvertDimensions(instruction->dimensions()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->to_apply(), &op.computation()));
|
||||
ImportAsRegion(*instruction->to_apply(), &op.computation()));
|
||||
return op.getOperation();
|
||||
}
|
||||
case HloOpcode::kConvolution: {
|
||||
|
@ -42,29 +42,39 @@ class Shape;
|
||||
// Helper class for importing HloComputations.
|
||||
class HloFunctionImporter {
|
||||
public:
|
||||
static StatusOr<mlir::FuncOp> ImportFunction(
|
||||
mlir::ModuleOp module, mlir::Builder* builder,
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map,
|
||||
xla::HloComputation* computation);
|
||||
// Imports the given computation as a function in the given module. This also
|
||||
// imports any computations referred by instructions in this computation.
|
||||
static Status ImportAsFunc(const xla::HloComputation& computation,
|
||||
mlir::ModuleOp module,
|
||||
std::unordered_map<const xla::HloComputation*,
|
||||
mlir::FuncOp>* function_map,
|
||||
mlir::Builder* builder);
|
||||
|
||||
// Imports the given hlo computation to the specified region.
|
||||
static Status ImportAsRegion(const xla::HloComputation& computation,
|
||||
mlir::Region* region, mlir::Builder* builder);
|
||||
|
||||
private:
|
||||
HloFunctionImporter(
|
||||
mlir::ModuleOp module, mlir::Builder* builder,
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map)
|
||||
HloFunctionImporter(mlir::ModuleOp module,
|
||||
std::unordered_map<const xla::HloComputation*,
|
||||
mlir::FuncOp>* function_map,
|
||||
mlir::Builder* builder)
|
||||
: context_(module.getContext()),
|
||||
module_(module),
|
||||
builder_(builder),
|
||||
function_map_(function_map) {}
|
||||
|
||||
StatusOr<mlir::FuncOp> ImportFunction(xla::HloComputation* computation);
|
||||
// Imports the given computation as a new function, if it hasn't been already
|
||||
// imported.
|
||||
StatusOr<mlir::FuncOp> ImportAsFunc(const xla::HloComputation& computation);
|
||||
|
||||
// Imports the given computation in the specified region.
|
||||
tensorflow::Status ImportComputation(HloComputation* computation,
|
||||
mlir::Region* region);
|
||||
tensorflow::Status ImportAsRegion(const HloComputation& computation,
|
||||
mlir::Region* region);
|
||||
|
||||
// Imports instructions from the given computation in the specified block.
|
||||
// Assumes that the block already has correct arguments populated.
|
||||
tensorflow::Status ImportInstructions(HloComputation* computation,
|
||||
tensorflow::Status ImportInstructions(const HloComputation& computation,
|
||||
mlir::Block* block);
|
||||
|
||||
// Imports an instruction.
|
||||
@ -125,7 +135,7 @@ class HloFunctionImporter {
|
||||
mlir::Builder* builder_;
|
||||
|
||||
// Mapping from HloComputation to the created MLIR function.
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map_;
|
||||
std::unordered_map<const xla::HloComputation*, mlir::FuncOp>* function_map_;
|
||||
|
||||
// Mapping from HloInstructions to the associative MLIR values.
|
||||
std::unordered_map<xla::HloInstruction*, mlir::Value> instruction_value_map_;
|
||||
|
@ -33,11 +33,11 @@ namespace xla {
|
||||
Status HloModuleImporter::Import(const xla::HloModule& module) {
|
||||
// TODO(hinsu): Only import the entry computation here once all HLO ops with
|
||||
// reference to other computation are updated to have a region instead of a
|
||||
// function attribute.
|
||||
for (const auto& computation : module.computations()) {
|
||||
auto result = HloFunctionImporter::ImportFunction(
|
||||
module_, &builder_, &function_map_, computation);
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
// function attribute. Currently the importer test doesn't refer to all the
|
||||
// computations from the entry computation so tests may need some update.
|
||||
for (const auto* computation : module.computations()) {
|
||||
TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc(
|
||||
*computation, module_, &function_map_, &builder_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -54,7 +54,7 @@ class HloModuleImporter {
|
||||
// Map for tracking which MLIR function map to which HLO Computation. This
|
||||
// tracks functions as they are imported and provides a quick lookup for
|
||||
// functions invoked by control flow related operations (e.g. while, call).
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp> function_map_;
|
||||
std::unordered_map<const xla::HloComputation*, mlir::FuncOp> function_map_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user