diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc index 7fc05608800..11a24b30ac9 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -190,13 +190,22 @@ absl::optional ReshapeSharding(const Shape& source_shape, target_dims_stack.push_back(t_size); } else if (s_size > t_size) { // Dimension split. - if (s_size % t_size != 0 || t_size % s_partitions != 0) { + if (s_size % t_size != 0 || s_size % s_partitions != 0) { + return absl::nullopt; + } + if (t_size % s_partitions == 0) { + target_tile_assignment_dimensions.push_back(s_partitions); + // We have part of the s_size unprocessed, so put it back to stack. + source_dims_stack.push_back(s_size / t_size); + sharding_tile_dims_stack.push_back(1); + } else if (s_partitions % t_size == 0) { + target_tile_assignment_dimensions.push_back(t_size); + // We have part of the s_size unprocessed, so put it back to stack. + source_dims_stack.push_back(s_size / t_size); + sharding_tile_dims_stack.push_back(s_partitions / t_size); + } else { return absl::nullopt; } - target_tile_assignment_dimensions.push_back(s_partitions); - // We have part of the s_size unprocessed, so put it back to stack. - source_dims_stack.push_back(s_size / t_size); - sharding_tile_dims_stack.push_back(1); } else { // Dimension merge. Also merge the source dimension with the next, and // process it next time. diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc index 02496c75965..08f136b2e45 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc @@ -76,6 +76,20 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) { EXPECT_EQ(result.value(), output_sharding); } +TEST(HloShardingUtilTest, ReshapeShardingTiledSplit2) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7}); + Array2D tile(16, 1); + tile.FillIota(0); + HloSharding input_sharding = HloSharding::Tile(tile); + tile.Reshape({4, 4, 1}); + HloSharding output_sharding = HloSharding::Tile(tile); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) { Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7}); Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7}); diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index ceead32f530..bac5c812814 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -267,8 +267,7 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(sharding(), target)) { - return ReshardWithAllToAll(target, src_tgt_dims->first, - src_tgt_dims->second); + return ReshardWithAllToAll(target, *src_tgt_dims); } // If not replicated yet, first replicate and then reshard to use one of the @@ -734,40 +733,82 @@ PartitionedHlo PartitionedHlo::Broadcast() const { return PartitionedHlo(result, base_shape_, state_); } -PartitionedHlo PartitionedHlo::ReshardWithAllToAll(const HloSharding& target, - int64 source_dim, - int64 target_dim) const { - const int64 group_size = sharding().tile_assignment().dim(source_dim); - - // If the device order is different in the target, fix the order with - // ReshardWithCollectivePermute. - std::vector xpose_dims(target.tile_assignment().num_dimensions()); - std::iota(xpose_dims.begin(), xpose_dims.end(), 0); - xpose_dims[source_dim] = target_dim; - xpose_dims[target_dim] = source_dim; - auto input_sharding_fixed_device_order = - hlo_sharding_util::TransposeSharding(target, xpose_dims); - if (input_sharding_fixed_device_order != sharding()) { - auto fixed_order = - ReshardWithCollectivePermute(input_sharding_fixed_device_order); - return fixed_order.ReshardWithAllToAll(target, source_dim, target_dim); +PartitionedHlo PartitionedHlo::ReshardWithAllToAll( + const HloSharding& target, + absl::Span> source_target_dims) const { + if (source_target_dims.empty()) { + if (target == sharding()) { + return *this; + } + // If the device order is different in the target, fix the order with + // ReshardWithCollectivePermute. + return ReshardWithCollectivePermute(target); } - auto padded_hlo = - PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); + // Swap one pair of dimensions. + int64 source_dim = source_target_dims[0].first; + int64 target_dim = source_target_dims[0].second; + const int64 group_size = sharding().tile_assignment().dim(source_dim) / + sharding().tile_assignment().dim(target_dim); - // The order of ids in the group must follow the target sharding. - std::vector groups(target.tile_assignment().num_elements() / - group_size); - target.tile_assignment().Each( + auto temp_target_tile = sharding().tile_assignment(); + { + std::vector reshape_tile_dims(temp_target_tile.num_dimensions() + 2); + int64 i = 0; + int64 added_source_dim = -1; + int64 added_target_dim = -1; + for (int64 j = 0; j < temp_target_tile.num_dimensions(); ++j) { + if (source_dim == j) { + reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size; + reshape_tile_dims[++i] = group_size; + added_source_dim = i; + } else if (target_dim == j) { + reshape_tile_dims[i] = temp_target_tile.dim(j); + reshape_tile_dims[++i] = 1; + added_target_dim = i; + } else { + reshape_tile_dims[i] = temp_target_tile.dim(j); + } + ++i; + } + temp_target_tile.Reshape(reshape_tile_dims); + std::vector xpose_dims(temp_target_tile.num_dimensions()); + std::iota(xpose_dims.begin(), xpose_dims.end(), 0); + xpose_dims[added_source_dim] = added_target_dim; + xpose_dims[added_target_dim] = added_source_dim; + temp_target_tile = hlo_sharding_util::TransposeSharding( + HloSharding::Tile(temp_target_tile), xpose_dims) + .tile_assignment(); + auto temp_target_tile_dims = sharding().tile_assignment().dimensions(); + temp_target_tile_dims[source_dim] = + sharding().tile_assignment().dim(target_dim); + temp_target_tile_dims[target_dim] = + sharding().tile_assignment().dim(source_dim); + temp_target_tile.Reshape(temp_target_tile_dims); + } + auto temp_target = HloSharding::Tile(temp_target_tile); + + auto padded_shape = hlo_->shape(); + padded_shape.set_dimensions( + target_dim, + RoundUpToNearest(padded_shape.dimensions(target_dim), + temp_target.tile_assignment().dim(target_dim))); + auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b); + + // The order of ids in the group must follow the temp_target sharding. + std::vector groups( + temp_target.tile_assignment().num_elements() / group_size); + temp_target.tile_assignment().Each( [&](absl::Span indices, int64 device) { int64 group_id = 0; for (int64 dim = 0; dim < indices.size(); ++dim) { if (dim == target_dim) { - continue; + group_id *= temp_target.tile_assignment().dim(dim) / group_size; + group_id += indices[dim] / group_size; + } else { + group_id *= temp_target.tile_assignment().dim(dim); + group_id += indices[dim]; } - group_id *= target.tile_assignment().dim(dim); - group_id += indices[dim]; } groups[group_id].add_replica_ids(device); }); @@ -819,14 +860,17 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(const HloSharding& target, result = state_.b->AddInstruction( HloInstruction::CreateReshape(new_shape, transpose)); - const Shape result_shape = MakePartitionedShape(base_shape_, target); + const Shape result_shape = MakePartitionedShape(base_shape_, temp_target); if (result_shape != result->shape()) { result = state_.b->AddInstruction(HloInstruction::CreateSlice( result_shape, result, std::vector(result_shape.rank(), 0), result_shape.dimensions(), std::vector(result_shape.rank(), 1))); } - result->set_sharding(target); - return PartitionedHlo(result, base_shape_, state_); + result->set_sharding(temp_target); + auto remaining_source_target_dims = source_target_dims; + remaining_source_target_dims.remove_prefix(1); + return PartitionedHlo(result, base_shape_, state_) + .ReshardWithAllToAll(target, remaining_source_target_dims); } PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( @@ -837,9 +881,7 @@ PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( sharding().tile_assignment().Each( [&](absl::Span indices, int64 src_device) { int64 dst_device = target.tile_assignment()(indices); - if (dst_device != src_device) { - src_dst_pairs.emplace_back(src_device, dst_device); - } + src_dst_pairs.emplace_back(src_device, dst_device); }); auto cp = state_.collective_ops_creator.create_cross_partition_collective_permute( diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index 40881b4b91c..606a7ae5f14 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -284,8 +284,9 @@ class PartitionedHlo { // Helper function to reshard the tensor using AllToAll (instead of the // default of Replicate followed by Slice). - PartitionedHlo ReshardWithAllToAll(const HloSharding& target, - int64 source_dim, int64 target_dim) const; + PartitionedHlo ReshardWithAllToAll( + const HloSharding& target, + absl::Span> source_target_dims) const; // Helper function to reshard the tensor using CollectivePermute. PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 5b6c869c5fa..1045d1187b8 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -3792,6 +3792,56 @@ ENTRY entry { 4); } +TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0), + sharding={devices=[2,4]0,1,2,3,4,5,6,7} + ROOT %copy = f32[8,8] copy(%param0), + sharding={devices=[4,2]0,1,4,5,2,3,6,7} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto all_to_all = op::AllToAll( + AllOf(op::Shape("f32[2,2,2]"), op::Reshape(op::Parameter(0)))); + auto reshape = + AllOf(op::Shape("f32[2,4]"), op::Reshape(op::Transpose(all_to_all))); + EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape))); +} + +TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard3) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8,8] parameter(0), + sharding={devices=[2,4,1]0,1,2,3,4,5,6,7} + ROOT %copy = f32[8,8,8] copy(%param0), + sharding={devices=[1,2,4]0,1,4,5,2,3,6,7} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto all_to_all = op::AllToAll( + AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(op::Parameter(0)))); + auto reshape = + AllOf(op::Shape("f32[4,8,2]"), op::Reshape(op::Transpose(all_to_all))); + auto all_to_all2 = + op::AllToAll(AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(reshape))); + auto reshape2 = + AllOf(op::Shape("f32[8,4,2]"), op::Reshape(op::Transpose(all_to_all2))); + EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc index 65f066db629..6beed5a15e5 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -885,37 +885,89 @@ int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) { return sharding.tile_assignment().dim(dim); } -absl::optional> GetReshardAllToAllSourceTargetDims( - const HloSharding& source, const HloSharding& target) { +absl::optional>> +GetReshardAllToAllSourceTargetDims(const HloSharding& source, + const HloSharding& target) { if (source.IsTileMaximal() || target.IsTileMaximal() || source.tile_assignment().num_dimensions() != target.tile_assignment().num_dimensions()) { return absl::nullopt; } - int64 source_dim = -1; - int64 target_dim = -1; + // Record partition count to index for indices that have different partition + // counts on source and target. + std::map> source_size_to_dim; + std::map> target_size_to_dim; for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) { - if (source.tile_assignment().dim(i) > 1 && - target.tile_assignment().dim(i) == 1) { - if (source_dim != -1) { - return absl::nullopt; - } - source_dim = i; - } else if (source.tile_assignment().dim(i) == 1 && - target.tile_assignment().dim(i) > 1) { - if (target_dim != -1) { - return absl::nullopt; - } - target_dim = i; - } else if (source.tile_assignment().dim(i) != - target.tile_assignment().dim(i)) { + if (source.tile_assignment().dim(i) == target.tile_assignment().dim(i)) { + continue; + } + source_size_to_dim[source.tile_assignment().dim(i)].push_back(i); + target_size_to_dim[target.tile_assignment().dim(i)].push_back(i); + } + // In order to shard via AllToAll, source_size_to_dim and target_size_to_dim + // must have the same distribution. + if (source_size_to_dim.empty() || + source_size_to_dim.size() != target_size_to_dim.size()) { + return absl::nullopt; + } + for (const auto& entry : source_size_to_dim) { + auto target_it = target_size_to_dim.find(entry.first); + if (target_it == target_size_to_dim.end() || + target_it->second.size() != entry.second.size()) { return absl::nullopt; } } - if (source_dim == -1 || target_dim == -1 || source_dim == target_dim) { - return absl::nullopt; + std::vector> result; + auto remove_entry = [](int64 size, int64 dim, + std::map>& size_to_dim) { + size_to_dim[size].erase( + std::remove_if(size_to_dim[size].begin(), size_to_dim[size].end(), + [dim](int64 a) { return a == dim; }), + size_to_dim[size].end()); + if (size_to_dim[size].empty()) { + size_to_dim.erase(size); + } + }; + // Find one pair of dimensions to swap at a time. + while (!source_size_to_dim.empty()) { + int64 source_size = source_size_to_dim.begin()->first; + int64 i = source_size_to_dim.begin()->second.back(); + int64 target_i_size = target.tile_assignment().dim(i); + if (target_i_size == source_size) { + remove_entry(source_size, i, source_size_to_dim); + remove_entry(source_size, i, target_size_to_dim); + continue; + } + auto j_it = source_size_to_dim[target_i_size].begin(); + int64 j = *j_it; + if (source_size == 1) { + // If possible, find a j where the target partition count is not one, so + // that when we swap, the resulting size-1 dimension will still be useful + // to other dimensions. + while (target.tile_assignment().dim(j) == 1) { + if (++j_it == source_size_to_dim[target_i_size].end()) { + break; + } + j = *j_it; + } + } else if (target_i_size % source_size == 0) { + // If possible, find a j where the target partition count is source_size, + // so that we can do a single swap. + while (target.tile_assignment().dim(j) != source_size) { + if (++j_it == source_size_to_dim[target_i_size].end()) { + break; + } + j = *j_it; + } + } else { + return absl::nullopt; + } + result.emplace_back(j, i); + remove_entry(target_i_size, i, target_size_to_dim); + source_size_to_dim.begin()->second.back() = j; + remove_entry(target_i_size, j, source_size_to_dim); } - return std::pair(source_dim, target_dim); + return result; } bool CanReshardWithCollectivePermute(const HloSharding& source, diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h index d924a5c7151..7b737daf78c 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -265,10 +265,12 @@ HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder, // Check if a dimension is sharded. int64 ShardCountAtDim(const HloSharding& sharding, int64 dim); -// Returns the pair of source and target dimensions is the resharding can be -// done via all-to-all. -absl::optional> GetReshardAllToAllSourceTargetDims( - const HloSharding& source, const HloSharding& target); +// Returns the list of source-target pairs of dimensions to swap during +// resharding via all-to-all. Reshard can be done by swapping each pair at a +// time. +absl::optional>> +GetReshardAllToAllSourceTargetDims(const HloSharding& source, + const HloSharding& target); // Returns whether the resharding can be done via collective-permute. bool CanReshardWithCollectivePermute(const HloSharding& source,