[XLA:SPMD] Improve resharding

1. Fix collective-permute bug. Should not skip i -> i pairs since that would result in 0 in the output.
2. All-to-all sharding for divisible partition counts, i.e., from [4,2] to [2,4] can be done as a subgroup all-to-all, since 4 % 2 == 0
3. Multi-step all-to-all resharding. E.g., resharding from [16,8,1] to [1,16,8] can be done via an intermediate sharding to [16,1,8].
4. Allow more ReshapeSharding cases.

PiperOrigin-RevId: 322729324
Change-Id: Ica2cf164e3c2bd15953ce37d6223723501be5b87
This commit is contained in:
Yuanzhong Xu 2020-07-22 23:12:08 -07:00 committed by TensorFlower Gardener
parent 488448c742
commit f59ff5d44e
7 changed files with 236 additions and 66 deletions

View File

@ -190,13 +190,22 @@ absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
target_dims_stack.push_back(t_size);
} else if (s_size > t_size) {
// Dimension split.
if (s_size % t_size != 0 || t_size % s_partitions != 0) {
if (s_size % t_size != 0 || s_size % s_partitions != 0) {
return absl::nullopt;
}
if (t_size % s_partitions == 0) {
target_tile_assignment_dimensions.push_back(s_partitions);
// We have part of the s_size unprocessed, so put it back to stack.
source_dims_stack.push_back(s_size / t_size);
sharding_tile_dims_stack.push_back(1);
} else if (s_partitions % t_size == 0) {
target_tile_assignment_dimensions.push_back(t_size);
// We have part of the s_size unprocessed, so put it back to stack.
source_dims_stack.push_back(s_size / t_size);
sharding_tile_dims_stack.push_back(s_partitions / t_size);
} else {
return absl::nullopt;
}
target_tile_assignment_dimensions.push_back(s_partitions);
// We have part of the s_size unprocessed, so put it back to stack.
source_dims_stack.push_back(s_size / t_size);
sharding_tile_dims_stack.push_back(1);
} else {
// Dimension merge. Also merge the source dimension with the next, and
// process it next time.

View File

@ -76,6 +76,20 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) {
EXPECT_EQ(result.value(), output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledSplit2) {
Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7});
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7});
Array2D<int64> tile(16, 1);
tile.FillIota(0);
HloSharding input_sharding = HloSharding::Tile(tile);
tile.Reshape({4, 4, 1});
HloSharding output_sharding = HloSharding::Tile(tile);
absl::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) {
Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7});
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7});

View File

