[XLA:SPMD] Improve resharding
1. Fix collective-permute bug. Should not skip i -> i pairs since that would result in 0 in the output. 2. All-to-all sharding for divisible partition counts, i.e., from [4,2] to [2,4] can be done as a subgroup all-to-all, since 4 % 2 == 0 3. Multi-step all-to-all resharding. E.g., resharding from [16,8,1] to [1,16,8] can be done via an intermediate sharding to [16,1,8]. 4. Allow more ReshapeSharding cases. PiperOrigin-RevId: 322729324 Change-Id: Ica2cf164e3c2bd15953ce37d6223723501be5b87
This commit is contained in:
parent
488448c742
commit
f59ff5d44e
@ -190,13 +190,22 @@ absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
|
||||
target_dims_stack.push_back(t_size);
|
||||
} else if (s_size > t_size) {
|
||||
// Dimension split.
|
||||
if (s_size % t_size != 0 || t_size % s_partitions != 0) {
|
||||
if (s_size % t_size != 0 || s_size % s_partitions != 0) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (t_size % s_partitions == 0) {
|
||||
target_tile_assignment_dimensions.push_back(s_partitions);
|
||||
// We have part of the s_size unprocessed, so put it back to stack.
|
||||
source_dims_stack.push_back(s_size / t_size);
|
||||
sharding_tile_dims_stack.push_back(1);
|
||||
} else if (s_partitions % t_size == 0) {
|
||||
target_tile_assignment_dimensions.push_back(t_size);
|
||||
// We have part of the s_size unprocessed, so put it back to stack.
|
||||
source_dims_stack.push_back(s_size / t_size);
|
||||
sharding_tile_dims_stack.push_back(s_partitions / t_size);
|
||||
} else {
|
||||
return absl::nullopt;
|
||||
}
|
||||
target_tile_assignment_dimensions.push_back(s_partitions);
|
||||
// We have part of the s_size unprocessed, so put it back to stack.
|
||||
source_dims_stack.push_back(s_size / t_size);
|
||||
sharding_tile_dims_stack.push_back(1);
|
||||
} else {
|
||||
// Dimension merge. Also merge the source dimension with the next, and
|
||||
// process it next time.
|
||||
|
@ -76,6 +76,20 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) {
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledSplit2) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7});
|
||||
Array2D<int64> tile(16, 1);
|
||||
tile.FillIota(0);
|
||||
HloSharding input_sharding = HloSharding::Tile(tile);
|
||||
tile.Reshape({4, 4, 1});
|
||||
HloSharding output_sharding = HloSharding::Tile(tile);
|
||||
absl::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7});
|
||||
|
@ -267,8 +267,7 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
|
||||
|
||||
if (auto src_tgt_dims =
|
||||
GetReshardAllToAllSourceTargetDims(sharding(), target)) {
|
||||
return ReshardWithAllToAll(target, src_tgt_dims->first,
|
||||
src_tgt_dims->second);
|
||||
return ReshardWithAllToAll(target, *src_tgt_dims);
|
||||
}
|
||||
|
||||
// If not replicated yet, first replicate and then reshard to use one of the
|
||||
@ -734,40 +733,82 @@ PartitionedHlo PartitionedHlo::Broadcast() const {
|
||||
return PartitionedHlo(result, base_shape_, state_);
|
||||
}
|
||||
|
||||
PartitionedHlo PartitionedHlo::ReshardWithAllToAll(const HloSharding& target,
|
||||
int64 source_dim,
|
||||
int64 target_dim) const {
|
||||
const int64 group_size = sharding().tile_assignment().dim(source_dim);
|
||||
|
||||
// If the device order is different in the target, fix the order with
|
||||
// ReshardWithCollectivePermute.
|
||||
std::vector<int64> xpose_dims(target.tile_assignment().num_dimensions());
|
||||
std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
|
||||
xpose_dims[source_dim] = target_dim;
|
||||
xpose_dims[target_dim] = source_dim;
|
||||
auto input_sharding_fixed_device_order =
|
||||
hlo_sharding_util::TransposeSharding(target, xpose_dims);
|
||||
if (input_sharding_fixed_device_order != sharding()) {
|
||||
auto fixed_order =
|
||||
ReshardWithCollectivePermute(input_sharding_fixed_device_order);
|
||||
return fixed_order.ReshardWithAllToAll(target, source_dim, target_dim);
|
||||
PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
|
||||
const HloSharding& target,
|
||||
absl::Span<const std::pair<int64, int64>> source_target_dims) const {
|
||||
if (source_target_dims.empty()) {
|
||||
if (target == sharding()) {
|
||||
return *this;
|
||||
}
|
||||
// If the device order is different in the target, fix the order with
|
||||
// ReshardWithCollectivePermute.
|
||||
return ReshardWithCollectivePermute(target);
|
||||
}
|
||||
|
||||
auto padded_hlo =
|
||||
PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
|
||||
// Swap one pair of dimensions.
|
||||
int64 source_dim = source_target_dims[0].first;
|
||||
int64 target_dim = source_target_dims[0].second;
|
||||
const int64 group_size = sharding().tile_assignment().dim(source_dim) /
|
||||
sharding().tile_assignment().dim(target_dim);
|
||||
|
||||
// The order of ids in the group must follow the target sharding.
|
||||
std::vector<ReplicaGroup> groups(target.tile_assignment().num_elements() /
|
||||
group_size);
|
||||
target.tile_assignment().Each(
|
||||
auto temp_target_tile = sharding().tile_assignment();
|
||||
{
|
||||
std::vector<int64> reshape_tile_dims(temp_target_tile.num_dimensions() + 2);
|
||||
int64 i = 0;
|
||||
int64 added_source_dim = -1;
|
||||
int64 added_target_dim = -1;
|
||||
for (int64 j = 0; j < temp_target_tile.num_dimensions(); ++j) {
|
||||
if (source_dim == j) {
|
||||
reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size;
|
||||
reshape_tile_dims[++i] = group_size;
|
||||
added_source_dim = i;
|
||||
} else if (target_dim == j) {
|
||||
reshape_tile_dims[i] = temp_target_tile.dim(j);
|
||||
reshape_tile_dims[++i] = 1;
|
||||
added_target_dim = i;
|
||||
} else {
|
||||
reshape_tile_dims[i] = temp_target_tile.dim(j);
|
||||
}
|
||||
++i;
|
||||
}
|
||||
temp_target_tile.Reshape(reshape_tile_dims);
|
||||
std::vector<int64> xpose_dims(temp_target_tile.num_dimensions());
|
||||
std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
|
||||
xpose_dims[added_source_dim] = added_target_dim;
|
||||
xpose_dims[added_target_dim] = added_source_dim;
|
||||
temp_target_tile = hlo_sharding_util::TransposeSharding(
|
||||
HloSharding::Tile(temp_target_tile), xpose_dims)
|
||||
.tile_assignment();
|
||||
auto temp_target_tile_dims = sharding().tile_assignment().dimensions();
|
||||
temp_target_tile_dims[source_dim] =
|
||||
sharding().tile_assignment().dim(target_dim);
|
||||
temp_target_tile_dims[target_dim] =
|
||||
sharding().tile_assignment().dim(source_dim);
|
||||
temp_target_tile.Reshape(temp_target_tile_dims);
|
||||
}
|
||||
auto temp_target = HloSharding::Tile(temp_target_tile);
|
||||
|
||||
auto padded_shape = hlo_->shape();
|
||||
padded_shape.set_dimensions(
|
||||
target_dim,
|
||||
RoundUpToNearest(padded_shape.dimensions(target_dim),
|
||||
temp_target.tile_assignment().dim(target_dim)));
|
||||
auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
|
||||
|
||||
// The order of ids in the group must follow the temp_target sharding.
|
||||
std::vector<ReplicaGroup> groups(
|
||||
temp_target.tile_assignment().num_elements() / group_size);
|
||||
temp_target.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 device) {
|
||||
int64 group_id = 0;
|
||||
for (int64 dim = 0; dim < indices.size(); ++dim) {
|
||||
if (dim == target_dim) {
|
||||
continue;
|
||||
group_id *= temp_target.tile_assignment().dim(dim) / group_size;
|
||||
group_id += indices[dim] / group_size;
|
||||
} else {
|
||||
group_id *= temp_target.tile_assignment().dim(dim);
|
||||
group_id += indices[dim];
|
||||
}
|
||||
group_id *= target.tile_assignment().dim(dim);
|
||||
group_id += indices[dim];
|
||||
}
|
||||
groups[group_id].add_replica_ids(device);
|
||||
});
|
||||
@ -819,14 +860,17 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(const HloSharding& target,
|
||||
result = state_.b->AddInstruction(
|
||||
HloInstruction::CreateReshape(new_shape, transpose));
|
||||
|
||||
const Shape result_shape = MakePartitionedShape(base_shape_, target);
|
||||
const Shape result_shape = MakePartitionedShape(base_shape_, temp_target);
|
||||
if (result_shape != result->shape()) {
|
||||
result = state_.b->AddInstruction(HloInstruction::CreateSlice(
|
||||
result_shape, result, std::vector<int64>(result_shape.rank(), 0),
|
||||
result_shape.dimensions(), std::vector<int64>(result_shape.rank(), 1)));
|
||||
}
|
||||
result->set_sharding(target);
|
||||
return PartitionedHlo(result, base_shape_, state_);
|
||||
result->set_sharding(temp_target);
|
||||
auto remaining_source_target_dims = source_target_dims;
|
||||
remaining_source_target_dims.remove_prefix(1);
|
||||
return PartitionedHlo(result, base_shape_, state_)
|
||||
.ReshardWithAllToAll(target, remaining_source_target_dims);
|
||||
}
|
||||
|
||||
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
|
||||
@ -837,9 +881,7 @@ PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
|
||||
sharding().tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 src_device) {
|
||||
int64 dst_device = target.tile_assignment()(indices);
|
||||
if (dst_device != src_device) {
|
||||
src_dst_pairs.emplace_back(src_device, dst_device);
|
||||
}
|
||||
src_dst_pairs.emplace_back(src_device, dst_device);
|
||||
});
|
||||
auto cp =
|
||||
state_.collective_ops_creator.create_cross_partition_collective_permute(
|
||||
|
@ -284,8 +284,9 @@ class PartitionedHlo {
|
||||
|
||||
// Helper function to reshard the tensor using AllToAll (instead of the
|
||||
// default of Replicate followed by Slice).
|
||||
PartitionedHlo ReshardWithAllToAll(const HloSharding& target,
|
||||
int64 source_dim, int64 target_dim) const;
|
||||
PartitionedHlo ReshardWithAllToAll(
|
||||
const HloSharding& target,
|
||||
absl::Span<const std::pair<int64, int64>> source_target_dims) const;
|
||||
|
||||
// Helper function to reshard the tensor using CollectivePermute.
|
||||
PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
|
||||
|
@ -3792,6 +3792,56 @@ ENTRY entry {
|
||||
4);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard2) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[8,8] parameter(0),
|
||||
sharding={devices=[2,4]0,1,2,3,4,5,6,7}
|
||||
ROOT %copy = f32[8,8] copy(%param0),
|
||||
sharding={devices=[4,2]0,1,4,5,2,3,6,7}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/8));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
auto all_to_all = op::AllToAll(
|
||||
AllOf(op::Shape("f32[2,2,2]"), op::Reshape(op::Parameter(0))));
|
||||
auto reshape =
|
||||
AllOf(op::Shape("f32[2,4]"), op::Reshape(op::Transpose(all_to_all)));
|
||||
EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard3) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[8,8,8] parameter(0),
|
||||
sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
|
||||
ROOT %copy = f32[8,8,8] copy(%param0),
|
||||
sharding={devices=[1,2,4]0,1,4,5,2,3,6,7}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/8));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
auto all_to_all = op::AllToAll(
|
||||
AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(op::Parameter(0))));
|
||||
auto reshape =
|
||||
AllOf(op::Shape("f32[4,8,2]"), op::Reshape(op::Transpose(all_to_all)));
|
||||
auto all_to_all2 =
|
||||
op::AllToAll(AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(reshape)));
|
||||
auto reshape2 =
|
||||
AllOf(op::Shape("f32[8,4,2]"), op::Reshape(op::Transpose(all_to_all2)));
|
||||
EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
@ -885,37 +885,89 @@ int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) {
|
||||
return sharding.tile_assignment().dim(dim);
|
||||
}
|
||||
|
||||
absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
|
||||
const HloSharding& source, const HloSharding& target) {
|
||||
absl::optional<std::vector<std::pair<int64, int64>>>
|
||||
GetReshardAllToAllSourceTargetDims(const HloSharding& source,
|
||||
const HloSharding& target) {
|
||||
if (source.IsTileMaximal() || target.IsTileMaximal() ||
|
||||
source.tile_assignment().num_dimensions() !=
|
||||
target.tile_assignment().num_dimensions()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
int64 source_dim = -1;
|
||||
int64 target_dim = -1;
|
||||
// Record partition count to index for indices that have different partition
|
||||
// counts on source and target.
|
||||
std::map<int64, std::vector<int64>> source_size_to_dim;
|
||||
std::map<int64, std::vector<int64>> target_size_to_dim;
|
||||
for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
|
||||
if (source.tile_assignment().dim(i) > 1 &&
|
||||
target.tile_assignment().dim(i) == 1) {
|
||||
if (source_dim != -1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
source_dim = i;
|
||||
} else if (source.tile_assignment().dim(i) == 1 &&
|
||||
target.tile_assignment().dim(i) > 1) {
|
||||
if (target_dim != -1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
target_dim = i;
|
||||
} else if (source.tile_assignment().dim(i) !=
|
||||
target.tile_assignment().dim(i)) {
|
||||
if (source.tile_assignment().dim(i) == target.tile_assignment().dim(i)) {
|
||||
continue;
|
||||
}
|
||||
source_size_to_dim[source.tile_assignment().dim(i)].push_back(i);
|
||||
target_size_to_dim[target.tile_assignment().dim(i)].push_back(i);
|
||||
}
|
||||
// In order to shard via AllToAll, source_size_to_dim and target_size_to_dim
|
||||
// must have the same distribution.
|
||||
if (source_size_to_dim.empty() ||
|
||||
source_size_to_dim.size() != target_size_to_dim.size()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
for (const auto& entry : source_size_to_dim) {
|
||||
auto target_it = target_size_to_dim.find(entry.first);
|
||||
if (target_it == target_size_to_dim.end() ||
|
||||
target_it->second.size() != entry.second.size()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
if (source_dim == -1 || target_dim == -1 || source_dim == target_dim) {
|
||||
return absl::nullopt;
|
||||
std::vector<std::pair<int64, int64>> result;
|
||||
auto remove_entry = [](int64 size, int64 dim,
|
||||
std::map<int64, std::vector<int64>>& size_to_dim) {
|
||||
size_to_dim[size].erase(
|
||||
std::remove_if(size_to_dim[size].begin(), size_to_dim[size].end(),
|
||||
[dim](int64 a) { return a == dim; }),
|
||||
size_to_dim[size].end());
|
||||
if (size_to_dim[size].empty()) {
|
||||
size_to_dim.erase(size);
|
||||
}
|
||||
};
|
||||
// Find one pair of dimensions to swap at a time.
|
||||
while (!source_size_to_dim.empty()) {
|
||||
int64 source_size = source_size_to_dim.begin()->first;
|
||||
int64 i = source_size_to_dim.begin()->second.back();
|
||||
int64 target_i_size = target.tile_assignment().dim(i);
|
||||
if (target_i_size == source_size) {
|
||||
remove_entry(source_size, i, source_size_to_dim);
|
||||
remove_entry(source_size, i, target_size_to_dim);
|
||||
continue;
|
||||
}
|
||||
auto j_it = source_size_to_dim[target_i_size].begin();
|
||||
int64 j = *j_it;
|
||||
if (source_size == 1) {
|
||||
// If possible, find a j where the target partition count is not one, so
|
||||
// that when we swap, the resulting size-1 dimension will still be useful
|
||||
// to other dimensions.
|
||||
while (target.tile_assignment().dim(j) == 1) {
|
||||
if (++j_it == source_size_to_dim[target_i_size].end()) {
|
||||
break;
|
||||
}
|
||||
j = *j_it;
|
||||
}
|
||||
} else if (target_i_size % source_size == 0) {
|
||||
// If possible, find a j where the target partition count is source_size,
|
||||
// so that we can do a single swap.
|
||||
while (target.tile_assignment().dim(j) != source_size) {
|
||||
if (++j_it == source_size_to_dim[target_i_size].end()) {
|
||||
break;
|
||||
}
|
||||
j = *j_it;
|
||||
}
|
||||
} else {
|
||||
return absl::nullopt;
|
||||
}
|
||||
result.emplace_back(j, i);
|
||||
remove_entry(target_i_size, i, target_size_to_dim);
|
||||
source_size_to_dim.begin()->second.back() = j;
|
||||
remove_entry(target_i_size, j, source_size_to_dim);
|
||||
}
|
||||
return std::pair<int64, int64>(source_dim, target_dim);
|
||||
return result;
|
||||
}
|
||||
|
||||
bool CanReshardWithCollectivePermute(const HloSharding& source,
|
||||
|
@ -265,10 +265,12 @@ HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
|
||||
// Check if a dimension is sharded.
|
||||
int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
|
||||
|
||||
// Returns the pair of source and target dimensions is the resharding can be
|
||||
// done via all-to-all.
|
||||
absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
|
||||
const HloSharding& source, const HloSharding& target);
|
||||
// Returns the list of source-target pairs of dimensions to swap during
|
||||
// resharding via all-to-all. Reshard can be done by swapping each pair at a
|
||||
// time.
|
||||
absl::optional<std::vector<std::pair<int64, int64>>>
|
||||
GetReshardAllToAllSourceTargetDims(const HloSharding& source,
|
||||
const HloSharding& target);
|
||||
|
||||
// Returns whether the resharding can be done via collective-permute.
|
||||
bool CanReshardWithCollectivePermute(const HloSharding& source,
|
||||
|
Loading…
x
Reference in New Issue
Block a user