[XLA:SPMD] Use per-mesh-dimension allreduce/allgather

PiperOrigin-RevId: 356589391
Change-Id: I0649b6e9395f11a02ba1890102991e7566443b4f
This commit is contained in:
Yuanzhong Xu 2021-02-09 14:17:13 -08:00 committed by TensorFlower Gardener
parent 65f038e14d
commit 47917266f3
5 changed files with 106 additions and 84 deletions

View File

@ -92,17 +92,6 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
namespace {
std::vector<int64> GetAllDevicesInOrder(const HloSharding& sharding) {
CHECK(!sharding.IsTileMaximal());
std::vector<int64> results;
results.reserve(sharding.tile_assignment().num_elements());
sharding.tile_assignment().Each(
[&](absl::Span<const int64> /* indices */, int64 device) {
results.push_back(device);
});
return results;
}
StatusOr<HloInstruction*> PartitionBaseCase(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
@ -1018,11 +1007,15 @@ StatusOr<HloInstruction*> PartitionBaseCase(
}
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
auto ar =
lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module),
{GetAllDevicesInOrder(lhs.sharding())},
(*lhs.state().next_channel_id)++);
std::vector<int64> lhs_contracting_dims;
lhs_contracting_dims.reserve(lhs.base_shape().rank());
for (const auto& cd : dims_mapping.contracting_dims) {
lhs_contracting_dims.push_back(cd.lhs);
}
auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
b, dot, lhs.sharding(), lhs.state().next_channel_id,
lhs_contracting_dims, lhs.state().collective_ops_creator,
MakeBinaryAdd(output_base_shape.element_type(), module));
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, output_base_shape, lhs.state())
.Reshard(output_sharding)
@ -1123,10 +1116,16 @@ StatusOr<HloInstruction*> PartitionBaseCase(
}
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
return lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module),
{GetAllDevicesInOrder(lhs.sharding())},
(*lhs.state().next_channel_id)++);
std::vector<int64> lhs_contracting_dims;
lhs_contracting_dims.reserve(lhs.base_shape().rank());
for (const auto& cd : dims_mapping.contracting_dims) {
lhs_contracting_dims.push_back(cd.lhs);
}
return lhs.state().partitioner->AllReduceAlongShardingDims(
b, dot, lhs.sharding(), lhs.state().next_channel_id,
lhs_contracting_dims, lhs.state().collective_ops_creator,
MakeBinaryAdd(output_base_shape.element_type(), module));
}
return nullptr;
}
@ -1679,20 +1678,10 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
if (!dot) {
return nullptr;
}
std::vector<int64> other_lhs_dims;
for (int64 i = 0; i < lhs_sharding.tile_assignment().num_dimensions(); ++i) {
if (!absl::c_linear_search(lhs_dims, i)) {
other_lhs_dims.push_back(i);
}
}
auto inverse_grouped = GroupShardingOnDims(lhs_sharding, other_lhs_dims);
auto ar =
CreatePerGroupPartitioningState(lhs.state(),
inverse_grouped.device_groups, b)
.collective_ops_creator.create_cross_partition_all_reduce(
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module),
{GetAllDevicesInOrder(inverse_grouped.sharding)},
(*lhs.state().next_channel_id)++);
auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
b, dot, lhs_sharding, lhs.state().next_channel_id, lhs_dims,
lhs.state().collective_ops_creator,
MakeBinaryAdd(output_base_shape.element_type(), module));
ar->set_sharding(outer_output_tmp_sharding);
return PartitionedHlo(ar, output_base_shape, lhs.state())
.Reshard(output_sharding)

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/algorithm/container.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
@ -285,17 +286,15 @@ StatusOr<HloInstruction*> ParititonTrivialIndexedOperandDimension(
replicated_dim.push_back(
operand.sharding().tile_assignment().num_dimensions() - 1);
}
auto sharding_grouped =
GroupShardingOnDims(operand.sharding(), replicated_dim);
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
operand.state(), sharding_grouped.device_groups, b);
auto collective_ops_creator =
per_group_partitioner_state.collective_ops_creator;
auto ar = collective_ops_creator.create_cross_partition_all_reduce(
b, filtered,
// All-reduce along all dims in operand sharding -- this is OK because the
// operand is sharded only on trivially sliced dimensions.
std::vector<int64> all_dims(operand.base_shape().rank());
absl::c_iota(all_dims, 0);
auto ar = operand.state().partitioner->AllReduceAlongShardingDims(
b, filtered, operand.sharding(), operand.state().next_channel_id,
all_dims, operand.state().collective_ops_creator,
MakeBinaryAdd(filtered->shape().element_type(),
per_group_partitioner_state.module),
{}, visitor->NewChannel());
operand.state().module));
VLOG(5) << "[Gather partitioning]: Partitioned as trivial operand "
"batch_dim slice";
ar->set_sharding(HloSharding::Replicate());
@ -574,8 +573,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
update_dim_to_index_dim);
CHECK(new_updates_sharding.has_value());
updates = updates.Reshard(*new_updates_sharding);
// Update collective_ops_creator and partition_id for partial replicate.
auto collective_ops_creator = collective_ops_creator_;
// Update partition_id for partial replicate.
auto partition_id = partition_id_;
if (indices.sharding().ReplicateOnLastTileDim()) {
auto sharding_grouped = GroupShardingOnDims(
@ -583,8 +581,6 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
{indices.sharding().tile_assignment().num_dimensions() - 1});
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
indices.state(), sharding_grouped.device_groups, &b_);
collective_ops_creator =
per_group_partitioner_state.collective_ops_creator;
partition_id = per_group_partitioner_state.partition_id;
}
// To avoid accumulating the initial operand multiple times during
@ -600,9 +596,13 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
identity, operand.Replicate().hlo()));
auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands(
scatter->shape(), {select_operand, indices.hlo(), updates.hlo()}));
auto all_reduce =
collective_ops_creator.create_cross_partition_all_reduce(
&b_, pscatter, scatter->to_apply(), {}, NewChannel());
// All-reduce along all dims in operand sharding -- this is OK because the
// operand is not sharded on index_vector_dim.
std::vector<int64> all_dims(indices.base_shape().rank());
absl::c_iota(all_dims, 0);
auto all_reduce = operand.state().partitioner->AllReduceAlongShardingDims(
&b_, pscatter, indices.sharding(), indices.state().next_channel_id,
all_dims, collective_ops_creator_, scatter->to_apply());
all_reduce->set_sharding(HloSharding::Replicate());
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(all_reduce, hlo->shape(), MakePartitioningState())