@ -267,8 +267,7 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
if (auto src_tgt_dims =
GetReshardAllToAllSourceTargetDims(sharding(), target)) {
return ReshardWithAllToAll(target, src_tgt_dims->first,
src_tgt_dims->second);
return ReshardWithAllToAll(target, *src_tgt_dims);
}
// If not replicated yet, first replicate and then reshard to use one of the
@ -734,40 +733,82 @@ PartitionedHlo PartitionedHlo::Broadcast() const {
return PartitionedHlo(result, base_shape_, state_);
}
PartitionedHlo PartitionedHlo::ReshardWithAllToAll(const HloSharding& target,
int64 source_dim,
int64 target_dim) const {
const int64 group_size = sharding().tile_assignment().dim(source_dim);
// If the device order is different in the target, fix the order with
// ReshardWithCollectivePermute.
std::vector<int64> xpose_dims(target.tile_assignment().num_dimensions());
std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
xpose_dims[source_dim] = target_dim;
xpose_dims[target_dim] = source_dim;
auto input_sharding_fixed_device_order =
hlo_sharding_util::TransposeSharding(target, xpose_dims);
if (input_sharding_fixed_device_order != sharding()) {
auto fixed_order =
ReshardWithCollectivePermute(input_sharding_fixed_device_order);
return fixed_order.ReshardWithAllToAll(target, source_dim, target_dim);
PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
const HloSharding& target,
absl::Span<const std::pair<int64, int64>> source_target_dims) const {
if (source_target_dims.empty()) {
if (target == sharding()) {
return *this;
}
// If the device order is different in the target, fix the order with
// ReshardWithCollectivePermute.
return ReshardWithCollectivePermute(target);
}
auto padded_hlo =
PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
// Swap one pair of dimensions.
int64 source_dim = source_target_dims[0].first;
int64 target_dim = source_target_dims[0].second;
const int64 group_size = sharding().tile_assignment().dim(source_dim) /
sharding().tile_assignment().dim(target_dim);
// The order of ids in the group must follow the target sharding.
std::vector<ReplicaGroup> groups(target.tile_assignment().num_elements() /
group_size);
target.tile_assignment().Each(
auto temp_target_tile = sharding().tile_assignment();
{
std::vector<int64> reshape_tile_dims(temp_target_tile.num_dimensions() + 2);
int64 i = 0;
int64 added_source_dim = -1;
int64 added_target_dim = -1;
for (int64 j = 0; j < temp_target_tile.num_dimensions(); ++j) {
if (source_dim == j) {
reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size;
reshape_tile_dims[++i] = group_size;
added_source_dim = i;
} else if (target_dim == j) {
reshape_tile_dims[i] = temp_target_tile.dim(j);
reshape_tile_dims[++i] = 1;
added_target_dim = i;
} else {
reshape_tile_dims[i] = temp_target_tile.dim(j);
}
++i;
}
temp_target_tile.Reshape(reshape_tile_dims);
std::vector<int64> xpose_dims(temp_target_tile.num_dimensions());
std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
xpose_dims[added_source_dim] = added_target_dim;
xpose_dims[added_target_dim] = added_source_dim;
temp_target_tile = hlo_sharding_util::TransposeSharding(
HloSharding::Tile(temp_target_tile), xpose_dims)
.tile_assignment();
auto temp_target_tile_dims = sharding().tile_assignment().dimensions();
temp_target_tile_dims[source_dim] =
sharding().tile_assignment().dim(target_dim);
temp_target_tile_dims[target_dim] =
sharding().tile_assignment().dim(source_dim);
temp_target_tile.Reshape(temp_target_tile_dims);
}
auto temp_target = HloSharding::Tile(temp_target_tile);
auto padded_shape = hlo_->shape();
padded_shape.set_dimensions(
target_dim,
RoundUpToNearest(padded_shape.dimensions(target_dim),
temp_target.tile_assignment().dim(target_dim)));
auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
// The order of ids in the group must follow the temp_target sharding.
std::vector<ReplicaGroup> groups(
temp_target.tile_assignment().num_elements() / group_size);
temp_target.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
int64 group_id = 0;
for (int64 dim = 0; dim < indices.size(); ++dim) {
if (dim == target_dim) {
continue;
group_id *= temp_target.tile_assignment().dim(dim) / group_size;
group_id += indices[dim] / group_size;
} else {
group_id *= temp_target.tile_assignment().dim(dim);
group_id += indices[dim];
}
group_id *= target.tile_assignment().dim(dim);
group_id += indices[dim];
}
groups[group_id].add_replica_ids(device);
});
@ -819,14 +860,17 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(const HloSharding& target,
result = state_.b->AddInstruction(
HloInstruction::CreateReshape(new_shape, transpose));
const Shape result_shape = MakePartitionedShape(base_shape_, target);
const Shape result_shape = MakePartitionedShape(base_shape_, temp_target);
if (result_shape != result->shape()) {
result = state_.b->AddInstruction(HloInstruction::CreateSlice(
result_shape, result, std::vector<int64>(result_shape.rank(), 0),
result_shape.dimensions(), std::vector<int64>(result_shape.rank(), 1)));
}
result->set_sharding(target);
return PartitionedHlo(result, base_shape_, state_);
result->set_sharding(temp_target);
auto remaining_source_target_dims = source_target_dims;
remaining_source_target_dims.remove_prefix(1);
return PartitionedHlo(result, base_shape_, state_)
.ReshardWithAllToAll(target, remaining_source_target_dims);
}
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
@ -837,9 +881,7 @@ PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
sharding().tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 src_device) {
int64 dst_device = target.tile_assignment()(indices);
if (dst_device != src_device) {
src_dst_pairs.emplace_back(src_device, dst_device);
}
src_dst_pairs.emplace_back(src_device, dst_device);
});
auto cp =
state_.collective_ops_creator.create_cross_partition_collective_permute(

View File

@ -284,8 +284,9 @@ class PartitionedHlo {
// Helper function to reshard the tensor using AllToAll (instead of the
// default of Replicate followed by Slice).
PartitionedHlo ReshardWithAllToAll(const HloSharding& target,
int64 source_dim, int64 target_dim) const;
PartitionedHlo ReshardWithAllToAll(
const HloSharding& target,
absl::Span<const std::pair<int64, int64>> source_target_dims) const;
// Helper function to reshard the tensor using CollectivePermute.
PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;

View File

@ -3792,6 +3792,56 @@ ENTRY entry {
4);
}
TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard2) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[8,8] parameter(0),
sharding={devices=[2,4]0,1,2,3,4,5,6,7}
ROOT %copy = f32[8,8] copy(%param0),
sharding={devices=[4,2]0,1,4,5,2,3,6,7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto all_to_all = op::AllToAll(
AllOf(op::Shape("f32[2,2,2]"), op::Reshape(op::Parameter(0))));
auto reshape =
AllOf(op::Shape("f32[2,4]"), op::Reshape(op::Transpose(all_to_all)));
EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape)));
}
TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard3) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[8,8,8] parameter(0),
sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
ROOT %copy = f32[8,8,8] copy(%param0),
sharding={devices=[1,2,4]0,1,4,5,2,3,6,7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto all_to_all = op::AllToAll(
AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(op::Parameter(0))));
auto reshape =
AllOf(op::Shape("f32[4,8,2]"), op::Reshape(op::Transpose(all_to_all)));
auto all_to_all2 =
op::AllToAll(AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(reshape)));
auto reshape2 =
AllOf(op::Shape("f32[8,4,2]"), op::Reshape(op::Transpose(all_to_all2)));
EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2)));
}
} // namespace
} // namespace spmd
} // namespace xla

