Add optional layout constraint to AllToAll
PiperOrigin-RevId: 308740413 Change-Id: I48773db125a4c542d1c8c46c71100bda4a2b1108
This commit is contained in:
		
							parent
							
								
									c1d09d25d2
								
							
						
					
					
						commit
						e2ccb9a5a5
					
				@ -2307,7 +2307,8 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
 | 
			
		||||
 | 
			
		||||
XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
 | 
			
		||||
                           int64 concat_dimension, int64 split_count,
 | 
			
		||||
                           const std::vector<ReplicaGroup>& replica_groups) {
 | 
			
		||||
                           const std::vector<ReplicaGroup>& replica_groups,
 | 
			
		||||
                           const absl::optional<Layout>& layout) {
 | 
			
		||||
  return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
 | 
			
		||||
 | 
			
		||||
@ -2342,7 +2343,21 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
 | 
			
		||||
                      [](const Shape& shape) { return &shape; });
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(
 | 
			
		||||
        Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
 | 
			
		||||
 | 
			
		||||
    if (layout) {
 | 
			
		||||
      TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape));
 | 
			
		||||
      for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
 | 
			
		||||
        if (layout->minor_to_major().size() != shape.tuple_shapes(i).rank()) {
 | 
			
		||||
          return InvalidArgument(
 | 
			
		||||
              "Provided layout must be compatible with the operand shape: %s "
 | 
			
		||||
              "vs %s",
 | 
			
		||||
              layout->ToString(), operand_shape->ToString());
 | 
			
		||||
        }
 | 
			
		||||
        *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    *instr.mutable_shape() = shape.ToProto();
 | 
			
		||||
 | 
			
		||||
    for (const ReplicaGroup& group : replica_groups) {
 | 
			
		||||
      *instr.add_replica_groups() = group;
 | 
			
		||||
    }
 | 
			
		||||
@ -3466,9 +3481,10 @@ XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
 | 
			
		||||
 | 
			
		||||
XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
 | 
			
		||||
               int64 concat_dimension, int64 split_count,
 | 
			
		||||
               const std::vector<ReplicaGroup>& replica_groups) {
 | 
			
		||||
               const std::vector<ReplicaGroup>& replica_groups,
 | 
			
		||||
               const absl::optional<Layout>& layout) {
 | 
			
		||||
  return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
 | 
			
		||||
                                     split_count, replica_groups);
 | 
			
		||||
                                     split_count, replica_groups, layout);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
XlaOp CollectivePermute(
 | 
			
		||||
 | 
			
		||||
@ -557,7 +557,8 @@ class XlaBuilder {
 | 
			
		||||
 | 
			
		||||
  XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
 | 
			
		||||
                 int64 split_count,
 | 
			
		||||
                 const std::vector<ReplicaGroup>& replica_groups);
 | 
			
		||||
                 const std::vector<ReplicaGroup>& replica_groups,
 | 
			
		||||
                 const absl::optional<Layout>& layout = absl::nullopt);
 | 
			
		||||
 | 
			
		||||
  XlaOp CollectivePermute(
 | 
			
		||||
      XlaOp operand,
 | 
			
		||||
@ -993,7 +994,8 @@ class XlaBuilder {
 | 
			
		||||
                         const absl::optional<Shape>& shape_with_layout);
 | 
			
		||||
  friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
 | 
			
		||||
                        int64 concat_dimension, int64 split_count,
 | 
			
		||||
                        const std::vector<ReplicaGroup>& replica_groups);
 | 
			
		||||
                        const std::vector<ReplicaGroup>& replica_groups,
 | 
			
		||||
                        const absl::optional<Layout>& layout);
 | 
			
		||||
  friend XlaOp CollectivePermute(
 | 
			
		||||
      XlaOp operand,
 | 
			
		||||
      const std::vector<std::pair<int64, int64>>& source_target_pairs);
 | 
			
		||||
