From 1e336d3ef0706792772b605a1fed0135f5cc5cfe Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 10 Aug 2020 15:19:24 -0700 Subject: [PATCH] [SE] Don't assume that the CUDA context has not changed in the outermost ScopedActivationContext. Will fix https://github.com/google/jax/issues/3802 when incorporated into JAX. PiperOrigin-RevId: 325899237 Change-Id: I1f2bf59d982da16db138229d8fa155f41a7e094a --- tensorflow/stream_executor/cuda/BUILD | 12 +++ .../stream_executor/cuda/cuda_driver.cc | 18 ++++- .../stream_executor/cuda/cuda_driver_test.cc | 76 +++++++++++++++++++ 3 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 tensorflow/stream_executor/cuda/cuda_driver_test.cc diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD index f3cffc04465..bd545f097cf 100644 --- a/tensorflow/stream_executor/cuda/BUILD +++ b/tensorflow/stream_executor/cuda/BUILD @@ -130,6 +130,18 @@ cc_library( ], ) +tf_cuda_cc_test( + name = "cuda_driver_test", + srcs = ["cuda_driver_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@local_config_cuda//cuda:cuda_headers", + ], +) + tf_cuda_cc_test( name = "memcpy_test", srcs = ["memcpy_test.cc"], diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index e30eb549a9c..67fd72d52f3 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -200,6 +200,21 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) { if (FLAGS_gpuexec_cuda_sync_around_driver_calls) SynchronizeOrDie(); auto* tls = &tls_data.get(); + + // If this is an outermost scope, we must not assume that the CUDA context has + // been left in the same state we left it. Other code may have run on this + // thread and altered the context. + if (tls->depth == 0) { + VLOG(3) << "ScopedActivateContext switching to " << cuda_context->id(); + FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()), + "Failed setting context"); + tls->depth = 1; + tls->id = cuda_context->id(); + tls->context = cuda_context; + to_restore_ = nullptr; + return; + } + tls->depth++; if (tls->id == cuda_context->id()) { if (kVerifyGpuContext) { @@ -212,8 +227,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) { VLOG(3) << "ScopedActivateContext switching context from " << tls->id << " to " << cuda_context->id(); - to_restore_ = (tls->depth == 1 ? nullptr : tls->context); - + to_restore_ = tls->context; // Set the context and update thread local. FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()), "Failed setting context"); diff --git a/tensorflow/stream_executor/cuda/cuda_driver_test.cc b/tensorflow/stream_executor/cuda/cuda_driver_test.cc new file mode 100644 index 00000000000..5b173f96d85 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_driver_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/cuda/cuda_driver.h" + +#include "absl/memory/memory.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "tensorflow/core/platform/test.h" + +namespace stream_executor { +namespace gpu { + +void CheckCuda(CUresult result, const char* file, int line) { + if (result == CUDA_SUCCESS) { + return; + } + const char* name; + cuGetErrorName(result, &name); + const char* message; + cuGetErrorString(result, &message); + LOG(FATAL) << file << "(" << line << "): " << name << ", " << message; +} + +void CheckCuda(cudaError_t result, const char* file, int line) { + if (result == cudaSuccess) { + return; + } + const char* name = cudaGetErrorName(result); + const char* message = cudaGetErrorString(result); + LOG(FATAL) << file << "(" << line << "): " << name << ", " << message; +} + +#define CHECK_CUDA(result) CheckCuda(result, __FILE__, __LINE__) + +TEST(CudaDriverTest, ScopedActivateContextTest) { + CHECK_CUDA(cuInit(0)); + CUdevice device; + CHECK_CUDA(cuDeviceGet(&device, 0)); + CUcontext context0, context1; + CHECK_CUDA(cuCtxCreate(&context0, 0, device)); + CHECK_CUDA(cuCtxCreate(&context1, 0, device)); + GpuContext se_context1(context1, /*id=*/101); + { + ScopedActivateContext scope(&se_context1); + CUcontext c; + CHECK_CUDA(cuCtxGetCurrent(&c)); + EXPECT_EQ(c, context1); + } + CHECK_CUDA(cuCtxSetCurrent(context0)); + // ScopedActivateContext must correctly set the CUDA context even if some + // other code changes the context between the two scopes. + { + ScopedActivateContext scope(&se_context1); + CUcontext c; + CHECK_CUDA(cuCtxGetCurrent(&c)); + EXPECT_EQ(c, context1); + } +} + +} // namespace gpu +} // namespace stream_executor + +#endif // GOOGLE_CUDA