diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index a1929cb8cd0..ee9a133a22e 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -67,7 +67,9 @@ HloSharding HloSharding::PartialTile( HloSharding HloSharding::PartialTile( const Array& tile_assignment_last_dim_replicate, absl::Span metadata) { - if (tile_assignment_last_dim_replicate.num_dimensions() == 1) { + if (tile_assignment_last_dim_replicate.num_dimensions() == 1 || + tile_assignment_last_dim_replicate.dimensions().back() == + tile_assignment_last_dim_replicate.num_elements()) { return Replicate(metadata); } if (tile_assignment_last_dim_replicate.dimensions().back() == 1) { diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index cc425a24468..eed6de9b5b7 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -4732,5 +4732,43 @@ ENTRY %module { } } +TEST_P(ParameterizedMetadataTest, CorrectlyReplicateGatherIndex) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY %module { + %parameter.0 = bf16[1,2,2,2,8]{4,3,2,1,0} parameter(0) + %parameter.1 = s32[1,2,2]{2,1,0} parameter(1) + %index = s32[1,2,2]{2,1,0} copy(%parameter.1) + %gather = bf16[1,2,2,2,8]{4,3,2,1,0} gather( + bf16[1,2,2,2,8]{4,3,2,1,0} %parameter.0, s32[1,2,2]{2,1,0} %index), + offset_dims={2,3,4}, collapsed_slice_dims={0,1}, start_index_map={0,1}, + index_vector_dim=2, slice_sizes={1,1,2,2,8}, + sharding={devices=[1,1,2,1,1]0,1 metadata={op_name="a"}} + ROOT %copy = bf16[1,2,2,2,8]{4,3,2,1,0} copy(%gather) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + if (GetParam().clear_metadata) { + ClearMetadata(module.get()); + } + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata) + .Run(module.get())); + EXPECT_TRUE(changed); + + const HloInstruction* index = FindInstruction(module.get(), "index"); + + ASSERT_NE(index, nullptr); + + EXPECT_THAT(index, op::Sharding("{replicated}")); + if (GetParam().propagate_metadata && !GetParam().clear_metadata) { + EXPECT_THAT(index->sharding(), ShardingMetadata({CreateMetadata("a")})); + } else { + EXPECT_THAT(index->sharding(), ShardingMetadata({})); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index eba37ae142c..9bf523e2281 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -5730,34 +5730,6 @@ ENTRY entry { EXPECT_THAT(root, tiled); } -TEST_F(SpmdPartitioningTest, - PartialReplicateToPartialReplicateReshard_DynamicSlice2) { - absl::string_view hlo_string = R"( -HloModule module - -ENTRY entry { - %param0 = f32[8,8] parameter(0) - %copy = f32[8,8] copy(%param0), - sharding={devices=[1,1,8]0,1,2,3,4,5,6,7 last_tile_dim_replicate} - ROOT %copy0 = f32[8,8] copy(%copy), - sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} -})"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - PartitionComputation(hlo_string, /*num_devices=*/8)); - VLOG(1) << module->ToString(); - auto partially_replicated = - AllOf(op::Shape("f32[8,8]"), - op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), - op::Constant()))); - auto tiled = - AllOf(op::Shape("f32[4,4]"), - op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(), - op::Subtract()))); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, tiled); -} - TEST_F(SpmdPartitioningTest, PartialReplicateToPartialReplicateReshardWithCollectivePermute) { absl::string_view hlo_string = R"(