[XLA:GPU] Add conversion from HLO -> LMHLO for all-gather and all-to-all
- Also create a common function to setup attributes common to all-reduce/all-gather, and all-to-all. PiperOrigin-RevId: 354990623 Change-Id: I8f836200c178722fb61a210ff4a443630b0261d5
This commit is contained in:
parent
22bae65f06
commit
e5473809f4
tensorflow/compiler/mlir/xla
tests/hlo_to_lhlo_with_xla
transforms
@ -598,6 +598,35 @@ ENTRY main {
|
||||
ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=131072
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule TestAllGather
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: "lmhlo.all_gather"
|
||||
// CHECK_SAME: all_gather_dimension = 1 : i64
|
||||
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
|
||||
// CHECK-SAME: use_global_device_ids = false
|
||||
ENTRY main {
|
||||
param0 = f32[10,20] parameter(0)
|
||||
ROOT ag = f32[10,80] all-gather(param0), replica_groups={{0,1,2,3}},
|
||||
dimensions={1}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @entry
|
||||
// CHECK: "lmhlo.all_to_all"
|
||||
// CHECK-SAME: constrain_layout = false
|
||||
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
|
||||
HloModule TestAllToAll
|
||||
ENTRY entry {
|
||||
p0 = f32[128,4]{0,1} parameter(0)
|
||||
p1 = f32[128,4]{1,0} parameter(1)
|
||||
ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
|
||||
replica_groups={{0,1}}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
HloModule TestReplicaId
|
||||
|
@ -249,6 +249,10 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
|
||||
return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr);
|
||||
case HloOpcode::kAdd:
|
||||
return CreateOpWithoutAttrs<lmhlo::AddOp>(instr);
|
||||
case HloOpcode::kAllToAll:
|
||||
return EmitAllToAllOp(instr);
|
||||
case HloOpcode::kAllGather:
|
||||
return EmitAllGatherOp(instr);
|
||||
case HloOpcode::kAllReduce:
|
||||
return EmitAllReduceOp(instr);
|
||||
case HloOpcode::kAnd:
|
||||
@ -1008,21 +1012,57 @@ StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
|
||||
return reduce_precision_op;
|
||||
}
|
||||
|
||||
template <typename OpT>
|
||||
Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
|
||||
mlir::OpBuilder& builder) {
|
||||
auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
|
||||
auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
|
||||
collective->replica_groups(), builder);
|
||||
op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
|
||||
op.constrain_layoutAttr(builder.getBoolAttr(collective->constrain_layout()));
|
||||
if (collective->channel_id().has_value()) {
|
||||
op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
|
||||
builder.getI64IntegerAttr(*collective->channel_id()),
|
||||
builder.getI64IntegerAttr(0), builder.getContext()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
|
||||
const HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto all_to_all_op,
|
||||
CreateOpWithoutAttrs<lmhlo::AllToAllOp>(instr));
|
||||
auto* all_to_all = xla::Cast<xla::HloAllToAllInstruction>(instr);
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetupCommonCollectiveOpAttributes(all_to_all_op, instr, builder_));
|
||||
if (all_to_all->split_dimension().has_value()) {
|
||||
all_to_all_op.split_dimensionAttr(
|
||||
builder_.getI64IntegerAttr(*all_to_all->split_dimension()));
|
||||
}
|
||||
return all_to_all_op;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::AllGatherOp> LhloDialectEmitter::EmitAllGatherOp(
|
||||
const HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto all_gather_op,
|
||||
CreateOpWithoutAttrs<lmhlo::AllGatherOp>(instr));
|
||||
auto* all_gather = xla::Cast<xla::HloAllGatherInstruction>(instr);
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetupCommonCollectiveOpAttributes(all_gather_op, instr, builder_));
|
||||
all_gather_op.use_global_device_idsAttr(
|
||||
builder_.getBoolAttr(all_gather->use_global_device_ids()));
|
||||
all_gather_op.all_gather_dimensionAttr(
|
||||
builder_.getI64IntegerAttr(all_gather->all_gather_dimension()));
|
||||
return all_gather_op;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
|
||||
const HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto all_reduce_op,
|
||||
CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
|
||||
auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
|
||||
auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
|
||||
all_reduce->replica_groups(), builder_);
|
||||
all_reduce_op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
|
||||
all_reduce_op.constrain_layoutAttr(
|
||||
builder_.getBoolAttr(all_reduce->constrain_layout()));
|
||||
if (all_reduce->channel_id().has_value()) {
|
||||
all_reduce_op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
|
||||
builder_.getI64IntegerAttr(all_reduce->channel_id().value()),
|
||||
builder_.getI64IntegerAttr(0), builder_.getContext()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetupCommonCollectiveOpAttributes(all_reduce_op, instr, builder_));
|
||||
all_reduce_op.use_global_device_idsAttr(
|
||||
builder_.getBoolAttr(all_reduce->use_global_device_ids()));
|
||||
TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
|
||||
|
@ -87,6 +87,10 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
|
||||
xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
|
||||
const xla::HloInstruction* instr);
|
||||
|
||||
xla::StatusOr<lmhlo::AllToAllOp> EmitAllToAllOp(
|
||||
const xla::HloInstruction* instr);
|
||||
xla::StatusOr<lmhlo::AllGatherOp> EmitAllGatherOp(
|
||||
const xla::HloInstruction* instr);
|
||||
xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
|
||||
const xla::HloInstruction* instr);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user