[XLA:SPMD] Halo exchange beyond direct neighbors

PiperOrigin-RevId: 313471820
Change-Id: Ie0c4fa412dff534ebae462726dd880c2e7093d40
This commit is contained in:
Yuanzhong Xu 2020-05-27 15:06:00 -07:00 committed by TensorFlower Gardener
parent d807bfcb64
commit ad8a4e1bda
2 changed files with 67 additions and 29 deletions

View File

@ -649,6 +649,43 @@ ENTRY entry {
op::ReduceWindow(masked, op::Constant()))); op::ReduceWindow(masked, op::Constant())));
} }
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1),
window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum,
sharding={devices=[5,1]0,1,2,3,4}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/5));
VLOG(1) << module->ToString();
auto halo0 = AllOf(op::Shape("f32[1,2]"),
op::CollectivePermute(op::Slice(op::Parameter(0))));
auto halo1 =
AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0)));
auto pre_mask =
AllOf(op::Shape("f32[4,2]"),
op::Slice(AllOf(op::Shape("f32[5,2]"),
op::Concatenate(halo0, halo1, op::Parameter(0)))));
auto masked =
op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())),
op::Broadcast(op::Constant())),
pre_mask, op::Broadcast(op::Constant()));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
op::ReduceWindow(masked, op::Constant())));
}
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) { TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) {
const char* const hlo_string = R"( const char* const hlo_string = R"(
HloModule module HloModule module

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include <algorithm>
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -23,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla { namespace xla {
@ -407,33 +410,30 @@ absl::optional<HloInstruction*> ExchangeHalo(
std::vector<HloInstruction*> concat_pieces; std::vector<HloInstruction*> concat_pieces;
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
if (max_left_halo_size > input_shard_size) { for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0;
VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor."; --i) {
return absl::nullopt;
}
if (max_left_halo_size > 0) {
std::vector<std::pair<int64, int64>> source_target_pairs; std::vector<std::pair<int64, int64>> source_target_pairs;
target.tile_assignment().Each( target.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) { [&](absl::Span<const int64> indices, int64 device) {
if (indices[dim] > 0) { if (indices[dim] > i) {
std::vector<int64> source_indices(indices.begin(), indices.end()); std::vector<int64> source_indices(indices.begin(), indices.end());
source_indices[dim] -= 1; source_indices[dim] -= i + 1;
source_target_pairs.emplace_back( source_target_pairs.emplace_back(
target.tile_assignment()(source_indices), device); target.tile_assignment()(source_indices), device);
} }
}); });
int64 halo_size =
std::min(max_left_halo_size - input_shard_size * i, input_shard_size);
auto halo_shape = hlo->shape(); auto halo_shape = hlo->shape();
auto source_halo_slice = hlo; auto source_halo_slice = hlo;
if (max_left_halo_size != hlo->shape().dimensions(dim)) { if (halo_size != hlo->shape().dimensions(dim)) {
halo_shape.set_dimensions(dim, max_left_halo_size); halo_shape.set_dimensions(dim, halo_size);
std::vector<int64> halo_start_indices(halo_shape.rank(), 0); std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
halo_start_indices[dim] = halo_start_indices[dim] = hlo->shape().dimensions(dim) - halo_size;
hlo->shape().dimensions(dim) - max_left_halo_size;
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1); std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
source_halo_slice = b->AddInstruction( halo_shape, hlo, halo_start_indices, hlo->shape().dimensions(),
hlo->CreateSlice(halo_shape, hlo, halo_start_indices, halo_slice_strides));
hlo->shape().dimensions(), halo_slice_strides));
} }
auto left_halo = auto left_halo =
collective_ops_creator.create_cross_partition_collective_permute( collective_ops_creator.create_cross_partition_collective_permute(
@ -446,29 +446,30 @@ absl::optional<HloInstruction*> ExchangeHalo(
// Right halo. // Right halo.
int64 max_right_halo_size = int64 max_right_halo_size =
right_halo_size_function.MaxInRange(0, shard_count - 1); right_halo_size_function.MaxInRange(0, shard_count - 1);
if (max_right_halo_size > input_shard_size) { for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size);
VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor."; ++i) {
return absl::nullopt;
}
if (max_right_halo_size > 0) {
std::vector<std::pair<int64, int64>> source_target_pairs; std::vector<std::pair<int64, int64>> source_target_pairs;
target.tile_assignment().Each( target.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) { [&](absl::Span<const int64> indices, int64 device) {
if (indices[dim] > 0) { if (indices[dim] > i) {
std::vector<int64> target_indices(indices.begin(), indices.end()); std::vector<int64> target_indices(indices.begin(), indices.end());
target_indices[dim] -= 1; target_indices[dim] -= i + 1;
source_target_pairs.emplace_back( source_target_pairs.emplace_back(
device, target.tile_assignment()(target_indices)); device, target.tile_assignment()(target_indices));
} }
}); });
int64 halo_size =
std::min(max_right_halo_size - input_shard_size * i, input_shard_size);
auto halo_shape = hlo->shape(); auto halo_shape = hlo->shape();
halo_shape.set_dimensions(dim, max_right_halo_size); HloInstruction* source_halo_slice = hlo;
std::vector<int64> halo_start_indices(halo_shape.rank(), 0); if (halo_size != halo_shape.dimensions(dim)) {
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1); halo_shape.set_dimensions(dim, halo_size);
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
auto source_halo_slice = b->AddInstruction( std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
hlo->CreateSlice(halo_shape, hlo, halo_start_indices, source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
halo_shape.dimensions(), halo_slice_strides)); halo_shape, hlo, halo_start_indices, halo_shape.dimensions(),
halo_slice_strides));
}
auto right_halo = auto right_halo =
collective_ops_creator.create_cross_partition_collective_permute( collective_ops_creator.create_cross_partition_collective_permute(
b, source_halo_slice, source_target_pairs, (*next_channel_id)++); b, source_halo_slice, source_target_pairs, (*next_channel_id)++);