[XLA:SPMD] More cases of reverse sharding
- Improve sharding propagation to reverse the tile assignment - Use reshard (collective permute) to fix mismatch operand sharding - Use halo exchange to fix uneven partitioning PiperOrigin-RevId: 313672162 Change-Id: I0816de794a0c18a0173889ed8cd638baecf389e9
This commit is contained in:
parent
b1cb3f12da
commit
5c77174291
@ -220,6 +220,24 @@ absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
|
|||||||
return HloSharding::Tile(new_tile_assignment);
|
return HloSharding::Tile(new_tile_assignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloSharding ReverseSharding(const HloSharding& sharding,
|
||||||
|
absl::Span<const int64> dimensions) {
|
||||||
|
if (sharding.IsTileMaximal() || dimensions.empty()) {
|
||||||
|
return sharding;
|
||||||
|
}
|
||||||
|
|
||||||
|
Array<int64> new_tile_assignment(sharding.tile_assignment().dimensions());
|
||||||
|
new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
|
||||||
|
std::vector<int64> original_indices(indices.begin(), indices.end());
|
||||||
|
for (int64 d : dimensions) {
|
||||||
|
original_indices[d] =
|
||||||
|
new_tile_assignment.dim(d) - 1 - original_indices[d];
|
||||||
|
}
|
||||||
|
*device = sharding.tile_assignment()(original_indices);
|
||||||
|
});
|
||||||
|
return HloSharding::Tile(new_tile_assignment);
|
||||||
|
}
|
||||||
|
|
||||||
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
|
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
|
||||||
absl::Span<const int64> dims) {
|
absl::Span<const int64> dims) {
|
||||||
CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
|
CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
|
||||||
|
@ -70,6 +70,12 @@ absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
|
|||||||
const Shape& target_shape,
|
const Shape& target_shape,
|
||||||
const HloSharding& sharding);
|
const HloSharding& sharding);
|
||||||
|
|
||||||
|
// Returns the HloSharding with the tile dimensions and tile assignment
|
||||||
|
// reversed based on the specified dimension numbers. In case of a tile
|
||||||
|
// maximal sharding returns the original sharding.
|
||||||
|
HloSharding ReverseSharding(const HloSharding& sharding,
|
||||||
|
absl::Span<const int64> dimensions);
|
||||||
|
|
||||||
// Returns a sharding tiled on unique dimension dim by reshaping the tile
|
// Returns a sharding tiled on unique dimension dim by reshaping the tile
|
||||||
// assignment of the sharding argument. Only dimensions in the dims span
|
// assignment of the sharding argument. Only dimensions in the dims span
|
||||||
// argument are considered for reshaping, the others are ignored.
|
// argument are considered for reshaping, the others are ignored.
|
||||||
|
@ -717,6 +717,15 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
case HloOpcode::kReverse: {
|
||||||
|
if (!IsSpatiallyPartitioned(instruction->operand(0))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return MaybeImproveInstructionSharding(
|
||||||
|
hlo_sharding_util::ReverseSharding(
|
||||||
|
instruction->operand(0)->sharding(), instruction->dimensions()),
|
||||||
|
instruction);
|
||||||
|
}
|
||||||
case HloOpcode::kDot: {
|
case HloOpcode::kDot: {
|
||||||
auto& dot_dim_numbs = instruction->dot_dimension_numbers();
|
auto& dot_dim_numbs = instruction->dot_dimension_numbers();
|
||||||
// Batch dimensions are the same for lhs and rhs on dot operations.
|
// Batch dimensions are the same for lhs and rhs on dot operations.
|
||||||
@ -1188,6 +1197,10 @@ absl::optional<HloSharding> GetShardingFromUser(
|
|||||||
return user.sharding();
|
return user.sharding();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case HloOpcode::kReverse: {
|
||||||
|
return hlo_sharding_util::ReverseSharding(user.sharding(),
|
||||||
|
user.dimensions());
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// If the user output shape is compatible with the current instruction
|
// If the user output shape is compatible with the current instruction
|
||||||
// shape excluding element type and the current instruction is supported
|
// shape excluding element type and the current instruction is supported
|
||||||
|
@ -2325,18 +2325,44 @@ Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
|
|||||||
if (reverse->sharding().IsTileMaximal()) {
|
if (reverse->sharding().IsTileMaximal()) {
|
||||||
return DefaultAction(hlo);
|
return DefaultAction(hlo);
|
||||||
}
|
}
|
||||||
if (absl::c_all_of(reverse->dimensions(), [&](int64 d) {
|
auto operand = GetPartitionedHlo(reverse->operand(0))
|
||||||
return reverse->sharding().tile_assignment().dim(d) == 1;
|
.Reshard(hlo_sharding_util::ReverseSharding(
|
||||||
})) {
|
reverse->sharding(), reverse->dimensions()));
|
||||||
auto operand =
|
// Create a window config to halo exchange for unevenly partitioned reverse
|
||||||
GetPartitionedHlo(reverse->operand(0)).Reshard(reverse->sharding());
|
// dimensions.
|
||||||
SetPartitionedHlo(hlo, [&] {
|
Window window;
|
||||||
return b_.AddInstruction(
|
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
|
||||||
hlo->CloneWithNewOperands(operand.hlo()->shape(), {operand.hlo()}));
|
WindowDimension* dim = window.add_dimensions();
|
||||||
});
|
dim->set_size(1);
|
||||||
return Status::OK();
|
dim->set_stride(1);
|
||||||
|
dim->set_window_dilation(1);
|
||||||
|
dim->set_window_reversal(false);
|
||||||
|
int64 low_padding = 0;
|
||||||
|
if (absl::c_linear_search(reverse->dimensions(), i)) {
|
||||||
|
low_padding =
|
||||||
|
RoundUpToNearest(reverse->shape().dimensions(i),
|
||||||
|
reverse->sharding().tile_assignment().dim(i)) -
|
||||||
|
reverse->shape().dimensions(i);
|
||||||
|
}
|
||||||
|
dim->set_padding_low(low_padding);
|
||||||
|
dim->set_padding_high(0);
|
||||||
|
dim->set_base_dilation(1);
|
||||||
}
|
}
|
||||||
return DefaultAction(hlo);
|
|
||||||
|
auto reshard_operand = operand.ReshardAsWindowedInput(
|
||||||
|
window, operand.sharding(),
|
||||||
|
CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
|
||||||
|
/*mask_invalid_region=*/false);
|
||||||
|
if (!reshard_operand.has_value()) {
|
||||||
|
return DefaultAction(hlo);
|
||||||
|
}
|
||||||
|
TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
|
||||||
|
SetPartitionedHlo(hlo, [&] {
|
||||||
|
return b_.AddInstruction(
|
||||||
|
hlo->CloneWithNewOperands(reshard_operand->sharded_input->shape(),
|
||||||
|
{reshard_operand->sharded_input}));
|
||||||
|
});
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
|
Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
|
||||||
|
@ -3212,7 +3212,7 @@ ENTRY entry {
|
|||||||
op::Shape("f32[9,9]")));
|
op::Shape("f32[9,9]")));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest, TiledReverse) {
|
TEST_F(SpmdPartitioningTest, TiledReversePassthrough) {
|
||||||
const char* const hlo_string = R"(
|
const char* const hlo_string = R"(
|
||||||
HloModule module
|
HloModule module
|
||||||
|
|
||||||
@ -3232,6 +3232,62 @@ ENTRY entry {
|
|||||||
op::Reshape(), op::Constant()))));
|
op::Reshape(), op::Constant()))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
param = f32[4] parameter(0), sharding={devices=[2]0,1}
|
||||||
|
ROOT reverse = f32[4] reverse(param), dimensions={0},
|
||||||
|
sharding={devices=[2]1,0}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||||
|
VLOG(1) << module->ToString();
|
||||||
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, AllOf(op::Shape("f32[2]"), op::Reverse(op::Parameter(0))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SpmdPartitioningTest, TiledReverseSwapShards) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
param = f32[4] parameter(0), sharding={devices=[2]0,1}
|
||||||
|
ROOT reverse = f32[4] reverse(param), dimensions={0},
|
||||||
|
sharding={devices=[2]0,1}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||||
|
VLOG(1) << module->ToString();
|
||||||
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root,
|
||||||
|
AllOf(op::Shape("f32[2]"),
|
||||||
|
op::Reverse(op::CollectivePermute(op::Parameter(0)))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SpmdPartitioningTest, TiledReverseHaloExchange) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
param = f32[3] parameter(0), sharding={devices=[2]0,1}
|
||||||
|
ROOT reverse = f32[3] reverse(param), dimensions={0},
|
||||||
|
sharding={devices=[2]1,0}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||||
|
VLOG(1) << module->ToString();
|
||||||
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
|
auto halo_exchange_concat =
|
||||||
|
op::Concatenate(AllOf(op::Shape("f32[1]"),
|
||||||
|
op::CollectivePermute(op::Slice(op::Parameter(0)))),
|
||||||
|
op::Parameter(0));
|
||||||
|
auto after_halo_exchange = op::Slice(halo_exchange_concat);
|
||||||
|
EXPECT_THAT(root,
|
||||||
|
AllOf(op::Shape("f32[2]"), op::Reverse(after_halo_exchange)));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
|
TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
|
||||||
const char* const hlo_string = R"(
|
const char* const hlo_string = R"(
|
||||||
HloModule module
|
HloModule module
|
||||||
|
Loading…
Reference in New Issue
Block a user