Tensorflowify the cuda runtime wrappers used by kernel generator.
Also add a cache for modules and streams to avoid repeated creation. PiperOrigin-RevId: 344772914 Change-Id: If7ce536da0f7abc3e5dcb57192aac45ae381701c
This commit is contained in:
parent
ccb98fd368
commit
270147ce81
@ -184,9 +184,10 @@ cc_library(
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]),
|
||||
deps = [
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:mutex",
|
||||
"//tensorflow/core/platform/default/build_config:stream_executor_cuda",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:mlir_c_runner_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
||||
@ -20,32 +20,79 @@ limitations under the License.
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
|
||||
#define CUDA_REPORT_IF_ERROR(expr) \
|
||||
[](CUresult result) { \
|
||||
if (!result) \
|
||||
return; \
|
||||
const char *name = nullptr; \
|
||||
cuGetErrorName(result, &name); \
|
||||
if (!name) \
|
||||
name = "<unknown>"; \
|
||||
llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \
|
||||
#define CUDA_REPORT_IF_ERROR(expr) \
|
||||
[](CUresult result) { \
|
||||
if (!result) return; \
|
||||
const char *name = nullptr; \
|
||||
cuGetErrorName(result, &name); \
|
||||
if (!name) name = "<unknown>"; \
|
||||
LOG(WARNING) << "'" << #expr << "' failed with '" << name << "'\n"; \
|
||||
}(expr)
|
||||
|
||||
namespace {
|
||||
// Implements a cache for loading modules and creating streams. The assumption
|
||||
// is that we never unload modules or delete streams again during the lifetime
|
||||
// of a tensorflow runtime process.
|
||||
struct CudaRuntimeCache {
|
||||
public:
|
||||
CUmodule loadModule(void *data) {
|
||||
tensorflow::mutex_lock lock(module_handle_mutex);
|
||||
auto it = module_handles.find(data);
|
||||
if (it != module_handles.end()) {
|
||||
return it->second;
|
||||
}
|
||||
CUmodule module = nullptr;
|
||||
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
|
||||
module_handles.insert({data, module});
|
||||
return module;
|
||||
}
|
||||
|
||||
CUstream createStream() {
|
||||
tensorflow::mutex_lock lock(stream_handle_mutex);
|
||||
CUstream stream = nullptr;
|
||||
if (stream_handles.empty()) {
|
||||
CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
|
||||
} else {
|
||||
stream = stream_handles.back();
|
||||
stream_handles.pop_back();
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
||||
void releaseStream(CUstream stream) {
|
||||
tensorflow::mutex_lock lock(stream_handle_mutex);
|
||||
stream_handles.push_back(stream);
|
||||
}
|
||||
|
||||
static CudaRuntimeCache *get() {
|
||||
static auto *instance = new CudaRuntimeCache();
|
||||
return instance;
|
||||
}
|
||||
|
||||
private:
|
||||
CudaRuntimeCache() = default;
|
||||
|
||||
tensorflow::mutex stream_handle_mutex;
|
||||
std::vector<CUstream> stream_handles TF_GUARDED_BY(stream_handle_mutex);
|
||||
tensorflow::mutex module_handle_mutex;
|
||||
absl::flat_hash_map<void *, CUmodule> module_handles
|
||||
TF_GUARDED_BY(module_handle_mutex);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
extern "C" CUmodule mgpuModuleLoad(void *data) {
|
||||
CUmodule module = nullptr;
|
||||
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
|
||||
return module;
|
||||
return CudaRuntimeCache::get()->loadModule(data);
|
||||
}
|
||||
|
||||
extern "C" void mgpuModuleUnload(CUmodule module) {
|
||||
CUDA_REPORT_IF_ERROR(cuModuleUnload(module));
|
||||
// We never unload modules.
|
||||
}
|
||||
|
||||
extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) {
|
||||
@ -68,13 +115,11 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
|
||||
}
|
||||
|
||||
extern "C" CUstream mgpuStreamCreate() {
|
||||
CUstream stream = nullptr;
|
||||
CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
|
||||
return stream;
|
||||
return CudaRuntimeCache::get()->createStream();
|
||||
}
|
||||
|
||||
extern "C" void mgpuStreamDestroy(CUstream stream) {
|
||||
CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream));
|
||||
return CudaRuntimeCache::get()->releaseStream(stream);
|
||||
}
|
||||
|
||||
extern "C" void mgpuStreamSynchronize(CUstream stream) {
|
||||
@ -103,36 +148,4 @@ extern "C" void mgpuEventRecord(CUevent event, CUstream stream) {
|
||||
CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream));
|
||||
}
|
||||
|
||||
/// Helper functions for writing mlir example code
|
||||
|
||||
// Allows to register byte array with the CUDA runtime. Helpful until we have
|
||||
// transfer functions implemented.
|
||||
extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
|
||||
CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
|
||||
}
|
||||
|
||||
// Allows to register a MemRef with the CUDA runtime. Helpful until we have
|
||||
// transfer functions implemented.
|
||||
extern "C" void
|
||||
mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
|
||||
int64_t elementSizeBytes) {
|
||||
|
||||
llvm::SmallVector<int64_t, 4> denseStrides(rank);
|
||||
llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
|
||||
llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
|
||||
|
||||
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
|
||||
std::multiplies<int64_t>());
|
||||
auto sizeBytes = denseStrides.front() * elementSizeBytes;
|
||||
|
||||
// Only densely packed tensors are currently supported.
|
||||
std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
|
||||
denseStrides.end());
|
||||
denseStrides.back() = 1;
|
||||
assert(strides == llvm::makeArrayRef(denseStrides));
|
||||
|
||||
auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
|
||||
mgpuMemHostRegister(ptr, sizeBytes);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user