Make the runtime cache in the tf cuda runtime wrappers context specific.
PiperOrigin-RevId: 344796382 Change-Id: I0c8dc738f9838a1a30546d001865eb119aca4893
This commit is contained in:
parent
9c6a53d746
commit
3b2b12fb53
@ -44,13 +44,10 @@ struct CudaRuntimeCache {
|
||||
public:
|
||||
CUmodule loadModule(void *data) {
|
||||
tensorflow::mutex_lock lock(module_handle_mutex);
|
||||
auto it = module_handles.find(data);
|
||||
if (it != module_handles.end()) {
|
||||
return it->second;
|
||||
auto &module = module_handles[data];
|
||||
if (!module) {
|
||||
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
|
||||
}
|
||||
CUmodule module = nullptr;
|
||||
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
|
||||
module_handles.insert({data, module});
|
||||
return module;
|
||||
}
|
||||
|
||||
@ -71,9 +68,20 @@ struct CudaRuntimeCache {
|
||||
stream_handles.push_back(stream);
|
||||
}
|
||||
|
||||
// Returns the runtime cache for the current context.
|
||||
static CudaRuntimeCache *get() {
|
||||
static auto *instance = new CudaRuntimeCache();
|
||||
return instance;
|
||||
using CacheWithLock =
|
||||
std::pair<tensorflow::mutex,
|
||||
absl::flat_hash_map<CUcontext, CudaRuntimeCache *>>;
|
||||
static auto *cache_with_lock = new CacheWithLock();
|
||||
tensorflow::mutex_lock lock(cache_with_lock->first);
|
||||
CUcontext context;
|
||||
CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&context));
|
||||
auto &runtime_cache = cache_with_lock->second[context];
|
||||
if (!runtime_cache) {
|
||||
runtime_cache = new CudaRuntimeCache();
|
||||
}
|
||||
return runtime_cache;
|
||||
}
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user