From 270147ce8191969ab0d843d75db32686e1abd862 Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Mon, 30 Nov 2020 02:17:01 -0800 Subject: [PATCH] Tensorflowify the cuda runtime wrappers used by kernel generator. Also add a cache for modules and streams to avoid repeated creation. PiperOrigin-RevId: 344772914 Change-Id: If7ce536da0f7abc3e5dcb57192aac45ae381701c --- .../compiler/mlir/tools/kernel_gen/BUILD | 5 +- .../kernel_gen/tf_cuda_runtime_wrappers.cc | 117 ++++++++++-------- 2 files changed, 68 insertions(+), 54 deletions(-) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index e91f0912a36..a44c4ac1276 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -184,9 +184,10 @@ cc_library( compatible_with = get_compatible_with_cloud(), copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]), deps = [ + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:mutex", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:mlir_c_runner_utils", + "@com_google_absl//absl/container:flat_hash_map", "@local_config_cuda//cuda:cuda_headers", ], ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc index 8c4380ffac8..a5c74ba5473 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc @@ -20,32 +20,79 @@ limitations under the License. #include #include -#include "llvm/ADT/ArrayRef.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" -#define CUDA_REPORT_IF_ERROR(expr) \ - [](CUresult result) { \ - if (!result) \ - return; \ - const char *name = nullptr; \ - cuGetErrorName(result, &name); \ - if (!name) \ - name = ""; \ - llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ +#define CUDA_REPORT_IF_ERROR(expr) \ + [](CUresult result) { \ + if (!result) return; \ + const char *name = nullptr; \ + cuGetErrorName(result, &name); \ + if (!name) name = ""; \ + LOG(WARNING) << "'" << #expr << "' failed with '" << name << "'\n"; \ }(expr) +namespace { +// Implements a cache for loading modules and creating streams. The assumption +// is that we never unload modules or delete streams again during the lifetime +// of a tensorflow runtime process. +struct CudaRuntimeCache { + public: + CUmodule loadModule(void *data) { + tensorflow::mutex_lock lock(module_handle_mutex); + auto it = module_handles.find(data); + if (it != module_handles.end()) { + return it->second; + } + CUmodule module = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); + module_handles.insert({data, module}); + return module; + } + + CUstream createStream() { + tensorflow::mutex_lock lock(stream_handle_mutex); + CUstream stream = nullptr; + if (stream_handles.empty()) { + CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); + } else { + stream = stream_handles.back(); + stream_handles.pop_back(); + } + return stream; + } + + void releaseStream(CUstream stream) { + tensorflow::mutex_lock lock(stream_handle_mutex); + stream_handles.push_back(stream); + } + + static CudaRuntimeCache *get() { + static auto *instance = new CudaRuntimeCache(); + return instance; + } + + private: + CudaRuntimeCache() = default; + + tensorflow::mutex stream_handle_mutex; + std::vector stream_handles TF_GUARDED_BY(stream_handle_mutex); + tensorflow::mutex module_handle_mutex; + absl::flat_hash_map module_handles + TF_GUARDED_BY(module_handle_mutex); +}; +} // namespace + extern "C" CUmodule mgpuModuleLoad(void *data) { - CUmodule module = nullptr; - CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); - return module; + return CudaRuntimeCache::get()->loadModule(data); } extern "C" void mgpuModuleUnload(CUmodule module) { - CUDA_REPORT_IF_ERROR(cuModuleUnload(module)); + // We never unload modules. } extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) { @@ -68,13 +115,11 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX, } extern "C" CUstream mgpuStreamCreate() { - CUstream stream = nullptr; - CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); - return stream; + return CudaRuntimeCache::get()->createStream(); } extern "C" void mgpuStreamDestroy(CUstream stream) { - CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream)); + return CudaRuntimeCache::get()->releaseStream(stream); } extern "C" void mgpuStreamSynchronize(CUstream stream) { @@ -103,36 +148,4 @@ extern "C" void mgpuEventRecord(CUevent event, CUstream stream) { CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream)); } -/// Helper functions for writing mlir example code - -// Allows to register byte array with the CUDA runtime. Helpful until we have -// transfer functions implemented. -extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { - CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); -} - -// Allows to register a MemRef with the CUDA runtime. Helpful until we have -// transfer functions implemented. -extern "C" void -mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType *descriptor, - int64_t elementSizeBytes) { - - llvm::SmallVector denseStrides(rank); - llvm::ArrayRef sizes(descriptor->sizes, rank); - llvm::ArrayRef strides(sizes.end(), rank); - - std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), - std::multiplies()); - auto sizeBytes = denseStrides.front() * elementSizeBytes; - - // Only densely packed tensors are currently supported. - std::rotate(denseStrides.begin(), denseStrides.begin() + 1, - denseStrides.end()); - denseStrides.back() = 1; - assert(strides == llvm::makeArrayRef(denseStrides)); - - auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; - mgpuMemHostRegister(ptr, sizeBytes); -} - #endif