Add optional layout constraints for AllReduce

PiperOrigin-RevId: 284198168
Change-Id: I4ef59638851ca1cef689f7db622bd06ca41bccad
This commit is contained in:
HyoukJoong Lee 2019-12-06 09:16:38 -08:00 committed by TensorFlower Gardener
parent c28ca27b96
commit c94797780e
19 changed files with 209 additions and 29 deletions

View File

@ -2112,7 +2112,8 @@ XlaOp XlaBuilder::CrossReplicaSum(
XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id) {
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Shape>& shape_with_layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
@ -2136,9 +2137,31 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
operand_shapes.push_back(operand_shape);
operands.push_back(operand);
}
TF_ASSIGN_OR_RETURN(Shape shape,
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
ShapeInference::InferAllReduceShape(operand_shapes));
*instr.mutable_shape() = shape.ToProto();
if (shape_with_layout) {
if (!LayoutUtil::HasLayout(*shape_with_layout)) {
return InvalidArgument("shape_with_layout must have the layout set: %s",
shape_with_layout->ToString());
}
if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
return InvalidArgument(
"Provided shape_with_layout must be compatible with the "
"operand shape: %s vs %s",
shape_with_layout->ToString(), operand_shape->ToString());
}
instr.set_constrain_layout(true);
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
// For a single-element tuple, take the tuple element shape.
TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
*instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
} else {
*instr.mutable_shape() = shape_with_layout->ToProto();
}
} else {
*instr.mutable_shape() = inferred_shape.ToProto();
}
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
@ -2153,10 +2176,10 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
TF_ASSIGN_OR_RETURN(
auto all_reduce,
AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
if (operand_shape->IsTuple() && !shape.IsTuple()) {
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
// For a single-element tuple, wrap the result into a tuple.
TF_RET_CHECK(operand_shapes.size() == 1);
TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], shape));
TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
return Tuple({all_reduce});
}
return all_reduce;
@ -3282,9 +3305,10 @@ XlaOp CrossReplicaSum(const XlaOp operand,
XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id) {
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Shape>& shape_with_layout) {
return operand.builder()->AllReduce(operand, computation, replica_groups,
channel_id);
channel_id, shape_with_layout);
}
XlaOp AllToAll(const XlaOp operand, int64 split_dimension,

View File

@ -514,7 +514,8 @@ class XlaBuilder {
XlaOp AllReduce(
XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
const absl::optional<Shape>& shape_with_layout = absl::nullopt);
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
int64 split_count,
@ -922,7 +923,8 @@ class XlaBuilder {
absl::Span<const ReplicaGroup> replica_groups);
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id);
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Shape>& shape_with_layout);
friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
@ -1666,10 +1668,14 @@ XlaOp CrossReplicaSum(XlaOp operand,
// - `channel_id`: for Allreduce nodes from different modules, if they have the
// same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
// applied cross modules.
XlaOp AllReduce(
XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
//
// - `shape_with_layout`: forces the layout of the AllReduce to the given
// layout. This is used to guarantee the same layout for a group of AllReduce
// ops compiled separately.
XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
const absl::optional<Shape>& shape_with_layout = absl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,

View File

@ -639,10 +639,12 @@ PYBIND11_MODULE(xla_extension, m) {
py::module ops = m.def_submodule("ops", "XLA operations");
ops.def("AfterAll", &AfterAll);
ops.def("AllReduce",
static_cast<XlaOp (*)(
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
const absl::optional<ChannelHandle>&)>(&AllReduce));
ops.def(
"AllReduce",
static_cast<XlaOp (*)(
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
&AllReduce));
ops.def("AllToAll", &AllToAll);
ops.def("CollectivePermute", &CollectivePermute);
ops.def("CreateToken", &CreateToken);

View File

@ -1034,7 +1034,7 @@ class ComputationBuilder(object):
"""
replica_groups_protos = _get_replica_groups_protos(replica_groups)
return ops.AllReduce(operand, computation.computation,
replica_groups_protos, None)
replica_groups_protos, None, None)
def AllToAll(self,
operand,

View File

@ -1505,6 +1505,7 @@ cc_library(
hdrs = ["hlo_query.h"],
deps = [
":hlo",
":hlo_casting_utils",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"@com_google_absl//absl/container:flat_hash_set",

View File

@ -239,6 +239,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) {
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum,
/*replica_groups=*/{},
/*constrain_layout=*/false,
/*channel_id=*/absl::nullopt));
HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));

View File

@ -259,6 +259,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
/*replica_groups=*/{},
/*constrain_layout=*/false,
/*channel_id=*/absl::nullopt));
builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));

View File

@ -211,7 +211,8 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
HloInstruction* all_reduce =
builder.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
/*replica_groups=*/{}, /*channel_id=*/1));
/*replica_groups=*/{}, /*constrain_layout=*/false,
/*channel_id=*/1));
HloInstruction* gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
HloInstruction* gte1 = builder.AddInstruction(

View File

@ -400,6 +400,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
/*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()),
/*constrain_layout=*/proto.constrain_layout(),
/*channel_id=*/channel_id);
break;
}
@ -900,10 +901,11 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id) {
return absl::make_unique<HloAllReduceInstruction>(
shape, operands, reduce_computation, replica_groups, channel_id);
shape, operands, reduce_computation, replica_groups, constrain_layout,
channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
@ -1341,7 +1343,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
case HloOpcode::kTrace:
return true;
case HloOpcode::kAllReduce:
return channel_id().has_value();
return channel_id().has_value() ||
Cast<HloAllReduceInstruction>(this)->constrain_layout();
case HloOpcode::kCustomCall:
return Cast<HloCustomCallInstruction>(this)
->custom_call_has_side_effect();

View File

@ -607,7 +607,7 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateAllReduce(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id);
// An all-to-all op takes N array operands of the same shape and scatters them

View File

@ -553,10 +553,11 @@ bool HloCollectiveInstruction::IdenticalSlowPath(
HloAllReduceInstruction::HloAllReduceInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id)
: HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
replica_groups, channel_id) {
replica_groups, channel_id),
constrain_layout_(constrain_layout) {
AppendComputation(reduce_computation);
}
@ -569,12 +570,29 @@ bool HloAllReduceInstruction::IsNoop() const {
return !channel_id();
}
HloInstructionProto HloAllReduceInstruction::ToProto() const {
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
proto.set_constrain_layout(constrain_layout_);
return proto;
}
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
std::vector<string> result =
HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
if (constrain_layout_) {
result.push_back("constrain_layout=true");
}
return result;
}
bool HloAllReduceInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
constrain_layout() == casted_other.constrain_layout() &&
eq_computations(to_apply(), casted_other.to_apply());
}
@ -583,7 +601,8 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloAllReduceInstruction>(
shape, new_operands, to_apply(), replica_groups(), channel_id());
shape, new_operands, to_apply(), replica_groups(), constrain_layout(),
channel_id());
}
HloAllToAllInstruction::HloAllToAllInstruction(

View File

@ -336,13 +336,33 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
explicit HloAllReduceInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id);
// Returns true if the AllReduce does no communication, so it's equivalent
// to a mem copy.
bool IsNoop() const;
// Returns true if the layout of the AllReduce is enforced by XLA client (as
// the layout set in the shape). The only reason for the client to set the
// layout is to separately compile computations that communicate with
// AllReduce. Since this field is only set `true` by the client, the compiler
// only needs to propagate existing values (e.g., Clone, X64Rewriter) or set
// `false` for all other cases.
//
// When this is `true`, there may be communication endpoints outside the
// current compilation unit, so the compiler considers this AllReduce as
// side-effecting to disable compiler transformations. The compiler is free to
// transform unconstrained AllReduces differently across compilation units.
// It is an error for an HloModule to have a mix of constrained and
// unconstrained AllReduce instructions (checked by HloVerifier).
bool constrain_layout() const { return constrain_layout_; }
protected:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
HloInstructionProto ToProto() const override;
private:
bool IdenticalSlowPath(
const HloInstruction& other,
@ -353,6 +373,8 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
bool constrain_layout_;
};
class HloAllToAllInstruction : public HloCollectiveInstruction {

View File

@ -857,11 +857,14 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
optional<HloComputation*> to_apply;
optional<std::vector<int64>> replica_group_ids;
optional<int64> channel_id;
optional<bool> constrain_layout;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
&constrain_layout};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@ -870,7 +873,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
replica_groups = CreateReplicaGroups(*tmp_groups);
}
instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
shape, operands, *to_apply, replica_groups, channel_id));
shape, operands, *to_apply, replica_groups,
constrain_layout ? *constrain_layout : false, channel_id));
break;
}
case HloOpcode::kAllToAll: {

View File

@ -1472,6 +1472,24 @@ ENTRY AllReduceWithSubgroups {
ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
}
)"
},
// all-reduce with constrained layout
{
"AllReduceWithLayout",
R"(HloModule CRS
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, constrain_layout=true, to_apply=add
}
)"
},
// all-reduce with all-reduce-id

View File

@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -119,5 +121,17 @@ bool ContainsInstrWithOpcode(const HloComputation* comp,
return false;
}
bool ContainsLayoutConstrainedAllReduce(const HloModule& module) {
for (auto computation : module.computations()) {
for (auto hlo : computation->instructions()) {
if (hlo->opcode() == HloOpcode::kAllReduce &&
DynCast<HloAllReduceInstruction>(hlo)->constrain_layout()) {
return true;
}
}
}
return false;
}
} // namespace hlo_query
} // namespace xla

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
namespace xla {
@ -72,6 +73,10 @@ bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
HloInstruction** matching_operand,
HloInstruction** other_operand);
// Returns whether the module contains all-reduce instructions with constrained
// layout.
bool ContainsLayoutConstrainedAllReduce(const HloModule& module);
} // namespace hlo_query
} // namespace xla

View File

@ -1310,6 +1310,29 @@ Status VerifyAsynchronousCopies(const HloModule& module) {
return Status::OK();
}
// Checks that AllReduce instructions in the module are either all layout
// constrained or all unconstrained.
Status VerifyLayoutConstrainedAllReduce(const HloModule& module) {
const HloAllReduceInstruction* reference = nullptr;
for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() != HloOpcode::kAllReduce) {
continue;
}
auto all_reduce = DynCast<HloAllReduceInstruction>(instruction);
if (!reference) {
reference = all_reduce;
}
if (reference->constrain_layout() != all_reduce->constrain_layout()) {
return FailedPrecondition(
"HloModule has a mix of layout constrained and unconstrained "
"AllReduce instructions.");
}
}
}
return Status::OK();
}
// Checks various invariants of send and recv instructions.
Status VerifySendsAndRecvs(const HloModule& module) {
absl::flat_hash_map<int64, const HloInstruction*> host_channels;
@ -1697,6 +1720,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
}));
TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module));
return false;
}

View File

@ -988,5 +988,30 @@ TEST_F(HloVerifierTest, FusionShapeVerifier) {
HasSubstr("Fused computation shape"));
}
TEST_F(HloVerifierTest, AllReduceVerifier) {
const char* const kModuleStr = R"(
HloModule test
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
input = f32[8,12]{0,1} parameter(0)
crs0 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add
crs1 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add,
constrain_layout=true
ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(crs0, crs1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(
verifier().Run(module.get()).status().error_message(),
HasSubstr("mix of layout constrained and unconstrained AllReduce"));
}
} // namespace
} // namespace xla

View File

@ -432,6 +432,12 @@ bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
return custom_call != nullptr && custom_call->layout_constrained();
}
bool IsLayoutConstrainedAllReduce(HloInstruction* instruction) {
const HloAllReduceInstruction* all_reduce =
DynCast<HloAllReduceInstruction>(instruction);
return all_reduce != nullptr && all_reduce->constrain_layout();
}
} // namespace
Status LayoutAssignment::AddMandatoryConstraints(
@ -516,6 +522,9 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
} else if (IsLayoutConstrainedAllReduce(instruction)) {
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(instruction->shape(), instruction));
} else if (instruction->IsCrossModuleAllReduce()) {
CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
@ -1765,7 +1774,8 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
}
// Some instructions carry mandatory layouts in their shape.
if (instruction->opcode() != HloOpcode::kInfeed &&
!IsLayoutConstrainedCustomCall(instruction)) {
!IsLayoutConstrainedCustomCall(instruction) &&
!IsLayoutConstrainedAllReduce(instruction)) {
LayoutUtil::ClearLayout(instruction->mutable_shape());
}
}