From ad8a4e1bdaf172a6490552c0bb072207fbaf08c0 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Wed, 27 May 2020 15:06:00 -0700 Subject: [PATCH] [XLA:SPMD] Halo exchange beyond direct neighbors PiperOrigin-RevId: 313471820 Change-Id: Ie0c4fa412dff534ebae462726dd880c2e7093d40 --- .../xla/service/spmd/spmd_partitioner_test.cc | 37 ++++++++++++ .../xla/service/spmd/spmd_partitioner_util.cc | 59 ++++++++++--------- 2 files changed, 67 insertions(+), 29 deletions(-) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 55d7dc43785..e766695385b 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -649,6 +649,43 @@ ENTRY entry { op::ReduceWindow(masked, op::Constant()))); } +TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1), + window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum, + sharding={devices=[5,1]0,1,2,3,4} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/5)); + VLOG(1) << module->ToString(); + auto halo0 = AllOf(op::Shape("f32[1,2]"), + op::CollectivePermute(op::Slice(op::Parameter(0)))); + auto halo1 = + AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0))); + auto pre_mask = + AllOf(op::Shape("f32[4,2]"), + op::Slice(AllOf(op::Shape("f32[5,2]"), + op::Concatenate(halo0, halo1, op::Parameter(0))))); + auto masked = + op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())), + op::Broadcast(op::Constant())), + pre_mask, op::Broadcast(op::Constant())); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) { const char* const hlo_string = R"( HloModule module diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc index 57617b59ffb..8db2ca84a05 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" +#include + #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -23,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -407,33 +410,30 @@ absl::optional ExchangeHalo( std::vector concat_pieces; int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); - if (max_left_halo_size > input_shard_size) { - VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor."; - return absl::nullopt; - } - if (max_left_halo_size > 0) { + for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0; + --i) { std::vector> source_target_pairs; target.tile_assignment().Each( [&](absl::Span indices, int64 device) { - if (indices[dim] > 0) { + if (indices[dim] > i) { std::vector source_indices(indices.begin(), indices.end()); - source_indices[dim] -= 1; + source_indices[dim] -= i + 1; source_target_pairs.emplace_back( target.tile_assignment()(source_indices), device); } }); + int64 halo_size = + std::min(max_left_halo_size - input_shard_size * i, input_shard_size); auto halo_shape = hlo->shape(); auto source_halo_slice = hlo; - if (max_left_halo_size != hlo->shape().dimensions(dim)) { - halo_shape.set_dimensions(dim, max_left_halo_size); + if (halo_size != hlo->shape().dimensions(dim)) { + halo_shape.set_dimensions(dim, halo_size); std::vector halo_start_indices(halo_shape.rank(), 0); - halo_start_indices[dim] = - hlo->shape().dimensions(dim) - max_left_halo_size; + halo_start_indices[dim] = hlo->shape().dimensions(dim) - halo_size; std::vector halo_slice_strides(halo_shape.rank(), 1); - - source_halo_slice = b->AddInstruction( - hlo->CreateSlice(halo_shape, hlo, halo_start_indices, - hlo->shape().dimensions(), halo_slice_strides)); + source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice( + halo_shape, hlo, halo_start_indices, hlo->shape().dimensions(), + halo_slice_strides)); } auto left_halo = collective_ops_creator.create_cross_partition_collective_permute( @@ -446,29 +446,30 @@ absl::optional ExchangeHalo( // Right halo. int64 max_right_halo_size = right_halo_size_function.MaxInRange(0, shard_count - 1); - if (max_right_halo_size > input_shard_size) { - VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor."; - return absl::nullopt; - } - if (max_right_halo_size > 0) { + for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size); + ++i) { std::vector> source_target_pairs; target.tile_assignment().Each( [&](absl::Span indices, int64 device) { - if (indices[dim] > 0) { + if (indices[dim] > i) { std::vector target_indices(indices.begin(), indices.end()); - target_indices[dim] -= 1; + target_indices[dim] -= i + 1; source_target_pairs.emplace_back( device, target.tile_assignment()(target_indices)); } }); + int64 halo_size = + std::min(max_right_halo_size - input_shard_size * i, input_shard_size); auto halo_shape = hlo->shape(); - halo_shape.set_dimensions(dim, max_right_halo_size); - std::vector halo_start_indices(halo_shape.rank(), 0); - std::vector halo_slice_strides(halo_shape.rank(), 1); - - auto source_halo_slice = b->AddInstruction( - hlo->CreateSlice(halo_shape, hlo, halo_start_indices, - halo_shape.dimensions(), halo_slice_strides)); + HloInstruction* source_halo_slice = hlo; + if (halo_size != halo_shape.dimensions(dim)) { + halo_shape.set_dimensions(dim, halo_size); + std::vector halo_start_indices(halo_shape.rank(), 0); + std::vector halo_slice_strides(halo_shape.rank(), 1); + source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice( + halo_shape, hlo, halo_start_indices, halo_shape.dimensions(), + halo_slice_strides)); + } auto right_halo = collective_ops_creator.create_cross_partition_collective_permute( b, source_halo_slice, source_target_pairs, (*next_channel_id)++);