Mask operands for unevenly partitioned contracting dims.
PiperOrigin-RevId: 340997388 Change-Id: I43ed530c3c3ef077e190dddcc832632bf024a71c
This commit is contained in:
		
							parent
							
								
									18ec783484
								
							
						
					
					
						commit
						63446fb3b5
					
				| @ -1148,6 +1148,27 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting( | |||||||
|     } |     } | ||||||
|     lhs = lhs.Reshard(lhs_sharding); |     lhs = lhs.Reshard(lhs_sharding); | ||||||
|   } |   } | ||||||
|  |   // Mask out invalid data.
 | ||||||
|  |   std::vector<int64> lhs_skipped_dims; | ||||||
|  |   for (int64 i = 0; i < lhs.base_shape().rank(); ++i) { | ||||||
|  |     if (absl::c_linear_search(lhs_dims, i)) { | ||||||
|  |       continue; | ||||||
|  |     } | ||||||
|  |     lhs_skipped_dims.push_back(i); | ||||||
|  |   } | ||||||
|  |   lhs = lhs.PadWithValue( | ||||||
|  |       CreateZero(ShapeUtil::MakeShape(lhs.base_shape().element_type(), {}), b), | ||||||
|  |       /*left_padded_dims=*/{}, lhs_skipped_dims); | ||||||
|  |   std::vector<int64> rhs_skipped_dims; | ||||||
|  |   for (int64 i = 0; i < rhs.base_shape().rank(); ++i) { | ||||||
|  |     if (absl::c_linear_search(rhs_dims, i)) { | ||||||
|  |       continue; | ||||||
|  |     } | ||||||
|  |     rhs_skipped_dims.push_back(i); | ||||||
|  |   } | ||||||
|  |   rhs = rhs.PadWithValue( | ||||||
|  |       CreateZero(ShapeUtil::MakeShape(rhs.base_shape().element_type(), {}), b), | ||||||
|  |       /*left_padded_dims=*/{}, rhs_skipped_dims); | ||||||
|   top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding); |   top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding); | ||||||
|   lhs.hlo()->set_sharding(lhs_grouped.sharding); |   lhs.hlo()->set_sharding(lhs_grouped.sharding); | ||||||
|   top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding); |   top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding); | ||||||
|  | |||||||
| @ -463,7 +463,8 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| PartitionedHlo PartitionedHlo::PadWithValue( | PartitionedHlo PartitionedHlo::PadWithValue( | ||||||
|     HloInstruction* pad_value, absl::Span<const int64> left_padded_dims) const { |     HloInstruction* pad_value, absl::Span<const int64> left_padded_dims, | ||||||
|  |     absl::Span<const int64> skipped_dims) const { | ||||||
|   const HloSharding& sharding = hlo_->sharding(); |   const HloSharding& sharding = hlo_->sharding(); | ||||||
|   const Shape& shape = hlo_->shape(); |   const Shape& shape = hlo_->shape(); | ||||||
|   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); |   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); | ||||||
| @ -502,7 +503,8 @@ PartitionedHlo PartitionedHlo::PadWithValue( | |||||||
|   auto offsets = MakePartitionOffsets(base_shape_, sharding, |   auto offsets = MakePartitionOffsets(base_shape_, sharding, | ||||||
|                                       state_.partition_id, state_.b); |                                       state_.partition_id, state_.b); | ||||||
|   for (int64 i = 0; i < shape.rank(); ++i) { |   for (int64 i = 0; i < shape.rank(); ++i) { | ||||||
|     if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0) { |     if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 || | ||||||
|  |         absl::c_linear_search(skipped_dims, i)) { | ||||||
|       continue; |       continue; | ||||||
|     } |     } | ||||||
|     if (mask == nullptr) { |     if (mask == nullptr) { | ||||||
|  | |||||||
| @ -283,9 +283,9 @@ class PartitionedHlo { | |||||||
|   // unevenly partitioned dimensions are padded on the right, but this function
 |   // unevenly partitioned dimensions are padded on the right, but this function
 | ||||||
|   // allows specifying left-padded dimensions, which can be used during the
 |   // allows specifying left-padded dimensions, which can be used during the
 | ||||||
|   // handling of kReverse, etc.
 |   // handling of kReverse, etc.
 | ||||||
|   PartitionedHlo PadWithValue( |   PartitionedHlo PadWithValue(HloInstruction* pad_value, | ||||||
|       HloInstruction* pad_value, |                               absl::Span<const int64> left_padded_dims = {}, | ||||||
|       absl::Span<const int64> left_padded_dims = {}) const; |                               absl::Span<const int64> skipped_dims = {}) const; | ||||||
| 
 | 
 | ||||||
|   // Returns the SPMD instruction.
 |   // Returns the SPMD instruction.
 | ||||||
|   HloInstruction* hlo() const { return hlo_; } |   HloInstruction* hlo() const { return hlo_; } | ||||||
|  | |||||||
| @ -5003,6 +5003,37 @@ ENTRY entry { | |||||||
|                           op::Dot(lhs_slice, partial_replicated_rhs))); |                           op::Dot(lhs_slice, partial_replicated_rhs))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(SpmdPartitioningTest, Dot2DPartitionedNoncontractingAndContracting3) { | ||||||
|  |   const char* const hlo_string = R"( | ||||||
|  | HloModule module | ||||||
|  | 
 | ||||||
|  | ENTRY entry { | ||||||
|  |   %lhs = f32[23,24] parameter(0), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} | ||||||
|  |   %rhs = f32[23,32] parameter(1), sharding={devices=[2,2]0,1,2,3} | ||||||
|  |   ROOT %dot = f32[24,32] dot(%lhs, %rhs), | ||||||
|  |     lhs_contracting_dims={0}, rhs_contracting_dims={0}, | ||||||
|  |     sharding={devices=[2,2]1,0,3,2} | ||||||
|  | })"; | ||||||
|  | 
 | ||||||
|  |   TF_ASSERT_OK_AND_ASSIGN(auto module, | ||||||
|  |                           PartitionComputation(hlo_string, /*num_devices=*/4)); | ||||||
|  |   VLOG(1) << module->ToString(); | ||||||
|  | 
 | ||||||
|  |   auto lhs = AllOf(op::Shape("f32[12,24]"), op::Parameter(0)); | ||||||
|  |   auto masked_lhs = op::Select(_, lhs, op::Broadcast(op::Constant())); | ||||||
|  |   auto rhs = AllOf(op::Shape("f32[12,16]"), op::Parameter(1)); | ||||||
|  |   auto masked_rhs = op::Select(_, rhs, op::Broadcast(op::Constant())); | ||||||
|  |   auto root = module->entry_computation()->root_instruction(); | ||||||
|  |   EXPECT_THAT( | ||||||
|  |       root, | ||||||
|  |       AllOf(op::Shape("f32[12,16]"), | ||||||
|  |             op::DynamicSlice( | ||||||
|  |                 AllOf(op::Shape("f32[24,16]"), | ||||||
|  |                       op::AllReduce(op::Dot( | ||||||
|  |                           masked_lhs, op::CollectivePermute(masked_rhs)))), | ||||||
|  |                 _, _))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) { | TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) { | ||||||
|   const char* const hlo_string = R"( |   const char* const hlo_string = R"( | ||||||
| HloModule module | HloModule module | ||||||
|  | |||||||
| @ -82,6 +82,9 @@ HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { | |||||||
|   } |   } | ||||||
|   auto zero = b->AddInstruction( |   auto zero = b->AddInstruction( | ||||||
|       HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); |       HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); | ||||||
|  |   if (shape.rank() == 0) { | ||||||
|  |     return zero; | ||||||
|  |   } | ||||||
|   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {})); |   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {})); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user