Mask operands for unevenly partitioned contracting dims.
PiperOrigin-RevId: 340997388 Change-Id: I43ed530c3c3ef077e190dddcc832632bf024a71c
This commit is contained in:
		
							parent
							
								
									18ec783484
								
							
						
					
					
						commit
						63446fb3b5
					
				| @ -1148,6 +1148,27 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting( | ||||
|     } | ||||
|     lhs = lhs.Reshard(lhs_sharding); | ||||
|   } | ||||
|   // Mask out invalid data.
 | ||||
|   std::vector<int64> lhs_skipped_dims; | ||||
|   for (int64 i = 0; i < lhs.base_shape().rank(); ++i) { | ||||
|     if (absl::c_linear_search(lhs_dims, i)) { | ||||
|       continue; | ||||
|     } | ||||
|     lhs_skipped_dims.push_back(i); | ||||
|   } | ||||
|   lhs = lhs.PadWithValue( | ||||
|       CreateZero(ShapeUtil::MakeShape(lhs.base_shape().element_type(), {}), b), | ||||
|       /*left_padded_dims=*/{}, lhs_skipped_dims); | ||||
|   std::vector<int64> rhs_skipped_dims; | ||||
|   for (int64 i = 0; i < rhs.base_shape().rank(); ++i) { | ||||
|     if (absl::c_linear_search(rhs_dims, i)) { | ||||
|       continue; | ||||
|     } | ||||
|     rhs_skipped_dims.push_back(i); | ||||
|   } | ||||
|   rhs = rhs.PadWithValue( | ||||
|       CreateZero(ShapeUtil::MakeShape(rhs.base_shape().element_type(), {}), b), | ||||
|       /*left_padded_dims=*/{}, rhs_skipped_dims); | ||||
|   top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding); | ||||
|   lhs.hlo()->set_sharding(lhs_grouped.sharding); | ||||
|   top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding); | ||||
|  | ||||
| @ -463,7 +463,8 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { | ||||
| } | ||||
| 
 | ||||
| PartitionedHlo PartitionedHlo::PadWithValue( | ||||
|     HloInstruction* pad_value, absl::Span<const int64> left_padded_dims) const { | ||||
|     HloInstruction* pad_value, absl::Span<const int64> left_padded_dims, | ||||
|     absl::Span<const int64> skipped_dims) const { | ||||
|   const HloSharding& sharding = hlo_->sharding(); | ||||
|   const Shape& shape = hlo_->shape(); | ||||
|   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); | ||||
| @ -502,7 +503,8 @@ PartitionedHlo PartitionedHlo::PadWithValue( | ||||
|   auto offsets = MakePartitionOffsets(base_shape_, sharding, | ||||
|                                       state_.partition_id, state_.b); | ||||
|   for (int64 i = 0; i < shape.rank(); ++i) { | ||||
|     if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0) { | ||||
|     if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 || | ||||
|         absl::c_linear_search(skipped_dims, i)) { | ||||
|       continue; | ||||
|     } | ||||
|     if (mask == nullptr) { | ||||
|  | ||||
| @ -283,9 +283,9 @@ class PartitionedHlo { | ||||
|   // unevenly partitioned dimensions are padded on the right, but this function
 | ||||
|   // allows specifying left-padded dimensions, which can be used during the
 | ||||
|   // handling of kReverse, etc.
 | ||||
|   PartitionedHlo PadWithValue( | ||||
|       HloInstruction* pad_value, | ||||
|       absl::Span<const int64> left_padded_dims = {}) const; | ||||
|   PartitionedHlo PadWithValue(HloInstruction* pad_value, | ||||
|                               absl::Span<const int64> left_padded_dims = {}, | ||||
|                               absl::Span<const int64> skipped_dims = {}) const; | ||||
| 
 | ||||
|   // Returns the SPMD instruction.
 | ||||
|   HloInstruction* hlo() const { return hlo_; } | ||||
|  | ||||
| @ -5003,6 +5003,37 @@ ENTRY entry { | ||||
|                           op::Dot(lhs_slice, partial_replicated_rhs))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(SpmdPartitioningTest, Dot2DPartitionedNoncontractingAndContracting3) { | ||||
|   const char* const hlo_string = R"( | ||||
| HloModule module | ||||
| 
 | ||||
| ENTRY entry { | ||||
|   %lhs = f32[23,24] parameter(0), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} | ||||
|   %rhs = f32[23,32] parameter(1), sharding={devices=[2,2]0,1,2,3} | ||||
|   ROOT %dot = f32[24,32] dot(%lhs, %rhs), | ||||
|     lhs_contracting_dims={0}, rhs_contracting_dims={0}, | ||||
|     sharding={devices=[2,2]1,0,3,2} | ||||
| })"; | ||||
| 
 | ||||
|   TF_ASSERT_OK_AND_ASSIGN(auto module, | ||||
|                           PartitionComputation(hlo_string, /*num_devices=*/4)); | ||||
|   VLOG(1) << module->ToString(); | ||||
| 
 | ||||
|   auto lhs = AllOf(op::Shape("f32[12,24]"), op::Parameter(0)); | ||||
|   auto masked_lhs = op::Select(_, lhs, op::Broadcast(op::Constant())); | ||||
|   auto rhs = AllOf(op::Shape("f32[12,16]"), op::Parameter(1)); | ||||
|   auto masked_rhs = op::Select(_, rhs, op::Broadcast(op::Constant())); | ||||
|   auto root = module->entry_computation()->root_instruction(); | ||||
|   EXPECT_THAT( | ||||
|       root, | ||||
|       AllOf(op::Shape("f32[12,16]"), | ||||
|             op::DynamicSlice( | ||||
|                 AllOf(op::Shape("f32[24,16]"), | ||||
|                       op::AllReduce(op::Dot( | ||||
|                           masked_lhs, op::CollectivePermute(masked_rhs)))), | ||||
|                 _, _))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) { | ||||
|   const char* const hlo_string = R"( | ||||
| HloModule module | ||||
|  | ||||
| @ -82,6 +82,9 @@ HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { | ||||
|   } | ||||
|   auto zero = b->AddInstruction( | ||||
|       HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); | ||||
|   if (shape.rank() == 0) { | ||||
|     return zero; | ||||
|   } | ||||
|   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {})); | ||||
| } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user