[XLA:SPMD] More cases of reverse sharding

- Improve sharding propagation to reverse the tile assignment
- Use reshard (collective permute) to fix mismatch operand sharding
- Use halo exchange to fix uneven partitioning

PiperOrigin-RevId: 313672162
Change-Id: I0816de794a0c18a0173889ed8cd638baecf389e9
This commit is contained in:
Yuanzhong Xu 2020-05-28 15:34:54 -07:00 committed by TensorFlower Gardener
parent b1cb3f12da
commit 5c77174291
5 changed files with 131 additions and 12 deletions

View File

@ -220,6 +220,24 @@ absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
return HloSharding::Tile(new_tile_assignment); return HloSharding::Tile(new_tile_assignment);
} }
HloSharding ReverseSharding(const HloSharding& sharding,
absl::Span<const int64> dimensions) {
if (sharding.IsTileMaximal() || dimensions.empty()) {
return sharding;
}
Array<int64> new_tile_assignment(sharding.tile_assignment().dimensions());
new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
std::vector<int64> original_indices(indices.begin(), indices.end());
for (int64 d : dimensions) {
original_indices[d] =
new_tile_assignment.dim(d) - 1 - original_indices[d];
}
*device = sharding.tile_assignment()(original_indices);
});
return HloSharding::Tile(new_tile_assignment);
}
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
absl::Span<const int64> dims) { absl::Span<const int64> dims) {
CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal()); CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());

View File

@ -70,6 +70,12 @@ absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
const Shape& target_shape, const Shape& target_shape,
const HloSharding& sharding); const HloSharding& sharding);
// Returns the HloSharding with the tile dimensions and tile assignment
// reversed based on the specified dimension numbers. In case of a tile
// maximal sharding returns the original sharding.
HloSharding ReverseSharding(const HloSharding& sharding,
absl::Span<const int64> dimensions);
// Returns a sharding tiled on unique dimension dim by reshaping the tile // Returns a sharding tiled on unique dimension dim by reshaping the tile
// assignment of the sharding argument. Only dimensions in the dims span // assignment of the sharding argument. Only dimensions in the dims span
// argument are considered for reshaping, the others are ignored. // argument are considered for reshaping, the others are ignored.

View File

@ -717,6 +717,15 @@ bool InferShardingFromOperands(HloInstruction* instruction,
} }
return false; return false;
} }
case HloOpcode::kReverse: {
if (!IsSpatiallyPartitioned(instruction->operand(0))) {
return false;
}
return MaybeImproveInstructionSharding(
hlo_sharding_util::ReverseSharding(
instruction->operand(0)->sharding(), instruction->dimensions()),
instruction);
}
case HloOpcode::kDot: { case HloOpcode::kDot: {
auto& dot_dim_numbs = instruction->dot_dimension_numbers(); auto& dot_dim_numbs = instruction->dot_dimension_numbers();
// Batch dimensions are the same for lhs and rhs on dot operations. // Batch dimensions are the same for lhs and rhs on dot operations.
@ -1188,6 +1197,10 @@ absl::optional<HloSharding> GetShardingFromUser(
return user.sharding(); return user.sharding();
} }
} }
case HloOpcode::kReverse: {
return hlo_sharding_util::ReverseSharding(user.sharding(),
user.dimensions());
}
default: { default: {
// If the user output shape is compatible with the current instruction // If the user output shape is compatible with the current instruction
// shape excluding element type and the current instruction is supported // shape excluding element type and the current instruction is supported

View File

@ -2325,19 +2325,45 @@ Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
if (reverse->sharding().IsTileMaximal()) { if (reverse->sharding().IsTileMaximal()) {
return DefaultAction(hlo); return DefaultAction(hlo);
} }
if (absl::c_all_of(reverse->dimensions(), [&](int64 d) { auto operand = GetPartitionedHlo(reverse->operand(0))
return reverse->sharding().tile_assignment().dim(d) == 1; .Reshard(hlo_sharding_util::ReverseSharding(
})) { reverse->sharding(), reverse->dimensions()));
auto operand = // Create a window config to halo exchange for unevenly partitioned reverse
GetPartitionedHlo(reverse->operand(0)).Reshard(reverse->sharding()); // dimensions.
Window window;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_stride(1);
dim->set_window_dilation(1);
dim->set_window_reversal(false);
int64 low_padding = 0;
if (absl::c_linear_search(reverse->dimensions(), i)) {
low_padding =
RoundUpToNearest(reverse->shape().dimensions(i),
reverse->sharding().tile_assignment().dim(i)) -
reverse->shape().dimensions(i);
}
dim->set_padding_low(low_padding);
dim->set_padding_high(0);
dim->set_base_dilation(1);
}
auto reshard_operand = operand.ReshardAsWindowedInput(
window, operand.sharding(),
CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
/*mask_invalid_region=*/false);
if (!reshard_operand.has_value()) {
return DefaultAction(hlo);
}
TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
SetPartitionedHlo(hlo, [&] { SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction( return b_.AddInstruction(
hlo->CloneWithNewOperands(operand.hlo()->shape(), {operand.hlo()})); hlo->CloneWithNewOperands(reshard_operand->sharded_input->shape(),
{reshard_operand->sharded_input}));
}); });
return Status::OK(); return Status::OK();
} }
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) { Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding(); const HloSharding& sharding = hlo->sharding();

View File

@ -3212,7 +3212,7 @@ ENTRY entry {
op::Shape("f32[9,9]"))); op::Shape("f32[9,9]")));
} }
TEST_F(SpmdPartitioningTest, TiledReverse) { TEST_F(SpmdPartitioningTest, TiledReversePassthrough) {
const char* const hlo_string = R"( const char* const hlo_string = R"(
HloModule module HloModule module
@ -3232,6 +3232,62 @@ ENTRY entry {
op::Reshape(), op::Constant())))); op::Reshape(), op::Constant()))));
} }
TEST_F(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
param = f32[4] parameter(0), sharding={devices=[2]0,1}
ROOT reverse = f32[4] reverse(param), dimensions={0},
sharding={devices=[2]1,0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[2]"), op::Reverse(op::Parameter(0))));
}
TEST_F(SpmdPartitioningTest, TiledReverseSwapShards) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
param = f32[4] parameter(0), sharding={devices=[2]0,1}
ROOT reverse = f32[4] reverse(param), dimensions={0},
sharding={devices=[2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root,
AllOf(op::Shape("f32[2]"),
op::Reverse(op::CollectivePermute(op::Parameter(0)))));
}
TEST_F(SpmdPartitioningTest, TiledReverseHaloExchange) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
param = f32[3] parameter(0), sharding={devices=[2]0,1}
ROOT reverse = f32[3] reverse(param), dimensions={0},
sharding={devices=[2]1,0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
auto halo_exchange_concat =
op::Concatenate(AllOf(op::Shape("f32[1]"),
op::CollectivePermute(op::Slice(op::Parameter(0)))),
op::Parameter(0));
auto after_halo_exchange = op::Slice(halo_exchange_concat);
EXPECT_THAT(root,
AllOf(op::Shape("f32[2]"), op::Reverse(after_halo_exchange)));
}
TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) { TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
const char* const hlo_string = R"( const char* const hlo_string = R"(
HloModule module HloModule module