From e2aa1ff751d437546b064c1a4ede92e1bc4de2f0 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 12 Oct 2020 12:32:21 -0700 Subject: [PATCH] [KERNEL_GEN] Cache cuda stream to avoid OOM. PiperOrigin-RevId: 336719808 Change-Id: I6422a64e897b75f4259eb4c0d6dad0f1f50affc9 --- .../mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc index 3744a5ea31f..06d613e0599 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc @@ -64,8 +64,13 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX, } extern "C" CUstream mgpuStreamCreate() { - CUstream stream = nullptr; - CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); + static CUstream stream = []() { + // TODO(b/170649852): This is neither thread-safe nor handles + // creation/descruction of one stream per context. + CUstream stream = nullptr; + CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); + return stream; + }(); return stream; }