[XLA/GPU] [NFC] Introduce a function for emitting a __syncthreads() call

PiperOrigin-RevId: 291175258
Change-Id: Ibefbbbd6b6b4ec6b38079a94907640cac182549c
This commit is contained in:
George Karpenkov 2020-01-23 08:50:16 -08:00 committed by TensorFlower Gardener
parent 3cce429ac2
commit fcad65986c
2 changed files with 16 additions and 5 deletions

View File

@ -2440,12 +2440,12 @@ IrArray::Index IrEmitterUnnested::EmitTilingKernel(
++i) {
int64 tile_size_for_dim = mapping_scheme.GetTileSizeFor(i);
// Only last row or column may not have full size.
llvm::Value* is_last_row =
llvm::Value* is_last =
b_.CreateICmpEQ(tile_index[i], constant(dims_in_blocks[i] - 1));
int64 partial_row_size =
int64 partial_row =
dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim;
output_tile_bounds[i] =
b_.CreateSelect(is_last_row, constant(partial_row_size),
b_.CreateSelect(is_last, constant(partial_row),
constant(tile_size_for_dim), "tile_bound");
}
IrArray::Index tile_origin =
@ -2484,6 +2484,10 @@ IrArray::Index IrEmitterUnnested::EmitTilingKernel(
return GetElementIndexForTileOrigin(starting_tile, mapping_scheme, &b_);
}
llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
}
// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
// algorithm to improve the memory access patterns for the input parameters
// with a shape that is a 0-2-1 transpose of the output tensor shape. The caller
@ -2617,7 +2621,7 @@ void IrEmitterUnnested::EmitHlo021Tile(
// Wait for all threads to reach this point using `__syncthreads` in
// CUDA.
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
EmitSyncThreads();
}
EmitTile(mapping_scheme, index, loop_name, ksl, &b_, y, x, tile_height,
@ -2629,7 +2633,7 @@ void IrEmitterUnnested::EmitHlo021Tile(
// memory buffer for the current tile before we move on to process the
// next tile and overwrite the shared memory buffers.
if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
EmitSyncThreads();
}
};

View File

@ -335,9 +335,16 @@ class IrEmitterUnnested : public IrEmitter,
// Lane id: `thread_id % kWarpSize`
llvm::Value* lane_id;
};
// Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane id.
//
// Returns a struct containting these values.
ThreadIdInfo EmitThreadIdInfo(int64 threads_per_block, llvm::Type* index_ty,
int64 num_threads_x);
// Emit __syncthreads(), synchronization barrier for all threads in a block.
llvm::CallInst* EmitSyncThreads();
// Emits current block id.
llvm::Value* EmitBlockId();