Split up operator writer's main function
Instead of having all generated functions be in the same function, split them out to separate functions similar to how the custom exports are done. Also update pointer to follow new reference style. PiperOrigin-RevId: 316286894 Change-Id: I27cb1325c160ca5b8b9668f1f6731443d409a7d5
This commit is contained in:
parent
52d1e72c66
commit
f9161fab8d
@ -68,10 +68,11 @@ static StringRef GetClientBuilder(const Operator& op) {
|
||||
return kOpToXLABuilderMap->lookup(op_name);
|
||||
}
|
||||
|
||||
static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
auto& os = *output;
|
||||
os << " auto& value_map = *lowering_context.values;\n"
|
||||
<< " auto result = xla_op.getResult();\n";
|
||||
static void BuildOperator(const Operator& op, raw_ostream& os) {
|
||||
os << "mlir::LogicalResult ExportXlaOp(mlir::xla_hlo::"
|
||||
<< op.getCppClassName() << " op, OpLoweringContext ctx) {\n"
|
||||
<< " auto& value_map = *ctx.values;\n"
|
||||
<< " auto result = op.getResult();\n";
|
||||
|
||||
// Build a conversion for each of the arguments.
|
||||
int operand_number = 0;
|
||||
@ -82,15 +83,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
|
||||
// Handle a non-variadic operand.
|
||||
if (!operand_cst->isVariableLength()) {
|
||||
os << " auto xla_arg_" << index
|
||||
<< " = value_map[*xla_op.getODSOperands(" << operand_number++
|
||||
<< ").begin()];\n";
|
||||
os << " auto xla_arg_" << index << " = value_map[*op.getODSOperands("
|
||||
<< operand_number++ << ").begin()];\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, this is a varidiac operand list.
|
||||
os << " std::vector<xla::XlaOp> xla_arg_" << index << ";\n"
|
||||
<< " for (auto operand : xla_op.getODSOperands(" << operand_number++
|
||||
os << " std::vector<xla::XlaOp> xla_arg_" << index << ";\n"
|
||||
<< " for (auto operand : op.getODSOperands(" << operand_number++
|
||||
<< "))\n xla_arg_" << index
|
||||
<< ".push_back(value_map[operand]);\n";
|
||||
continue;
|
||||
@ -98,18 +98,18 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
|
||||
// Otherwise, this is an attribute.
|
||||
auto named_attr = arg.get<NamedAttribute*>();
|
||||
os << " auto xla_arg_" << index << " = "
|
||||
<< GetDefaultAttrExport(*named_attr) << "(xla_op."
|
||||
<< op.getArgName(index) << "());\n";
|
||||
os << " auto xla_arg_" << index << " = "
|
||||
<< GetDefaultAttrExport(*named_attr) << "(op." << op.getArgName(index)
|
||||
<< "());\n";
|
||||
}
|
||||
|
||||
// Emit call to client API
|
||||
os << " auto xla_result = xla::" << GetClientBuilder(op) << "(";
|
||||
os << " auto xla_result = xla::" << GetClientBuilder(op) << "(";
|
||||
|
||||
// If all operands are variadic, then pass the builder explicitly to xla
|
||||
// client API call
|
||||
if (op.getNumOperands() == op.getNumVariableLengthOperands()) {
|
||||
os << "lowering_context.builder";
|
||||
os << "ctx.builder";
|
||||
if (op.getNumArgs() != 0) os << ", ";
|
||||
}
|
||||
|
||||
@ -118,8 +118,9 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
[&](int i) { os << "Unwrap(xla_arg_" << i << ')'; });
|
||||
os << ");\n";
|
||||
|
||||
os << " value_map[result] = xla_result;\n";
|
||||
os << " return mlir::success();\n";
|
||||
os << " value_map[result] = xla_result;\n";
|
||||
os << " return mlir::success();\n";
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
// The function below has a non-constant reference as that is required by LLVM's
|
||||
@ -128,6 +129,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
|
||||
emitSourceFileHeader("MLIR XLA Builders", os);
|
||||
|
||||
// Emit all the helper functions.
|
||||
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
|
||||
Operator op(def);
|
||||
|
||||
// Skip operations that have a custom exporter.
|
||||
if (!def->getValueAsBit("hasCustomHLOConverter")) BuildOperator(op, os);
|
||||
}
|
||||
|
||||
// Emit a function to generate an XLA operation for the operations with
|
||||
// auto-generated builders.
|
||||
os << "mlir::LogicalResult ExportXlaOperator(\n"
|
||||
@ -153,12 +162,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
|
||||
// Cast to the current operation and build the exporter.
|
||||
os << " if (auto xla_op = llvm::dyn_cast<mlir::xla_hlo::"
|
||||
<< op.getCppClassName() << ">(op)) {\n";
|
||||
if (def->getValueAsBit("hasCustomHLOConverter")) {
|
||||
os << " return mlir::xla_hlo::ExportXlaOp(xla_op, "
|
||||
"lowering_context);\n";
|
||||
} else {
|
||||
BuildOperator(op, &os);
|
||||
}
|
||||
os << " return ";
|
||||
// The autogenerated converters aren't in the same namespace.
|
||||
// TODO(jpienaar): Reconsider this.
|
||||
if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::xla_hlo::";
|
||||
os << "ExportXlaOp(xla_op, lowering_context);\n";
|
||||
os << " }\n";
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user