[XLA:SPMD] Halo exchange beyond direct neighbors
PiperOrigin-RevId: 313471820 Change-Id: Ie0c4fa412dff534ebae462726dd880c2e7093d40
This commit is contained in:
parent
d807bfcb64
commit
ad8a4e1bda
@ -649,6 +649,43 @@ ENTRY entry {
|
|||||||
op::ReduceWindow(masked, op::Constant())));
|
op::ReduceWindow(masked, op::Constant())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
sum {
|
||||||
|
a = f32[] parameter(0)
|
||||||
|
b = f32[] parameter(1)
|
||||||
|
ROOT add = f32[] add(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4}
|
||||||
|
constant.1 = f32[] constant(0), sharding={replicated}
|
||||||
|
ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1),
|
||||||
|
window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum,
|
||||||
|
sharding={devices=[5,1]0,1,2,3,4}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
PartitionComputation(hlo_string, /*num_devices=*/5));
|
||||||
|
VLOG(1) << module->ToString();
|
||||||
|
auto halo0 = AllOf(op::Shape("f32[1,2]"),
|
||||||
|
op::CollectivePermute(op::Slice(op::Parameter(0))));
|
||||||
|
auto halo1 =
|
||||||
|
AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0)));
|
||||||
|
auto pre_mask =
|
||||||
|
AllOf(op::Shape("f32[4,2]"),
|
||||||
|
op::Slice(AllOf(op::Shape("f32[5,2]"),
|
||||||
|
op::Concatenate(halo0, halo1, op::Parameter(0)))));
|
||||||
|
auto masked =
|
||||||
|
op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())),
|
||||||
|
op::Broadcast(op::Constant())),
|
||||||
|
pre_mask, op::Broadcast(op::Constant()));
|
||||||
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
|
||||||
|
op::ReduceWindow(masked, op::Constant())));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) {
|
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) {
|
||||||
const char* const hlo_string = R"(
|
const char* const hlo_string = R"(
|
||||||
HloModule module
|
HloModule module
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
|
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
@ -23,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -407,33 +410,30 @@ absl::optional<HloInstruction*> ExchangeHalo(
|
|||||||
std::vector<HloInstruction*> concat_pieces;
|
std::vector<HloInstruction*> concat_pieces;
|
||||||
|
|
||||||
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
|
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
|
||||||
if (max_left_halo_size > input_shard_size) {
|
for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0;
|
||||||
VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor.";
|
--i) {
|
||||||
return absl::nullopt;
|
|
||||||
}
|
|
||||||
if (max_left_halo_size > 0) {
|
|
||||||
std::vector<std::pair<int64, int64>> source_target_pairs;
|
std::vector<std::pair<int64, int64>> source_target_pairs;
|
||||||
target.tile_assignment().Each(
|
target.tile_assignment().Each(
|
||||||
[&](absl::Span<const int64> indices, int64 device) {
|
[&](absl::Span<const int64> indices, int64 device) {
|
||||||
if (indices[dim] > 0) {
|
if (indices[dim] > i) {
|
||||||
std::vector<int64> source_indices(indices.begin(), indices.end());
|
std::vector<int64> source_indices(indices.begin(), indices.end());
|
||||||
source_indices[dim] -= 1;
|
source_indices[dim] -= i + 1;
|
||||||
source_target_pairs.emplace_back(
|
source_target_pairs.emplace_back(
|
||||||
target.tile_assignment()(source_indices), device);
|
target.tile_assignment()(source_indices), device);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
int64 halo_size =
|
||||||
|
std::min(max_left_halo_size - input_shard_size * i, input_shard_size);
|
||||||
auto halo_shape = hlo->shape();
|
auto halo_shape = hlo->shape();
|
||||||
auto source_halo_slice = hlo;
|
auto source_halo_slice = hlo;
|
||||||
if (max_left_halo_size != hlo->shape().dimensions(dim)) {
|
if (halo_size != hlo->shape().dimensions(dim)) {
|
||||||
halo_shape.set_dimensions(dim, max_left_halo_size);
|
halo_shape.set_dimensions(dim, halo_size);
|
||||||
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
|
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
|
||||||
halo_start_indices[dim] =
|
halo_start_indices[dim] = hlo->shape().dimensions(dim) - halo_size;
|
||||||
hlo->shape().dimensions(dim) - max_left_halo_size;
|
|
||||||
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
|
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
|
||||||
|
source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
|
||||||
source_halo_slice = b->AddInstruction(
|
halo_shape, hlo, halo_start_indices, hlo->shape().dimensions(),
|
||||||
hlo->CreateSlice(halo_shape, hlo, halo_start_indices,
|
halo_slice_strides));
|
||||||
hlo->shape().dimensions(), halo_slice_strides));
|
|
||||||
}
|
}
|
||||||
auto left_halo =
|
auto left_halo =
|
||||||
collective_ops_creator.create_cross_partition_collective_permute(
|
collective_ops_creator.create_cross_partition_collective_permute(
|
||||||
@ -446,29 +446,30 @@ absl::optional<HloInstruction*> ExchangeHalo(
|
|||||||
// Right halo.
|
// Right halo.
|
||||||
int64 max_right_halo_size =
|
int64 max_right_halo_size =
|
||||||
right_halo_size_function.MaxInRange(0, shard_count - 1);
|
right_halo_size_function.MaxInRange(0, shard_count - 1);
|
||||||
if (max_right_halo_size > input_shard_size) {
|
for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size);
|
||||||
VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor.";
|
++i) {
|
||||||
return absl::nullopt;
|
|
||||||
}
|
|
||||||
if (max_right_halo_size > 0) {
|
|
||||||
std::vector<std::pair<int64, int64>> source_target_pairs;
|
std::vector<std::pair<int64, int64>> source_target_pairs;
|
||||||
target.tile_assignment().Each(
|
target.tile_assignment().Each(
|
||||||
[&](absl::Span<const int64> indices, int64 device) {
|
[&](absl::Span<const int64> indices, int64 device) {
|
||||||
if (indices[dim] > 0) {
|
if (indices[dim] > i) {
|
||||||
std::vector<int64> target_indices(indices.begin(), indices.end());
|
std::vector<int64> target_indices(indices.begin(), indices.end());
|
||||||
target_indices[dim] -= 1;
|
target_indices[dim] -= i + 1;
|
||||||
source_target_pairs.emplace_back(
|
source_target_pairs.emplace_back(
|
||||||
device, target.tile_assignment()(target_indices));
|
device, target.tile_assignment()(target_indices));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
int64 halo_size =
|
||||||
|
std::min(max_right_halo_size - input_shard_size * i, input_shard_size);
|
||||||
auto halo_shape = hlo->shape();
|
auto halo_shape = hlo->shape();
|
||||||
halo_shape.set_dimensions(dim, max_right_halo_size);
|
HloInstruction* source_halo_slice = hlo;
|
||||||
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
|
if (halo_size != halo_shape.dimensions(dim)) {
|
||||||
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
|
halo_shape.set_dimensions(dim, halo_size);
|
||||||
|
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
|
||||||
auto source_halo_slice = b->AddInstruction(
|
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
|
||||||
hlo->CreateSlice(halo_shape, hlo, halo_start_indices,
|
source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice(
|
||||||
halo_shape.dimensions(), halo_slice_strides));
|
halo_shape, hlo, halo_start_indices, halo_shape.dimensions(),
|
||||||
|
halo_slice_strides));
|
||||||
|
}
|
||||||
auto right_halo =
|
auto right_halo =
|
||||||
collective_ops_creator.create_cross_partition_collective_permute(
|
collective_ops_creator.create_cross_partition_collective_permute(
|
||||||
b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
|
b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
|
||||||
|
Loading…
Reference in New Issue
Block a user