From 8be0435b0147263c3872bedec58fd215f784b450 Mon Sep 17 00:00:00 2001 From: Ehsan Toosi Date: Wed, 9 Sep 2020 11:48:38 +0200 Subject: [PATCH] [hlo] Unbreak hlo-legalize-to-lhlo test --- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 42 ++++++++++++++++--- .../mlir/hlo/tests/hlo-legalize-to-lhlo.mlir | 3 -- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index b3cb8bf69e6..e900bae5a15 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -287,11 +287,36 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { } }; -// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality -// is provided by mlir buffer assignment, so use the pattern from there. -using HloToLhloReturnOpConverter = - BufferAssignmentReturnOpConverter; +// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. +struct HloToLhloReturnOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::ReturnOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); + auto& entry_block = op.getParentRegion()->front(); + auto num_arguments = entry_block.getNumArguments(); + if (operands.size() > num_arguments) { + return op.emitError( + "The number of operands that need Copy operations is more " + "than the number of target function arguments."); + } + + // The index of the first output block argument. + auto dest_arg_idx = num_arguments - operands.size(); + + // Create a lmhlo.copy for each operand of mhlo.return. + for (Value operand : operands) { + rewriter.create(loc, operand, + entry_block.getArgument(dest_arg_idx)); + ++dest_arg_idx; + } + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; class HloToLhloTensorLoadOpConverter : public BaseOpConversion { @@ -429,6 +454,13 @@ struct HloLegalizeToLhlo isMemRefType); }); + auto kind = results_escape_function + ? BufferAssignmentTypeConverter::KeepAsFunctionResult + : BufferAssignmentTypeConverter::AppendToArgumentsList; + converter.setResultConversionKind( + kind); + converter.setResultConversionKind(kind); + populateHLOToLHLOConversionPattern(&context, &converter, &patterns); populateWithBufferAssignmentOpConversionPatterns< mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter, diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index cd263483afe..c01d451bbeb 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -1,8 +1,5 @@ // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s -// TODO(herhut): unbreak the test after upstream API changes. -// XFAIL: * - // BOTH-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {