[XLA:GPU] Eliminate tuple population from batch norm thunks
- These tuples on the GPU side should not be directly used by anyone since XLA should folded the GetTupleElement into which these tuple feeds. PiperOrigin-RevId: 346672532 Change-Id: Ia5980e14ddd157d84fe60cb40dd4ebc2f5a77c9e
This commit is contained in:
		
							parent
							
								
									4d1c107bef
								
							
						
					
					
						commit
						5a08d776c6
					
				| @ -76,8 +76,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( | ||||
|     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, | ||||
|     const BufferAllocation::Slice& output_data, | ||||
|     const BufferAllocation::Slice& output_mean, | ||||
|     const BufferAllocation::Slice& output_inv_stddev, | ||||
|     const BufferAllocation::Slice& output_tuple) | ||||
|     const BufferAllocation::Slice& output_inv_stddev) | ||||
|     : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info), | ||||
|       config_(std::move(config)), | ||||
|       operand_(operand), | ||||
| @ -85,8 +84,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( | ||||
|       offset_(offset), | ||||
|       output_data_(output_data), | ||||
|       output_mean_(output_mean), | ||||
|       output_inv_stddev_(output_inv_stddev), | ||||
|       output_tuple_(output_tuple) {} | ||||
|       output_inv_stddev_(output_inv_stddev) {} | ||||
| 
 | ||||
| Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( | ||||
|     const ExecuteParams& params) { | ||||
| @ -110,16 +108,6 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( | ||||
|       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)), | ||||
|       &stream)); | ||||
| 
 | ||||
|   // Write the output tuple.
 | ||||
|   const int kNumOutputs = 3; | ||||
|   auto ptrs = absl::make_unique<void*[]>(kNumOutputs); | ||||
|   ptrs[0] = output_data.opaque(); | ||||
|   ptrs[1] = output_mean.opaque(); | ||||
|   ptrs[2] = output_inv_stddev.opaque(); | ||||
|   se::DeviceMemory<void*> tuple_addr( | ||||
|       buffer_allocations.GetDeviceAddress(output_tuple_)); | ||||
|   SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream, | ||||
|                 params.deferred_host_callbacks); | ||||
|   if (!stream.ok()) { | ||||
|     return InternalError("BatchNormalizationTraining call failed."); | ||||
|   } | ||||
| @ -134,8 +122,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( | ||||
|     const BufferAllocation::Slice& grad_output, | ||||
|     const BufferAllocation::Slice& output_grad_data, | ||||
|     const BufferAllocation::Slice& output_grad_scale, | ||||
|     const BufferAllocation::Slice& output_grad_offset, | ||||
|     const BufferAllocation::Slice& output_tuple) | ||||
|     const BufferAllocation::Slice& output_grad_offset) | ||||
|     : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info), | ||||
|       config_(std::move(config)), | ||||
|       operand_(operand), | ||||
| @ -145,8 +132,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( | ||||
|       grad_output_(grad_output), | ||||
|       output_grad_data_(output_grad_data), | ||||
|       output_grad_scale_(output_grad_scale), | ||||
|       output_grad_offset_(output_grad_offset), | ||||
|       output_tuple_(output_tuple) {} | ||||
|       output_grad_offset_(output_grad_offset) {} | ||||
| 
 | ||||
| Status CudnnBatchNormBackwardThunk::ExecuteOnStream( | ||||
|     const ExecuteParams& params) { | ||||
| @ -172,17 +158,6 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( | ||||
|       se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)), | ||||
|       stream)); | ||||
| 
 | ||||
|   // Write the output tuple.
 | ||||