View File

@ -885,37 +885,89 @@ int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) {
return sharding.tile_assignment().dim(dim);
}
absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
const HloSharding& source, const HloSharding& target) {
absl::optional<std::vector<std::pair<int64, int64>>>
GetReshardAllToAllSourceTargetDims(const HloSharding& source,
const HloSharding& target) {
if (source.IsTileMaximal() || target.IsTileMaximal() ||
source.tile_assignment().num_dimensions() !=
target.tile_assignment().num_dimensions()) {
return absl::nullopt;
}
int64 source_dim = -1;
int64 target_dim = -1;
// Record partition count to index for indices that have different partition
// counts on source and target.
std::map<int64, std::vector<int64>> source_size_to_dim;
std::map<int64, std::vector<int64>> target_size_to_dim;
for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
if (source.tile_assignment().dim(i) > 1 &&
target.tile_assignment().dim(i) == 1) {
if (source_dim != -1) {
return absl::nullopt;
}
source_dim = i;
} else if (source.tile_assignment().dim(i) == 1 &&
target.tile_assignment().dim(i) > 1) {
if (target_dim != -1) {
return absl::nullopt;
}
target_dim = i;
} else if (source.tile_assignment().dim(i) !=
target.tile_assignment().dim(i)) {
if (source.tile_assignment().dim(i) == target.tile_assignment().dim(i)) {
continue;
}
source_size_to_dim[source.tile_assignment().dim(i)].push_back(i);
target_size_to_dim[target.tile_assignment().dim(i)].push_back(i);
}
// In order to shard via AllToAll, source_size_to_dim and target_size_to_dim
// must have the same distribution.
if (source_size_to_dim.empty() ||
source_size_to_dim.size() != target_size_to_dim.size()) {
return absl::nullopt;
}
for (const auto& entry : source_size_to_dim) {
auto target_it = target_size_to_dim.find(entry.first);
if (target_it == target_size_to_dim.end() ||
target_it->second.size() != entry.second.size()) {
return absl::nullopt;
}
}
if (source_dim == -1 || target_dim == -1 || source_dim == target_dim) {
return absl::nullopt;
std::vector<std::pair<int64, int64>> result;
auto remove_entry = [](int64 size, int64 dim,
std::map<int64, std::vector<int64>>& size_to_dim) {
size_to_dim[size].erase(
std::remove_if(size_to_dim[size].begin(), size_to_dim[size].end(),
[dim](int64 a) { return a == dim; }),
size_to_dim[size].end());
if (size_to_dim[size].empty()) {
size_to_dim.erase(size);
}
};
// Find one pair of dimensions to swap at a time.
while (!source_size_to_dim.empty()) {
int64 source_size = source_size_to_dim.begin()->first;
int64 i = source_size_to_dim.begin()->second.back();
int64 target_i_size = target.tile_assignment().dim(i);
if (target_i_size == source_size) {
remove_entry(source_size, i, source_size_to_dim);
remove_entry(source_size, i, target_size_to_dim);
continue;
}
auto j_it = source_size_to_dim[target_i_size].begin();
int64 j = *j_it;
if (source_size == 1) {
// If possible, find a j where the target partition count is not one, so
// that when we swap, the resulting size-1 dimension will still be useful
// to other dimensions.
while (target.tile_assignment().dim(j) == 1) {
if (++j_it == source_size_to_dim[target_i_size].end()) {
break;
}
j = *j_it;
}
} else if (target_i_size % source_size == 0) {
// If possible, find a j where the target partition count is source_size,
// so that we can do a single swap.
while (target.tile_assignment().dim(j) != source_size) {
if (++j_it == source_size_to_dim[target_i_size].end()) {
break;
}
j = *j_it;
}
} else {
return absl::nullopt;
}
result.emplace_back(j, i);
remove_entry(target_i_size, i, target_size_to_dim);
source_size_to_dim.begin()->second.back() = j;
remove_entry(target_i_size, j, source_size_to_dim);
}
return std::pair<int64, int64>(source_dim, target_dim);
return result;
}
bool CanReshardWithCollectivePermute(const HloSharding& source,

View File

@ -265,10 +265,12 @@ HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
// Check if a dimension is sharded.
int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
// Returns the pair of source and target dimensions is the resharding can be
// done via all-to-all.
absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
const HloSharding& source, const HloSharding& target);
// Returns the list of source-target pairs of dimensions to swap during
// resharding via all-to-all. Reshard can be done by swapping each pair at a
// time.
absl::optional<std::vector<std::pair<int64, int64>>>
GetReshardAllToAllSourceTargetDims(const HloSharding& source,
const HloSharding& target);
// Returns whether the resharding can be done via collective-permute.
bool CanReshardWithCollectivePermute(const HloSharding& source,