diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 27cd7e21147..10dc5ec9dd0 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -68,10 +68,11 @@ static StringRef GetClientBuilder(const Operator& op) { return kOpToXLABuilderMap->lookup(op_name); } -static void BuildOperator(const Operator& op, raw_ostream* output) { - auto& os = *output; - os << " auto& value_map = *lowering_context.values;\n" - << " auto result = xla_op.getResult();\n"; +static void BuildOperator(const Operator& op, raw_ostream& os) { + os << "mlir::LogicalResult ExportXlaOp(mlir::xla_hlo::" + << op.getCppClassName() << " op, OpLoweringContext ctx) {\n" + << " auto& value_map = *ctx.values;\n" + << " auto result = op.getResult();\n"; // Build a conversion for each of the arguments. int operand_number = 0; @@ -82,15 +83,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) { if (auto* operand_cst = arg.dyn_cast()) { // Handle a non-variadic operand. if (!operand_cst->isVariableLength()) { - os << " auto xla_arg_" << index - << " = value_map[*xla_op.getODSOperands(" << operand_number++ - << ").begin()];\n"; + os << " auto xla_arg_" << index << " = value_map[*op.getODSOperands(" + << operand_number++ << ").begin()];\n"; continue; } // Otherwise, this is a varidiac operand list. - os << " std::vector xla_arg_" << index << ";\n" - << " for (auto operand : xla_op.getODSOperands(" << operand_number++ + os << " std::vector xla_arg_" << index << ";\n" + << " for (auto operand : op.getODSOperands(" << operand_number++ << "))\n xla_arg_" << index << ".push_back(value_map[operand]);\n"; continue; @@ -98,18 +98,18 @@ static void BuildOperator(const Operator& op, raw_ostream* output) { // Otherwise, this is an attribute. auto named_attr = arg.get(); - os << " auto xla_arg_" << index << " = " - << GetDefaultAttrExport(*named_attr) << "(xla_op." - << op.getArgName(index) << "());\n"; + os << " auto xla_arg_" << index << " = " + << GetDefaultAttrExport(*named_attr) << "(op." << op.getArgName(index) + << "());\n"; } // Emit call to client API - os << " auto xla_result = xla::" << GetClientBuilder(op) << "("; + os << " auto xla_result = xla::" << GetClientBuilder(op) << "("; // If all operands are variadic, then pass the builder explicitly to xla // client API call if (op.getNumOperands() == op.getNumVariableLengthOperands()) { - os << "lowering_context.builder"; + os << "ctx.builder"; if (op.getNumArgs() != 0) os << ", "; } @@ -118,8 +118,9 @@ static void BuildOperator(const Operator& op, raw_ostream* output) { [&](int i) { os << "Unwrap(xla_arg_" << i << ')'; }); os << ");\n"; - os << " value_map[result] = xla_result;\n"; - os << " return mlir::success();\n"; + os << " value_map[result] = xla_result;\n"; + os << " return mlir::success();\n"; + os << "}\n"; } // The function below has a non-constant reference as that is required by LLVM's @@ -128,6 +129,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) { static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { emitSourceFileHeader("MLIR XLA Builders", os); + // Emit all the helper functions. + for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) { + Operator op(def); + + // Skip operations that have a custom exporter. + if (!def->getValueAsBit("hasCustomHLOConverter")) BuildOperator(op, os); + } + // Emit a function to generate an XLA operation for the operations with // auto-generated builders. os << "mlir::LogicalResult ExportXlaOperator(\n" @@ -153,12 +162,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { // Cast to the current operation and build the exporter. os << " if (auto xla_op = llvm::dyn_cast(op)) {\n"; - if (def->getValueAsBit("hasCustomHLOConverter")) { - os << " return mlir::xla_hlo::ExportXlaOp(xla_op, " - "lowering_context);\n"; - } else { - BuildOperator(op, &os); - } + os << " return "; + // The autogenerated converters aren't in the same namespace. + // TODO(jpienaar): Reconsider this. + if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::xla_hlo::"; + os << "ExportXlaOp(xla_op, lowering_context);\n"; os << " }\n"; }