Add optional layout constraints for AllReduce
PiperOrigin-RevId: 284198168 Change-Id: I4ef59638851ca1cef689f7db622bd06ca41bccad
This commit is contained in:
parent
c28ca27b96
commit
c94797780e
@ -2112,7 +2112,8 @@ XlaOp XlaBuilder::CrossReplicaSum(
|
||||
|
||||
XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id) {
|
||||
const absl::optional<ChannelHandle>& channel_id,
|
||||
const absl::optional<Shape>& shape_with_layout) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
@ -2136,9 +2137,31 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
operand_shapes.push_back(operand_shape);
|
||||
operands.push_back(operand);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
|
||||
ShapeInference::InferAllReduceShape(operand_shapes));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
if (shape_with_layout) {
|
||||
if (!LayoutUtil::HasLayout(*shape_with_layout)) {
|
||||
return InvalidArgument("shape_with_layout must have the layout set: %s",
|
||||
shape_with_layout->ToString());
|
||||
}
|
||||
if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"Provided shape_with_layout must be compatible with the "
|
||||
"operand shape: %s vs %s",
|
||||
shape_with_layout->ToString(), operand_shape->ToString());
|
||||
}
|
||||
instr.set_constrain_layout(true);
|
||||
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
|
||||
// For a single-element tuple, take the tuple element shape.
|
||||
TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
|
||||
*instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
|
||||
} else {
|
||||
*instr.mutable_shape() = shape_with_layout->ToProto();
|
||||
}
|
||||
} else {
|
||||
*instr.mutable_shape() = inferred_shape.ToProto();
|
||||
}
|
||||
|
||||
for (const ReplicaGroup& group : replica_groups) {
|
||||
*instr.add_replica_groups() = group;
|
||||
@ -2153,10 +2176,10 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto all_reduce,
|
||||
AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
|
||||
if (operand_shape->IsTuple() && !shape.IsTuple()) {
|
||||
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
|
||||
// For a single-element tuple, wrap the result into a tuple.
|
||||
TF_RET_CHECK(operand_shapes.size() == 1);
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], shape));
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
|
||||
return Tuple({all_reduce});
|
||||
}
|
||||
return all_reduce;
|
||||
@ -3282,9 +3305,10 @@ XlaOp CrossReplicaSum(const XlaOp operand,
|
||||
|
||||
XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id) {
|
||||
const absl::optional<ChannelHandle>& channel_id,
|
||||
const absl::optional<Shape>& shape_with_layout) {
|
||||
return operand.builder()->AllReduce(operand, computation, replica_groups,
|
||||
channel_id);
|
||||
channel_id, shape_with_layout);
|
||||
}
|
||||
|
||||
XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
|
||||
|
@ -514,7 +514,8 @@ class XlaBuilder {
|
||||
XlaOp AllReduce(
|
||||
XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
||||
const absl::optional<Shape>& shape_with_layout = absl::nullopt);
|
||||
|
||||
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
|
||||
int64 split_count,
|
||||
@ -922,7 +923,8 @@ class XlaBuilder {
|
||||
absl::Span<const ReplicaGroup> replica_groups);
|
||||
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id);
|
||||
const absl::optional<ChannelHandle>& channel_id,
|
||||
const absl::optional<Shape>& shape_with_layout);
|
||||
friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups);
|
||||
@ -1666,10 +1668,14 @@ XlaOp CrossReplicaSum(XlaOp operand,
|
||||
// - `channel_id`: for Allreduce nodes from different modules, if they have the
|
||||
// same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
|
||||
// applied cross modules.
|
||||
XlaOp AllReduce(
|
||||
XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
|
||||
//
|
||||
// - `shape_with_layout`: forces the layout of the AllReduce to the given
|
||||
// layout. This is used to guarantee the same layout for a group of AllReduce
|
||||
// ops compiled separately.
|
||||
XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
||||
const absl::optional<Shape>& shape_with_layout = absl::nullopt);
|
||||
|
||||
// Enqueues an operation that do an Alltoall of the operand cross cores.
|
||||
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
|
||||
|
@ -639,10 +639,12 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
py::module ops = m.def_submodule("ops", "XLA operations");
|
||||
|
||||
ops.def("AfterAll", &AfterAll);
|
||||
ops.def("AllReduce",
|
||||
static_cast<XlaOp (*)(
|
||||
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
|
||||
const absl::optional<ChannelHandle>&)>(&AllReduce));
|
||||
ops.def(
|
||||
"AllReduce",
|
||||
static_cast<XlaOp (*)(
|
||||
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
|
||||
const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
|
||||
&AllReduce));
|
||||
ops.def("AllToAll", &AllToAll);
|
||||
ops.def("CollectivePermute", &CollectivePermute);
|
||||
ops.def("CreateToken", &CreateToken);
|
||||
|
@ -1034,7 +1034,7 @@ class ComputationBuilder(object):
|
||||
"""
|
||||
replica_groups_protos = _get_replica_groups_protos(replica_groups)
|
||||
return ops.AllReduce(operand, computation.computation,
|
||||
replica_groups_protos, None)
|
||||
replica_groups_protos, None, None)
|
||||
|
||||
def AllToAll(self,
|
||||
operand,
|
||||
|
@ -1505,6 +1505,7 @@ cc_library(
|
||||
hdrs = ["hlo_query.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
|
@ -239,6 +239,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) {
|
||||
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
|
||||
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum,
|
||||
/*replica_groups=*/{},
|
||||
/*constrain_layout=*/false,
|
||||
/*channel_id=*/absl::nullopt));
|
||||
HloInstruction* gte_a = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
|
||||
|
@ -259,6 +259,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
|
||||
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
|
||||
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
|
||||
/*replica_groups=*/{},
|
||||
/*constrain_layout=*/false,
|
||||
/*channel_id=*/absl::nullopt));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
|
||||
|
@ -211,7 +211,8 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
|
||||
HloInstruction* all_reduce =
|
||||
builder.AddInstruction(HloInstruction::CreateAllReduce(
|
||||
ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
|
||||
/*replica_groups=*/{}, /*channel_id=*/1));
|
||||
/*replica_groups=*/{}, /*constrain_layout=*/false,
|
||||
/*channel_id=*/1));
|
||||
HloInstruction* gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
|
||||
HloInstruction* gte1 = builder.AddInstruction(
|
||||
|
@ -400,6 +400,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
/*replica_groups=*/
|
||||
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
|
||||
proto.replica_groups().end()),
|
||||
/*constrain_layout=*/proto.constrain_layout(),
|
||||
/*channel_id=*/channel_id);
|
||||
break;
|
||||
}
|
||||
@ -900,10 +901,11 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* reduce_computation,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id) {
|
||||
return absl::make_unique<HloAllReduceInstruction>(
|
||||
shape, operands, reduce_computation, replica_groups, channel_id);
|
||||
shape, operands, reduce_computation, replica_groups, constrain_layout,
|
||||
channel_id);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
|
||||
@ -1341,7 +1343,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
|
||||
case HloOpcode::kTrace:
|
||||
return true;
|
||||
case HloOpcode::kAllReduce:
|
||||
return channel_id().has_value();
|
||||
return channel_id().has_value() ||
|
||||
Cast<HloAllReduceInstruction>(this)->constrain_layout();
|
||||
case HloOpcode::kCustomCall:
|
||||
return Cast<HloCustomCallInstruction>(this)
|
||||
->custom_call_has_side_effect();
|
||||
|
@ -607,7 +607,7 @@ class HloInstruction {
|
||||
static std::unique_ptr<HloInstruction> CreateAllReduce(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* reduce_computation,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
// An all-to-all op takes N array operands of the same shape and scatters them
|
||||
|
@ -553,10 +553,11 @@ bool HloCollectiveInstruction::IdenticalSlowPath(
|
||||
HloAllReduceInstruction::HloAllReduceInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* reduce_computation,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id)
|
||||
: HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
|
||||
replica_groups, channel_id) {
|
||||
replica_groups, channel_id),
|
||||
constrain_layout_(constrain_layout) {
|
||||
AppendComputation(reduce_computation);
|
||||
}
|
||||
|
||||
@ -569,12 +570,29 @@ bool HloAllReduceInstruction::IsNoop() const {
|
||||
return !channel_id();
|
||||
}
|
||||
|
||||
HloInstructionProto HloAllReduceInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
|
||||
proto.set_constrain_layout(constrain_layout_);
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
std::vector<string> result =
|
||||
HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
|
||||
if (constrain_layout_) {
|
||||
result.push_back("constrain_layout=true");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool HloAllReduceInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const {
|
||||
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
|
||||
return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
|
||||
constrain_layout() == casted_other.constrain_layout() &&
|
||||
eq_computations(to_apply(), casted_other.to_apply());
|
||||
}
|
||||
|
||||
@ -583,7 +601,8 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* /*context*/) const {
|
||||
return absl::make_unique<HloAllReduceInstruction>(
|
||||
shape, new_operands, to_apply(), replica_groups(), channel_id());
|
||||
shape, new_operands, to_apply(), replica_groups(), constrain_layout(),
|
||||
channel_id());
|
||||
}
|
||||
|
||||
HloAllToAllInstruction::HloAllToAllInstruction(
|
||||
|
@ -336,13 +336,33 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
|
||||
explicit HloAllReduceInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* reduce_computation,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
// Returns true if the AllReduce does no communication, so it's equivalent
|
||||
// to a mem copy.
|
||||
bool IsNoop() const;
|
||||
|
||||
// Returns true if the layout of the AllReduce is enforced by XLA client (as
|
||||
// the layout set in the shape). The only reason for the client to set the
|
||||
// layout is to separately compile computations that communicate with
|
||||
// AllReduce. Since this field is only set `true` by the client, the compiler
|
||||
// only needs to propagate existing values (e.g., Clone, X64Rewriter) or set
|
||||
// `false` for all other cases.
|
||||
//
|
||||
// When this is `true`, there may be communication endpoints outside the
|
||||
// current compilation unit, so the compiler considers this AllReduce as
|
||||
// side-effecting to disable compiler transformations. The compiler is free to
|
||||
// transform unconstrained AllReduces differently across compilation units.
|
||||
// It is an error for an HloModule to have a mix of constrained and
|
||||
// unconstrained AllReduce instructions (checked by HloVerifier).
|
||||
bool constrain_layout() const { return constrain_layout_; }
|
||||
|
||||
protected:
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const override;
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
private:
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
@ -353,6 +373,8 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
|
||||
bool constrain_layout_;
|
||||
};
|
||||
|
||||
class HloAllToAllInstruction : public HloCollectiveInstruction {
|
||||
|
@ -857,11 +857,14 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
optional<HloComputation*> to_apply;
|
||||
optional<std::vector<int64>> replica_group_ids;
|
||||
optional<int64> channel_id;
|
||||
optional<bool> constrain_layout;
|
||||
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
|
||||
&to_apply};
|
||||
attrs["replica_groups"] = {/*required=*/false,
|
||||
AttrTy::kBracedInt64ListList, &tmp_groups};
|
||||
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
|
||||
attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
|
||||
&constrain_layout};
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
@ -870,7 +873,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
replica_groups = CreateReplicaGroups(*tmp_groups);
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
|
||||
shape, operands, *to_apply, replica_groups, channel_id));
|
||||
shape, operands, *to_apply, replica_groups,
|
||||
constrain_layout ? *constrain_layout : false, channel_id));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kAllToAll: {
|
||||
|
@ -1472,6 +1472,24 @@ ENTRY AllReduceWithSubgroups {
|
||||
ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// all-reduce with constrained layout
|
||||
{
|
||||
"AllReduceWithLayout",
|
||||
R"(HloModule CRS
|
||||
|
||||
add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY CRS {
|
||||
input = f32[8]{0} parameter(0)
|
||||
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, constrain_layout=true, to_apply=add
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// all-reduce with all-reduce-id
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_query.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
||||
@ -119,5 +121,17 @@ bool ContainsInstrWithOpcode(const HloComputation* comp,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ContainsLayoutConstrainedAllReduce(const HloModule& module) {
|
||||
for (auto computation : module.computations()) {
|
||||
for (auto hlo : computation->instructions()) {
|
||||
if (hlo->opcode() == HloOpcode::kAllReduce &&
|
||||
DynCast<HloAllReduceInstruction>(hlo)->constrain_layout()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace hlo_query
|
||||
} // namespace xla
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -72,6 +73,10 @@ bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
|
||||
HloInstruction** matching_operand,
|
||||
HloInstruction** other_operand);
|
||||
|
||||
// Returns whether the module contains all-reduce instructions with constrained
|
||||
// layout.
|
||||
bool ContainsLayoutConstrainedAllReduce(const HloModule& module);
|
||||
|
||||
} // namespace hlo_query
|
||||
} // namespace xla
|
||||
|
||||
|
@ -1310,6 +1310,29 @@ Status VerifyAsynchronousCopies(const HloModule& module) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks that AllReduce instructions in the module are either all layout
|
||||
// constrained or all unconstrained.
|
||||
Status VerifyLayoutConstrainedAllReduce(const HloModule& module) {
|
||||
const HloAllReduceInstruction* reference = nullptr;
|
||||
for (const HloComputation* computation : module.computations()) {
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() != HloOpcode::kAllReduce) {
|
||||
continue;
|
||||
}
|
||||
auto all_reduce = DynCast<HloAllReduceInstruction>(instruction);
|
||||
if (!reference) {
|
||||
reference = all_reduce;
|
||||
}
|
||||
if (reference->constrain_layout() != all_reduce->constrain_layout()) {
|
||||
return FailedPrecondition(
|
||||
"HloModule has a mix of layout constrained and unconstrained "
|
||||
"AllReduce instructions.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks various invariants of send and recv instructions.
|
||||
Status VerifySendsAndRecvs(const HloModule& module) {
|
||||
absl::flat_hash_map<int64, const HloInstruction*> host_channels;
|
||||
@ -1697,6 +1720,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
}));
|
||||
|
||||
TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
|
||||
TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module));
|
||||
|
||||
return false;
|
||||
}
|
||||
|
@ -988,5 +988,30 @@ TEST_F(HloVerifierTest, FusionShapeVerifier) {
|
||||
HasSubstr("Fused computation shape"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, AllReduceVerifier) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule test
|
||||
|
||||
add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
input = f32[8,12]{0,1} parameter(0)
|
||||
crs0 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add
|
||||
crs1 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add,
|
||||
constrain_layout=true
|
||||
ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(crs0, crs1)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
EXPECT_THAT(
|
||||
verifier().Run(module.get()).status().error_message(),
|
||||
HasSubstr("mix of layout constrained and unconstrained AllReduce"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -432,6 +432,12 @@ bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
|
||||
return custom_call != nullptr && custom_call->layout_constrained();
|
||||
}
|
||||
|
||||
bool IsLayoutConstrainedAllReduce(HloInstruction* instruction) {
|
||||
const HloAllReduceInstruction* all_reduce =
|
||||
DynCast<HloAllReduceInstruction>(instruction);
|
||||
return all_reduce != nullptr && all_reduce->constrain_layout();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status LayoutAssignment::AddMandatoryConstraints(
|
||||
@ -516,6 +522,9 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetBufferLayout(new_shape.layout(), *buffer));
|
||||
}
|
||||
} else if (IsLayoutConstrainedAllReduce(instruction)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetInstructionLayout(instruction->shape(), instruction));
|
||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||
CHECK(get_channel_constraints(instruction))
|
||||
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
|
||||
@ -1765,7 +1774,8 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
|
||||
}
|
||||
// Some instructions carry mandatory layouts in their shape.
|
||||
if (instruction->opcode() != HloOpcode::kInfeed &&
|
||||
!IsLayoutConstrainedCustomCall(instruction)) {
|
||||
!IsLayoutConstrainedCustomCall(instruction) &&
|
||||
!IsLayoutConstrainedAllReduce(instruction)) {
|
||||
LayoutUtil::ClearLayout(instruction->mutable_shape());
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user