[XLA:SPMD] More cases of reverse sharding
- Improve sharding propagation to reverse the tile assignment - Use reshard (collective permute) to fix mismatch operand sharding - Use halo exchange to fix uneven partitioning PiperOrigin-RevId: 313672162 Change-Id: I0816de794a0c18a0173889ed8cd638baecf389e9
This commit is contained in:
parent
b1cb3f12da
commit
5c77174291
@ -220,6 +220,24 @@ absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding ReverseSharding(const HloSharding& sharding,
|
||||
absl::Span<const int64> dimensions) {
|
||||
if (sharding.IsTileMaximal() || dimensions.empty()) {
|
||||
return sharding;
|
||||
}
|
||||
|
||||
Array<int64> new_tile_assignment(sharding.tile_assignment().dimensions());
|
||||
new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
|
||||
std::vector<int64> original_indices(indices.begin(), indices.end());
|
||||
for (int64 d : dimensions) {
|
||||
original_indices[d] =
|
||||
new_tile_assignment.dim(d) - 1 - original_indices[d];
|
||||
}
|
||||
*device = sharding.tile_assignment()(original_indices);
|
||||
});
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
|
||||
absl::Span<const int64> dims) {
|
||||
CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
|
||||
|
@ -70,6 +70,12 @@ absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
|
||||
const Shape& target_shape,
|
||||
const HloSharding& sharding);
|
||||
|
||||
// Returns the HloSharding with the tile dimensions and tile assignment
|
||||
// reversed based on the specified dimension numbers. In case of a tile
|
||||
// maximal sharding returns the original sharding.
|
||||
HloSharding ReverseSharding(const HloSharding& sharding,
|
||||
absl::Span<const int64> dimensions);
|
||||
|
||||
// Returns a sharding tiled on unique dimension dim by reshaping the tile
|
||||
// assignment of the sharding argument. Only dimensions in the dims span
|
||||
// argument are considered for reshaping, the others are ignored.
|
||||
|
@ -717,6 +717,15 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
}
|
||||
return false;
|
||||
}
|
||||
case HloOpcode::kReverse: {
|
||||
if (!IsSpatiallyPartitioned(instruction->operand(0))) {
|
||||
return false;
|
||||
}
|
||||
return MaybeImproveInstructionSharding(
|
||||
hlo_sharding_util::ReverseSharding(
|
||||
instruction->operand(0)->sharding(), instruction->dimensions()),
|
||||
instruction);
|
||||
}
|
||||
case HloOpcode::kDot: {
|
||||
auto& dot_dim_numbs = instruction->dot_dimension_numbers();
|
||||
// Batch dimensions are the same for lhs and rhs on dot operations.
|
||||
@ -1188,6 +1197,10 @@ absl::optional<HloSharding> GetShardingFromUser(
|
||||
return user.sharding();
|
||||
}
|
||||
}
|
||||
case HloOpcode::kReverse: {
|
||||
return hlo_sharding_util::ReverseSharding(user.sharding(),
|
||||
user.dimensions());
|
||||
}
|
||||
default: {
|
||||
// If the user output shape is compatible with the current instruction
|
||||
// shape excluding element type and the current instruction is supported
|
||||
|
@ -2325,18 +2325,44 @@ Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
|
||||
if (reverse->sharding().IsTileMaximal()) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
if (absl::c_all_of(reverse->dimensions(), [&](int64 d) {
|
||||
return reverse->sharding().tile_assignment().dim(d) == 1;
|
||||
})) {
|
||||
auto operand =
|
||||
GetPartitionedHlo(reverse->operand(0)).Reshard(reverse->sharding());
|
||||
auto operand = GetPartitionedHlo(reverse->operand(0))
|
||||
.Reshard(hlo_sharding_util::ReverseSharding(
|
||||
reverse->sharding(), reverse->dimensions()));
|
||||
// Create a window config to halo exchange for unevenly partitioned reverse
|
||||
// dimensions.
|
||||
Window window;
|
||||
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
|
||||
WindowDimension* dim = window.add_dimensions();
|
||||
dim->set_size(1);
|
||||
dim->set_stride(1);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_window_reversal(false);
|
||||
int64 low_padding = 0;
|
||||
if (absl::c_linear_search(reverse->dimensions(), i)) {
|
||||
low_padding =
|
||||
RoundUpToNearest(reverse->shape().dimensions(i),
|
||||
reverse->sharding().tile_assignment().dim(i)) -
|
||||
reverse->shape().dimensions(i);
|
||||
}
|
||||
dim->set_padding_low(low_padding);
|
||||
dim->set_padding_high(0);
|
||||
dim->set_base_dilation(1);
|
||||
}
|
||||
|
||||
auto reshard_operand = operand.ReshardAsWindowedInput(
|
||||
window, operand.sharding(),
|
||||
CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
|
||||
/*mask_invalid_region=*/false);
|
||||
if (!reshard_operand.has_value()) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
return b_.AddInstruction(
|
||||
hlo->CloneWithNewOperands(operand.hlo()->shape(), {operand.hlo()}));
|
||||
hlo->CloneWithNewOperands(reshard_operand->sharded_input->shape(),
|
||||
{reshard_operand->sharded_input}));
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
|
||||
|
@ -3212,7 +3212,7 @@ ENTRY entry {
|
||||
op::Shape("f32[9,9]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, TiledReverse) {
|
||||
TEST_F(SpmdPartitioningTest, TiledReversePassthrough) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
@ -3232,6 +3232,62 @@ ENTRY entry {
|
||||
op::Reshape(), op::Constant()))));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
param = f32[4] parameter(0), sharding={devices=[2]0,1}
|
||||
ROOT reverse = f32[4] reverse(param), dimensions={0},
|
||||
sharding={devices=[2]1,0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Shape("f32[2]"), op::Reverse(op::Parameter(0))));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, TiledReverseSwapShards) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
param = f32[4] parameter(0), sharding={devices=[2]0,1}
|
||||
ROOT reverse = f32[4] reverse(param), dimensions={0},
|
||||
sharding={devices=[2]0,1}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root,
|
||||
AllOf(op::Shape("f32[2]"),
|
||||
op::Reverse(op::CollectivePermute(op::Parameter(0)))));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, TiledReverseHaloExchange) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
param = f32[3] parameter(0), sharding={devices=[2]0,1}
|
||||
ROOT reverse = f32[3] reverse(param), dimensions={0},
|
||||
sharding={devices=[2]1,0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
auto halo_exchange_concat =
|
||||
op::Concatenate(AllOf(op::Shape("f32[1]"),
|
||||
op::CollectivePermute(op::Slice(op::Parameter(0)))),
|
||||
op::Parameter(0));
|
||||
auto after_halo_exchange = op::Slice(halo_exchange_concat);
|
||||
EXPECT_THAT(root,
|
||||
AllOf(op::Shape("f32[2]"), op::Reverse(after_halo_exchange)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
Loading…
Reference in New Issue
Block a user