[SE] Don't assume that the CUDA context has not changed in the outermost ScopedActivationContext.
Will fix https://github.com/google/jax/issues/3802 when incorporated into JAX. PiperOrigin-RevId: 325899237 Change-Id: I1f2bf59d982da16db138229d8fa155f41a7e094a
This commit is contained in:
parent
dd2ee4e8cc
commit
1e336d3ef0
tensorflow/stream_executor/cuda
@ -130,6 +130,18 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "cuda_driver_test",
|
||||||
|
srcs = ["cuda_driver_test.cc"],
|
||||||
|
tags = tf_cuda_tests_tags(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"@local_config_cuda//cuda:cuda_headers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_cc_test(
|
tf_cuda_cc_test(
|
||||||
name = "memcpy_test",
|
name = "memcpy_test",
|
||||||
srcs = ["memcpy_test.cc"],
|
srcs = ["memcpy_test.cc"],
|
||||||
|
@ -200,6 +200,21 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) {
|
|||||||
if (FLAGS_gpuexec_cuda_sync_around_driver_calls) SynchronizeOrDie();
|
if (FLAGS_gpuexec_cuda_sync_around_driver_calls) SynchronizeOrDie();
|
||||||
|
|
||||||
auto* tls = &tls_data.get();
|
auto* tls = &tls_data.get();
|
||||||
|
|
||||||
|
// If this is an outermost scope, we must not assume that the CUDA context has
|
||||||
|
// been left in the same state we left it. Other code may have run on this
|
||||||
|
// thread and altered the context.
|
||||||
|
if (tls->depth == 0) {
|
||||||
|
VLOG(3) << "ScopedActivateContext switching to " << cuda_context->id();
|
||||||
|
FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()),
|
||||||
|
"Failed setting context");
|
||||||
|
tls->depth = 1;
|
||||||
|
tls->id = cuda_context->id();
|
||||||
|
tls->context = cuda_context;
|
||||||
|
to_restore_ = nullptr;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
tls->depth++;
|
tls->depth++;
|
||||||
if (tls->id == cuda_context->id()) {
|
if (tls->id == cuda_context->id()) {
|
||||||
if (kVerifyGpuContext) {
|
if (kVerifyGpuContext) {
|
||||||
@ -212,8 +227,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) {
|
|||||||
VLOG(3) << "ScopedActivateContext switching context from " << tls->id
|
VLOG(3) << "ScopedActivateContext switching context from " << tls->id
|
||||||
<< " to " << cuda_context->id();
|
<< " to " << cuda_context->id();
|
||||||
|
|
||||||
to_restore_ = (tls->depth == 1 ? nullptr : tls->context);
|
to_restore_ = tls->context;
|
||||||
|
|
||||||
// Set the context and update thread local.
|
// Set the context and update thread local.
|
||||||
FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()),
|
FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()),
|
||||||
"Failed setting context");
|
"Failed setting context");
|
||||||
|
76
tensorflow/stream_executor/cuda/cuda_driver_test.cc
Normal file
76
tensorflow/stream_executor/cuda/cuda_driver_test.cc
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace stream_executor {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
void CheckCuda(CUresult result, const char* file, int line) {
|
||||||
|
if (result == CUDA_SUCCESS) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const char* name;
|
||||||
|
cuGetErrorName(result, &name);
|
||||||
|
const char* message;
|
||||||
|
cuGetErrorString(result, &message);
|
||||||
|
LOG(FATAL) << file << "(" << line << "): " << name << ", " << message;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckCuda(cudaError_t result, const char* file, int line) {
|
||||||
|
if (result == cudaSuccess) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const char* name = cudaGetErrorName(result);
|
||||||
|
const char* message = cudaGetErrorString(result);
|
||||||
|
LOG(FATAL) << file << "(" << line << "): " << name << ", " << message;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CHECK_CUDA(result) CheckCuda(result, __FILE__, __LINE__)
|
||||||
|
|
||||||
|
TEST(CudaDriverTest, ScopedActivateContextTest) {
|
||||||
|
CHECK_CUDA(cuInit(0));
|
||||||
|
CUdevice device;
|
||||||
|
CHECK_CUDA(cuDeviceGet(&device, 0));
|
||||||
|
CUcontext context0, context1;
|
||||||
|
CHECK_CUDA(cuCtxCreate(&context0, 0, device));
|
||||||
|
CHECK_CUDA(cuCtxCreate(&context1, 0, device));
|
||||||
|
GpuContext se_context1(context1, /*id=*/101);
|
||||||
|
{
|
||||||
|
ScopedActivateContext scope(&se_context1);
|
||||||
|
CUcontext c;
|
||||||
|
CHECK_CUDA(cuCtxGetCurrent(&c));
|
||||||
|
EXPECT_EQ(c, context1);
|
||||||
|
}
|
||||||
|
CHECK_CUDA(cuCtxSetCurrent(context0));
|
||||||
|
// ScopedActivateContext must correctly set the CUDA context even if some
|
||||||
|
// other code changes the context between the two scopes.
|
||||||
|
{
|
||||||
|
ScopedActivateContext scope(&se_context1);
|
||||||
|
CUcontext c;
|
||||||
|
CHECK_CUDA(cuCtxGetCurrent(&c));
|
||||||
|
EXPECT_EQ(c, context1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace stream_executor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
Loading…
Reference in New Issue
Block a user