diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc index 129091ca06f..7fc05608800 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -220,6 +220,24 @@ absl::optional ReshapeSharding(const Shape& source_shape, return HloSharding::Tile(new_tile_assignment); } +HloSharding ReverseSharding(const HloSharding& sharding, + absl::Span dimensions) { + if (sharding.IsTileMaximal() || dimensions.empty()) { + return sharding; + } + + Array new_tile_assignment(sharding.tile_assignment().dimensions()); + new_tile_assignment.Each([&](absl::Span indices, int64* device) { + std::vector original_indices(indices.begin(), indices.end()); + for (int64 d : dimensions) { + original_indices[d] = + new_tile_assignment.dim(d) - 1 - original_indices[d]; + } + *device = sharding.tile_assignment()(original_indices); + }); + return HloSharding::Tile(new_tile_assignment); +} + HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, absl::Span dims) { CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal()); diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.h b/tensorflow/compiler/xla/service/hlo_sharding_util.h index 00d9434a34d..562f6d1420d 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.h @@ -70,6 +70,12 @@ absl::optional ReshapeSharding(const Shape& source_shape, const Shape& target_shape, const HloSharding& sharding); +// Returns the HloSharding with the tile dimensions and tile assignment +// reversed based on the specified dimension numbers. In case of a tile +// maximal sharding returns the original sharding. +HloSharding ReverseSharding(const HloSharding& sharding, + absl::Span dimensions); + // Returns a sharding tiled on unique dimension dim by reshaping the tile // assignment of the sharding argument. Only dimensions in the dims span // argument are considered for reshaping, the others are ignored. diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index bee2e04fabf..c6990e76c95 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -717,6 +717,15 @@ bool InferShardingFromOperands(HloInstruction* instruction, } return false; } + case HloOpcode::kReverse: { + if (!IsSpatiallyPartitioned(instruction->operand(0))) { + return false; + } + return MaybeImproveInstructionSharding( + hlo_sharding_util::ReverseSharding( + instruction->operand(0)->sharding(), instruction->dimensions()), + instruction); + } case HloOpcode::kDot: { auto& dot_dim_numbs = instruction->dot_dimension_numbers(); // Batch dimensions are the same for lhs and rhs on dot operations. @@ -1188,6 +1197,10 @@ absl::optional GetShardingFromUser( return user.sharding(); } } + case HloOpcode::kReverse: { + return hlo_sharding_util::ReverseSharding(user.sharding(), + user.dimensions()); + } default: { // If the user output shape is compatible with the current instruction // shape excluding element type and the current instruction is supported diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index 068442ad5c7..a0c46e0b6e7 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -2325,18 +2325,44 @@ Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) { if (reverse->sharding().IsTileMaximal()) { return DefaultAction(hlo); } - if (absl::c_all_of(reverse->dimensions(), [&](int64 d) { - return reverse->sharding().tile_assignment().dim(d) == 1; - })) { - auto operand = - GetPartitionedHlo(reverse->operand(0)).Reshard(reverse->sharding()); - SetPartitionedHlo(hlo, [&] { - return b_.AddInstruction( - hlo->CloneWithNewOperands(operand.hlo()->shape(), {operand.hlo()})); - }); - return Status::OK(); + auto operand = GetPartitionedHlo(reverse->operand(0)) + .Reshard(hlo_sharding_util::ReverseSharding( + reverse->sharding(), reverse->dimensions())); + // Create a window config to halo exchange for unevenly partitioned reverse + // dimensions. + Window window; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + int64 low_padding = 0; + if (absl::c_linear_search(reverse->dimensions(), i)) { + low_padding = + RoundUpToNearest(reverse->shape().dimensions(i), + reverse->sharding().tile_assignment().dim(i)) - + reverse->shape().dimensions(i); + } + dim->set_padding_low(low_padding); + dim->set_padding_high(0); + dim->set_base_dilation(1); } - return DefaultAction(hlo); + + auto reshard_operand = operand.ReshardAsWindowedInput( + window, operand.sharding(), + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_operand.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value()); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(reshard_operand->sharded_input->shape(), + {reshard_operand->sharded_input})); + }); + return Status::OK(); } Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index e766695385b..2daf3444014 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -3212,7 +3212,7 @@ ENTRY entry { op::Shape("f32[9,9]"))); } -TEST_F(SpmdPartitioningTest, TiledReverse) { +TEST_F(SpmdPartitioningTest, TiledReversePassthrough) { const char* const hlo_string = R"( HloModule module @@ -3232,6 +3232,62 @@ ENTRY entry { op::Reshape(), op::Constant())))); } +TEST_F(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[4] parameter(0), sharding={devices=[2]0,1} + ROOT reverse = f32[4] reverse(param), dimensions={0}, + sharding={devices=[2]1,0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2]"), op::Reverse(op::Parameter(0)))); +} + +TEST_F(SpmdPartitioningTest, TiledReverseSwapShards) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[4] parameter(0), sharding={devices=[2]0,1} + ROOT reverse = f32[4] reverse(param), dimensions={0}, + sharding={devices=[2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Shape("f32[2]"), + op::Reverse(op::CollectivePermute(op::Parameter(0))))); +} + +TEST_F(SpmdPartitioningTest, TiledReverseHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[3] parameter(0), sharding={devices=[2]0,1} + ROOT reverse = f32[3] reverse(param), dimensions={0}, + sharding={devices=[2]1,0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + auto halo_exchange_concat = + op::Concatenate(AllOf(op::Shape("f32[1]"), + op::CollectivePermute(op::Slice(op::Parameter(0)))), + op::Parameter(0)); + auto after_halo_exchange = op::Slice(halo_exchange_concat); + EXPECT_THAT(root, + AllOf(op::Shape("f32[2]"), op::Reverse(after_halo_exchange))); +} + TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) { const char* const hlo_string = R"( HloModule module