[XLA/GPU] [NFC] Introduce a function for emitting a __syncthreads() call
PiperOrigin-RevId: 291175258 Change-Id: Ibefbbbd6b6b4ec6b38079a94907640cac182549c
This commit is contained in:
parent
3cce429ac2
commit
fcad65986c
@ -2440,12 +2440,12 @@ IrArray::Index IrEmitterUnnested::EmitTilingKernel(
|
||||
++i) {
|
||||
int64 tile_size_for_dim = mapping_scheme.GetTileSizeFor(i);
|
||||
// Only last row or column may not have full size.
|
||||
llvm::Value* is_last_row =
|
||||
llvm::Value* is_last =
|
||||
b_.CreateICmpEQ(tile_index[i], constant(dims_in_blocks[i] - 1));
|
||||
int64 partial_row_size =
|
||||
int64 partial_row =
|
||||
dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim;
|
||||
output_tile_bounds[i] =
|
||||
b_.CreateSelect(is_last_row, constant(partial_row_size),
|
||||
b_.CreateSelect(is_last, constant(partial_row),
|
||||
constant(tile_size_for_dim), "tile_bound");
|
||||
}
|
||||
IrArray::Index tile_origin =
|
||||
@ -2484,6 +2484,10 @@ IrArray::Index IrEmitterUnnested::EmitTilingKernel(
|
||||
return GetElementIndexForTileOrigin(starting_tile, mapping_scheme, &b_);
|
||||
}
|
||||
|
||||
llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
|
||||
return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
|
||||
}
|
||||
|
||||
// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
|
||||
// algorithm to improve the memory access patterns for the input parameters
|
||||
// with a shape that is a 0-2-1 transpose of the output tensor shape. The caller
|
||||
@ -2617,7 +2621,7 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
||||
|
||||
// Wait for all threads to reach this point using `__syncthreads` in
|
||||
// CUDA.
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
|
||||
EmitSyncThreads();
|
||||
}
|
||||
|
||||
EmitTile(mapping_scheme, index, loop_name, ksl, &b_, y, x, tile_height,
|
||||
@ -2629,7 +2633,7 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
||||
// memory buffer for the current tile before we move on to process the
|
||||
// next tile and overwrite the shared memory buffers.
|
||||
if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
|
||||
EmitSyncThreads();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -335,9 +335,16 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// Lane id: `thread_id % kWarpSize`
|
||||
llvm::Value* lane_id;
|
||||
};
|
||||
|
||||
// Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane id.
|
||||
//
|
||||
// Returns a struct containting these values.
|
||||
ThreadIdInfo EmitThreadIdInfo(int64 threads_per_block, llvm::Type* index_ty,
|
||||
int64 num_threads_x);
|
||||
|
||||
// Emit __syncthreads(), synchronization barrier for all threads in a block.
|
||||
llvm::CallInst* EmitSyncThreads();
|
||||
|
||||
// Emits current block id.
|
||||
llvm::Value* EmitBlockId();
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user