[XLA] Make sharding propagation faster
1. Avoid copy in MergeSharding 2. Keep track of a workset to avoid unnecessary computing. PiperOrigin-RevId: 326768403 Change-Id: Iea3f1ff3c448864a06f4ebb14c37f73a16ebea1e
This commit is contained in:
parent
a74a993732
commit
ea0a469bdd
@ -120,34 +120,34 @@ HloSharding MergeForMoreSpecificSharding(const HloSharding& a,
|
|||||||
return IsShardingMoreSpecific(a, b) ? a : b;
|
return IsShardingMoreSpecific(a, b) ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a sharding that is refined by merging old and to_merge. May combine
|
// Tries to refine `to_merge` by combining with `old`. Returns if the final
|
||||||
// partial sharding in addition to MergeForMoreSpecificSharding().
|
// `to_merge` is more specific than `old`. May combine partial sharding in
|
||||||
HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
// addition to MergeForMoreSpecificSharding().
|
||||||
bool may_combine_partial_sharding) {
|
bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
|
||||||
|
bool may_combine_partial_sharding) {
|
||||||
if (old.IsTuple()) {
|
if (old.IsTuple()) {
|
||||||
HloSharding result = old;
|
CHECK(to_merge->IsTuple());
|
||||||
CHECK(to_merge.IsTuple());
|
bool changed = false;
|
||||||
CHECK_EQ(old.tuple_elements().size(), to_merge.tuple_elements().size());
|
for (int64 i = 0; i < old.tuple_elements().size(); ++i) {
|
||||||
for (int64 i = 0; i < result.tuple_elements().size(); ++i) {
|
changed |=
|
||||||
result.tuple_elements()[i] =
|
MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
|
||||||
MergeSharding(old.tuple_elements()[i], to_merge.tuple_elements()[i],
|
|
||||||
may_combine_partial_sharding);
|
may_combine_partial_sharding);
|
||||||
}
|
}
|
||||||
return result;
|
return changed;
|
||||||
}
|
}
|
||||||
if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() ||
|
if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() ||
|
||||||
!to_merge.ReplicateOnLastTileDim() ||
|
!to_merge->ReplicateOnLastTileDim() ||
|
||||||
old.tile_assignment().num_elements() !=
|
old.tile_assignment().num_elements() !=
|
||||||
to_merge.tile_assignment().num_elements()) {
|
to_merge->tile_assignment().num_elements()) {
|
||||||
return IsShardingMoreSpecific(to_merge, old) ? to_merge : old;
|
return IsShardingMoreSpecific(*to_merge, old);
|
||||||
}
|
}
|
||||||
// Combine the tile dimension sizes from new and old.
|
// Combine the tile dimension sizes from new and old.
|
||||||
int64 num_devices = old.tile_assignment().num_elements();
|
int64 num_devices = old.tile_assignment().num_elements();
|
||||||
std::vector<int64> new_tile_dims;
|
std::vector<int64> new_tile_dims;
|
||||||
bool compatible = true;
|
bool compatible = true;
|
||||||
new_tile_dims.reserve(to_merge.tile_assignment().num_dimensions());
|
new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions());
|
||||||
for (int64 i = 0; i < to_merge.tile_assignment().num_dimensions() - 1; ++i) {
|
for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) {
|
||||||
int64 new_dim = to_merge.tile_assignment().dim(i);
|
int64 new_dim = to_merge->tile_assignment().dim(i);
|
||||||
int64 old_dim = old.tile_assignment().dim(i);
|
int64 old_dim = old.tile_assignment().dim(i);
|
||||||
if (new_dim == 1) {
|
if (new_dim == 1) {
|
||||||
new_tile_dims.push_back(old_dim);
|
new_tile_dims.push_back(old_dim);
|
||||||
@ -163,7 +163,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
|||||||
int64 replication = num_devices / Product(new_tile_dims);
|
int64 replication = num_devices / Product(new_tile_dims);
|
||||||
if (!compatible || num_devices % Product(new_tile_dims) != 0 ||
|
if (!compatible || num_devices % Product(new_tile_dims) != 0 ||
|
||||||
replication >= old.tile_assignment().dimensions().back()) {
|
replication >= old.tile_assignment().dimensions().back()) {
|
||||||
return IsShardingMoreSpecific(to_merge, old) ? to_merge : old;
|
return IsShardingMoreSpecific(*to_merge, old);
|
||||||
}
|
}
|
||||||
new_tile_dims.push_back(replication);
|
new_tile_dims.push_back(replication);
|
||||||
Array<int64> new_tile(new_tile_dims);
|
Array<int64> new_tile(new_tile_dims);
|
||||||
@ -174,7 +174,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
|||||||
const HloSharding& sharding) {
|
const HloSharding& sharding) {
|
||||||
int64 group_id = 0;
|
int64 group_id = 0;
|
||||||
for (int64 i = 0; i < tile_indices.size() - 1; ++i) {
|
for (int64 i = 0; i < tile_indices.size() - 1; ++i) {
|
||||||
group_id *= to_merge.tile_assignment().dim(i);
|
group_id *= to_merge->tile_assignment().dim(i);
|
||||||
group_id += tile_indices[i];
|
group_id += tile_indices[i];
|
||||||
}
|
}
|
||||||
return group_id;
|
return group_id;
|
||||||
@ -183,9 +183,9 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
|||||||
[&](absl::Span<const int64> indices, int64 device) {
|
[&](absl::Span<const int64> indices, int64 device) {
|
||||||
old_group_members[get_group_index(indices, old)].insert(device);
|
old_group_members[get_group_index(indices, old)].insert(device);
|
||||||
});
|
});
|
||||||
to_merge.tile_assignment().Each(
|
to_merge->tile_assignment().Each(
|
||||||
[&](absl::Span<const int64> indices, int64 device) {
|
[&](absl::Span<const int64> indices, int64 device) {
|
||||||
new_group_members[get_group_index(indices, to_merge)].insert(device);
|
new_group_members[get_group_index(indices, *to_merge)].insert(device);
|
||||||
});
|
});
|
||||||
// Try to find the intersection of old and new replication groups, in
|
// Try to find the intersection of old and new replication groups, in
|
||||||
// order to determine the merged tile assignment.
|
// order to determine the merged tile assignment.
|
||||||
@ -199,12 +199,12 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
|||||||
if (old.tile_assignment().dim(i) == 1) {
|
if (old.tile_assignment().dim(i) == 1) {
|
||||||
old_index[i] = 0;
|
old_index[i] = 0;
|
||||||
}
|
}
|
||||||
if (to_merge.tile_assignment().dim(i) == 1) {
|
if (to_merge->tile_assignment().dim(i) == 1) {
|
||||||
new_index[i] = 0;
|
new_index[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int64 old_group_id = get_group_index(old_index, old);
|
int64 old_group_id = get_group_index(old_index, old);
|
||||||
int64 new_group_id = get_group_index(new_index, to_merge);
|
int64 new_group_id = get_group_index(new_index, *to_merge);
|
||||||
if (old_group_members[old_group_id].empty() ||
|
if (old_group_members[old_group_id].empty() ||
|
||||||
new_group_members[new_group_id].empty() ||
|
new_group_members[new_group_id].empty() ||
|
||||||
*old_group_members[old_group_id].begin() !=
|
*old_group_members[old_group_id].begin() !=
|
||||||
@ -220,11 +220,13 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
|||||||
if (replication == 1) {
|
if (replication == 1) {
|
||||||
new_tile_dims.pop_back();
|
new_tile_dims.pop_back();
|
||||||
new_tile.Reshape(new_tile_dims);
|
new_tile.Reshape(new_tile_dims);
|
||||||
return HloSharding::Tile(new_tile);
|
*to_merge = HloSharding::Tile(new_tile);
|
||||||
|
} else {
|
||||||
|
*to_merge = HloSharding::PartialTile(new_tile);
|
||||||
}
|
}
|
||||||
return HloSharding::PartialTile(new_tile);
|
return true;
|
||||||
}
|
}
|
||||||
return IsShardingMoreSpecific(to_merge, old) ? to_merge : old;
|
return IsShardingMoreSpecific(*to_merge, old);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Updates the sharding of the specified instruction with the specified sharding
|
// Updates the sharding of the specified instruction with the specified sharding
|
||||||
@ -232,7 +234,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
|||||||
// been applied. If may_combine_partial_sharding is true, this may combine the
|
// been applied. If may_combine_partial_sharding is true, this may combine the
|
||||||
// new and existing sharding if they are both partial tiling partial
|
// new and existing sharding if they are both partial tiling partial
|
||||||
// replication.
|
// replication.
|
||||||
bool MaybeImproveInstructionSharding(const HloSharding& sharding,
|
bool MaybeImproveInstructionSharding(HloSharding sharding,
|
||||||
HloInstruction* instruction,
|
HloInstruction* instruction,
|
||||||
bool may_combine_partial_sharding) {
|
bool may_combine_partial_sharding) {
|
||||||
// We don't want to propagate tile maximal shardings.
|
// We don't want to propagate tile maximal shardings.
|
||||||
@ -241,13 +243,13 @@ bool MaybeImproveInstructionSharding(const HloSharding& sharding,
|
|||||||
}
|
}
|
||||||
// Any sharding is better then no sharding.
|
// Any sharding is better then no sharding.
|
||||||
if (!instruction->has_sharding()) {
|
if (!instruction->has_sharding()) {
|
||||||
instruction->set_sharding(sharding);
|
instruction->set_sharding(std::move(sharding));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
auto merged = MergeSharding(instruction->sharding(), sharding,
|
auto merged = MergeSharding(instruction->sharding(), &sharding,
|
||||||
may_combine_partial_sharding);
|
may_combine_partial_sharding);
|
||||||
if (merged != instruction->sharding()) {
|
if (merged) {
|
||||||
instruction->set_sharding(merged);
|
instruction->set_sharding(std::move(sharding));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -620,7 +622,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
HloSharding new_sharding = operand->sharding().GetSubSharding(
|
HloSharding new_sharding = operand->sharding().GetSubSharding(
|
||||||
operand->shape(), {instruction->tuple_index()});
|
operand->shape(), {instruction->tuple_index()});
|
||||||
return MaybeImproveInstructionSharding(
|
return MaybeImproveInstructionSharding(
|
||||||
new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
|
std::move(new_sharding), instruction,
|
||||||
|
/*may_combine_partial_sharding=*/is_spmd);
|
||||||
}
|
}
|
||||||
case HloOpcode::kTuple: {
|
case HloOpcode::kTuple: {
|
||||||
if (absl::c_none_of(instruction->operands(),
|
if (absl::c_none_of(instruction->operands(),
|
||||||
@ -685,12 +688,12 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
if (!IsSpatiallyPartitioned(operand)) {
|
if (!IsSpatiallyPartitioned(operand)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto get_maybe_tuple_sharding = [&](const HloSharding& sharding) {
|
auto get_maybe_tuple_sharding = [&](HloSharding sharding) {
|
||||||
if (instruction->operand_count() == 2) {
|
if (instruction->operand_count() == 2) {
|
||||||
return sharding;
|
return sharding;
|
||||||
}
|
}
|
||||||
std::vector<HloSharding> tuple(instruction->operand_count() / 2,
|
std::vector<HloSharding> tuple(instruction->operand_count() / 2,
|
||||||
sharding);
|
std::move(sharding));
|
||||||
return HloSharding::Tuple(instruction->shape(), tuple);
|
return HloSharding::Tuple(instruction->shape(), tuple);
|
||||||
};
|
};
|
||||||
if (operand->sharding().IsReplicated() ||
|
if (operand->sharding().IsReplicated() ||
|
||||||
@ -722,7 +725,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions(
|
get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions(
|
||||||
after_partial_replication, instruction->dimensions()));
|
after_partial_replication, instruction->dimensions()));
|
||||||
changed |= MaybeImproveInstructionSharding(
|
changed |= MaybeImproveInstructionSharding(
|
||||||
new_sharding, instruction,
|
std::move(new_sharding), instruction,
|
||||||
/*may_combine_partial_sharding=*/is_spmd);
|
/*may_combine_partial_sharding=*/is_spmd);
|
||||||
}
|
}
|
||||||
return changed;
|
return changed;
|
||||||
@ -764,7 +767,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
? HloSharding::PartialTile(new_tile_assignment)
|
? HloSharding::PartialTile(new_tile_assignment)
|
||||||
: HloSharding::Tile(new_tile_assignment);
|
: HloSharding::Tile(new_tile_assignment);
|
||||||
return MaybeImproveInstructionSharding(
|
return MaybeImproveInstructionSharding(
|
||||||
new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
|
std::move(new_sharding), instruction,
|
||||||
|
/*may_combine_partial_sharding=*/is_spmd);
|
||||||
}
|
}
|
||||||
case HloOpcode::kConvolution:
|
case HloOpcode::kConvolution:
|
||||||
return InferConvolutionShardingFromOperands(
|
return InferConvolutionShardingFromOperands(
|
||||||
@ -778,7 +782,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
HloSharding sharding = hlo_sharding_util::TransposeSharding(
|
HloSharding sharding = hlo_sharding_util::TransposeSharding(
|
||||||
input->sharding(), instruction->dimensions());
|
input->sharding(), instruction->dimensions());
|
||||||
return MaybeImproveInstructionSharding(
|
return MaybeImproveInstructionSharding(
|
||||||
sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
|
std::move(sharding), instruction,
|
||||||
|
/*may_combine_partial_sharding=*/is_spmd);
|
||||||
}
|
}
|
||||||
case HloOpcode::kReduceWindow: {
|
case HloOpcode::kReduceWindow: {
|
||||||
const HloInstruction* lhs = instruction->operand(0);
|
const HloInstruction* lhs = instruction->operand(0);
|
||||||
@ -831,7 +836,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
instruction->operand(0)->sharding());
|
instruction->operand(0)->sharding());
|
||||||
if (new_sharding.has_value()) {
|
if (new_sharding.has_value()) {
|
||||||
return MaybeImproveInstructionSharding(
|
return MaybeImproveInstructionSharding(
|
||||||
new_sharding.value(), instruction,
|
std::move(*new_sharding), instruction,
|
||||||
/*may_combine_partial_sharding=*/is_spmd);
|
/*may_combine_partial_sharding=*/is_spmd);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -947,7 +952,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
|
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
|
||||||
instruction->operand(1)->sharding(), instruction);
|
instruction->operand(1)->sharding(), instruction);
|
||||||
changed |= MaybeImproveInstructionSharding(
|
changed |= MaybeImproveInstructionSharding(
|
||||||
new_sharding, instruction,
|
std::move(new_sharding), instruction,
|
||||||
/*may_combine_partial_sharding=*/is_spmd);
|
/*may_combine_partial_sharding=*/is_spmd);
|
||||||
}
|
}
|
||||||
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) {
|
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) {
|
||||||
@ -956,7 +961,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
instruction->operand(0)->sharding(), *instruction);
|
instruction->operand(0)->sharding(), *instruction);
|
||||||
if (maybe_from_data) {
|
if (maybe_from_data) {
|
||||||
changed |= MaybeImproveInstructionSharding(
|
changed |= MaybeImproveInstructionSharding(
|
||||||
*maybe_from_data, instruction,
|
std::move(*maybe_from_data), instruction,
|
||||||
/*may_combine_partial_sharding=*/is_spmd);
|
/*may_combine_partial_sharding=*/is_spmd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -979,7 +984,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
instruction->operand(2)->sharding(), *instruction);
|
instruction->operand(2)->sharding(), *instruction);
|
||||||
if (maybe_from_update) {
|
if (maybe_from_update) {
|
||||||
changed |= MaybeImproveInstructionSharding(
|
changed |= MaybeImproveInstructionSharding(
|
||||||
*maybe_from_update, instruction,
|
std::move(*maybe_from_update), instruction,
|
||||||
/*may_combine_partial_sharding=*/is_spmd);
|
/*may_combine_partial_sharding=*/is_spmd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -998,7 +1003,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
MergeForMoreSpecificSharding(sharding, instruction->sharding());
|
MergeForMoreSpecificSharding(sharding, instruction->sharding());
|
||||||
}
|
}
|
||||||
return MaybeImproveInstructionSharding(
|
return MaybeImproveInstructionSharding(
|
||||||
sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
|
std::move(sharding), instruction,
|
||||||
|
/*may_combine_partial_sharding=*/is_spmd);
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
if (instruction->IsElementwise() && is_spmd) {
|
if (instruction->IsElementwise() && is_spmd) {
|
||||||
@ -1089,12 +1095,14 @@ HloSharding InferDotOperandSharding(
|
|||||||
operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] =
|
operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] =
|
||||||
operand_index == 0 ? dim.rhs : dim.lhs;
|
operand_index == 0 ? dim.rhs : dim.lhs;
|
||||||
}
|
}
|
||||||
sharding =
|
HloSharding sharding_from_other =
|
||||||
MergeSharding(sharding,
|
*hlo_sharding_util::TransposeShardingWithCollapsedDims(
|
||||||
*hlo_sharding_util::TransposeShardingWithCollapsedDims(
|
other_operand_dims_replicated, other_to_operand_dims,
|
||||||
other_operand_dims_replicated, other_to_operand_dims,
|
operand_to_other_dims);
|
||||||
operand_to_other_dims),
|
if (MergeSharding(sharding, &sharding_from_other,
|
||||||
may_combine_partial_sharding);
|
may_combine_partial_sharding)) {
|
||||||
|
sharding = std::move(sharding_from_other);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return sharding;
|
return sharding;
|
||||||
}
|
}
|
||||||
@ -1376,7 +1384,7 @@ bool InferShardingFromUsers(HloInstruction* instruction,
|
|||||||
GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd);
|
GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd);
|
||||||
if (user_sharding) {
|
if (user_sharding) {
|
||||||
improved_sharding |= MaybeImproveInstructionSharding(
|
improved_sharding |= MaybeImproveInstructionSharding(
|
||||||
*user_sharding, instruction,
|
std::move(*user_sharding), instruction,
|
||||||
/*may_combine_partial_sharding=*/is_spmd);
|
/*may_combine_partial_sharding=*/is_spmd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1648,9 +1656,17 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
|||||||
// indefinitely.
|
// indefinitely.
|
||||||
int64 iterations = 0;
|
int64 iterations = 0;
|
||||||
auto run_to_fix_point = [&](bool aggressive_prop) {
|
auto run_to_fix_point = [&](bool aggressive_prop) {
|
||||||
bool changed = true;
|
absl::flat_hash_set<const HloInstruction*> workset;
|
||||||
while (changed) {
|
for (const HloComputation* computation : module->computations()) {
|
||||||
changed = false;
|
for (const HloInstruction* instruction : computation->instructions()) {
|
||||||
|
// Remove the instructions where the sharding was provided from the
|
||||||
|
// outside so we don't modify them.
|
||||||
|
if (!provided_shardings.contains(instruction)) {
|
||||||
|
workset.insert(instruction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (!workset.empty()) {
|
||||||
int64 inferred_from_operand_counter = 0;
|
int64 inferred_from_operand_counter = 0;
|
||||||
int64 inferred_from_user_counter = 0;
|
int64 inferred_from_user_counter = 0;
|
||||||
int64 instruction_counter = 0;
|
int64 instruction_counter = 0;
|
||||||
@ -1664,12 +1680,10 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
|||||||
already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
|
already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the instructions where the sharding was provided from the
|
|
||||||
// outside so we don't modify them.
|
|
||||||
instructions.erase(
|
instructions.erase(
|
||||||
std::remove_if(instructions.begin(), instructions.end(),
|
std::remove_if(instructions.begin(), instructions.end(),
|
||||||
[&](HloInstruction* instruction) {
|
[&](HloInstruction* instruction) {
|
||||||
return provided_shardings.contains(instruction);
|
return !workset.contains(instruction);
|
||||||
}),
|
}),
|
||||||
instructions.end());
|
instructions.end());
|
||||||
|
|
||||||
@ -1679,10 +1693,17 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
|||||||
if (InferShardingFromOperands(instruction, computation_map, is_spmd_,
|
if (InferShardingFromOperands(instruction, computation_map, is_spmd_,
|
||||||
aggressive_prop)) {
|
aggressive_prop)) {
|
||||||
++inferred_from_operand_counter;
|
++inferred_from_operand_counter;
|
||||||
changed = true;
|
any_changed = true;
|
||||||
VLOG(2) << "Add sharding (forward-pass): "
|
VLOG(2) << "Add sharding (forward-pass): "
|
||||||
<< instruction->ToString();
|
<< instruction->ToString();
|
||||||
maybe_computation_propagation(instruction);
|
maybe_computation_propagation(instruction);
|
||||||
|
for (auto user : instruction->users()) {
|
||||||
|
if (!provided_shardings.contains(user)) {
|
||||||
|
workset.insert(user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
workset.erase(instruction);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1692,13 +1713,18 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
|||||||
if (InferShardingFromUsers(*it, computation_map, aggressive_prop,
|
if (InferShardingFromUsers(*it, computation_map, aggressive_prop,
|
||||||
is_spmd_)) {
|
is_spmd_)) {
|
||||||
++inferred_from_user_counter;
|
++inferred_from_user_counter;
|
||||||
changed = true;
|
any_changed = true;
|
||||||
VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
|
VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
|
||||||
maybe_computation_propagation(*it);
|
maybe_computation_propagation(*it);
|
||||||
|
workset.insert(*it);
|
||||||
|
for (auto operand : (*it)->operands()) {
|
||||||
|
if (!provided_shardings.contains(operand)) {
|
||||||
|
workset.insert(operand);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
any_changed |= changed;
|
|
||||||
VLOG(1) << "Sharding propagation iteration " << iterations << ";";
|
VLOG(1) << "Sharding propagation iteration " << iterations << ";";
|
||||||
VLOG(1) << " total instructions: " << instruction_counter;
|
VLOG(1) << " total instructions: " << instruction_counter;
|
||||||
VLOG(1) << " instructions already sharded: " << already_sharded_counter;
|
VLOG(1) << " instructions already sharded: " << already_sharded_counter;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user