[XLA:SPMD] Fix vector indices sharding in GatherIndexSharding.
PiperOrigin-RevId: 326781481 Change-Id: I95945e894c3ecac0d74fa0a4f51570d70b43ebf3
This commit is contained in:
		
							parent
							
								
									9e416e14c0
								
							
						
					
					
						commit
						71ff216576
					
				@ -357,8 +357,8 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
 | 
			
		||||
                                const HloInstruction* hlo,
 | 
			
		||||
                                const int64 index_rank) {
 | 
			
		||||
                                const HloInstruction* hlo) {
 | 
			
		||||
  CHECK(hlo->opcode() == HloOpcode::kGather);
 | 
			
		||||
  if (output_sharding.IsTileMaximal()) {
 | 
			
		||||
    return output_sharding;
 | 
			
		||||
  }
 | 
			
		||||
@ -371,10 +371,12 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding,
 | 
			
		||||
          output_sharding.tile_assignment().dim(i));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  int64 index_rank = hlo->operand(1)->shape().rank();
 | 
			
		||||
 | 
			
		||||
  // Vector indices sharding is not supported yet.
 | 
			
		||||
  for (int64 i = dnums.index_vector_dim(); i < index_rank; ++i) {
 | 
			
		||||
    index_tile_assignment_dims.push_back(1);
 | 
			
		||||
  if (index_rank > index_tile_assignment_dims.size()) {
 | 
			
		||||
    index_tile_assignment_dims.insert(
 | 
			
		||||
        index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Array<int64> new_tile_assignment = output_sharding.tile_assignment();
 | 
			
		||||
 | 
			
		||||
@ -95,8 +95,7 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
 | 
			
		||||
// Returns the preferred index sharding for a gather op based on the sharding
 | 
			
		||||
// of the output.
 | 
			
		||||
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
 | 
			
		||||
                                const HloInstruction* hlo,
 | 
			
		||||
                                const int64 index_rank);
 | 
			
		||||
                                const HloInstruction* hlo);
 | 
			
		||||
 | 
			
		||||
// Returns a new HloSharding for a gather op so that only non offset dimensions
 | 
			
		||||
// are sharded. Assume "result" is returned by this function. It is ensured that
 | 
			
		||||
 | 
			
		||||
@ -1319,8 +1319,7 @@ absl::optional<HloSharding> GetShardingFromUser(
 | 
			
		||||
    }
 | 
			
		||||
    case HloOpcode::kGather: {
 | 
			
		||||
      if (&instruction == user.operand(1)) {
 | 
			
		||||
        return hlo_sharding_util::GatherIndexSharding(
 | 
			
		||||
            user.sharding(), &user, instruction.shape().rank());
 | 
			
		||||
        return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user);
 | 
			
		||||
      }
 | 
			
		||||
      if (is_spmd) {
 | 
			
		||||
        return hlo_sharding_util::GatherDataOperandShardingFromOutput(
 | 
			
		||||
 | 
			
		||||
@ -1839,6 +1839,29 @@ ENTRY entry {
 | 
			
		||||
              op::Sharding("{devices=[1,2,1]0,1}"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ShardingPropagationTest, GatherToIndex3) {
 | 
			
		||||
  const char* hlo_string = R"(
 | 
			
		||||
HloModule module
 | 
			
		||||
 | 
			
		||||
ENTRY entry {
 | 
			
		||||
  %input = bf16[2,4819,4] parameter(0), sharding={replicated}
 | 
			
		||||
  %p1 = s32[2,2,1000] parameter(1)
 | 
			
		||||
  %indices = s32[2,2,1000] copy(%p1)
 | 
			
		||||
  ROOT %gather = bf16[2,1000,4]
 | 
			
		||||
    gather(bf16[2,4819,4] %input, s32[2,2,1000] %indices),
 | 
			
		||||
    offset_dims={2}, collapsed_slice_dims={0,1},
 | 
			
		||||
    start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1,4},
 | 
			
		||||
    sharding={devices=[1,2,1]0,1}
 | 
			
		||||
})";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(auto module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(hlo_string));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool changed,
 | 
			
		||||
                          ShardingPropagation().Run(module.get()));
 | 
			
		||||
  EXPECT_TRUE(changed);
 | 
			
		||||
  EXPECT_THAT(FindInstruction(module.get(), "indices"),
 | 
			
		||||
              op::Sharding("{devices=[1,1,2]0,1}"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ShardingPropagationTest, GatherToDataOperand) {
 | 
			
		||||
  const char* hlo_string = R"(
 | 
			
		||||
HloModule module
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user