diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 9e319d4c356..69399e36c4c 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -60,14 +60,19 @@ tensorflow::Status KernelThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); - se::KernelBase kernel(executor); LaunchDimensions launch_dimensions; + const se::KernelBase* kernel = nullptr; { tensorflow::mutex_lock lock(mutex_); - if (!executor->GetKernel(*loader_spec_, &kernel)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + auto it = kernel_cache_.find(executor); + if (kernel_cache_.end() == it) { + it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; + if (!executor->GetKernel(*loader_spec_, &it->second)) { + return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + } } launch_dimensions = launch_dimensions_; + kernel = &it->second; } // Launch the kernel with potentially multiple blocks and threads. @@ -81,7 +86,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream( buffer_allocations.GetTempBufferBase()); if (!stream->parent()->Launch( stream, se::ThreadDim(launch_dimensions.threads_per_block()), - se::BlockDim(launch_dimensions.block_count()), kernel, + se::BlockDim(launch_dimensions.block_count()), *kernel, *kernel_args)) { return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); } diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index dd636552b42..350b5aaf360 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -78,6 +78,11 @@ class KernelThunk : public Thunk { mutable tensorflow::mutex mutex_; std::unique_ptr loader_spec_ GUARDED_BY(mutex_); + + // Loaded kernels for each `StreamExecutor` + std::unordered_map + kernel_cache_ GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h index bbe02e5112f..d9d40d77bd9 100644 --- a/tensorflow/stream_executor/kernel.h +++ b/tensorflow/stream_executor/kernel.h @@ -136,6 +136,8 @@ class KernelMetadata { // Thread-compatible. class KernelBase { public: + KernelBase(KernelBase &&) = default; + // Constructs an "empty" (not-yet-loaded) kernel instance. // // parent is the StreamExecutor that will be responsible for loading the