View File

@ -877,7 +877,7 @@ HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span<const int64> dims) {
HloInstruction* result = nullptr;
if (state_.collective_ops_creator.create_cross_partition_all_gather) {
result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding(),
NewChannel(), dims,
state_.next_channel_id, dims,
state_.collective_ops_creator);
}
if (result == nullptr) {
@ -892,12 +892,9 @@ HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span<const int64> dims) {
padded_target_shape, zero_bcast, hlo_, offsets));
HloComputation* reduction =
MakeBinaryAdd(shard_shape.element_type(), state_.module);
auto all_reduce =
state_.collective_ops_creator.create_cross_partition_all_reduce(
state_.b, dus, reduction,
GetPartitionGroupsForReplication(sharding(), dims), NewChannel());
result = all_reduce;
result = state_.partitioner->AllReduceAlongShardingDims(
state_.b, dus, sharding(), state_.next_channel_id, dims,
state_.collective_ops_creator, reduction);
}
if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) {
std::vector<int64> start_indices(target_shape.rank(), 0);
@ -2765,14 +2762,15 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
if (inputs[0].sharding().ReplicateOnLastTileDim()) {
preserved_dims.push_back(inputs[0].base_shape().rank());
}
auto grouped = GroupShardingOnDims(inputs[0].sharding(), preserved_dims);
auto grouped_state = CreatePerGroupPartitioningState(
inputs[0].state(), grouped.device_groups, &b_);
if (local_reduce->shape().IsArray()) {
reduce = grouped_state.collective_ops_creator
.create_cross_partition_all_reduce(
&b_, local_reduce, hlo->to_apply(), {}, NewChannel());
reduce = partitioner_->AllReduceAlongShardingDims(
&b_, local_reduce, inputs[0].sharding(), next_channel_id_,
hlo->dimensions(), collective_ops_creator_, hlo->to_apply());
} else {
auto grouped =
GroupShardingOnDims(inputs[0].sharding(), preserved_dims);
auto grouped_state = CreatePerGroupPartitioningState(
inputs[0].state(), grouped.device_groups, &b_);
std::vector<HloInstruction*> all_gathered_partial_results(input_count);
for (int64 i = 0; i < input_count; ++i) {
auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
@ -3500,8 +3498,11 @@ SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas,
HloInstruction* SpmdPartitioner::AllGatherShards(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64 channel_id, absl::Span<const int64> selected_dims,
int64* next_channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator) {
if (selected_dims.empty()) {
return operand;
}
CHECK(!sharding.IsTileMaximal());
// Add one leading dimension to gather all partitions.
std::vector<int64> shape;
@ -3511,12 +3512,18 @@ HloInstruction* SpmdPartitioner::AllGatherShards(
}
auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
auto partition_subgroups =
GetPartitionGroupsForReplication(sharding, selected_dims);
shape[0] = partition_subgroups[0].size();
auto result = collectives_creator.create_cross_partition_all_gather(
b, reshape, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
partition_subgroups, channel_id, /*all_gather_dimension=*/0);
HloInstruction* result = reshape;
for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
if (sharding.tile_assignment().dim(*it) == 1) {
continue;
}
auto partition_subgroups =
GetPartitionGroupsForReplication(sharding, {*it});
shape[0] *= partition_subgroups[0].size();
result = collectives_creator.create_cross_partition_all_gather(
b, result, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
partition_subgroups, (*next_channel_id)++, /*all_gather_dimension=*/0);
}
// If n > 1 dimensions are partitioned, split the leading dimension to n.
std::vector<int64> tiled_dims;
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
@ -3567,6 +3574,24 @@ HloInstruction* SpmdPartitioner::AllGatherShards(
return result;
}
HloInstruction* SpmdPartitioner::AllReduceAlongShardingDims(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64* next_channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator,
HloComputation* reduction) {
auto result = operand;
for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
if (sharding.tile_assignment().dim(*it) == 1) {
continue;
}
auto partition_subgroups =
GetPartitionGroupsForReplication(sharding, {*it});
result = collectives_creator.create_cross_partition_all_reduce(
b, result, reduction, partition_subgroups, (*next_channel_id)++);
}
return result;
}
StatusOr<bool> SpmdPartitioner::PartitionComputation(
HloComputation* computation, const HloSharding& root_sharding,
int64* next_channel_id, SpmdLogger* logger) {

View File

@ -197,7 +197,7 @@ class SpmdPartitioner : public HloModulePass {
int64* next_channel_id,
SpmdLogger* logger);
// Creates all-gather based on HloSharding. Can be overridden to customize.
// Creates all-gather(s) based on HloSharding. Can be overridden to customize.
// The default uses a single all-gather even if there are multiple sharded
// dimensions, and adds potential reshapes and transposes to achieve that.
// If it returns false, the partitioner will fall back to all-reduce.
@ -206,9 +206,17 @@ class SpmdPartitioner : public HloModulePass {
// all-gather.
virtual HloInstruction* AllGatherShards(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64 channel_id, absl::Span<const int64> selected_dims,
int64* next_channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator);
// Creates all-reduce(s) across devices along selected_dims in sharding. Can
// be overridden to customize.
virtual HloInstruction* AllReduceAlongShardingDims(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64* next_channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator,
HloComputation* reduction);
const SpmdPartitionerOptions& options() { return options_; }
protected:

View File

@ -3705,7 +3705,7 @@ ENTRY entry {
op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
op::Constant(), op::Reshape(), op::Reshape())),
op::Shape("f32[32,39296,32,64]"));
EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)),
EXPECT_THAT(root, AllOf(op::AllReduce(op::AllReduce(op::Dot(lhs, rhs))),
op::Shape("f32[32,24,39296]")));
}
@ -4662,10 +4662,10 @@ ENTRY entry {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::AllReduce(op::Scatter(
AllOf(op::AllReduce(op::AllReduce(op::Scatter(
op::Select(op::Broadcast(op::Convert(op::PartitionId())),
op::Broadcast(op::Constant()), op::Parameter(0)),
op::Parameter(1), op::Parameter(2))),
op::Parameter(1), op::Parameter(2)))),
op::Shape("f32[2,9,8]")));
}
@ -4698,10 +4698,10 @@ ENTRY entry {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::AllReduce(op::Scatter(
AllOf(op::AllReduce(op::AllReduce(op::Scatter(
op::Select(op::Broadcast(op::Convert(op::Reshape())),
op::Broadcast(op::Constant()), op::Parameter(0)),
op::Parameter(1), op::Parameter(2))),
op::Parameter(1), op::Parameter(2)))),
op::Shape("f32[2,9,8]")));
}
@ -4732,10 +4732,10 @@ ENTRY entry {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::AllReduce(op::Scatter(
AllOf(op::AllReduce(op::AllReduce(op::Scatter(
op::Select(op::Broadcast(op::Convert(op::PartitionId())),
op::Broadcast(op::Constant()), op::Parameter(0)),
op::Parameter(1), op::Parameter(2))),
op::Parameter(1), op::Parameter(2)))),
op::Shape("f32[2,9,8]")));
}
@ -6733,8 +6733,8 @@ ENTRY %module {
auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::DynamicSlice());
auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract());
auto gather = AllOf(op::Shape("s32[2,2,2,2]"), op::Gather(operand, indices));
EXPECT_THAT(root,
op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
EXPECT_THAT(root, op::AllReduce(op::AllReduce(
op::DynamicUpdateSlice(_, gather, _, _, _, _))));
}
TEST_F(SpmdPartitioningTest, GatherParallelDimReplicatedIndices) {
@ -6887,8 +6887,8 @@ ENTRY %module {
auto operand = AllOf(op::Shape("s32[4,1,2,2]"), op::CollectivePermute());
auto indices = AllOf(op::Shape("s32[2,4,1]"), op::Subtract());
auto gather = AllOf(op::Shape("s32[4,1,2,2]"), op::Gather(operand, indices));
EXPECT_THAT(root,
op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
EXPECT_THAT(root, op::AllReduce(op::AllReduce(
op::DynamicUpdateSlice(_, gather, _, _, _, _))));
}
TEST_F(SpmdPartitioningTest, GatherMergedParalleIndexPassthrough) {