[XLA:GPU] Add HLO -> LMHLO conversion for CollectivePermuteOp

PiperOrigin-RevId: 355254944
Change-Id: I652e3e9996e06fa206b96bb0ab24518d55a8334a
This commit is contained in:
Rahul Joshi 2021-02-02 14:36:30 -08:00 committed by TensorFlower Gardener
parent b4bf78ffec
commit 110a7631b2
5 changed files with 65 additions and 25 deletions

View File

@ -311,8 +311,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
return new_operation;
}
case HloOpcode::kCollectivePermute: {
attributes.push_back(
ConvertSourceTargetPairs(instruction->source_target_pairs()));
attributes.push_back(ConvertSourceTargetPairs(
instruction->source_target_pairs(), builder_));
MakeAndReturn(CollectivePermuteOp);
}
case HloOpcode::kCustomCall: {
@ -539,7 +539,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
case HloOpcode::kAllReduce: {
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
attributes.push_back(
ConvertReplicaGroups(all_reduce->replica_groups(), *builder_));
ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
loc, result_type, operands, attributes);
@ -921,20 +921,21 @@ mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
source_target_pairs) {
source_target_pairs,
mlir::Builder* builder) {
std::vector<int64_t> attr(source_target_pairs.size() * 2);
for (auto p : llvm::enumerate(source_target_pairs)) {
attr[2 * p.index()] = p.value().first;
attr[2 * p.index() + 1] = p.value().second;
}
auto type = mlir::RankedTensorType::get(
{static_cast<int64_t>(attr.size() / 2), 2}, builder_->getIntegerType(64));
return builder_->getNamedAttr("source_target_pairs",
DenseIntElementsAttr::get(type, attr));
{static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
return builder->getNamedAttr("source_target_pairs",
DenseIntElementsAttr::get(type, attr));
}
mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder) {
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder* builder) {
const int64_t num_groups = replica_groups.size();
// Replica groups in HLO can be non-uniform in size, for example:
// replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
@ -950,9 +951,9 @@ mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
for (const int64& id : replica_groups[i].replica_ids()) attr[index++] = id;
}
auto type = mlir::RankedTensorType::get({num_groups, group_size},
builder.getIntegerType(64));
return builder.getNamedAttr("replica_groups",
DenseIntElementsAttr::get(type, attr));
builder->getIntegerType(64));
return builder->getNamedAttr("replica_groups",
DenseIntElementsAttr::get(type, attr));
}
mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(

View File

@ -64,11 +64,17 @@ class HloFunctionImporter {
static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape);
// TODO(b/179166199): move this to attribute_importer.h.
// Converts XLA instruction source target pairs to MLIR attribute.
static mlir::NamedAttribute ConvertSourceTargetPairs(
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
source_target_pairs,
mlir::Builder* builder);
// TODO(b/179166199): move this to attribute_importer.h.
// Converts replica groups to attribute
//
// TODO(timshen): move this to attribute_importer.h.
static mlir::NamedAttribute ConvertReplicaGroups(
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder);
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder* builder);
private:
HloFunctionImporter(mlir::ModuleOp module,
@ -149,11 +155,6 @@ class HloFunctionImporter {
// Converts channel handle to attribute
mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel);
// Converts XLA instruction source target pairs to MLIR attribute.
mlir::NamedAttribute ConvertSourceTargetPairs(
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
source_target_pairs);
mlir::MLIRContext* context_;
mlir::ModuleOp module_;
mlir::Builder* builder_;

View File

@ -627,6 +627,18 @@ ENTRY entry {
replica_groups={{0,1}}
}
// -----
// CHECK: func @main
// CHECK: "lmhlo.collective_permute"
// CHECK-SAME: channel_id = {handle = 2 : i64, type = 0 : i64}
// CHECK-SAME{LITERAL}: source_target_pairs = dense<[[0, 1], [0, 2], [1, 0]]> : tensor<3x2xi64>
HloModule TestCollectivePermute
ENTRY main {
p0 = f32[128] parameter(0)
ROOT permute = f32[128] collective-permute(p0),
source_target_pairs={{0,1}, {0,2}, {1,0}}, channel_id=2
}
// -----
HloModule TestReplicaId

View File

@ -269,6 +269,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
return CreateOpWithoutAttrs<lmhlo::CbrtOp>(instr);
case HloOpcode::kClamp:
return CreateOpWithoutAttrs<lmhlo::ClampOp>(instr);
case HloOpcode::kCollectivePermute:
return EmitCollectivePermuteOp(instr);
case HloOpcode::kClz:
return CreateOpWithoutAttrs<lmhlo::ClzOp>(instr);
case HloOpcode::kCompare:
@ -1012,21 +1014,29 @@ StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
return reduce_precision_op;
}
namespace {
template <typename OpT>
void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr,
mlir::Builder builder) {
if (instr->channel_id().has_value()) {
op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
builder.getI64IntegerAttr(*instr->channel_id()),
builder.getI64IntegerAttr(0), builder.getContext()));
}
}
template <typename OpT>
Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr,
mlir::OpBuilder& builder) {
auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr);
auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups(
collective->replica_groups(), builder);
collective->replica_groups(), &builder);
op->setAttr(replica_groups_attr.first, replica_groups_attr.second);
op.constrain_layoutAttr(builder.getBoolAttr(collective->constrain_layout()));
if (collective->channel_id().has_value()) {
op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
builder.getI64IntegerAttr(*collective->channel_id()),
builder.getI64IntegerAttr(0), builder.getContext()));
}
SetupChannelIdAttribute(op, collective, builder);
return Status::OK();
}
} // namespace
StatusOr<lmhlo::AllToAllOp> LhloDialectEmitter::EmitAllToAllOp(
const HloInstruction* instr) {
@ -1071,6 +1081,20 @@ StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
return all_reduce_op;
}
StatusOr<lmhlo::CollectivePermuteOp>
LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto permute_op,
CreateOpWithoutAttrs<lmhlo::CollectivePermuteOp>(instr));
auto* permute = xla::Cast<xla::HloCollectivePermuteInstruction>(instr);
SetupChannelIdAttribute(permute_op, permute, builder_);
mlir::NamedAttribute source_target_pairs_attr =
xla::HloFunctionImporter::ConvertSourceTargetPairs(
permute->source_target_pairs(), &builder_);
permute_op->setAttr(source_target_pairs_attr.first,
source_target_pairs_attr.second);
return permute_op;
}
StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
const HloInstruction* instr) {
const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);

View File

@ -93,6 +93,8 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::CollectivePermuteOp> EmitCollectivePermuteOp(
const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::BroadcastInDimOp> EmitBroadcastOp(
const xla::HloInstruction* instr);