[XLA:GPU] Add conversion from HLO -> LMHLO for all-gather and all-to-all

- Also create a common function to setup attributes common to all-reduce/all-gather,
  and all-to-all.

PiperOrigin-RevId: 354990623
Change-Id: I8f836200c178722fb61a210ff4a443630b0261d5
This commit is contained in:
Rahul Joshi 2021-02-01 11:46:55 -08:00 committed by TensorFlower Gardener
parent 22bae65f06
commit e5473809f4
3 changed files with 83 additions and 10 deletions
tensorflow/compiler/mlir/xla

View File

@ -598,6 +598,35 @@ ENTRY main {
ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=131072
}
// -----
HloModule TestAllGather
// CHECK: func @main
// CHECK: "lmhlo.all_gather"
// CHECK_SAME: all_gather_dimension = 1 : i64
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
// CHECK-SAME: use_global_device_ids = false
ENTRY main {
param0 = f32[10,20] parameter(0)
ROOT ag = f32[10,80] all-gather(param0), replica_groups={{0,1,2,3}},
dimensions={1}
}
// -----
// CHECK: func @entry
// CHECK: "lmhlo.all_to_all"
// CHECK-SAME: constrain_layout = false
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
HloModule TestAllToAll
ENTRY entry {
p0 = f32[128,4]{0,1} parameter(0)
p1 = f32[128,4]{1,0} parameter(1)
ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
replica_groups={{0,1}}
}
// -----
HloModule TestReplicaId

View File

@ -249,6 +249,10 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr);
case HloOpcode::kAdd:
return CreateOpWithoutAttrs<lmhlo::AddOp>(instr);
case HloOpcode::kAllToAll:
return EmitAllToAllOp(instr);
case HloOpcode::kAllGather:
return EmitAllGatherOp(instr);
case HloOpcode::kAllReduce:
return EmitAllReduceOp(instr);
case HloOpcode::kAnd:
@ -1008,21 +1012,57 @@ StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
return reduce_precision_op;
}
template <typename OpT>
Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
mlir::OpBuilder& builder) {
auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
collective->replica_groups(), builder);
op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
op.constrain_layoutAttr(builder.getBoolAttr(collective->constrain_layout()));
if (collective->channel_id().has_value()) {
op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
builder.getI64IntegerAttr(*collective->channel_id()),
builder.getI64IntegerAttr(0), builder.getContext()));
}
return Status::OK();
}
StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto all_to_all_op,
CreateOpWithoutAttrs<lmhlo::AllToAllOp>(instr));
auto* all_to_all = xla::Cast<xla::HloAllToAllInstruction>(instr);
TF_RETURN_IF_ERROR(
SetupCommonCollectiveOpAttributes(all_to_all_op, instr, builder_));
if (all_to_all->split_dimension().has_value()) {
all_to_all_op.split_dimensionAttr(
builder_.getI64IntegerAttr(*all_to_all->split_dimension()));
}
return all_to_all_op;
}
StatusOr<lmhlo::AllGatherOp> LhloDialectEmitter::EmitAllGatherOp(
const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto all_gather_op,
CreateOpWithoutAttrs<lmhlo::AllGatherOp>(instr));
auto* all_gather = xla::Cast<xla::HloAllGatherInstruction>(instr);
TF_RETURN_IF_ERROR(
SetupCommonCollectiveOpAttributes(all_gather_op, instr, builder_));
all_gather_op.use_global_device_idsAttr(
builder_.getBoolAttr(all_gather->use_global_device_ids()));
all_gather_op.all_gather_dimensionAttr(
builder_.getI64IntegerAttr(all_gather->all_gather_dimension()));
return all_gather_op;
}
StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto all_reduce_op,
CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
all_reduce->replica_groups(), builder_);
all_reduce_op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
all_reduce_op.constrain_layoutAttr(
builder_.getBoolAttr(all_reduce->constrain_layout()));
if (all_reduce->channel_id().has_value()) {
all_reduce_op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
builder_.getI64IntegerAttr(all_reduce->channel_id().value()),
builder_.getI64IntegerAttr(0), builder_.getContext()));
}
TF_RETURN_IF_ERROR(
SetupCommonCollectiveOpAttributes(all_reduce_op, instr, builder_));
all_reduce_op.use_global_device_idsAttr(
builder_.getBoolAttr(all_reduce->use_global_device_ids()));
TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(

View File

@ -87,6 +87,10 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::AllToAllOp> EmitAllToAllOp(
const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::AllGatherOp> EmitAllGatherOp(
const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
const xla::HloInstruction* instr);