[XLA:GPU] Cache kernels in KernelThunk::ExecuteOnStream

Change: 147497724
This commit is contained in:
A. Unique TensorFlower 2017-02-14 11:22:50 -08:00 committed by TensorFlower Gardener
parent 2b2f26b088
commit 5799620640
3 changed files with 16 additions and 4 deletions

View File

@ -60,14 +60,19 @@ tensorflow::Status KernelThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream) {
// Load the kernel.
se::StreamExecutor* executor = stream->parent();
se::KernelBase kernel(executor);
LaunchDimensions launch_dimensions;
const se::KernelBase* kernel = nullptr;
{
tensorflow::mutex_lock lock(mutex_);
if (!executor->GetKernel(*loader_spec_, &kernel)) {
return InternalError("Unable to load kernel %s", kernel_name_.c_str());
auto it = kernel_cache_.find(executor);
if (kernel_cache_.end() == it) {
it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first;
if (!executor->GetKernel(*loader_spec_, &it->second)) {
return InternalError("Unable to load kernel %s", kernel_name_.c_str());
}
}
launch_dimensions = launch_dimensions_;
kernel = &it->second;
}
// Launch the kernel with potentially multiple blocks and threads.
@ -81,7 +86,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream(
buffer_allocations.GetTempBufferBase());
if (!stream->parent()->Launch(
stream, se::ThreadDim(launch_dimensions.threads_per_block()),
se::BlockDim(launch_dimensions.block_count()), kernel,
se::BlockDim(launch_dimensions.block_count()), *kernel,
*kernel_args)) {
return InternalError("Unable to launch kernel %s", kernel_name_.c_str());
}

View File

@ -78,6 +78,11 @@ class KernelThunk : public Thunk {
mutable tensorflow::mutex mutex_;
std::unique_ptr<perftools::gputools::MultiKernelLoaderSpec> loader_spec_
GUARDED_BY(mutex_);
// Loaded kernels for each `StreamExecutor`
std::unordered_map<perftools::gputools::StreamExecutor*,
perftools::gputools::KernelBase>
kernel_cache_ GUARDED_BY(mutex_);
};
} // namespace gpu

View File

@ -136,6 +136,8 @@ class KernelMetadata {
// Thread-compatible.
class KernelBase {
public:
KernelBase(KernelBase &&) = default;
// Constructs an "empty" (not-yet-loaded) kernel instance.
//
// parent is the StreamExecutor that will be responsible for loading the