[XLA:GPU] Cache kernels in KernelThunk::ExecuteOnStream
Change: 147497724
This commit is contained in:
parent
2b2f26b088
commit
5799620640
@ -60,14 +60,19 @@ tensorflow::Status KernelThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream) {
|
||||
// Load the kernel.
|
||||
se::StreamExecutor* executor = stream->parent();
|
||||
se::KernelBase kernel(executor);
|
||||
LaunchDimensions launch_dimensions;
|
||||
const se::KernelBase* kernel = nullptr;
|
||||
{
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
if (!executor->GetKernel(*loader_spec_, &kernel)) {
|
||||
return InternalError("Unable to load kernel %s", kernel_name_.c_str());
|
||||
auto it = kernel_cache_.find(executor);
|
||||
if (kernel_cache_.end() == it) {
|
||||
it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first;
|
||||
if (!executor->GetKernel(*loader_spec_, &it->second)) {
|
||||
return InternalError("Unable to load kernel %s", kernel_name_.c_str());
|
||||
}
|
||||
}
|
||||
launch_dimensions = launch_dimensions_;
|
||||
kernel = &it->second;
|
||||
}
|
||||
|
||||
// Launch the kernel with potentially multiple blocks and threads.
|
||||
@ -81,7 +86,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream(
|
||||
buffer_allocations.GetTempBufferBase());
|
||||
if (!stream->parent()->Launch(
|
||||
stream, se::ThreadDim(launch_dimensions.threads_per_block()),
|
||||
se::BlockDim(launch_dimensions.block_count()), kernel,
|
||||
se::BlockDim(launch_dimensions.block_count()), *kernel,
|
||||
*kernel_args)) {
|
||||
return InternalError("Unable to launch kernel %s", kernel_name_.c_str());
|
||||
}
|
||||
|
@ -78,6 +78,11 @@ class KernelThunk : public Thunk {
|
||||
mutable tensorflow::mutex mutex_;
|
||||
std::unique_ptr<perftools::gputools::MultiKernelLoaderSpec> loader_spec_
|
||||
GUARDED_BY(mutex_);
|
||||
|
||||
// Loaded kernels for each `StreamExecutor`
|
||||
std::unordered_map<perftools::gputools::StreamExecutor*,
|
||||
perftools::gputools::KernelBase>
|
||||
kernel_cache_ GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -136,6 +136,8 @@ class KernelMetadata {
|
||||
// Thread-compatible.
|
||||
class KernelBase {
|
||||
public:
|
||||
KernelBase(KernelBase &&) = default;
|
||||
|
||||
// Constructs an "empty" (not-yet-loaded) kernel instance.
|
||||
//
|
||||
// parent is the StreamExecutor that will be responsible for loading the
|
||||
|
Loading…
Reference in New Issue
Block a user