[XLA:GPU] Add HLO -> LMHLO conversion for CollectivePermuteOp
PiperOrigin-RevId: 355254944 Change-Id: I652e3e9996e06fa206b96bb0ab24518d55a8334a
This commit is contained in:
parent
b4bf78ffec
commit
110a7631b2
@ -311,8 +311,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
||||
return new_operation;
|
||||
}
|
||||
case HloOpcode::kCollectivePermute: {
|
||||
attributes.push_back(
|
||||
ConvertSourceTargetPairs(instruction->source_target_pairs()));
|
||||
attributes.push_back(ConvertSourceTargetPairs(
|
||||
instruction->source_target_pairs(), builder_));
|
||||
MakeAndReturn(CollectivePermuteOp);
|
||||
}
|
||||
case HloOpcode::kCustomCall: {
|
||||
@ -539,7 +539,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
||||
case HloOpcode::kAllReduce: {
|
||||
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
|
||||
attributes.push_back(
|
||||
ConvertReplicaGroups(all_reduce->replica_groups(), *builder_));
|
||||
ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
|
||||
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
|
||||
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
@ -921,20 +921,21 @@ mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
|
||||
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
|
||||
source_target_pairs) {
|
||||
source_target_pairs,
|
||||
mlir::Builder* builder) {
|
||||
std::vector<int64_t> attr(source_target_pairs.size() * 2);
|
||||
for (auto p : llvm::enumerate(source_target_pairs)) {
|
||||
attr[2 * p.index()] = p.value().first;
|
||||
attr[2 * p.index() + 1] = p.value().second;
|
||||
}
|
||||
auto type = mlir::RankedTensorType::get(
|
||||
{static_cast<int64_t>(attr.size() / 2), 2}, builder_->getIntegerType(64));
|
||||
return builder_->getNamedAttr("source_target_pairs",
|
||||
DenseIntElementsAttr::get(type, attr));
|
||||
{static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
|
||||
return builder->getNamedAttr("source_target_pairs",
|
||||
DenseIntElementsAttr::get(type, attr));
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
|
||||
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder) {
|
||||
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder* builder) {
|
||||
const int64_t num_groups = replica_groups.size();
|
||||
// Replica groups in HLO can be non-uniform in size, for example:
|
||||
// replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
|
||||
@ -950,9 +951,9 @@ mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
|
||||
for (const int64& id : replica_groups[i].replica_ids()) attr[index++] = id;
|
||||
}
|
||||
auto type = mlir::RankedTensorType::get({num_groups, group_size},
|
||||
builder.getIntegerType(64));
|
||||
return builder.getNamedAttr("replica_groups",
|
||||
DenseIntElementsAttr::get(type, attr));
|
||||
builder->getIntegerType(64));
|
||||
return builder->getNamedAttr("replica_groups",
|
||||
DenseIntElementsAttr::get(type, attr));
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
|
||||
|
@ -64,11 +64,17 @@ class HloFunctionImporter {
|
||||
|
||||
static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape);
|
||||
|
||||
// TODO(b/179166199): move this to attribute_importer.h.
|
||||
// Converts XLA instruction source target pairs to MLIR attribute.
|
||||
static mlir::NamedAttribute ConvertSourceTargetPairs(
|
||||
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
|
||||
source_target_pairs,
|
||||
mlir::Builder* builder);
|
||||
|
||||
// TODO(b/179166199): move this to attribute_importer.h.
|
||||
// Converts replica groups to attribute
|
||||
//
|
||||
// TODO(timshen): move this to attribute_importer.h.
|
||||
static mlir::NamedAttribute ConvertReplicaGroups(
|
||||
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder);
|
||||
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder* builder);
|
||||
|
||||
private:
|
||||
HloFunctionImporter(mlir::ModuleOp module,
|
||||
@ -149,11 +155,6 @@ class HloFunctionImporter {
|
||||
// Converts channel handle to attribute
|
||||
mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel);
|
||||
|
||||
// Converts XLA instruction source target pairs to MLIR attribute.
|
||||
mlir::NamedAttribute ConvertSourceTargetPairs(
|
||||
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
|
||||
source_target_pairs);
|
||||
|
||||
mlir::MLIRContext* context_;
|
||||
mlir::ModuleOp module_;
|
||||
mlir::Builder* builder_;
|
||||
|
@ -627,6 +627,18 @@ ENTRY entry {
|
||||
replica_groups={{0,1}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: "lmhlo.collective_permute"
|
||||
// CHECK-SAME: channel_id = {handle = 2 : i64, type = 0 : i64}
|
||||
// CHECK-SAME{LITERAL}: source_target_pairs = dense<[[0, 1], [0, 2], [1, 0]]> : tensor<3x2xi64>
|
||||
HloModule TestCollectivePermute
|
||||
ENTRY main {
|
||||
p0 = f32[128] parameter(0)
|
||||
ROOT permute = f32[128] collective-permute(p0),
|
||||
source_target_pairs={{0,1}, {0,2}, {1,0}}, channel_id=2
|
||||
}
|
||||
|
||||
// -----
|
||||
HloModule TestReplicaId
|
||||
|
@ -269,6 +269,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
|
||||
return CreateOpWithoutAttrs<lmhlo::CbrtOp>(instr);
|
||||
case HloOpcode::kClamp:
|
||||
return CreateOpWithoutAttrs<lmhlo::ClampOp>(instr);
|
||||
case HloOpcode::kCollectivePermute:
|
||||
return EmitCollectivePermuteOp(instr);
|
||||
case HloOpcode::kClz:
|
||||
return CreateOpWithoutAttrs<lmhlo::ClzOp>(instr);
|
||||
case HloOpcode::kCompare:
|
||||
@ -1012,21 +1014,29 @@ StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
|
||||
return reduce_precision_op;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename OpT>
|
||||
void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr,
|
||||
mlir::Builder builder) {
|
||||
if (instr->channel_id().has_value()) {
|
||||
op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
|
||||
builder.getI64IntegerAttr(*instr->channel_id()),
|
||||
builder.getI64IntegerAttr(0), builder.getContext()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OpT>
|
||||
Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
|
||||
mlir::OpBuilder& builder) {
|
||||
auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
|
||||
auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
|
||||
collective->replica_groups(), builder);
|
||||
collective->replica_groups(), &builder);
|
||||
op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
|
||||
op.constrain_layoutAttr(builder.getBoolAttr(collective->constrain_layout()));
|
||||
if (collective->channel_id().has_value()) {
|
||||
op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
|
||||
builder.getI64IntegerAttr(*collective->channel_id()),
|
||||
builder.getI64IntegerAttr(0), builder.getContext()));
|
||||
}
|
||||
SetupChannelIdAttribute(op, collective, builder);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
|
||||
const HloInstruction* instr) {
|
||||
@ -1071,6 +1081,20 @@ StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
|
||||
return all_reduce_op;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::CollectivePermuteOp>
|
||||
LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto permute_op,
|
||||
CreateOpWithoutAttrs<lmhlo::CollectivePermuteOp>(instr));
|
||||
auto* permute = xla::Cast<xla::HloCollectivePermuteInstruction>(instr);
|
||||
SetupChannelIdAttribute(permute_op, permute, builder_);
|
||||
mlir::NamedAttribute source_target_pairs_attr =
|
||||
xla::HloFunctionImporter::ConvertSourceTargetPairs(
|
||||
permute->source_target_pairs(), &builder_);
|
||||
permute_op->setAttr(source_target_pairs_attr.first,
|
||||
source_target_pairs_attr.second);
|
||||
return permute_op;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
|
||||
const HloInstruction* instr) {
|
||||
const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
|
||||
|
@ -93,6 +93,8 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
|
||||
const xla::HloInstruction* instr);
|
||||
xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
|
||||
const xla::HloInstruction* instr);
|
||||
xla::StatusOr<lmhlo::CollectivePermuteOp> EmitCollectivePermuteOp(
|
||||
const xla::HloInstruction* instr);
|
||||
|
||||
xla::StatusOr<lmhlo::BroadcastInDimOp> EmitBroadcastOp(
|
||||
const xla::HloInstruction* instr);
|
||||
|
Loading…
x
Reference in New Issue
Block a user