@ -1789,9 +1791,13 @@ XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
 | 
			
		||||
                const absl::optional<Shape>& shape_with_layout = absl::nullopt);
 | 
			
		||||
 | 
			
		||||
// Enqueues an operation that do an Alltoall of the operand cross cores.
 | 
			
		||||
// An optional `layout` can be specified to force the layout of the instruction.
 | 
			
		||||
// This is used to guarantee the same layout for a group of AllToAll ops
 | 
			
		||||
// compiled separately.
 | 
			
		||||
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
 | 
			
		||||
               int64 split_count,
 | 
			
		||||
               const std::vector<ReplicaGroup>& replica_groups = {});
 | 
			
		||||
               const std::vector<ReplicaGroup>& replica_groups = {},
 | 
			
		||||
               const absl::optional<Layout>& layout = absl::nullopt);
 | 
			
		||||
 | 
			
		||||
// Enqueues an collective operation that sends and receives data cross replicas.
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
@ -330,7 +330,8 @@ void BuildOpsSubmodule(py::module* m) {
 | 
			
		||||
      py::arg("shape_with_layout") = absl::nullopt);
 | 
			
		||||
  ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
 | 
			
		||||
          py::arg("concat_dimension"), py::arg("split_count"),
 | 
			
		||||
          py::arg("replica_groups") = py::list());
 | 
			
		||||
          py::arg("replica_groups") = py::list(),
 | 
			
		||||
          py::arg("layout") = absl::nullopt);
 | 
			
		||||
  ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
 | 
			
		||||
          py::arg("source_target_pairs"));
 | 
			
		||||
  ops.def("CreateToken", &CreateToken, py::arg("builder"));
 | 
			
		||||
 | 
			
		||||
@ -289,7 +289,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
 | 
			
		||||
  replica_groups[0].add_replica_ids(1);
 | 
			
		||||
  HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll(
 | 
			
		||||
      ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), {a, a},
 | 
			
		||||
      replica_groups, absl::nullopt));
 | 
			
		||||
      replica_groups, /*constrain_layout=*/false, absl::nullopt));
 | 
			
		||||
  auto computation = module->AddEntryComputation(builder.Build());
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(Normalize(module.get()));
 | 
			
		||||
@ -318,7 +318,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) {
 | 
			
		||||
  replica_groups[0].add_replica_ids(1);
 | 
			
		||||
  HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll(
 | 
			
		||||
      ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), {a, a},
 | 
			
		||||
      replica_groups, absl::nullopt));
 | 
			
		||||
      replica_groups, /*constrain_layout=*/false, absl::nullopt));
 | 
			
		||||
  auto computation = module->AddEntryComputation(builder.Build());
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(Normalize(module.get()));
 | 
			
		||||
 | 
			
		||||
@ -54,8 +54,9 @@ StatusOr<bool> HloDCE::RunOnComputation(
 | 
			
		||||
         (remove_cross_partition_collective_ops &&
 | 
			
		||||
          ((instruction->opcode() == HloOpcode::kAllReduce &&
 | 
			
		||||
            !Cast<HloAllReduceInstruction>(instruction)->constrain_layout()) ||
 | 
			
		||||
           instruction->opcode() == HloOpcode::kCollectivePermute ||
 | 
			
		||||
           instruction->opcode() == HloOpcode::kAllToAll)))) {
 | 
			
		||||
           (instruction->opcode() == HloOpcode::kAllToAll &&
 | 
			
		||||
            !Cast<HloAllToAllInstruction>(instruction)->constrain_layout()) ||
 | 
			
		||||
           instruction->opcode() == HloOpcode::kCollectivePermute)))) {
 | 
			
		||||
      dead_roots.push_back(instruction);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -430,6 +430,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
 | 
			
		||||
          /*replica_groups=*/
 | 
			
		||||
          std::vector<ReplicaGroup>(proto.replica_groups().begin(),
 | 
			
		||||
                                    proto.replica_groups().end()),
 | 
			
		||||
          /*constrain_layout=*/proto.constrain_layout(),
 | 
			
		||||
          /*channel_id=*/channel_id, split_dimension);
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
@ -939,11 +940,12 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
 | 
			
		||||
    const Shape& shape, absl::Span<HloInstruction* const> operands,
 | 
			
		||||
    const std::vector<ReplicaGroup>& replica_groups,
 | 
			
		||||
    const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
 | 
			
		||||
    const absl::optional<int64>& channel_id,
 | 
			
		||||
    const absl::optional<int64>& split_dimension) {
 | 
			
		||||
  return absl::make_unique<HloAllToAllInstruction>(
 | 
			
		||||
      shape, operands, replica_groups, channel_id, split_dimension);
 | 
			
		||||
      shape, operands, replica_groups, constrain_layout, channel_id,
 | 
			
		||||
      split_dimension);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ std::unique_ptr<HloInstruction>
 | 
			
		||||
