Make the runtime cache in the tf cuda runtime wrappers context specific.

PiperOrigin-RevId: 344796382
Change-Id: I0c8dc738f9838a1a30546d001865eb119aca4893
This commit is contained in:
Stephan Herhut 2020-11-30 05:53:42 -08:00 committed by TensorFlower Gardener
parent 9c6a53d746
commit 3b2b12fb53

View File

@ -44,13 +44,10 @@ struct CudaRuntimeCache {
public:
CUmodule loadModule(void *data) {
tensorflow::mutex_lock lock(module_handle_mutex);
auto it = module_handles.find(data);
if (it != module_handles.end()) {
return it->second;
auto &module = module_handles[data];
if (!module) {
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
}
CUmodule module = nullptr;
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
module_handles.insert({data, module});
return module;
}
@ -71,9 +68,20 @@ struct CudaRuntimeCache {
stream_handles.push_back(stream);
}
// Returns the runtime cache for the current context.
static CudaRuntimeCache *get() {
static auto *instance = new CudaRuntimeCache();
return instance;
using CacheWithLock =
std::pair<tensorflow::mutex,
absl::flat_hash_map<CUcontext, CudaRuntimeCache *>>;
static auto *cache_with_lock = new CacheWithLock();
tensorflow::mutex_lock lock(cache_with_lock->first);
CUcontext context;
CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&context));
auto &runtime_cache = cache_with_lock->second[context];
if (!runtime_cache) {
runtime_cache = new CudaRuntimeCache();
}
return runtime_cache;
}
private: