[SE] Don't assume that the CUDA context has not changed in the outermost ScopedActivationContext.
Will fix https://github.com/google/jax/issues/3802 when incorporated into JAX. PiperOrigin-RevId: 325899237 Change-Id: I1f2bf59d982da16db138229d8fa155f41a7e094a
This commit is contained in:
parent
dd2ee4e8cc
commit
1e336d3ef0
tensorflow/stream_executor/cuda
@ -130,6 +130,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "cuda_driver_test",
|
||||
srcs = ["cuda_driver_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "memcpy_test",
|
||||
srcs = ["memcpy_test.cc"],
|
||||
|
@ -200,6 +200,21 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) {
|
||||
if (FLAGS_gpuexec_cuda_sync_around_driver_calls) SynchronizeOrDie();
|
||||
|
||||
auto* tls = &tls_data.get();
|
||||
|
||||
// If this is an outermost scope, we must not assume that the CUDA context has
|
||||
// been left in the same state we left it. Other code may have run on this
|
||||
// thread and altered the context.
|
||||
if (tls->depth == 0) {
|
||||
VLOG(3) << "ScopedActivateContext switching to " << cuda_context->id();
|
||||
FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()),
|
||||
"Failed setting context");
|
||||
tls->depth = 1;
|
||||
tls->id = cuda_context->id();
|
||||
tls->context = cuda_context;
|
||||
to_restore_ = nullptr;
|
||||
return;
|
||||
}
|
||||
|
||||
tls->depth++;
|
||||
if (tls->id == cuda_context->id()) {
|
||||
if (kVerifyGpuContext) {
|
||||
@ -212,8 +227,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) {
|
||||
VLOG(3) << "ScopedActivateContext switching context from " << tls->id
|
||||
<< " to " << cuda_context->id();
|
||||
|
||||
to_restore_ = (tls->depth == 1 ? nullptr : tls->context);
|
||||
|
||||
to_restore_ = tls->context;
|
||||
// Set the context and update thread local.
|
||||
FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()),
|
||||
"Failed setting context");
|
||||
|
76
tensorflow/stream_executor/cuda/cuda_driver_test.cc
Normal file
76
tensorflow/stream_executor/cuda/cuda_driver_test.cc
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
|
||||
void CheckCuda(CUresult result, const char* file, int line) {
|
||||
if (result == CUDA_SUCCESS) {
|
||||
return;
|
||||
}
|
||||
const char* name;
|
||||
cuGetErrorName(result, &name);
|
||||
const char* message;
|
||||
cuGetErrorString(result, &message);
|
||||
LOG(FATAL) << file << "(" << line << "): " << name << ", " << message;
|
||||
}
|
||||
|
||||
void CheckCuda(cudaError_t result, const char* file, int line) {
|
||||
if (result == cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
const char* name = cudaGetErrorName(result);
|
||||
const char* message = cudaGetErrorString(result);
|
||||
LOG(FATAL) << file << "(" << line << "): " << name << ", " << message;
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(result) CheckCuda(result, __FILE__, __LINE__)
|
||||
|
||||
TEST(CudaDriverTest, ScopedActivateContextTest) {
|
||||
CHECK_CUDA(cuInit(0));
|
||||
CUdevice device;
|
||||
CHECK_CUDA(cuDeviceGet(&device, 0));
|
||||
CUcontext context0, context1;
|
||||
CHECK_CUDA(cuCtxCreate(&context0, 0, device));
|
||||
CHECK_CUDA(cuCtxCreate(&context1, 0, device));
|
||||
GpuContext se_context1(context1, /*id=*/101);
|
||||
{
|
||||
ScopedActivateContext scope(&se_context1);
|
||||
CUcontext c;
|
||||
CHECK_CUDA(cuCtxGetCurrent(&c));
|
||||
EXPECT_EQ(c, context1);
|
||||
}
|
||||
CHECK_CUDA(cuCtxSetCurrent(context0));
|
||||
// ScopedActivateContext must correctly set the CUDA context even if some
|
||||
// other code changes the context between the two scopes.
|
||||
{
|
||||
ScopedActivateContext scope(&se_context1);
|
||||
CUcontext c;
|
||||
CHECK_CUDA(cuCtxGetCurrent(&c));
|
||||
EXPECT_EQ(c, context1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
Loading…
Reference in New Issue
Block a user