@ -1375,6 +1377,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
 | 
			
		||||
    case HloOpcode::kAllReduce:
 | 
			
		||||
      return channel_id().has_value() ||
 | 
			
		||||
             Cast<HloAllReduceInstruction>(this)->constrain_layout();
 | 
			
		||||
    case HloOpcode::kAllToAll:
 | 
			
		||||
      return Cast<HloAllToAllInstruction>(this)->constrain_layout();
 | 
			
		||||
    case HloOpcode::kCustomCall:
 | 
			
		||||
      return Cast<HloCustomCallInstruction>(this)
 | 
			
		||||
          ->custom_call_has_side_effect();
 | 
			
		||||
 | 
			
		||||
@ -667,7 +667,7 @@ class HloInstruction {
 | 
			
		||||
  // It is used to implement the higher-level instruction in XlaBuilder.
 | 
			
		||||
  static std::unique_ptr<HloInstruction> CreateAllToAll(
 | 
			
		||||
      const Shape& shape, absl::Span<HloInstruction* const> operands,
 | 
			
		||||
      const std::vector<ReplicaGroup>& replica_groups,
 | 
			
		||||
      const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
 | 
			
		||||
      const absl::optional<int64>& channel_id,
 | 
			
		||||
      const absl::optional<int64>& split_dimension = absl::nullopt);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -513,10 +513,11 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
 | 
			
		||||
HloCollectiveInstruction::HloCollectiveInstruction(
 | 
			
		||||
    HloOpcode opcode, const Shape& shape,
 | 
			
		||||
    absl::Span<HloInstruction* const> operands,
 | 
			
		||||
    const std::vector<ReplicaGroup>& replica_groups,
 | 
			
		||||
    const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
 | 
			
		||||
    const absl::optional<int64>& channel_id)
 | 
			
		||||
    : HloChannelInstruction(opcode, shape, channel_id),
 | 
			
		||||
      replica_groups_(replica_groups) {
 | 
			
		||||
      replica_groups_(replica_groups),
 | 
			
		||||
      constrain_layout_(constrain_layout) {
 | 
			
		||||
  for (auto operand : operands) {
 | 
			
		||||
    AppendOperand(operand);
 | 
			
		||||
  }
 | 
			
		||||
@ -526,6 +527,7 @@ HloInstructionProto HloCollectiveInstruction::ToProto() const {
 | 
			
		||||
  HloInstructionProto proto = HloChannelInstruction::ToProto();
 | 
			
		||||
  *proto.mutable_replica_groups() = {replica_groups_.begin(),
 | 
			
		||||
                                     replica_groups_.end()};
 | 
			
		||||
  proto.set_constrain_layout(constrain_layout_);
 | 
			
		||||
  return proto;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -535,6 +537,9 @@ std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
 | 
			
		||||
      HloChannelInstruction::ExtraAttributesToStringImpl(options);
 | 
			
		||||
  result.push_back(
 | 
			
		||||
      StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
 | 
			
		||||
  if (constrain_layout_) {
 | 
			
		||||
    result.push_back("constrain_layout=true");
 | 
			
		||||
  }
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -557,8 +562,7 @@ HloAllReduceInstruction::HloAllReduceInstruction(
 | 
			
		||||
    const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
 | 
			
		||||
    const absl::optional<int64>& channel_id, bool use_global_device_ids)
 | 
			
		||||
    : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
 | 
			
		||||
                               replica_groups, channel_id),
 | 
			
		||||
      constrain_layout_(constrain_layout),
 | 
			
		||||
                               replica_groups, constrain_layout, channel_id),
 | 
			
		||||
      use_global_device_ids_(use_global_device_ids) {
 | 
			
		||||
  AppendComputation(reduce_computation);
 | 
			
		||||
}
 | 
			
		||||
@ -574,7 +578,6 @@ bool HloAllReduceInstruction::IsNoop() const {
 | 
			
		||||
 | 
			
		||||
HloInstructionProto HloAllReduceInstruction::ToProto() const {
 | 
			
		||||
  HloInstructionProto proto = HloCollectiveInstruction::ToProto();
 | 
			
		||||
  proto.set_constrain_layout(constrain_layout_);
 | 
			
		||||
  proto.set_use_global_device_ids(use_global_device_ids_);
 | 
			
		||||
  return proto;
 | 
			
		||||
}
 | 
			
		||||
@ -583,9 +586,6 @@ std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
 | 
			
		||||
    const HloPrintOptions& options) const {
 | 
			
		||||
  std::vector<string> result =
 | 
			
		||||
      HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
 | 
			
		||||
  if (constrain_layout_) {
 | 
			
		||||
    result.push_back("constrain_layout=true");
 | 
			
		||||
  }
 | 
			
		||||
  if (use_global_device_ids_) {
 | 
			
		||||
    result.push_back("use_global_device_ids=true");
 | 
			
		||||
  }
 | 
			
		||||
@ -614,11 +614,11 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
 | 
			
		||||
 | 
			
		||||
HloAllToAllInstruction::HloAllToAllInstruction(
 | 
			
		||||
    const Shape& shape, absl::Span<HloInstruction* const> operands,
 | 
			
		||||
    const std::vector<ReplicaGroup>& replica_groups,
 | 
			
		||||
    const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
 | 
			
		||||
    const absl::optional<int64>& channel_id,
 | 
			
		||||
    const absl::optional<int64>& split_dimension)
 | 
			
		||||
    : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
 | 
			
		||||
                               replica_groups, channel_id),
 | 
			
		||||
                               replica_groups, constrain_layout, channel_id),
 | 
			
		||||
      split_dimension_(split_dimension) {}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<HloInstruction>
 | 
			
		||||
@ -626,7 +626,8 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
 | 
			
		||||
    const Shape& shape, absl::Span<HloInstruction* const> new_operands,
 | 
			
		||||
    HloCloneContext* /*context*/) const {
 | 
			
		||||
  return absl::make_unique<HloAllToAllInstruction>(
 | 
			
		||||
      shape, new_operands, replica_groups(), channel_id(), split_dimension());
 | 
			
		||||
      shape, new_operands, replica_groups(), constrain_layout(), channel_id(),
 | 
			
		||||
      split_dimension());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
HloInstructionProto HloAllToAllInstruction::ToProto() const {
 | 
			
		||||
 | 
			
		||||
@ -313,37 +313,6 @@ class HloCollectiveInstruction : public HloChannelInstruction {
 | 
			
		||||
    return replica_groups_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  explicit HloCollectiveInstruction(
 | 
			
		||||
      HloOpcode opcode, const Shape& shape,
 | 
			
		||||
      absl::Span<HloInstruction* const> operands,
 | 
			
		||||
      const std::vector<ReplicaGroup>& replica_groups,
 | 
			
		||||
      const absl::optional<int64>& channel_id);
 | 
			
		||||
 | 
			
		||||
  HloInstructionProto ToProto() const override;
 | 
			
		||||
 | 
			
		||||
  std::vector<string> ExtraAttributesToStringImpl(
 | 
			
		||||
      const HloPrintOptions& options) const override;
 | 
			
		||||
  bool IdenticalSlowPath(
 | 
			
		||||
      const HloInstruction& other,
 | 
			
		||||
      const std::function<bool(const HloComputation*, const HloComputation*)>&
 | 
			
		||||
          eq_computations) const override;
 | 
			
		||||
 | 
			
		||||
  std::vector<ReplicaGroup> replica_groups_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class HloAllReduceInstruction : public HloCollectiveInstruction {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit HloAllReduceInstruction(
 | 
			
		||||
      const Shape& shape, absl::Span<HloInstruction* const> operands,
 | 
			
		||||
      HloComputation* reduce_computation,
 | 
			
		||||
      const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
 | 
			
		||||
      const absl::optional<int64>& channel_id, bool use_global_device_ids);
 | 
			
		||||
 | 
			
		||||
  // Returns true if the AllReduce does no communication, so it's equivalent
 | 
			
		||||
  // to a mem copy.
 | 
			
		||||
  bool IsNoop() const;
 | 
			
		||||
 | 
			
		||||
  // Returns true if the layout of the AllReduce is enforced by XLA client (as
 | 
			
		||||
  // the layout set in the shape). The only reason for the client to set the
 | 
			
		||||
  // layout is to separately compile computations that communicate with
 | 
			
		||||
@ -359,6 +328,38 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
 | 
			
		||||
  // unconstrained AllReduce instructions (checked by HloVerifier).
 | 
			
		||||
  bool constrain_layout() const { return constrain_layout_; }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  explicit HloCollectiveInstruction(
 | 
			
		||||
      HloOpcode opcode, const Shape& shape,
 | 
			
		||||
      absl::Span<HloInstruction* const> operands,
 | 
			
		||||
      const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
 | 
			
		||||
      const absl::optional<int64>& channel_id);
 | 
			
		||||
 | 
			
		||||
  HloInstructionProto ToProto() const override;
 | 
			
		||||
 | 
			
		||||
  std::vector<string> ExtraAttributesToStringImpl(
 | 
			
		||||
      const HloPrintOptions& options) const override;
 | 
			
		||||
  bool IdenticalSlowPath(
 | 
			
		||||
      const HloInstruction& other,
 | 
			
		||||
      const std::function<bool(const HloComputation*, const HloComputation*)>&
 | 
			
		||||
          eq_computations) const override;
 | 
			
		||||
 | 
			
		||||
  std::vector<ReplicaGroup> replica_groups_;
 | 
			
		||||
  bool constrain_layout_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class HloAllReduceInstruction : public HloCollectiveInstruction {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit HloAllReduceInstruction(
 | 
			
		||||
      const Shape& shape, absl::Span<HloInstruction* const> operands,
 | 
			
		||||
      HloComputation* reduce_computation,
 | 
			
		||||
      const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
 | 
			
		||||
      const absl::optional<int64>& channel_id, bool use_global_device_ids);
 | 
			
		||||
 | 
			
		||||
  // Returns true if the AllReduce does no communication, so it's equivalent
 | 
			
		||||
  // to a mem copy.
 | 
			
		||||
  bool IsNoop() const;
 | 
			
		||||
 | 
			
		||||
  // Returns true if the ids in the ReplicaGroup config represent a global id of
 | 
			
		||||
  // (replica_id * partition_count + partition_id) instead of a replica id.
 | 
			
		||||
  // This enables more flexible grouping of devices if this all-reduce is both
 | 
			
		||||
@ -387,7 +388,6 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
 | 
			
		||||
      const Shape& shape, absl::Span<HloInstruction* const> new_operands,
 | 
			
		||||
      HloCloneContext* context) const override;
 | 
			
		||||
 | 
			
		||||
  bool constrain_layout_;
 | 
			
		||||
  bool use_global_device_ids_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -395,7 +395,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit HloAllToAllInstruction(
 | 
			
		||||
      const Shape& shape, absl::Span<HloInstruction* const> operands,
 | 
			
		||||
      const std::vector<ReplicaGroup>& replica_groups,
 | 
			
		||||
      const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
 | 
			
		||||
      const absl::optional<int64>& channel_id,
 | 
			
		||||
      const absl::optional<int64>& split_dimension);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -887,6 +887,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
 | 
			
		||||
      optional<std::vector<int64>> dimensions;
 | 
			
		||||
      attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
 | 
			
		||||
                             &dimensions};
 | 
			
		||||
      optional<bool> constrain_layout;
 | 
			
		||||
      attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
 | 
			
		||||
                                   &constrain_layout};
 | 
			
		||||
      if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
 | 
			
		||||
          (dimensions && dimensions->size() != 1)) {
 | 
			
		||||
        return false;
 | 
			
		||||
@ -900,7 +903,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
 | 
			
		||||
        split_dimension = dimensions->at(0);
 | 
			
		||||
      }
 | 
			
		||||
      instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
 | 
			
		||||
          shape, operands, replica_groups, channel_id, split_dimension));
 | 
			
		||||
          shape, operands, replica_groups,
 | 
			
		||||
          constrain_layout ? *constrain_layout : false, channel_id,
 | 
			
		||||
          split_dimension));
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case HloOpcode::kCollectivePermute: {
 | 
			
		||||
 | 
			
		||||
@ -944,11 +944,6 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
 | 
			
		||||
    std::vector<std::vector<int64>> replica_groups) {
 | 
			
		||||
  const char* kTemplate = R"(
 | 
			
		||||
  HloModule test
 | 
			
		||||
  add {
 | 
			
		||||
    x = f32[] parameter(0)
 | 
			
		||||
    y = f32[] parameter(1)
 | 
			
		||||
    ROOT add = f32[] add(x, y)
 | 
			
		||||
  }
 | 
			
		||||
  ENTRY entry {
 | 
			
		||||
    p0 = f32[128]{0} parameter(0)
 | 
			
		||||
    p1 = f32[128]{0} parameter(1)
 | 
			
		||||
@ -994,6 +989,24 @@ TEST_F(HloVerifierTest, AllToAll_WrongNumberOfReplicasInGroup) {
 | 
			
		||||
              HasSubstr("Replica group has size 1"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
 | 
			
		||||
  const char* const kModuleStr = R"(
 | 
			
		||||
  HloModule test
 | 
			
		||||
  ENTRY entry {
 | 
			
		||||
    p0 = f32[128,4]{0,1} parameter(0)
 | 
			
		||||
    p1 = f32[128,4]{1,0} parameter(1)
 | 
			
		||||
    ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
 | 
			
		||||
      replica_groups={{0,1}}
 | 
			
		||||
  }
 | 
			
		||||
  )";
 | 
			
		||||
  HloModuleConfig config;
 | 
			
		||||
  config.set_replica_count(2);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(auto module,
 | 
			
		||||
                          ParseAndReturnUnverifiedModule(kModuleStr, config));
 | 
			
		||||
  EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
 | 
			
		||||
              HasSubstr("HLO all-to-all has operands with different shapes"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
 | 
			
		||||
  const char* const kModuleStr = R"(
 | 
			
		||||
  HloModule test
 | 
			
		||||
 | 
			
		||||
@ -432,10 +432,10 @@ bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
 | 
			
		||||
  return custom_call != nullptr && custom_call->layout_constrained();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool IsLayoutConstrainedAllReduce(const HloInstruction* instruction) {
 | 
			
		||||
  const HloAllReduceInstruction* all_reduce =
 | 
			
		||||
      DynCast<HloAllReduceInstruction>(instruction);
 | 
			
		||||
  return all_reduce != nullptr && all_reduce->constrain_layout();
 | 
			
		||||
bool IsLayoutConstrainedCollective(const HloInstruction* instruction) {
 | 
			
		||||
  const HloCollectiveInstruction* collective =
 | 
			
		||||
      DynCast<HloCollectiveInstruction>(instruction);
 | 
			
		||||
  return collective != nullptr && collective->constrain_layout();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
@ -520,7 +520,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
 | 
			
		||||
        TF_RETURN_IF_ERROR(
 | 
			
		||||
            constraints->SetBufferLayout(new_shape.layout(), *buffer));
 | 
			
		||||
      }
 | 
			
		||||
    } else if (IsLayoutConstrainedAllReduce(instruction)) {
 | 
			
		||||
    } else if (IsLayoutConstrainedCollective(instruction)) {
 | 
			
		||||
      TF_RETURN_IF_ERROR(
 | 
			
		||||
          constraints->SetInstructionLayout(instruction->shape(), instruction));
 | 
			
		||||
    } else if (instruction->IsCrossModuleAllReduce()) {
 | 
			
		||||
@ -1808,7 +1808,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
 | 
			
		||||
    // Some instructions carry mandatory layouts in their shape.
 | 
			
		||||
    if (instruction->opcode() != HloOpcode::kInfeed &&
 | 
			
		||||
        !IsLayoutConstrainedCustomCall(instruction) &&
 | 
			
		||||
        !IsLayoutConstrainedAllReduce(instruction)) {
 | 
			
		||||
        !IsLayoutConstrainedCollective(instruction)) {
 | 
			
		||||
      LayoutUtil::ClearLayout(instruction->mutable_shape());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -1385,5 +1385,42 @@ ENTRY entry_computation {
 | 
			
		||||
  ExpectLayoutIs(crs->operand(1)->shape(), {1, 0});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(LayoutAssignmentTest, LayoutConstrainedAllToAll) {
 | 
			
		||||
  const char* module_str = R"(
 | 
			
		||||
HloModule test_module
 | 
			
		||||
 | 
			
		||||
add {
 | 
			
		||||
  lhs = f32[] parameter(0)
 | 
			
		||||
  rhs = f32[] parameter(1)
 | 
			
		||||
  ROOT add = f32[] add(lhs, rhs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ENTRY entry_computation {
 | 
			
		||||
  param = (f32[16,4]{0,1}, f32[16,4]{1,0}) parameter(0)
 | 
			
		||||
  gte0 = f32[16,4] get-tuple-element(param), index=0
 | 
			
		||||
  gte1 = f32[16,4] get-tuple-element(param), index=1
 | 
			
		||||
  alltoall = (f32[16,4]{1,0}, f32[16,4]{1,0}) all-reduce(gte0, gte1),
 | 
			
		||||
    replica_groups={{0,1}}, constrain_layout=true, to_apply=add
 | 
			
		||||
  gte2 = f32[16,4] get-tuple-element(alltoall), index=0
 | 
			
		||||
  gte3 = f32[16,4] get-tuple-element(alltoall), index=1
 | 
			
		||||
  ROOT concat = f32[16,8]{0,1} concatenate(gte2, gte3), dimensions={1}
 | 
			
		||||
}
 | 
			
		||||
)";
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      std::unique_ptr<HloModule> m,
 | 
			
		||||
      ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
 | 
			
		||||
  ComputationLayout computation_layout(
 | 
			
		||||
      m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
 | 
			
		||||
 | 
			
		||||
  ChannelLayoutConstraints channel_constraints;
 | 
			
		||||
  AssignLayouts(m.get(), &computation_layout, &channel_constraints);
 | 
			
		||||
 | 
			
		||||
  const HloInstruction* alltoall = FindInstruction(m.get(), "alltoall");
 | 
			
		||||
  ExpectTupleLayoutIs(alltoall->shape(), {{1, 0}, {1, 0}});
 | 
			
		||||
  ExpectLayoutIs(alltoall->operand(0)->shape(), {1, 0});
 | 
			
		||||
  ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user