|   const int kNumOutputs = 3; | ||||
|   auto ptrs = absl::make_unique<void*[]>(kNumOutputs); | ||||
|   ptrs[0] = output_grad_data.opaque(); | ||||
|   ptrs[1] = output_grad_scale.opaque(); | ||||
|   ptrs[2] = output_grad_offset.opaque(); | ||||
|   se::DeviceMemory<void*> tuple_addr( | ||||
|       buffer_allocations.GetDeviceAddress(output_tuple_)); | ||||
|   SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, stream, | ||||
|                 params.deferred_host_callbacks); | ||||
| 
 | ||||
|   if (!stream->ok()) { | ||||
|     return InternalError("BatchNormalizationBackward call failed."); | ||||
|   } | ||||
|  | ||||
| @ -82,8 +82,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { | ||||
|       const BufferAllocation::Slice& offset, | ||||
|       const BufferAllocation::Slice& output_data, | ||||
|       const BufferAllocation::Slice& output_mean, | ||||
|       const BufferAllocation::Slice& output_inv_stddev, | ||||
|       const BufferAllocation::Slice& output_tuple); | ||||
|       const BufferAllocation::Slice& output_inv_stddev); | ||||
| 
 | ||||
|   CudnnBatchNormForwardTrainingThunk( | ||||
|       const CudnnBatchNormForwardTrainingThunk&) = delete; | ||||
| @ -100,22 +99,19 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { | ||||
|   BufferAllocation::Slice output_data_; | ||||
|   BufferAllocation::Slice output_mean_; | ||||
|   BufferAllocation::Slice output_inv_stddev_; | ||||
|   BufferAllocation::Slice output_tuple_; | ||||
| }; | ||||
| 
 | ||||
| class CudnnBatchNormBackwardThunk : public Thunk { | ||||
|  public: | ||||
|   CudnnBatchNormBackwardThunk(ThunkInfo thunk_info, | ||||
|                               CudnnBatchNormConfig&& config, | ||||
|   CudnnBatchNormBackwardThunk( | ||||
|       ThunkInfo thunk_info, CudnnBatchNormConfig&& config, | ||||
|       const BufferAllocation::Slice& operand, | ||||
|                               const BufferAllocation::Slice& scale, | ||||
|                               const BufferAllocation::Slice& mean, | ||||
|       const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, | ||||
|       const BufferAllocation::Slice& inv_stddev, | ||||
|       const BufferAllocation::Slice& grad_output, | ||||
|       const BufferAllocation::Slice& output_grad_data, | ||||
|       const BufferAllocation::Slice& output_grad_scale, | ||||
|                               const BufferAllocation::Slice& output_grad_offset, | ||||
|                               const BufferAllocation::Slice& output_tuple); | ||||
|       const BufferAllocation::Slice& output_grad_offset); | ||||
| 
 | ||||
|   CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete; | ||||
|   CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) = | ||||
| @ -133,7 +129,6 @@ class CudnnBatchNormBackwardThunk : public Thunk { | ||||
|   BufferAllocation::Slice output_grad_data_; | ||||
|   BufferAllocation::Slice output_grad_scale_; | ||||
|   BufferAllocation::Slice output_grad_offset_; | ||||
|   BufferAllocation::Slice output_tuple_; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace gpu
 | ||||
|  | ||||
| @ -258,8 +258,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | ||||
|             /*offset=*/GetAllocationSlice(*custom_call->operand(2)), | ||||
|             /*output_data=*/output_data, | ||||
|             /*output_mean=*/output_mean, | ||||
|             /*output_inv_stddev=*/output_inv_stddev, | ||||
|             /*output_tuple=*/GetAllocationSlice(*custom_call))); | ||||
|             /*output_inv_stddev=*/output_inv_stddev)); | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
| @ -295,8 +294,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | ||||
|         /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), | ||||
|         /*output_grad_data=*/output_grad_data, | ||||
|         /*output_grad_scale=*/output_grad_scale, | ||||
|         /*output_grad_offset=*/output_grad_offset, | ||||
|         /*output_tuple=*/GetAllocationSlice(*custom_call))); | ||||
|         /*output_grad_offset=*/output_grad_offset)); | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user