From 110a7631b2b75ea5700c88ddc979662a5370815f Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Tue, 2 Feb 2021 14:36:30 -0800 Subject: [PATCH] [XLA:GPU] Add HLO -> LMHLO conversion for CollectivePermuteOp PiperOrigin-RevId: 355254944 Change-Id: I652e3e9996e06fa206b96bb0ab24518d55a8334a --- .../mlir/xla/hlo_function_importer.cc | 23 ++++++------ .../compiler/mlir/xla/hlo_function_importer.h | 17 ++++----- .../hlo_text_to_lhlo_no_opt.hlotxt | 12 +++++++ .../xla/transforms/mhlo_to_lhlo_with_xla.cc | 36 +++++++++++++++---- .../xla/transforms/mhlo_to_lhlo_with_xla.h | 2 ++ 5 files changed, 65 insertions(+), 25 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 957267c9740..5d5c6c8b4be 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -311,8 +311,8 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( return new_operation; } case HloOpcode::kCollectivePermute: { - attributes.push_back( - ConvertSourceTargetPairs(instruction->source_target_pairs())); + attributes.push_back(ConvertSourceTargetPairs( + instruction->source_target_pairs(), builder_)); MakeAndReturn(CollectivePermuteOp); } case HloOpcode::kCustomCall: { @@ -539,7 +539,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( case HloOpcode::kAllReduce: { auto all_reduce = Cast(instruction); attributes.push_back( - ConvertReplicaGroups(all_reduce->replica_groups(), *builder_)); + ConvertReplicaGroups(all_reduce->replica_groups(), builder_)); attributes.push_back(ConvertChannelHandle(all_reduce->channel_id())); auto all_reduce_op = func_builder->create( loc, result_type, operands, attributes); @@ -921,20 +921,21 @@ mlir::NamedAttribute HloFunctionImporter::ConvertPadding( mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs( const std::vector>& - source_target_pairs) { + source_target_pairs, + mlir::Builder* builder) { std::vector attr(source_target_pairs.size() * 2); for (auto p : llvm::enumerate(source_target_pairs)) { attr[2 * p.index()] = p.value().first; attr[2 * p.index() + 1] = p.value().second; } auto type = mlir::RankedTensorType::get( - {static_cast(attr.size() / 2), 2}, builder_->getIntegerType(64)); - return builder_->getNamedAttr("source_target_pairs", - DenseIntElementsAttr::get(type, attr)); + {static_cast(attr.size() / 2), 2}, builder->getIntegerType(64)); + return builder->getNamedAttr("source_target_pairs", + DenseIntElementsAttr::get(type, attr)); } mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups( - const std::vector& replica_groups, mlir::Builder builder) { + const std::vector& replica_groups, mlir::Builder* builder) { const int64_t num_groups = replica_groups.size(); // Replica groups in HLO can be non-uniform in size, for example: // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D @@ -950,9 +951,9 @@ mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups( for (const int64& id : replica_groups[i].replica_ids()) attr[index++] = id; } auto type = mlir::RankedTensorType::get({num_groups, group_size}, - builder.getIntegerType(64)); - return builder.getNamedAttr("replica_groups", - DenseIntElementsAttr::get(type, attr)); + builder->getIntegerType(64)); + return builder->getNamedAttr("replica_groups", + DenseIntElementsAttr::get(type, attr)); } mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle( diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index 99fc64f40ba..d83dcdeead5 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -64,11 +64,17 @@ class HloFunctionImporter { static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape); + // TODO(b/179166199): move this to attribute_importer.h. + // Converts XLA instruction source target pairs to MLIR attribute. + static mlir::NamedAttribute ConvertSourceTargetPairs( + const std::vector>& + source_target_pairs, + mlir::Builder* builder); + + // TODO(b/179166199): move this to attribute_importer.h. // Converts replica groups to attribute - // - // TODO(timshen): move this to attribute_importer.h. static mlir::NamedAttribute ConvertReplicaGroups( - const std::vector& replica_groups, mlir::Builder builder); + const std::vector& replica_groups, mlir::Builder* builder); private: HloFunctionImporter(mlir::ModuleOp module, @@ -149,11 +155,6 @@ class HloFunctionImporter { // Converts channel handle to attribute mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel); - // Converts XLA instruction source target pairs to MLIR attribute. - mlir::NamedAttribute ConvertSourceTargetPairs( - const std::vector>& - source_target_pairs); - mlir::MLIRContext* context_; mlir::ModuleOp module_; mlir::Builder* builder_; diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt index 51f43ecd7eb..4c29b5f36c9 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt @@ -627,6 +627,18 @@ ENTRY entry { replica_groups={{0,1}} } +// ----- + +// CHECK: func @main +// CHECK: "lmhlo.collective_permute" +// CHECK-SAME: channel_id = {handle = 2 : i64, type = 0 : i64} +// CHECK-SAME{LITERAL}: source_target_pairs = dense<[[0, 1], [0, 2], [1, 0]]> : tensor<3x2xi64> +HloModule TestCollectivePermute +ENTRY main { + p0 = f32[128] parameter(0) + ROOT permute = f32[128] collective-permute(p0), + source_target_pairs={{0,1}, {0,2}, {1,0}}, channel_id=2 +} // ----- HloModule TestReplicaId diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index 3eb2fa14b70..0018f0ae36b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -269,6 +269,8 @@ StatusOr LhloDialectEmitter::EmitOp( return CreateOpWithoutAttrs(instr); case HloOpcode::kClamp: return CreateOpWithoutAttrs(instr); + case HloOpcode::kCollectivePermute: + return EmitCollectivePermuteOp(instr); case HloOpcode::kClz: return CreateOpWithoutAttrs(instr); case HloOpcode::kCompare: @@ -1012,21 +1014,29 @@ StatusOr LhloDialectEmitter::EmitReducePrecisionOp( return reduce_precision_op; } +namespace { +template +void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr, + mlir::Builder builder) { + if (instr->channel_id().has_value()) { + op.channel_idAttr(mlir::mhlo::ChannelHandle::get( + builder.getI64IntegerAttr(*instr->channel_id()), + builder.getI64IntegerAttr(0), builder.getContext())); + } +} + template Status SetupCommonCollectiveOpAttributes(OpT op, const HloInstruction* instr, mlir::OpBuilder& builder) { auto* collective = xla::Cast(instr); auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups( - collective->replica_groups(), builder); + collective->replica_groups(), &builder); op->setAttr(replica_groups_attr.first, replica_groups_attr.second); op.constrain_layoutAttr(builder.getBoolAttr(collective->constrain_layout())); - if (collective->channel_id().has_value()) { - op.channel_idAttr(mlir::mhlo::ChannelHandle::get( - builder.getI64IntegerAttr(*collective->channel_id()), - builder.getI64IntegerAttr(0), builder.getContext())); - } + SetupChannelIdAttribute(op, collective, builder); return Status::OK(); } +} // namespace StatusOr LhloDialectEmitter::EmitAllToAllOp( const HloInstruction* instr) { @@ -1071,6 +1081,20 @@ StatusOr LhloDialectEmitter::EmitAllReduceOp( return all_reduce_op; } +StatusOr +LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto permute_op, + CreateOpWithoutAttrs(instr)); + auto* permute = xla::Cast(instr); + SetupChannelIdAttribute(permute_op, permute, builder_); + mlir::NamedAttribute source_target_pairs_attr = + xla::HloFunctionImporter::ConvertSourceTargetPairs( + permute->source_target_pairs(), &builder_); + permute_op->setAttr(source_target_pairs_attr.first, + source_target_pairs_attr.second); + return permute_op; +} + StatusOr LhloDialectEmitter::EmitInfeedOp( const HloInstruction* instr) { const HloInfeedInstruction* infeed = xla::Cast(instr); diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index d5cc0fb4985..7a4e46e59dd 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -93,6 +93,8 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault { const xla::HloInstruction* instr); xla::StatusOr EmitAllReduceOp( const xla::HloInstruction* instr); + xla::StatusOr EmitCollectivePermuteOp( + const xla::HloInstruction* instr); xla::StatusOr EmitBroadcastOp( const xla::HloInstruction* instr);