[XLA:GPU] Map HLO Outfeed to LHLO OutfeedOp
- Drop tuple argument and result during the conversion. PiperOrigin-RevId: 351451341 Change-Id: Ia053c248b5297e8c59e8818ceeb8e4400c7fcacf
This commit is contained in:
parent
b8cd771a05
commit
6191636c4a
@ -412,3 +412,33 @@ ENTRY main {
|
||||
%tok = token[] parameter(0)
|
||||
ROOT %infeed = (f32[3]{0}, token[]) infeed(token[] %tok)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule Outfeed
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: "lmhlo.outfeed"
|
||||
// CHECK-SAME: config = ""
|
||||
// CHECK-SAME: (memref<3xf32>) -> ()
|
||||
ENTRY main {
|
||||
%source = f32[3] parameter(0)
|
||||
%tok = token[] parameter(1)
|
||||
ROOT %o = token[] outfeed(f32[3] %source, token[] %tok)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule Outfeed
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: "lmhlo.custom_call"
|
||||
// CHECK-SAME: call_target_name = "foo"
|
||||
// CHECK: "lmhlo.outfeed"
|
||||
// CHECK-SAME: config = ""
|
||||
// CHECK-SAME: (memref<3xf32>, memref<5xf16>) -> ()
|
||||
ENTRY main {
|
||||
%tok = token[] parameter(0)
|
||||
%tuple = (f32[3], f16[5]) custom-call(),custom_call_target="foo"
|
||||
ROOT %o = token[] outfeed((f32[3], f16[5]) %tuple, token[] %tok)
|
||||
}
|
@ -73,6 +73,7 @@ using xla::HloInfeedInstruction;
|
||||
using xla::HloInstruction;
|
||||
using xla::HloModule;
|
||||
using xla::HloModuleProto;
|
||||
using xla::HloOutfeedInstruction;
|
||||
using xla::HloProto;
|
||||
using xla::Shape;
|
||||
using xla::StatusOr;
|
||||
@ -304,6 +305,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
|
||||
return CreateOpWithoutAttrs<lmhlo::NotOp>(instr);
|
||||
case HloOpcode::kOr:
|
||||
return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
|
||||
case HloOpcode::kOutfeed:
|
||||
return EmitOutfeedOp(instr);
|
||||
case HloOpcode::kPopulationCount:
|
||||
return CreateOpWithoutAttrs<lmhlo::PopulationCountOp>(instr);
|
||||
case HloOpcode::kPower:
|
||||
@ -1004,6 +1007,19 @@ StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
|
||||
return infeed_op;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp(
|
||||
HloInstruction* instr) {
|
||||
HloOutfeedInstruction* outfeed = xla::Cast<HloOutfeedInstruction>(instr);
|
||||
// HLO outfeed instruction has 2 operands, the source and a token, and a
|
||||
// single token output. LMHLO Outfeed does not need the token operand and
|
||||
// result, do drop it.
|
||||
SmallVector<Value, 2> operands;
|
||||
TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands));
|
||||
auto outfeed_op = CreateOpWithoutAttrs<lmhlo::OutfeedOp>(instr, operands);
|
||||
outfeed_op.configAttr(builder_.getStringAttr(outfeed->outfeed_config()));
|
||||
return outfeed_op;
|
||||
}
|
||||
|
||||
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||
const ::xla::ShapeIndex& shape_index) {
|
||||
|
@ -82,6 +82,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr);
|
||||
|
||||
::xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
|
||||
|
||||
::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
|
||||
|
Loading…
Reference in New Issue
Block a user