From f9161fab8d8b6e362e222b5140d38a0b30abd105 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sat, 13 Jun 2020 15:00:27 -0700 Subject: [PATCH] Split up operator writer's main function Instead of having all generated functions be in the same function, split them out to separate functions similar to how the custom exports are done. Also update pointer to follow new reference style. PiperOrigin-RevId: 316286894 Change-Id: I27cb1325c160ca5b8b9668f1f6731443d409a7d5 --- .../compiler/mlir/xla/operator_writer_gen.cc | 52 +++++++++++-------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 27cd7e21147..10dc5ec9dd0 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -68,10 +68,11 @@ static StringRef GetClientBuilder(const Operator& op) { return kOpToXLABuilderMap->lookup(op_name); } -static void BuildOperator(const Operator& op, raw_ostream* output) { - auto& os = *output; - os << " auto& value_map = *lowering_context.values;\n" - << " auto result = xla_op.getResult();\n"; +static void BuildOperator(const Operator& op, raw_ostream& os) { + os << "mlir::LogicalResult ExportXlaOp(mlir::xla_hlo::" + << op.getCppClassName() << " op, OpLoweringContext ctx) {\n" + << " auto& value_map = *ctx.values;\n" + << " auto result = op.getResult();\n"; // Build a conversion for each of the arguments. int operand_number = 0; @@ -82,15 +83,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) { if (auto* operand_cst = arg.dyn_cast()) { // Handle a non-variadic operand. if (!operand_cst->isVariableLength()) { - os << " auto xla_arg_" << index - << " = value_map[*xla_op.getODSOperands(" << operand_number++ - << ").begin()];\n"; + os << " auto xla_arg_" << index << " = value_map[*op.getODSOperands(" + << operand_number++ << ").begin()];\n"; continue; } // Otherwise, this is a varidiac operand list. - os << " std::vector xla_arg_" << index << ";\n" - << " for (auto operand : xla_op.getODSOperands(" << operand_number++ + os << " std::vector xla_arg_" << index << ";\n" + << " for (auto operand : op.getODSOperands(" << operand_number++ << "))\n xla_arg_" << index << ".push_back(value_map[operand]);\n"; continue; @@ -98,18 +98,18 @@ static void BuildOperator(const Operator& op, raw_ostream* output) { // Otherwise, this is an attribute. auto named_attr = arg.get(); - os << " auto xla_arg_" << index << " = " - << GetDefaultAttrExport(*named_attr) << "(xla_op." - << op.getArgName(index) << "());\n"; + os << " auto xla_arg_" << index << " = " + << GetDefaultAttrExport(*named_attr) << "(op." << op.getArgName(index) + << "());\n"; } // Emit call to client API - os << " auto xla_result = xla::" << GetClientBuilder(op) << "("; + os << " auto xla_result = xla::" << GetClientBuilder(op) << "("; // If all operands are variadic, then pass the builder explicitly to xla // client API call if (op.getNumOperands() == op.getNumVariableLengthOperands()) { - os << "lowering_context.builder"; + os << "ctx.builder"; if (op.getNumArgs() != 0) os << ", "; } @@ -118,8 +118,9 @@ static void BuildOperator(const Operator& op, raw_ostream* output) { [&](int i) { os << "Unwrap(xla_arg_" << i << ')'; }); os << ");\n"; - os << " value_map[result] = xla_result;\n"; - os << " return mlir::success();\n"; + os << " value_map[result] = xla_result;\n"; + os << " return mlir::success();\n"; + os << "}\n"; } // The function below has a non-constant reference as that is required by LLVM's @@ -128,6 +129,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) { static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { emitSourceFileHeader("MLIR XLA Builders", os); + // Emit all the helper functions. + for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) { + Operator op(def); + + // Skip operations that have a custom exporter. + if (!def->getValueAsBit("hasCustomHLOConverter")) BuildOperator(op, os); + } + // Emit a function to generate an XLA operation for the operations with // auto-generated builders. os << "mlir::LogicalResult ExportXlaOperator(\n" @@ -153,12 +162,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { // Cast to the current operation and build the exporter. os << " if (auto xla_op = llvm::dyn_cast(op)) {\n"; - if (def->getValueAsBit("hasCustomHLOConverter")) { - os << " return mlir::xla_hlo::ExportXlaOp(xla_op, " - "lowering_context);\n"; - } else { - BuildOperator(op, &os); - } + os << " return "; + // The autogenerated converters aren't in the same namespace. + // TODO(jpienaar): Reconsider this. + if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::xla_hlo::"; + os << "ExportXlaOp(xla_op, lowering_context);\n"; os << " }\n"; }