diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt index 26f68273e09..39a6f175d79 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt @@ -412,3 +412,33 @@ ENTRY main { %tok = token[] parameter(0) ROOT %infeed = (f32[3]{0}, token[]) infeed(token[] %tok) } + +// ----- + +HloModule Outfeed + +// CHECK: func @main +// CHECK: "lmhlo.outfeed" +// CHECK-SAME: config = "" +// CHECK-SAME: (memref<3xf32>) -> () +ENTRY main { + %source = f32[3] parameter(0) + %tok = token[] parameter(1) + ROOT %o = token[] outfeed(f32[3] %source, token[] %tok) +} + +// ----- + +HloModule Outfeed + +// CHECK: func @main +// CHECK: "lmhlo.custom_call" +// CHECK-SAME: call_target_name = "foo" +// CHECK: "lmhlo.outfeed" +// CHECK-SAME: config = "" +// CHECK-SAME: (memref<3xf32>, memref<5xf16>) -> () +ENTRY main { + %tok = token[] parameter(0) + %tuple = (f32[3], f16[5]) custom-call(),custom_call_target="foo" + ROOT %o = token[] outfeed((f32[3], f16[5]) %tuple, token[] %tok) +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index b5ddd564227..a2f620048bb 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -73,6 +73,7 @@ using xla::HloInfeedInstruction; using xla::HloInstruction; using xla::HloModule; using xla::HloModuleProto; +using xla::HloOutfeedInstruction; using xla::HloProto; using xla::Shape; using xla::StatusOr; @@ -304,6 +305,8 @@ StatusOr LhloDialectEmitter::EmitOp(HloInstruction* instr) { return CreateOpWithoutAttrs(instr); case HloOpcode::kOr: return CreateOpWithoutAttrs(instr); + case HloOpcode::kOutfeed: + return EmitOutfeedOp(instr); case HloOpcode::kPopulationCount: return CreateOpWithoutAttrs(instr); case HloOpcode::kPower: @@ -1004,6 +1007,19 @@ StatusOr LhloDialectEmitter::EmitInfeedOp( return infeed_op; } +StatusOr LhloDialectEmitter::EmitOutfeedOp( + HloInstruction* instr) { + HloOutfeedInstruction* outfeed = xla::Cast(instr); + // HLO outfeed instruction has 2 operands, the source and a token, and a + // single token output. LMHLO Outfeed does not need the token operand and + // result, do drop it. + SmallVector operands; + TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands)); + auto outfeed_op = CreateOpWithoutAttrs(instr, operands); + outfeed_op.configAttr(builder_.getStringAttr(outfeed->outfeed_config())); + return outfeed_op; +} + StatusOr LhloDialectEmitter::GetOrCreateArrayView( const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, const ::xla::ShapeIndex& shape_index) { diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 49c2bc36491..d45d8760b0d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -82,6 +82,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { ::xla::StatusOr EmitCompareOp(::xla::HloInstruction* instr); ::xla::StatusOr EmitInfeedOp(::xla::HloInstruction* instr); + ::xla::StatusOr EmitOutfeedOp(::xla::HloInstruction* instr); ::xla::StatusOr EmitMapOp(::xla::HloInstruction* instr); ::xla::StatusOr EmitReducePrecisionOp(