Stop using DSO loader for CUDA SDK libraries

The DSO loader intermediate layer is no longer required, so it is removed here.
Change: 145589281
This commit is contained in:
A. Unique TensorFlower 2017-01-25 13:32:47 -08:00 committed by TensorFlower Gardener
parent 3dae99f979
commit 191658d54f
9 changed files with 488 additions and 737 deletions

View File

@ -45,21 +45,20 @@ cc_library(
exclude = ["cuda/cuda_platform_id.cc"],
),
),
data = [
"//tensorflow/core:cuda",
"@local_config_cuda//cuda:cublas",
"@local_config_cuda//cuda:cudnn",
"@local_config_cuda//cuda:cufft",
"@local_config_cuda//cuda:curand",
],
linkopts = [
"-ldl",
],
visibility = ["//visibility:public"],
deps = [
":stream_executor",
"//tensorflow/core:cuda",
"//tensorflow/core:lib",
"@local_config_cuda//cuda:cublas",
"@local_config_cuda//cuda:cuda_driver",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudnn",
"@local_config_cuda//cuda:cufft",
"@local_config_cuda//cuda:curand",
],
alwayslink = 1,
)

File diff suppressed because it is too large Load Diff

View File

@ -107,7 +107,7 @@ string ToString(cudnnStatus_t status) {
}
}
namespace dynload {
namespace wrap {
static port::ThreadPool* InitCudnnThreadpool() {
port::ThreadPool* cudnn_threadpool_;
@ -130,54 +130,15 @@ static port::ThreadPool* GetCudaThreadpool() {
return cudnn_threadpool;
}
// Retrieves the CUDNN DSO, dies on failure.
void* GetDsoHandle() {
static auto result = internal::CachedDsoLoader::GetCudnnDsoHandle();
return result.ValueOrDie();
}
static void* DynLoadGetVersionOrDie() {
void* f;
port::Status s = port::Env::Default()->GetSymbolFromLibrary(
GetDsoHandle(), "cudnnGetVersion", &f);
if (f == nullptr) {
LOG(FATAL) << "could not find cudnnGetVersion in cudnn DSO; dlerror: "
<< s.error_message();
}
return f;
}
// Calls cudnnGetVersion in the loaded DSO.
size_t cudnnGetVersion() {
static void* f = DynLoadGetVersionOrDie();
auto callable = reinterpret_cast<size_t (*)(void)>(f);
return callable();
}
#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char* kName; \
typedef std::add_pointer<decltype(::__name)>::type FuncPointerT; \
static FuncPointerT LoadOrDie() { \
void* f; \
port::Status s = port::Env::Default()->GetSymbolFromLibrary( \
GetDsoHandle(), kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in cudnn DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPointerT>(f); \
} \
static FuncPointerT DynLoad() { \
static FuncPointerT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
cudnnStatus_t operator()(CUDAExecutor* parent, Args... args) { \
cuda::ScopedActivateExecutorContext sac{parent}; \
cudnnStatus_t retval = DynLoad()(args...); \
return retval; \
} \
} __name; \
const char* DynLoadShim__##__name::kName = #__name;
#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP(__name) \
struct WrapperShim__##__name { \
template <typename... Args> \
cudnnStatus_t operator()(CUDAExecutor* parent, Args... args) { \
cuda::ScopedActivateExecutorContext sac{parent}; \
cudnnStatus_t retval = ::__name(args...); \
return retval; \
} \
} __name;
// clang-format off
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
@ -278,7 +239,7 @@ CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH
} // namespace dynload
} // namespace wrap
namespace {
@ -347,19 +308,19 @@ CudnnSupport::CudnnSupport(CUDAExecutor* parent)
: parent_(parent), dnn_handle_(nullptr) {}
CudnnSupport::~CudnnSupport() {
auto status = dynload::cudnnDestroy(parent_, ToHandle(dnn_handle_));
auto status = wrap::cudnnDestroy(parent_, ToHandle(dnn_handle_));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn handle: " << ToString(status);
}
}
port::Status CudnnSupport::Init() {
auto status = dynload::cudnnCreate(
auto status = wrap::cudnnCreate(
parent_, reinterpret_cast<cudnnHandle_t*>(&dnn_handle_));
if (status == CUDNN_STATUS_SUCCESS) {
// Check whether loaded version of CuDNN matches what the source
// was built with.
size_t loaded_version = dynload::cudnnGetVersion();
size_t loaded_version = ::cudnnGetVersion();
size_t loaded_compat_version = cudnnCompatibilityVersion(loaded_version);
size_t compiled_compat_version = cudnnCompatibilityVersion(CUDNN_VERSION);
bool library_loaded_matches_source =
@ -416,8 +377,7 @@ class ScopedTensorDescriptor {
const BatchDescriptor& batch_descriptor,
cudnnDataType_t elem_type)
: parent_(parent), handle_(nullptr) {
cudnnStatus_t status =
dynload::cudnnCreateTensorDescriptor(parent_, &handle_);
cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent_, &handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn tensor descriptor: "
<< ToString(status);
@ -447,8 +407,8 @@ class ScopedTensorDescriptor {
&CheckedNarrowing<int64, int>);
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
&CheckedNarrowing<int64, int>);
status = dynload::cudnnSetTensorNdDescriptor(
parent_, handle_, elem_type, nd, dims.data(), strides.data());
status = wrap::cudnnSetTensorNdDescriptor(parent_, handle_, elem_type, nd,
dims.data(), strides.data());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn tensor descriptor: "
@ -457,8 +417,7 @@ class ScopedTensorDescriptor {
}
~ScopedTensorDescriptor() {
cudnnStatus_t status =
dynload::cudnnDestroyTensorDescriptor(parent_, handle_);
cudnnStatus_t status = wrap::cudnnDestroyTensorDescriptor(parent_, handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn tensor descriptor: "
<< ToString(status);
@ -482,8 +441,7 @@ class ScopedFilterDescriptor {
const BatchDescriptor& batch_descriptor,
cudnnDataType_t elem_type)
: parent_(parent), handle_(nullptr) {
cudnnStatus_t status =
dynload::cudnnCreateFilterDescriptor(parent_, &handle_);
cudnnStatus_t status = wrap::cudnnCreateFilterDescriptor(parent_, &handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn filter descriptor: "
<< ToString(status);
@ -512,11 +470,11 @@ class ScopedFilterDescriptor {
const auto& spatial_dims = filter_descriptor.input_filter_dims();
std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
status = dynload::cudnnSetFilterNdDescriptor(parent_, handle_, elem_type,
status = wrap::cudnnSetFilterNdDescriptor(parent_, handle_, elem_type,
#if CUDNN_VERSION >= 5000
format,
format,
#endif
dims.size(), dims.data());
dims.size(), dims.data());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn filter descriptor: "
<< ToString(status);
@ -524,8 +482,7 @@ class ScopedFilterDescriptor {
}
~ScopedFilterDescriptor() {
cudnnStatus_t status =
dynload::cudnnDestroyFilterDescriptor(parent_, handle_);
cudnnStatus_t status = wrap::cudnnDestroyFilterDescriptor(parent_, handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn filter descriptor: "
<< ToString(status);
@ -553,7 +510,7 @@ class ScopedConvolutionDescriptor {
cudnnDataType_t data_type)
: parent_(parent), handle_(nullptr) {
cudnnStatus_t status =
dynload::cudnnCreateConvolutionDescriptor(parent_, &handle_);
wrap::cudnnCreateConvolutionDescriptor(parent_, &handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn convolution descriptor: "
<< ToString(status);
@ -574,7 +531,7 @@ class ScopedConvolutionDescriptor {
&CheckedNarrowing<int64, int>);
std::vector<int> upscale(convolution_descriptor.ndims(), 1);
status = dynload::cudnnSetConvolutionNdDescriptor(
status = wrap::cudnnSetConvolutionNdDescriptor(
parent_, handle_, convolution_descriptor.ndims(), padding.data(),
strides.data(), upscale.data(),
// NOTE(keveman): cuDNN supports convolution and cross correlation.
@ -590,7 +547,7 @@ class ScopedConvolutionDescriptor {
~ScopedConvolutionDescriptor() {
cudnnStatus_t status =
dynload::cudnnDestroyConvolutionDescriptor(parent_, handle_);
wrap::cudnnDestroyConvolutionDescriptor(parent_, handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn convolution descriptor: "
<< ToString(status);
@ -614,7 +571,7 @@ class ScopedPoolingDescriptor {
const PoolingDescriptor& pooling_descriptor)
: parent_(parent), handle_(nullptr) {
cudnnStatus_t status =
dynload::cudnnCreatePoolingDescriptor(parent_, &handle_);
wrap::cudnnCreatePoolingDescriptor(parent_, &handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn pooling descriptor: "
<< ToString(status);
@ -634,7 +591,7 @@ class ScopedPoolingDescriptor {
&CheckedNarrowing<int64, int>);
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
&CheckedNarrowing<int64, int>);
status = dynload::cudnnSetPoolingNdDescriptor(
status = wrap::cudnnSetPoolingNdDescriptor(
parent_, handle_,
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
? CUDNN_POOLING_MAX
@ -651,7 +608,7 @@ class ScopedPoolingDescriptor {
}
~ScopedPoolingDescriptor() {
cudnnStatus_t status =
dynload::cudnnDestroyPoolingDescriptor(parent_, handle_);
wrap::cudnnDestroyPoolingDescriptor(parent_, handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn pooling descriptor: "
<< ToString(status);
@ -673,7 +630,7 @@ class ScopedNormalizeDescriptor {
ScopedNormalizeDescriptor(CUDAExecutor* parent,
const NormalizeDescriptor& normalize_descriptor)
: parent_(parent), handle_(nullptr) {
cudnnStatus_t status = dynload::cudnnCreateLRNDescriptor(parent_, &handle_);
cudnnStatus_t status = wrap::cudnnCreateLRNDescriptor(parent_, &handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn LRN descriptor: "
<< ToString(status);
@ -699,15 +656,15 @@ class ScopedNormalizeDescriptor {
double lrnBeta = normalize_descriptor.beta();
double lrnK = normalize_descriptor.bias();
status = dynload::cudnnSetLRNDescriptor(parent_, handle_, lrnN, lrnAlpha,
lrnBeta, lrnK);
status = wrap::cudnnSetLRNDescriptor(parent_, handle_, lrnN, lrnAlpha,
lrnBeta, lrnK);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn LRN descriptor: " << ToString(status);
}
}
~ScopedNormalizeDescriptor() {
cudnnStatus_t status = dynload::cudnnDestroyLRNDescriptor(parent_, handle_);
cudnnStatus_t status = wrap::cudnnDestroyLRNDescriptor(parent_, handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn LRN descriptor: "
<< ToString(status);
@ -733,7 +690,7 @@ class ScopedActivationDescriptor {
double value_max)
: parent_(parent), handle_(nullptr) {
cudnnStatus_t status =
dynload::cudnnCreateActivationDescriptor(parent_, &handle_);
wrap::cudnnCreateActivationDescriptor(parent_, &handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn activation descriptor: "
<< ToString(status);
@ -766,9 +723,8 @@ class ScopedActivationDescriptor {
// Always propagate nans.
cudnnNanPropagation_t nan_propagation = CUDNN_PROPAGATE_NAN;
status = dynload::cudnnSetActivationDescriptor(
parent_, handle_,
mode, nan_propagation, relu_ceiling);
status = wrap::cudnnSetActivationDescriptor(parent_, handle_, mode,
nan_propagation, relu_ceiling);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn activation descriptor: "
<< ToString(status);
@ -777,7 +733,7 @@ class ScopedActivationDescriptor {
~ScopedActivationDescriptor() {
cudnnStatus_t status =
dynload::cudnnDestroyActivationDescriptor(parent_, handle_);
wrap::cudnnDestroyActivationDescriptor(parent_, handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn activation descriptor: "
<< ToString(status);
@ -892,7 +848,7 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
ScratchAllocator* state_allocator)
: parent_(parent), handle_(nullptr) {
cudnnStatus_t status;
status = dynload::cudnnCreateDropoutDescriptor(parent_, &handle_);
status = wrap::cudnnCreateDropoutDescriptor(parent_, &handle_);
CUDNN_RETURN_IF_FAIL(status, "Failed to create dropout descriptor");
if (dropout == 0.f) {
@ -902,8 +858,8 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
DeviceMemory<uint8> state_memory;
if (state_allocator) {
size_t state_sizes_in_bytes = 0;
status = dynload::cudnnDropoutGetStatesSize(parent_, cudnn_handle,
&state_sizes_in_bytes);
status = wrap::cudnnDropoutGetStatesSize(parent_, cudnn_handle,
&state_sizes_in_bytes);
CUDNN_RETURN_IF_FAIL(status, "Failed to query dropout state sizes");
auto allocated =
@ -917,16 +873,16 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
return;
}
}
status = dynload::cudnnSetDropoutDescriptor(parent_, handle_, cudnn_handle,
dropout, state_memory.opaque(),
state_memory.size(), seed);
status = wrap::cudnnSetDropoutDescriptor(parent_, handle_, cudnn_handle,
dropout, state_memory.opaque(),
state_memory.size(), seed);
CUDNN_RETURN_IF_FAIL(status, "Failed to set dropout descriptor");
}
~CudnnDropoutDescriptor() {
if (handle_) {
cudnnStatus_t status =
dynload::cudnnDestroyDropoutDescriptor(parent_, handle_);
wrap::cudnnDestroyDropoutDescriptor(parent_, handle_);
CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: ");
}
}
@ -952,8 +908,7 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> {
CudnnRnnParamsDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
const CudnnRnnDescriptor& rnn_desc);
~CudnnRnnParamsDescriptor() {
cudnnStatus_t status =
dynload::cudnnDestroyFilterDescriptor(parent_, handle_);
cudnnStatus_t status = wrap::cudnnDestroyFilterDescriptor(parent_, handle_);
CUDNN_RETURN_IF_FAIL(status, "Failed to destroy RNN filter desciptor");
}
cudnnFilterDescriptor_t handle() const {
@ -1009,10 +964,9 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
}
// Create the RNN handle
cudnnStatus_t status =
dynload::cudnnCreateRNNDescriptor(parent_, &rnn_desc_);
cudnnStatus_t status = wrap::cudnnCreateRNNDescriptor(parent_, &rnn_desc_);
CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
status = dynload::cudnnSetRNNDescriptor(
status = wrap::cudnnSetRNNDescriptor(
parent, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
input_mode /*inputMode*/, direction_mode /*direction*/,
@ -1030,7 +984,7 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
~CudnnRnnDescriptor() override {
if (rnn_desc_) {
cudnnStatus_t status =
dynload::cudnnDestroyRNNDescriptor(parent_, rnn_desc_);
wrap::cudnnDestroyRNNDescriptor(parent_, rnn_desc_);
CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor");
}
}
@ -1091,18 +1045,18 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
cudnnTensorDescriptor_t input_desc = nullptr;
{
// Query the params size.
auto status = dynload::cudnnCreateTensorDescriptor(parent, &input_desc);
auto status = wrap::cudnnCreateTensorDescriptor(parent, &input_desc);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create tensor descriptor");
int dims[] = {1, rnn_desc.input_size(), 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
status = dynload::cudnnSetTensorNdDescriptor(
status = wrap::cudnnSetTensorNdDescriptor(
parent, input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/,
sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
strides /*strideA*/);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor");
size_t params_size = 0;
status = dynload::cudnnGetRNNParamsSize(
status = wrap::cudnnGetRNNParamsSize(
parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
input_desc /*xDesc*/, &params_size /*sizeInBytes*/,
rnn_desc.data_type() /*dataType*/);
@ -1112,10 +1066,10 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
{
// Create the params descriptor.
auto status = dynload::cudnnCreateFilterDescriptor(parent, &handle_);
auto status = wrap::cudnnCreateFilterDescriptor(parent, &handle_);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor");
int dims[] = {static_cast<int>(params_size_in_bytes_), 1, 1};
status = dynload::cudnnSetFilterNdDescriptor(
status = wrap::cudnnSetFilterNdDescriptor(
parent, handle_ /*filterDesc*/, rnn_desc.data_type() /*dataType*/,
CUDNN_TENSOR_NCHW /*format*/, sizeof(dims) / sizeof(dims[0]) /*nbDims*/,
dims /*filterDimA*/);
@ -1127,14 +1081,14 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
int region_count_per_layer = GetRegionCountPerLayer();
cudnnFilterDescriptor_t region_desc_handle = nullptr;
auto status =
dynload::cudnnCreateFilterDescriptor(parent, &region_desc_handle);
wrap::cudnnCreateFilterDescriptor(parent, &region_desc_handle);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create filter descriptor");
for (int layer = 0; layer < rnn_desc.num_layers(); layer++) {
for (int region = 0; region < region_count_per_layer; region++) {
for (int type = 0; type < 2; type++) {
void* offset = nullptr;
if (type == 0) {
status = dynload::cudnnGetRNNLinLayerMatrixParams(
status = wrap::cudnnGetRNNLinLayerMatrixParams(
parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/,
nullptr /*w*/, region /*linLayerID*/,
@ -1143,7 +1097,7 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
CUDNN_RETURN_IF_FAIL(
status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams");
} else {
status = dynload::cudnnGetRNNLinLayerBiasParams(
status = wrap::cudnnGetRNNLinLayerBiasParams(
parent, cudnn_handle /*rnnDesc*/, rnn_desc.handle() /*rnnDesc*/,
layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/,
nullptr /*w*/, region /*linLayerID*/,
@ -1156,7 +1110,7 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
cudnnDataType_t data_type;
cudnnTensorFormat_t tensor_format;
int n_dims;
status = dynload::cudnnGetFilterNdDescriptor(
status = wrap::cudnnGetFilterNdDescriptor(
parent, region_desc_handle /*filterDesc*/,
sizeof(dims) / sizeof(dims[0]) /*nbDimsRequested*/,
&data_type /*dataType*/, &tensor_format /*format*/,
@ -1173,13 +1127,13 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
}
}
}
status = dynload::cudnnDestroyFilterDescriptor(parent, region_desc_handle);
status = wrap::cudnnDestroyFilterDescriptor(parent, region_desc_handle);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy filter descriptor");
}
{
// Release the dummy input tensor descriptor.
auto status = dynload::cudnnDestroyTensorDescriptor(parent, input_desc);
auto status = wrap::cudnnDestroyTensorDescriptor(parent, input_desc);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy tensor descriptor");
}
}
@ -1218,12 +1172,11 @@ class CudnnRnnSequenceTensorDescriptor
SetFailure(port::Status(port::error::UNKNOWN, error_msg));
return;
}
cudnnStatus_t status =
dynload::cudnnCreateTensorDescriptor(parent, &handle);
cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent, &handle);
CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
int dims[] = {batch_size, data_size, 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
status = dynload::cudnnSetTensorNdDescriptor(
status = wrap::cudnnSetTensorNdDescriptor(
parent, handle /*tensorDesc*/, data_type /*dataType*/,
sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
strides /*strideA*/);
@ -1235,7 +1188,7 @@ class CudnnRnnSequenceTensorDescriptor
~CudnnRnnSequenceTensorDescriptor() override {
// Only the first one needs to be destroyed. All others are the same.
cudnnStatus_t status =
dynload::cudnnDestroyTensorDescriptor(parent_, handles_[0]);
wrap::cudnnDestroyTensorDescriptor(parent_, handles_[0]);
CUDNN_RETURN_IF_FAIL(status, "Failed to destroy sequence tensor desciptor");
}
@ -1272,12 +1225,11 @@ class CudnnRnnStateTensorDescriptor
batch_size_(batch_size),
data_size_(data_size),
data_type_(data_type) {
cudnnStatus_t status =
dynload::cudnnCreateTensorDescriptor(parent, &handle_);
cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent, &handle_);
CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
int dims[] = {num_layers, batch_size, data_size};
int strides[] = {dims[1] * dims[2], dims[2], 1};
status = dynload::cudnnSetTensorNdDescriptor(
status = wrap::cudnnSetTensorNdDescriptor(
parent, handle_ /*tensorDesc*/, data_type /*dataType*/,
sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
strides /*strideA*/);
@ -1287,7 +1239,7 @@ class CudnnRnnStateTensorDescriptor
~CudnnRnnStateTensorDescriptor() override {
if (!handle_) {
cudnnStatus_t status =
dynload::cudnnDestroyTensorDescriptor(parent_, handle_);
wrap::cudnnDestroyTensorDescriptor(parent_, handle_);
CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN state tensor");
}
}
@ -1387,7 +1339,7 @@ bool CheckRNNParameterSize(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
const CudnnRnnDescriptor& rnn_desc,
const CudnnRnnSequenceTensorDescriptor& input_desc) {
size_t params_size_in_bytes = 0;
cudnnStatus_t status = dynload::cudnnGetRNNParamsSize(
cudnnStatus_t status = wrap::cudnnGetRNNParamsSize(
parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
input_desc.handles()[0] /*xDesc*/, &params_size_in_bytes /*sizeInBytes*/,
rnn_desc.data_type() /*dataType*/);
@ -1407,7 +1359,7 @@ bool CreateRnnWorkspace(Stream* stream, CUDAExecutor* parent,
DeviceMemory<uint8>* workspace) {
// Query the workspace size.
size_t workspace_size_in_bytes = 0;
cudnnStatus_t status = dynload::cudnnGetRNNWorkspaceSize(
cudnnStatus_t status = wrap::cudnnGetRNNWorkspaceSize(
parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/,
&workspace_size_in_bytes /*sizeInBytes*/);
@ -1482,7 +1434,7 @@ bool CudnnSupport::DoRnnForwardImpl(
DeviceMemory<uint8> reserve_space;
if (is_training) {
size_t reserve_space_size_in_bytes = 0;
cudnnStatus_t status = dynload::cudnnGetRNNTrainingReserveSize(
cudnnStatus_t status = wrap::cudnnGetRNNTrainingReserveSize(
parent_, ToHandle(dnn_handle_) /*handle*/,
rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
input_desc.handles() /*xDesc*/,
@ -1505,7 +1457,7 @@ bool CudnnSupport::DoRnnForwardImpl(
// make the forward call
if (!is_training) {
cudnnStatus_t status = dynload::cudnnRNNForwardInference(
cudnnStatus_t status = wrap::cudnnRNNForwardInference(
parent_, ToHandle(dnn_handle_) /*handle*/,
rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
@ -1523,7 +1475,7 @@ bool CudnnSupport::DoRnnForwardImpl(
return false;
}
} else {
cudnnStatus_t status = dynload::cudnnRNNForwardTraining(
cudnnStatus_t status = wrap::cudnnRNNForwardTraining(
parent_, ToHandle(dnn_handle_) /*handle*/,
rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
@ -1600,7 +1552,7 @@ bool CudnnSupport::DoRnnBackwardImpl(
}
// make the backward data call
cudnnStatus_t status = dynload::cudnnRNNBackwardData(
cudnnStatus_t status = wrap::cudnnRNNBackwardData(
parent_, ToHandle(dnn_handle_) /*handle*/, rnn_desc.handle() /*rnnDesc*/,
model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/,
output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/,
@ -1628,7 +1580,7 @@ bool CudnnSupport::DoRnnBackwardImpl(
// Clear the dw to zeros.
stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
// make the backward weight call
status = dynload::cudnnRNNBackwardWeights(
status = wrap::cudnnRNNBackwardWeights(
parent_, ToHandle(dnn_handle_) /*handle*/,
rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
@ -1845,8 +1797,8 @@ bool CudnnSupport::DoConvolveImpl(
CUDNN_DATA_FLOAT};
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
}
@ -1876,7 +1828,7 @@ bool CudnnSupport::DoConvolveImpl(
}
cudnnConvolutionFwdAlgo_t algo_to_use;
status = dynload::cudnnGetConvolutionForwardAlgorithm(
status = wrap::cudnnGetConvolutionForwardAlgorithm(
parent_, ToHandle(dnn_handle_), input_nd.handle(),
filter.handle(), conv.handle(), output_nd.handle(),
/*preference=*/preference,
@ -1893,7 +1845,7 @@ bool CudnnSupport::DoConvolveImpl(
if (scratch_allocator != nullptr) {
size_t size_in_bytes;
status = dynload::cudnnGetConvolutionForwardWorkspaceSize(
status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
/*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
/*destDesc=*/output_nd.handle(), /*algo=*/algo,
@ -1917,7 +1869,7 @@ bool CudnnSupport::DoConvolveImpl(
algo = ToConvForwardAlgo(algorithm_config.algorithm());
size_t size_in_bytes;
status = dynload::cudnnGetConvolutionForwardWorkspaceSize(
status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
/*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
/*destDesc=*/output_nd.handle(), /*algo=*/algo,
@ -1967,7 +1919,7 @@ bool CudnnSupport::DoConvolveImpl(
return false;
}
}
status = dynload::cudnnConvolutionForward(
status = wrap::cudnnConvolutionForward(
parent_, ToHandle(dnn_handle_),
/*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
/*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
@ -2078,8 +2030,8 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
bool is_training, std::function<const DeviceMemory<T>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
@ -2096,7 +2048,7 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
if (is_training) {
stream->ThenMemZero(batch_mean, batch_mean->size());
stream->ThenMemZero(batch_var, batch_var->size());
status = dynload::cudnnBatchNormalizationForwardTraining(
status = wrap::cudnnBatchNormalizationForwardTraining(
parent_, ToHandle(dnn_handle_), mode, &one, &zero,
x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(),
scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(), 1.0,
@ -2113,7 +2065,7 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
#else
const void* maybe_inv_var = estimated_variance.opaque();
#endif
status = dynload::cudnnBatchNormalizationForwardInference(
status = wrap::cudnnBatchNormalizationForwardInference(
parent_, ToHandle(dnn_handle_), mode, &one, &zero,
x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(),
scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(),
@ -2150,8 +2102,8 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
DeviceMemory<T>* x_backprop, DeviceMemory<T>* scale_backprop,
DeviceMemory<T>* offset_backprop) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
@ -2165,7 +2117,7 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
float one = 1.0;
float zero = 0.0;
status = dynload::cudnnBatchNormalizationBackward(
status = wrap::cudnnBatchNormalizationBackward(
parent_, ToHandle(dnn_handle_), mode, &one, &zero, &one, &zero,
x_descriptor.handle(), x.opaque(), x_descriptor.handle(),
y_backprop.opaque(), x_descriptor.handle(), x_backprop->opaque(),
@ -2249,7 +2201,7 @@ DeviceMemory<T> CudnnSupport::MaybeTransformLayout(
float alpha = 1.0f;
float beta = 0.0f;
auto status = dynload::cudnnTransformTensor(
auto status = wrap::cudnnTransformTensor(
parent_, ToHandle(dnn_handle_), &alpha, orig_out_back_nd.handle(),
backward_output_data.opaque(), &beta, transformed_out_back_nd.handle(),
(*transform_scratch)->mutable_device_memory()->opaque());
@ -2275,8 +2227,8 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
}
@ -2328,7 +2280,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
}
cudnnConvolutionBwdDataAlgo_t algo_to_use;
cudnnStatus_t status = dynload::cudnnGetConvolutionBackwardDataAlgorithm(
cudnnStatus_t status = wrap::cudnnGetConvolutionBackwardDataAlgorithm(
parent_, ToHandle(dnn_handle_),
/*filterDesc=*/filter.handle(),
/*diffDesc=*/out_back_nd.handle(),
@ -2347,7 +2299,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
if (scratch_allocator != nullptr) {
size_t size_in_bytes;
status = dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
status = wrap::cudnnGetConvolutionBackwardDataWorkspaceSize(
parent_, ToHandle(dnn_handle_),
/*filterDesc=*/filter.handle(),
/*diffDesc=*/out_back_nd.handle(),
@ -2373,7 +2325,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
// An algorithm has been specified.
algo = ToConvBackwardDataAlgo(algorithm_config.algorithm());
size_t size_in_bytes;
status = dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
status = wrap::cudnnGetConvolutionBackwardDataWorkspaceSize(
parent_, ToHandle(dnn_handle_),
/*filterDesc=*/filter.handle(),
/*diffDesc=*/out_back_nd.handle(),
@ -2423,9 +2375,9 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
}
#if CUDNN_VERSION >= 5000
status = dynload::cudnnConvolutionBackwardData(
status = wrap::cudnnConvolutionBackwardData(
#else
status = dynload::cudnnConvolutionBackwardData_v3(
status = wrap::cudnnConvolutionBackwardData_v3(
#endif
parent_, ToHandle(dnn_handle_),
/*alpha=*/&alpha,
@ -2508,8 +2460,8 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
}
@ -2566,16 +2518,15 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
}
cudnnConvolutionBwdFilterAlgo_t algo_to_use;
cudnnStatus_t status =
dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
parent_, ToHandle(dnn_handle_),
/*srcDesc=*/input_nd.handle(),
/*diffDesc=*/out_back_nd.handle(),
/*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(),
/*preference=*/preference,
/*memoryLimitInBytes=*/memory_limit_bytes,
/*algo=*/&algo_to_use);
cudnnStatus_t status = wrap::cudnnGetConvolutionBackwardFilterAlgorithm(
parent_, ToHandle(dnn_handle_),
/*srcDesc=*/input_nd.handle(),
/*diffDesc=*/out_back_nd.handle(),
/*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(),
/*preference=*/preference,
/*memoryLimitInBytes=*/memory_limit_bytes,
/*algo=*/&algo_to_use);
CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
"algorithm for doing backward "
"filter convolution";
@ -2586,7 +2537,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
if (scratch_allocator != nullptr) {
size_t size_in_bytes;
status = dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
status = wrap::cudnnGetConvolutionBackwardFilterWorkspaceSize(
parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
/*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(), /*algo=*/algo,
@ -2610,7 +2561,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
algo = ToConvBackwardFilterAlgo(algorithm_config.algorithm());
size_t size_in_bytes;
status = dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
status = wrap::cudnnGetConvolutionBackwardFilterWorkspaceSize(
parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
/*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(), /*algo=*/algo,
@ -2658,9 +2609,9 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
}
#if CUDNN_VERSION >= 5000
status = dynload::cudnnConvolutionBackwardFilter(
status = wrap::cudnnConvolutionBackwardFilter(
#else
status = dynload::cudnnConvolutionBackwardFilter_v3(
status = wrap::cudnnConvolutionBackwardFilter_v3(
#endif
parent_, ToHandle(dnn_handle_), /*alpha=*/&alpha,
/*srcDesc=*/input_nd.handle(),
@ -2737,8 +2688,8 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl(
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<T>* backward_bias_data) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
}
@ -2753,7 +2704,7 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl(
// Beta is the scaling factor for output.
float beta = 0.0;
status = dynload::cudnnConvolutionBackwardBias(
status = wrap::cudnnConvolutionBackwardBias(
parent_, ToHandle(dnn_handle_), &alpha, input_nd.handle(),
input_data.opaque(), &beta, bias_nd.handle(),
backward_bias_data->opaque());
@ -2961,8 +2912,8 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
}
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
@ -2972,13 +2923,12 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
const float beta = 1.0f;
#if CUDNN_VERSION >= 5000
status = dynload::cudnnAddTensor(
status = wrap::cudnnAddTensor(
#else
status = dynload::cudnnAddTensor_v3(
status = wrap::cudnnAddTensor_v3(
#endif
parent_, ToHandle(dnn_handle_), &alpha, bias_descriptor.handle(),
biases.opaque(), &beta, input_descriptor.handle(),
output_data->opaque());
biases.opaque(), &beta, input_descriptor.handle(), output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
@ -2995,8 +2945,8 @@ bool CudnnSupport::DoActivate(Stream* stream,
DeviceMemory<float>* output_data,
uint64 options) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
@ -3039,7 +2989,7 @@ bool CudnnSupport::DoActivate(Stream* stream,
float alpha = 1.0;
// Beta is the output scaling factor.
float beta = 0.0;
status = dynload::cudnnActivationForward(
status = wrap::cudnnActivationForward(
parent_, ToHandle(dnn_handle_),
#if CUDNN_VERSION >= 5000
activation_desc.handle(),
@ -3064,8 +3014,8 @@ bool CudnnSupport::DoPoolForward(
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
@ -3080,7 +3030,7 @@ bool CudnnSupport::DoPoolForward(
ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
CUDNN_DATA_FLOAT};
ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
status = dynload::cudnnPoolingForward(
status = wrap::cudnnPoolingForward(
parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
output_data->opaque());
@ -3099,8 +3049,8 @@ bool CudnnSupport::DoPoolForward(
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<Eigen::half>* output_data) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
@ -3114,7 +3064,7 @@ bool CudnnSupport::DoPoolForward(
ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF};
ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
status = dynload::cudnnPoolingForward(
status = wrap::cudnnPoolingForward(
parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
output_data->opaque());
@ -3135,8 +3085,8 @@ bool CudnnSupport::DoPoolBackward(
const DeviceMemory<float>& input_diff_data,
DeviceMemory<float>* output_diff_data) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
@ -3151,7 +3101,7 @@ bool CudnnSupport::DoPoolBackward(
ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
CUDNN_DATA_FLOAT};
ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
status = dynload::cudnnPoolingBackward(
status = wrap::cudnnPoolingBackward(
parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
@ -3173,8 +3123,8 @@ bool CudnnSupport::DoPoolBackward(
const DeviceMemory<Eigen::half>& input_diff_data,
DeviceMemory<Eigen::half>* output_diff_data) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
@ -3188,7 +3138,7 @@ bool CudnnSupport::DoPoolBackward(
ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF};
ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
status = dynload::cudnnPoolingBackward(
status = wrap::cudnnPoolingBackward(
parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
@ -3224,8 +3174,8 @@ bool CudnnSupport::DoNormalizeWithDimensions(
// Launch the normalization.
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
@ -3239,7 +3189,7 @@ bool CudnnSupport::DoNormalizeWithDimensions(
// Beta is the scaling factor for output.
float beta = 0.0f;
status = dynload::cudnnLRNCrossChannelForward(
status = wrap::cudnnLRNCrossChannelForward(
parent_, ToHandle(dnn_handle_), normalize.handle(),
CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(), input_data.opaque(),
&beta, dims.handle(), output_data->opaque());
@ -3267,8 +3217,8 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
}
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
@ -3280,7 +3230,7 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
float alpha = 1.0f;
float beta = 0.0f;
status = dynload::cudnnLRNCrossChannelBackward(
status = wrap::cudnnLRNCrossChannelBackward(
parent_, ToHandle(dnn_handle_), normalize.handle(),
CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(),
normalized_data.opaque(), dims.handle(),
@ -3402,7 +3352,7 @@ bool CudnnSupport::DeriveOutputBatchDescriptor(
int dn = batch_descriptor.ndims() + 2;
std::vector<int> dims(dn); // in BDYX
auto status = dynload::cudnnGetConvolutionNdForwardOutputDim(
auto status = wrap::cudnnGetConvolutionNdForwardOutputDim(
parent_, conv.handle(), input_nd.handle(), filter.handle(), dn,
dims.data());
if (status != CUDNN_STATUS_SUCCESS) {
@ -3458,12 +3408,6 @@ void initialize_cudnn() {
<< status.error_message();
}
// Prime the cuDNN DSO. The loader will log more information.
auto statusor = gpu::internal::CachedDsoLoader::GetCudnnDsoHandle();
if (!statusor.ok()) {
LOG(INFO) << "Unable to load cuDNN DSO";
}
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
gpu::PluginKind::kDnn,
gpu::cuda::kCuDnnPlugin);

View File

@ -20,9 +20,7 @@ limitations under the License.
#include <stdlib.h>
#include <set>
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
#include "tensorflow/stream_executor/dso_loader.h"
#include "tensorflow/stream_executor/lib/casts.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/error.h"
@ -58,106 +56,6 @@ namespace perftools {
namespace gputools {
namespace cuda {
namespace dynload {
#define PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char *kName; \
using FuncPointerT = std::add_pointer<decltype(::__name)>::type; \
static void *GetDsoHandle() { \
static auto status = internal::CachedDsoLoader::GetLibcudaDsoHandle(); \
return status.ValueOrDie(); \
} \
static FuncPointerT LoadOrDie() { \
void *f; \
port::Status s = port::Env::Default()->GetSymbolFromLibrary( \
GetDsoHandle(), kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in libcuda DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPointerT>(f); \
} \
static FuncPointerT DynLoad() { \
static FuncPointerT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
CUresult operator()(Args... args) { \
return DynLoad()(args...); \
} \
} __name; \
const char *DynLoadShim__##__name::kName = #__name;
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxCreate_v2);
#if CUDA_VERSION >= 7000
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDevicePrimaryCtxRetain);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDevicePrimaryCtxRelease);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDevicePrimaryCtxSetFlags);
#endif
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxDestroy);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxEnablePeerAccess);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxGetCurrent);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxGetDevice);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxGetSharedMemConfig);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxPopCurrent_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxSetCurrent);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxSetSharedMemConfig);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxSynchronize);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceComputeCapability);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceCanAccessPeer);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGet);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGetAttribute);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGetCount);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGetName);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGetPCIBusId);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGetProperties);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceTotalMem);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDriverGetVersion);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventCreate);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventDestroy_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventElapsedTime);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventQuery);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventRecord);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventSynchronize);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuFuncGetAttribute);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuFuncSetCacheConfig);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuGetErrorName);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuGetErrorString);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuInit);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuLaunchKernel);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemAlloc_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyDtoD_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyDtoH_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyHtoD_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyDtoDAsync_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyDtoHAsync_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyHtoDAsync_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemGetAddressRange_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemFree_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemFreeHost);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemGetInfo_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemHostAlloc);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemHostRegister_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemHostUnregister);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD32_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD32Async);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD8_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD8Async);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleGetFunction);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleGetGlobal_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleLoadDataEx);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleLoadFatBinary);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleUnload);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuOccupancyMaxActiveBlocksPerMultiprocessor);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuPointerGetAttribute);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamAddCallback);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamCreate);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamDestroy_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamQuery);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamSynchronize);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamWaitEvent);
} // namespace dynload
namespace {
// Manages the singleton map of contexts that we've created, mapping
@ -374,7 +272,7 @@ namespace {
// Call cuCtxtSynchronize and crash if it doesn't succeed.
void SynchronizeOrDie() {
auto res = dynload::cuCtxSynchronize();
auto res = cuCtxSynchronize();
if (res != CUDA_SUCCESS) {
LOG(FATAL) << "Synchronize found "
<< ToString(res) << " :: " << port::CurrentStackTrace();
@ -410,7 +308,7 @@ ScopedActivateContext::ScopedActivateContext(CudaContext* cuda_context) {
to_restore_ = (tls->depth == 1 ? nullptr : tls->context);
// Set the context and update thread local.
CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxSetCurrent(cuda_context->context()));
CHECK_EQ(CUDA_SUCCESS, cuCtxSetCurrent(cuda_context->context()));
tls->id = cuda_context->id();
tls->context = cuda_context;
}
@ -435,7 +333,7 @@ ScopedActivateContext::~ScopedActivateContext() {
}
// Set context and update thread local.
CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxSetCurrent(to_restore_->context()));
CHECK_EQ(CUDA_SUCCESS, cuCtxSetCurrent(to_restore_->context()));
tls->id = to_restore_->id();
tls->context = to_restore_;
}
@ -496,10 +394,8 @@ static port::Status InternalInit() {
CUresult res = CUDA_ERROR_NO_DEVICE;
if (FLAGS_gpuexec_cuda_driver_inject_init_error) {
LOG(ERROR) << "injecting CUDA init error; initialization will fail";
} else if (internal::CachedDsoLoader::GetLibcudaDsoHandle().ok()) {
// We only call cuInit if we can dynload libcuda.
res = dynload::cuInit(0 /* = flags */);
} else {
res = cuInit(0 /* = flags */);
}
if (res == CUDA_SUCCESS) {
@ -532,7 +428,7 @@ static port::Status InternalInit() {
/* static */ port::Status CUDADriver::GetDevice(int device_ordinal,
CUdevice *device) {
CUresult res = dynload::cuDeviceGet(device, device_ordinal);
CUresult res = cuDeviceGet(device, device_ordinal);
if (res == CUDA_SUCCESS) {
return port::Status::OK();
}
@ -546,8 +442,7 @@ static port::Status InternalInit() {
string *device_name) {
static const size_t kCharLimit = 64;
port::InlinedVector<char, 4> chars(kCharLimit);
CUresult res =
dynload::cuDeviceGetName(chars.begin(), kCharLimit - 1, device);
CUresult res = cuDeviceGetName(chars.begin(), kCharLimit - 1, device);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to get device name for " << device << ": "
<< ToString(res);
@ -603,13 +498,13 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
// context creation: see http://b/13248943
#if CUDA_VERSION >= 7000
res = dynload::cuDevicePrimaryCtxSetFlags(device, flags);
res = dynload::cuDevicePrimaryCtxRetain(&new_context, device);
res = cuDevicePrimaryCtxSetFlags(device, flags);
res = cuDevicePrimaryCtxRetain(&new_context, device);
#else
res = dynload::cuCtxCreate_v2(&new_context, flags, device);
res = cuCtxCreate(&new_context, flags, device);
#endif
}
CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxSetCurrent(former_context));
CHECK_EQ(CUDA_SUCCESS, cuCtxSetCurrent(former_context));
if (res == CUDA_SUCCESS) {
*context = CreatedContexts::Add(new_context);
@ -642,14 +537,14 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
}
#if CUDA_VERSION >= 7000
CUcontext former_context = CurrentContext();
CUresult res = dynload::cuCtxSetCurrent(context->context());
CUresult res = cuCtxSetCurrent(context->context());
CUdevice device;
dynload::cuCtxGetDevice(&device);
dynload::cuCtxSetCurrent(former_context);
cuCtxGetDevice(&device);
cuCtxSetCurrent(former_context);
res = dynload::cuDevicePrimaryCtxRelease(device);
res = cuDevicePrimaryCtxRelease(device);
#else
CUresult res = dynload::cuCtxDestroy_v2(context->context());
CUresult res = cuCtxDestroy(context->context());
#endif
if (res != CUDA_SUCCESS) {
@ -662,7 +557,7 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
/* static */ bool CUDADriver::FuncGetAttribute(CUfunction_attribute attribute,
CUfunction func,
int *attribute_value) {
CUresult res = dynload::cuFuncGetAttribute(attribute_value, attribute, func);
CUresult res = cuFuncGetAttribute(attribute_value, attribute, func);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to query kernel attribute. kernel: " << func
<< ", attribute: " << attribute;
@ -673,7 +568,7 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
/* static */ bool CUDADriver::FuncSetCacheConfig(CUfunction function,
CUfunc_cache cache_config) {
CUresult res = dynload::cuFuncSetCacheConfig(function, cache_config);
CUresult res = cuFuncSetCacheConfig(function, cache_config);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to set CUDA kernel cache config. kernel: " << function
<< ", config: " << cache_config << ", result: " << ToString(res);
@ -687,10 +582,10 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUsharedconfig shared_mem_config;
ScopedActivateContext activation{context};
CUresult result = dynload::cuCtxGetSharedMemConfig(&shared_mem_config);
CUresult result = cuCtxGetSharedMemConfig(&shared_mem_config);
if (result != CUDA_SUCCESS) {
CUdevice device;
dynload::cuCtxGetDevice(&device);
cuCtxGetDevice(&device);
LOG(ERROR) << "failed to get CUDA device shared memory config. "
<< "Context device ID: " << device
<< ", result: " << ToString(result);
@ -704,10 +599,10 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ port::Status CUDADriver::ContextSetSharedMemConfig(
CudaContext* context, CUsharedconfig shared_mem_config) {
ScopedActivateContext activation{context};
CUresult result = dynload::cuCtxSetSharedMemConfig(shared_mem_config);
CUresult result = cuCtxSetSharedMemConfig(shared_mem_config);
if (result != CUDA_SUCCESS) {
CUdevice device;
dynload::cuCtxGetDevice(&device);
cuCtxGetDevice(&device);
LOG(ERROR) << "failed to set CUDA device shared memory config. "
<< "Context device ID: " << device
<< ", config: " << shared_mem_config
@ -730,9 +625,9 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
<< " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
<< " bdx: " << block_dim_x << " bdy: " << block_dim_y
<< " bdz: " << block_dim_z;
CUresult res = dynload::cuLaunchKernel(
function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y,
block_dim_z, shared_mem_bytes, stream, kernel_params, extra);
CUresult res = cuLaunchKernel(function, grid_dim_x, grid_dim_y, grid_dim_z,
block_dim_x, block_dim_y, block_dim_z,
shared_mem_bytes, stream, kernel_params, extra);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to launch CUDA kernel: " << function
<< "; result: " << ToString(res);
@ -746,7 +641,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
const char *cubin_bytes,
CUmodule *module) {
ScopedActivateContext activation{context};
CUresult result = dynload::cuModuleLoadFatBinary(module, cubin_bytes);
CUresult result = cuModuleLoadFatBinary(module, cubin_bytes);
if (result != CUDA_SUCCESS) {
return port::Status{port::error::INTERNAL,
"failed to load in-memory CUBIN: " + ToString(result)};
@ -789,8 +684,8 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
// TODO(leary) Need to see if NVIDIA can expunge the leakiness in their
// module loading: see http://b/13248943
res = dynload::cuModuleLoadDataEx(module, ptx_data, ARRAYSIZE(options),
options, option_values);
res = cuModuleLoadDataEx(module, ptx_data, ARRAYSIZE(options), options,
option_values);
}
// The PTX JIT mutates the values in the option values array to reflect the
@ -829,7 +724,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUdeviceptr location,
uint8 value, size_t size) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemsetD8_v2(location, value, size);
CUresult res = cuMemsetD8(location, value, size);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to memset memory: " << ToString(res);
return false;
@ -842,7 +737,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
uint32 value,
size_t uint32_count) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemsetD32_v2(location, value, uint32_count);
CUresult res = cuMemsetD32(location, value, uint32_count);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to memset memory: " << ToString(res);
return false;
@ -856,8 +751,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
size_t uint32_count,
CUstream stream) {
ScopedActivateContext activation{context};
CUresult res =
dynload::cuMemsetD8Async(location, value, uint32_count, stream);
CUresult res = cuMemsetD8Async(location, value, uint32_count, stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
return false;
@ -872,8 +766,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
size_t uint32_count,
CUstream stream) {
ScopedActivateContext activation{context};
CUresult res =
dynload::cuMemsetD32Async(location, value, uint32_count, stream);
CUresult res = cuMemsetD32Async(location, value, uint32_count, stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
return false;
@ -887,8 +780,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
StreamCallback callback,
void *data) {
// Note: flags param is required to be zero according to CUDA 6.0.
CUresult res =
dynload::cuStreamAddCallback(stream, callback, data, 0 /* = flags */);
CUresult res = cuStreamAddCallback(stream, callback, data, 0 /* = flags */);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "unable to add host callback: " << ToString(res);
return false;
@ -902,7 +794,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUfunction *function) {
ScopedActivateContext activated{context};
CHECK(module != nullptr && kernel_name != nullptr);
CUresult res = dynload::cuModuleGetFunction(function, module, kernel_name);
CUresult res = cuModuleGetFunction(function, module, kernel_name);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to get PTX kernel \"" << kernel_name
<< "\" from module: " << ToString(res);
@ -920,8 +812,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
ScopedActivateContext activated{context};
CHECK(module != nullptr && symbol_name != nullptr &&
(dptr != nullptr || bytes != nullptr));
CUresult res =
dynload::cuModuleGetGlobal_v2(dptr, bytes, module, symbol_name);
CUresult res = cuModuleGetGlobal(dptr, bytes, module, symbol_name);
if (res != CUDA_SUCCESS) {
// symbol may not be found in the current module, but it may reside in
// another module.
@ -936,7 +827,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ void CUDADriver::UnloadModule(CudaContext *context,
CUmodule module) {
ScopedActivateContext activated{context};
CUresult res = dynload::cuModuleUnload(module);
CUresult res = cuModuleUnload(module);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to unload module " << module
<< "; leaking: " << ToString(res);
@ -947,7 +838,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CudaContext* context) {
ScopedActivateContext activated{context};
CUdevice device = -1;
CUresult result = dynload::cuCtxGetDevice(&device);
CUresult result = cuCtxGetDevice(&device);
if (result == CUDA_SUCCESS) {
return device;
}
@ -963,7 +854,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
// up synchronization with respect to memsets and any other things that have
// to occur on the default stream?
ScopedActivateContext activated{context};
CUresult res = dynload::cuStreamCreate(out, 0);
CUresult res = cuStreamCreate(out, 0);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "could not allocate CUDA stream for context " << context
<< ": " << ToString(res);
@ -982,7 +873,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
}
ScopedActivateContext activated{context};
CUresult res = dynload::cuStreamDestroy_v2(*stream);
CUresult res = cuStreamDestroy(*stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to destroy CUDA stream for context " << context
<< ": " << ToString(res);
@ -997,7 +888,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
uint64 bytes) {
ScopedActivateContext activated{context};
CUdeviceptr result = 0;
CUresult res = dynload::cuMemAlloc_v2(&result, bytes);
CUresult res = cuMemAlloc(&result, bytes);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to allocate "
<< port::HumanReadableNumBytes::ToString(bytes) << " (" << bytes
@ -1014,7 +905,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
void *location) {
ScopedActivateContext activation{context};
CUdeviceptr pointer = port::bit_cast<CUdeviceptr>(location);
CUresult res = dynload::cuMemFree_v2(pointer);
CUresult res = cuMemFree(pointer);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to free device memory at " << location
<< "; result: " << ToString(res);
@ -1028,8 +919,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
ScopedActivateContext activation{context};
void *host_mem = nullptr;
// "Portable" memory is visible to all CUDA contexts. Safe for our use model.
CUresult res =
dynload::cuMemHostAlloc(&host_mem, bytes, CU_MEMHOSTALLOC_PORTABLE);
CUresult res = cuMemHostAlloc(&host_mem, bytes, CU_MEMHOSTALLOC_PORTABLE);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to alloc " << bytes
<< " bytes on host: " << ToString(res);
@ -1040,7 +930,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ void CUDADriver::HostDeallocate(CudaContext* context,
void *location) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemFreeHost(location);
CUresult res = cuMemFreeHost(location);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "error deallocating host memory at " << location << ": "
<< ToString(res);
@ -1052,7 +942,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
ScopedActivateContext activation{context};
// "Portable" memory is visible to all CUDA contexts. Safe for our use model.
CUresult res =
dynload::cuMemHostRegister(location, bytes, CU_MEMHOSTREGISTER_PORTABLE);
cuMemHostRegister(location, bytes, CU_MEMHOSTREGISTER_PORTABLE);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "error registering host memory at " << location << ": "
<< ToString(res);
@ -1064,7 +954,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ bool CUDADriver::HostUnregister(CudaContext* context,
void *location) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemHostUnregister(location);
CUresult res = cuMemHostUnregister(location);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "error unregistering host memory at " << location << ": "
<< ToString(res);
@ -1081,7 +971,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
}
ScopedActivateContext activated{context};
CUresult res = dynload::cuEventDestroy_v2(*event);
CUresult res = cuEventDestroy(*event);
*event = nullptr;
switch (res) {
@ -1105,7 +995,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUevent event,
CUstream stream) {
ScopedActivateContext activated{context};
CUresult res = dynload::cuEventRecord(event, stream);
CUresult res = cuEventRecord(event, stream);
switch (res) {
case CUDA_SUCCESS:
return port::Status::OK();
@ -1126,7 +1016,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ port::StatusOr<CUresult> CUDADriver::QueryEvent(
CudaContext *context, CUevent event) {
ScopedActivateContext activated{context};
CUresult res = dynload::cuEventQuery(event);
CUresult res = cuEventQuery(event);
if (res != CUDA_SUCCESS && res != CUDA_ERROR_NOT_READY) {
return port::Status{
port::error::INTERNAL,
@ -1142,12 +1032,12 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
ScopedActivateContext activated{context};
// The stop event must have completed in order for cuEventElapsedTime to
// work.
CUresult res = dynload::cuEventSynchronize(stop);
CUresult res = cuEventSynchronize(stop);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res);
return false;
}
res = dynload::cuEventElapsedTime(elapsed_milliseconds, start, stop);
res = cuEventElapsedTime(elapsed_milliseconds, start, stop);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to get elapsed time between events: "
<< ToString(res);
@ -1161,7 +1051,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUstream stream,
CUevent event) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuStreamWaitEvent(stream, event, 0 /* = flags */);
CUresult res = cuStreamWaitEvent(stream, event, 0 /* = flags */);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "could not wait stream on event: " << ToString(res);
return false;
@ -1172,7 +1062,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ bool CUDADriver::SynchronizeContext(CudaContext* context) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuCtxSynchronize();
CUresult res = cuCtxSynchronize();
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "could not synchronize on CUDA context: " << ToString(res)
<< " :: " << port::CurrentStackTrace();
@ -1186,7 +1076,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUstream stream) {
ScopedActivateContext activated{context};
CHECK(stream != nullptr);
CUresult res = dynload::cuStreamSynchronize(stream);
CUresult res = cuStreamSynchronize(stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "could not synchronize on CUDA stream: " << ToString(res)
<< " :: " << port::CurrentStackTrace();
@ -1201,7 +1091,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUstream stream) {
ScopedActivateContext activated{context};
CHECK(stream != nullptr);
CUresult res = dynload::cuStreamQuery(stream);
CUresult res = cuStreamQuery(stream);
if (res == CUDA_SUCCESS) {
return true;
}
@ -1217,7 +1107,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUdeviceptr gpu_src,
uint64 size) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemcpyDtoH_v2(host_dst, gpu_src, size);
CUresult res = cuMemcpyDtoH(host_dst, gpu_src, size);
if (res != CUDA_SUCCESS) {
return port::InternalError(
port::Printf("failed to synchronous memcpy from device to host: %s; "
@ -1235,7 +1125,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
const void *host_src,
uint64 size) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemcpyHtoD_v2(gpu_dst, host_src, size);
CUresult res = cuMemcpyHtoD(gpu_dst, host_src, size);
if (res != CUDA_SUCCESS) {
return port::InternalError(port::Printf(
"failed to synchronous memcpy from host to device: %s; GPU dst: %p;"
@ -1252,7 +1142,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUdeviceptr gpu_src,
uint64 size) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemcpyDtoD_v2(gpu_dst, gpu_src, size);
CUresult res = cuMemcpyDtoD(gpu_dst, gpu_src, size);
if (res != CUDA_SUCCESS) {
return port::InternalError(port::Printf(
"failed to synchronous memcpy from host to device: %s; GPU dst: %p; "
@ -1270,7 +1160,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
uint64 size,
CUstream stream) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemcpyDtoHAsync_v2(host_dst, gpu_src, size, stream);
CUresult res = cuMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << port::Printf(
"failed to enqueue async memcpy from device to host: %s; host dst: %p; "
@ -1290,7 +1180,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
uint64 size,
CUstream stream) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemcpyHtoDAsync_v2(gpu_dst, host_src, size, stream);
CUresult res = cuMemcpyHtoDAsync(gpu_dst, host_src, size, stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << port::Printf(
"failed to enqueue async memcpy from host to device: %s; GPU dst: %p; "
@ -1309,8 +1199,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
uint64 size,
CUstream stream) {
ScopedActivateContext activation{context};
CUresult result =
dynload::cuMemcpyDtoDAsync_v2(gpu_dst, gpu_src, size, stream);
CUresult result = cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
if (result != CUDA_SUCCESS) {
LOG(ERROR) << port::Printf(
"failed to enqueue async memcpy from device to device: %s"
@ -1346,7 +1235,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
}
ScopedActivateContext activated{context};
CUresult res = dynload::cuEventCreate(result, cuflags);
CUresult res = cuEventCreate(result, cuflags);
if (res == CUDA_SUCCESS) {
return port::Status::OK();
@ -1362,7 +1251,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ int CUDADriver::GetDeviceCount() {
int device_count = 0;
CUresult res = dynload::cuDeviceGetCount(&device_count);
CUresult res = cuDeviceGetCount(&device_count);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "could not retrieve CUDA device count: " << ToString(res);
return 0;
@ -1377,8 +1266,8 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ port::StatusOr<CudaContext*> CUDADriver::GetPointerContext(
CUdeviceptr pointer) {
CudaContext* context = nullptr;
CUresult result = dynload::cuPointerGetAttribute(
&context, CU_POINTER_ATTRIBUTE_CONTEXT, pointer);
CUresult result =
cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT, pointer);
if (result == CUDA_SUCCESS) {
CHECK(context != nullptr) << "success should entail non-null context";
return context;
@ -1393,8 +1282,8 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ port::StatusOr<MemorySpace> CUDADriver::GetPointerMemorySpace(
CUdeviceptr pointer) {
unsigned int value;
CUresult result = dynload::cuPointerGetAttribute(
&value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer);
CUresult result =
cuPointerGetAttribute(&value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer);
if (result == CUDA_SUCCESS) {
switch (value) {
case CU_MEMORYTYPE_DEVICE:
@ -1417,7 +1306,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ port::Status CUDADriver::GetPointerAddressRange(CUdeviceptr dptr,
CUdeviceptr *base,
size_t *size) {
CUresult result = dynload::cuMemGetAddressRange(base, size, dptr);
CUresult result = cuMemGetAddressRange(base, size, dptr);
if (result == CUDA_SUCCESS) {
return port::Status::OK();
} else if (result == CUDA_ERROR_NOT_FOUND) {
@ -1451,8 +1340,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUdevice device) {
*cc_major = 0;
*cc_minor = 0;
CUresult result =
dynload::cuDeviceComputeCapability(cc_major, cc_minor, device);
CUresult result = cuDeviceComputeCapability(cc_major, cc_minor, device);
if (result == CUDA_SUCCESS) {
return port::Status::OK();
}
@ -1469,7 +1357,7 @@ template <typename T>
static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
CUdevice_attribute attribute) {
int value = -1;
CUresult result = dynload::cuDeviceGetAttribute(&value, attribute, device);
CUresult result = cuDeviceGetAttribute(&value, attribute, device);
if (result != CUDA_SUCCESS) {
return port::Status{
port::error::NOT_FOUND,
@ -1524,24 +1412,24 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
/* static */ bool CUDADriver::GetGridLimits(int *x, int *y, int *z,
CUdevice device) {
int value;
CUresult res = dynload::cuDeviceGetAttribute(
&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device);
CUresult res =
cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to query max grid dim x: " << ToString(res);
return false;
}
*x = value;
res = dynload::cuDeviceGetAttribute(
&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device);
res =
cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to query max grid dim y: " << ToString(res);
return false;
}
*y = value;
res = dynload::cuDeviceGetAttribute(
&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device);
res =
cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to query max grid dim z: " << ToString(res);
return false;
@ -1551,7 +1439,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
}
/* static */ bool CUDADriver::GetDriverVersion(int *driver_version) {
CUresult res = dynload::cuDriverGetVersion(driver_version);
CUresult res = cuDriverGetVersion(driver_version);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to query driver version: " << ToString(res);
return false;
@ -1562,8 +1450,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
/* static */ bool CUDADriver::GetDeviceProperties(CUdevprop *device_properties,
int device_ordinal) {
CUresult res =
dynload::cuDeviceGetProperties(device_properties, device_ordinal);
CUresult res = cuDeviceGetProperties(device_properties, device_ordinal);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to query device properties: " << ToString(res);
return false;
@ -1574,8 +1461,8 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
/* static */ bool CUDADriver::IsEccEnabled(CUdevice device, bool *result) {
int value = -1;
CUresult res = dynload::cuDeviceGetAttribute(
&value, CU_DEVICE_ATTRIBUTE_ECC_ENABLED, device);
CUresult res =
cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_ECC_ENABLED, device);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to query ECC status: " << ToString(res);
return false;
@ -1591,7 +1478,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
ScopedActivateContext activation{context};
size_t free = 0;
size_t total = 0;
CUresult res = dynload::cuMemGetInfo_v2(&free, &total);
CUresult res = cuMemGetInfo(&free, &total);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to query device memory info: " << ToString(res);
return false;
@ -1605,7 +1492,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
/* static */ bool CUDADriver::GetDeviceTotalMemory(CUdevice device,
uint64 *result) {
size_t value = -1;
CUresult res = dynload::cuDeviceTotalMem_v2(&value, device);
CUresult res = cuDeviceTotalMem(&value, device);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to query total available memory: " << ToString(res);
return false;
@ -1620,8 +1507,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
static const int kBufferSize = 64;
port::InlinedVector<char, 4> chars(kBufferSize);
chars[kBufferSize - 1] = '\0';
CUresult res =
dynload::cuDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device);
CUresult res = cuDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res);
return pci_bus_id;
@ -1649,7 +1535,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
<< to_device.status();
return false;
}
CUresult res = dynload::cuDeviceCanAccessPeer(
CUresult res = cuDeviceCanAccessPeer(
&can_access_peer, from_device.ValueOrDie(), to_device.ValueOrDie());
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to detect peer access capability: " << ToString(res);
@ -1666,8 +1552,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
}
ScopedActivateContext activated{from};
CUresult result =
dynload::cuCtxEnablePeerAccess(to->context(), 0 /* = flags */);
CUresult result = cuCtxEnablePeerAccess(to->context(), 0 /* = flags */);
if (result != CUDA_SUCCESS &&
result != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED) {
return port::Status{
@ -1685,7 +1570,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
ScopedActivateContext activation{context};
int max_blocks;
CUresult result = dynload::cuOccupancyMaxActiveBlocksPerMultiprocessor(
CUresult result = cuOccupancyMaxActiveBlocksPerMultiprocessor(
&max_blocks, kernel, threads_per_block, dynamic_shared_memory_bytes);
if (result != CUDA_SUCCESS) {
return port::Status{
@ -1699,7 +1584,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
/* static */ CUcontext CUDADriver::CurrentContextOrDie() {
CUcontext current = nullptr;
CUresult result = dynload::cuCtxGetCurrent(&current);
CUresult result = cuCtxGetCurrent(&current);
if (result != CUDA_SUCCESS) {
LOG(FATAL) << "failed to query current context: " << ToString(result);
}

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dso_loader.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/status.h"
@ -38,36 +37,21 @@ namespace cuda {
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuFftPlugin);
namespace dynload {
namespace wrap {
// This macro wraps a global identifier, given by __name, in a callable
// structure that loads the DLL symbol out of the DSO handle in a thread-safe
// manner on first use. This dynamic loading technique is used to avoid DSO
// dependencies on vendor libraries which may or may not be available in the
// deployed binary environment.
#define PERFTOOLS_GPUTOOLS_CUFFT_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char *kName; \
using FuncPointerT = std::add_pointer<decltype(::__name)>::type; \
static void *GetDsoHandle() { \
static auto status = internal::CachedDsoLoader::GetCufftDsoHandle(); \
return status.ValueOrDie(); \
} \
static FuncPointerT DynLoad() { \
static void *f; \
port::Status s = port::Env::Default()->GetSymbolFromLibrary( \
GetDsoHandle(), kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in cuFFT DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPointerT>(f); \
} \
template <typename... Args> \
cufftResult operator()(CUDAExecutor *parent, Args... args) { \
cuda::ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char *DynLoadShim__##__name::kName = #__name;
#define PERFTOOLS_GPUTOOLS_CUFFT_WRAP(__name) \
struct WrapperShim__##__name { \
template <typename... Args> \
cufftResult operator()(CUDAExecutor *parent, Args... args) { \
cuda::ScopedActivateExecutorContext sac{parent}; \
return ::__name(args...); \
} \
} __name;
#define CUFFT_ROUTINE_EACH(__macro) \
__macro(cufftDestroy) __macro(cufftSetStream) __macro(cufftPlan1d) \
@ -78,7 +62,7 @@ namespace dynload {
CUFFT_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUFFT_WRAP)
} // namespace dynload
} // namespace wrap
namespace {
@ -106,7 +90,7 @@ cufftType CUDAFftType(fft::Type type) {
// Associates the given stream with the given cuFFT plan.
bool SetStream(CUDAExecutor *parent, cufftHandle plan, Stream *stream) {
auto ret = dynload::cufftSetStream(parent, plan, AsCUDAStreamValue(stream));
auto ret = wrap::cufftSetStream(parent, plan, AsCUDAStreamValue(stream));
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to run cuFFT routine cufftSetStream: " << ret;
return false;
@ -118,8 +102,8 @@ bool SetStream(CUDAExecutor *parent, cufftHandle plan, Stream *stream) {
CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, fft::Type type)
: parent_(parent), fft_type_(type) {
auto ret = dynload::cufftPlan1d(parent, &plan_, num_x, CUDAFftType(type),
1 /* = batch */);
auto ret = wrap::cufftPlan1d(parent, &plan_, num_x, CUDAFftType(type),
1 /* = batch */);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to create cuFFT 1d plan:" << ret;
}
@ -128,8 +112,7 @@ CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, fft::Type type)
CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, uint64 num_y,
fft::Type type)
: parent_(parent), fft_type_(type) {
auto ret =
dynload::cufftPlan2d(parent, &plan_, num_x, num_y, CUDAFftType(type));
auto ret = wrap::cufftPlan2d(parent, &plan_, num_x, num_y, CUDAFftType(type));
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to create cuFFT 2d plan:" << ret;
}
@ -138,8 +121,8 @@ CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, uint64 num_y,
CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, uint64 num_y,
uint64 num_z, fft::Type type)
: parent_(parent), fft_type_(type) {
auto ret = dynload::cufftPlan3d(parent, &plan_, num_x, num_y, num_z,
CUDAFftType(type));
auto ret =
wrap::cufftPlan3d(parent, &plan_, num_x, num_y, num_z, CUDAFftType(type));
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to create cuFFT 3d plan:" << ret;
}
@ -161,7 +144,7 @@ CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, int rank, uint64 *elem_count,
output_embed_[i] = output_embed[i];
}
}
auto ret = dynload::cufftPlanMany(
auto ret = wrap::cufftPlanMany(
parent, &plan_, rank, elem_count_, input_embed ? input_embed_ : nullptr,
input_stride, input_distance, output_embed ? output_embed_ : nullptr,
output_stride, output_distance, CUDAFftType(type), batch_count);
@ -170,7 +153,7 @@ CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, int rank, uint64 *elem_count,
}
}
CUDAFftPlan::~CUDAFftPlan() { dynload::cufftDestroy(parent_, plan_); }
CUDAFftPlan::~CUDAFftPlan() { wrap::cufftDestroy(parent_, plan_); }
int CUDAFftPlan::GetFftDirection() const {
switch (fft_type_) {
@ -277,25 +260,25 @@ bool CUDAFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
return true;
}
#define PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
__fft_type3) \
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<__type>> &input, \
DeviceMemory<std::complex<__type>> *output) { \
return DoFftWithDirectionInternal( \
stream, plan, dynload::cufftExec##__fft_type1, input, output); \
} \
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<__type> &input, \
DeviceMemory<std::complex<__type>> *output) { \
return DoFftInternal(stream, plan, dynload::cufftExec##__fft_type2, input, \
output); \
} \
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<__type>> &input, \
DeviceMemory<__type> *output) { \
return DoFftInternal(stream, plan, dynload::cufftExec##__fft_type3, input, \
output); \
#define PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
__fft_type3) \
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<__type>> &input, \
DeviceMemory<std::complex<__type>> *output) { \
return DoFftWithDirectionInternal( \
stream, plan, wrap::cufftExec##__fft_type1, input, output); \
} \
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<__type> &input, \
DeviceMemory<std::complex<__type>> *output) { \
return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type2, input, \
output); \
} \
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<__type>> &input, \
DeviceMemory<__type> *output) { \
return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type3, input, \
output); \
}
PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(float, C2C, R2C, C2R)
@ -332,12 +315,6 @@ REGISTER_MODULE_INITIALIZER(register_cufft, {
<< status.error_message();
}
// Prime the cuFFT DSO. The loader will log more information.
auto statusor = gpu::internal::CachedDsoLoader::GetCufftDsoHandle();
if (!statusor.ok()) {
LOG(INFO) << "Unable to load cuFFT DSO.";
}
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
gpu::PluginKind::kFft,
gpu::cuda::kCuFftPlugin);

View File

@ -1059,19 +1059,6 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const {
namespace gpu = ::perftools::gputools;
void initialize_cuda_gpu_executor() {
port::StatusOr<void *> status =
gpu::internal::CachedDsoLoader::GetLibcudaDsoHandle();
if (!status.ok()) {
gpu::cuda::Diagnostician::LogDriverVersionInformation();
LOG(INFO) << "LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH");
LOG(INFO) << "failed to find libcuda.so on this system: "
<< status.status();
}
// TODO(b/22689637): Temporary until users are migrated off of PlatformKind.
gpu::PluginRegistry::Instance()->MapPlatformKindToId(
gpu::PlatformKind::kCuda, gpu::cuda::kCudaPlatformId);
*gpu::internal::MakeCUDAExecutorImplementation() = [](
const gpu::PluginConfig &config) {
return new gpu::cuda::CUDAExecutor{config};

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dso_loader.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/status.h"
@ -61,35 +60,16 @@ namespace cuda {
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuRandPlugin);
namespace dynload {
namespace wrap {
#define PERFTOOLS_GPUTOOLS_CURAND_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char *kName; \
using FuncPointerT = std::add_pointer<decltype(::__name)>::type; \
static void *GetDsoHandle() { \
static auto status = internal::CachedDsoLoader::GetCurandDsoHandle(); \
return status.ValueOrDie(); \
} \
static FuncPointerT LoadOrDie() { \
void *f; \
port::Status s = port::Env::Default()->GetSymbolFromLibrary( \
GetDsoHandle(), kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in curand DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPointerT>(f); \
} \
static FuncPointerT DynLoad() { \
static FuncPointerT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
curandStatus_t operator()(CUDAExecutor *parent, Args... args) { \
cuda::ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char *DynLoadShim__##__name::kName = #__name;
#define PERFTOOLS_GPUTOOLS_CURAND_WRAP(__name) \
struct WrapperShim__##__name { \
template <typename... Args> \
curandStatus_t operator()(CUDAExecutor *parent, Args... args) { \
cuda::ScopedActivateExecutorContext sac{parent}; \
return ::__name(args...); \
} \
} __name;
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandCreateGenerator);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandDestroyGenerator);
@ -101,7 +81,7 @@ PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetGeneratorOffset);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormal);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormalDouble);
} // namespace dynload
} // namespace wrap
template <typename T>
string TypeString();
@ -130,7 +110,7 @@ CUDARng::CUDARng(CUDAExecutor *parent) : parent_(parent), rng_(nullptr) {}
CUDARng::~CUDARng() {
if (rng_ != nullptr) {
dynload::curandDestroyGenerator(parent_, rng_);
wrap::curandDestroyGenerator(parent_, rng_);
}
}
@ -139,7 +119,7 @@ bool CUDARng::Init() {
CHECK(rng_ == nullptr);
curandStatus_t ret =
dynload::curandCreateGenerator(parent_, &rng_, CURAND_RNG_PSEUDO_DEFAULT);
wrap::curandCreateGenerator(parent_, &rng_, CURAND_RNG_PSEUDO_DEFAULT);
if (ret != CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "failed to create random number generator: " << ret;
return false;
@ -151,7 +131,7 @@ bool CUDARng::Init() {
bool CUDARng::SetStream(Stream *stream) {
curandStatus_t ret =
dynload::curandSetStream(parent_, rng_, AsCUDAStreamValue(stream));
wrap::curandSetStream(parent_, rng_, AsCUDAStreamValue(stream));
if (ret != CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for random generation: " << ret;
return false;
@ -189,11 +169,11 @@ bool CUDARng::DoPopulateRandUniformInternal(Stream *stream,
curandStatus_t ret;
if (std::is_same<T, float>::value ||
std::is_same<T, std::complex<float>>::value) {
ret = dynload::curandGenerateUniform(
ret = wrap::curandGenerateUniform(
parent_, rng_, reinterpret_cast<float *>(CUDAMemoryMutable(v)),
element_count);
} else {
ret = dynload::curandGenerateUniformDouble(
ret = wrap::curandGenerateUniformDouble(
parent_, rng_, reinterpret_cast<double *>(CUDAMemoryMutable(v)),
element_count);
}
@ -252,13 +232,13 @@ bool CUDARng::DoPopulateRandGaussianInternal(Stream *stream, ElemT mean,
bool CUDARng::DoPopulateRandGaussian(Stream *stream, float mean, float stddev,
DeviceMemory<float> *v) {
return DoPopulateRandGaussianInternal(stream, mean, stddev, v,
dynload::curandGenerateNormal);
wrap::curandGenerateNormal);
}
bool CUDARng::DoPopulateRandGaussian(Stream *stream, double mean, double stddev,
DeviceMemory<double> *v) {
return DoPopulateRandGaussianInternal(stream, mean, stddev, v,
dynload::curandGenerateNormalDouble);
wrap::curandGenerateNormalDouble);
}
bool CUDARng::SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) {
@ -275,14 +255,14 @@ bool CUDARng::SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) {
// Requires 8 bytes of seed data; checked in RngSupport::CheckSeed (above)
// (which itself requires 16 for API consistency with host RNG fallbacks).
curandStatus_t ret = dynload::curandSetPseudoRandomGeneratorSeed(
curandStatus_t ret = wrap::curandSetPseudoRandomGeneratorSeed(
parent_, rng_, *(reinterpret_cast<const uint64 *>(seed)));
if (ret != CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set rng seed: " << ret;
return false;
}
ret = dynload::curandSetGeneratorOffset(parent_, rng_, 0);
ret = wrap::curandSetGeneratorOffset(parent_, rng_, 0);
if (ret != CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "failed to reset rng position: " << ret;
return false;
@ -326,12 +306,6 @@ REGISTER_MODULE_INITIALIZER(register_curand, {
<< status.error_message();
}
// Prime the cuRAND DSO. The loader will log more information.
auto statusor = gpu::internal::CachedDsoLoader::GetCurandDsoHandle();
if (!statusor.ok()) {
LOG(INFO) << "Unable to load cuRAND DSO.";
}
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
gpu::PluginKind::kRng,
gpu::cuda::kCuRandPlugin);

View File

@ -57,6 +57,13 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "cuda_driver",
srcs = ["lib/%{cuda_driver_lib}"],
includes = ["include/"],
visibility = ["//visibility:public"],
)
cc_library(
name = "cudart",
srcs = ["lib/%{cudart_lib}"],

View File

@ -407,6 +407,9 @@ def _find_cuda_lib(lib, repository_ctx, cpu_value, basedir, version="",
file_name = _lib_name(lib, cpu_value, version, static)
if cpu_value == "Linux":
path = repository_ctx.path("%s/lib64/%s" % (basedir, file_name))
if path.exists:
return struct(file_name=file_name, path=str(path.realpath))
path = repository_ctx.path("%s/lib64/stubs/%s" % (basedir, file_name))
if path.exists:
return struct(file_name=file_name, path=str(path.realpath))
path = repository_ctx.path(
@ -492,6 +495,7 @@ def _find_libs(repository_ctx, cuda_config):
cudnn_ext = ".%s" % cudnn_version if cudnn_version else ""
cpu_value = cuda_config.cpu_value
return {
"cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path),
"cudart": _find_cuda_lib(
"cudart", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
cuda_config.cuda_version),
@ -648,6 +652,7 @@ def _create_dummy_repository(repository_ctx):
})
_tpl(repository_ctx, "cuda:BUILD",
{
"%{cuda_driver_lib}": _lib_name("cuda", cpu_value),
"%{cudart_static_lib}": _lib_name("cudart_static", cpu_value,
static=True),
"%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
@ -660,6 +665,7 @@ def _create_dummy_repository(repository_ctx):
})
_tpl(repository_ctx, "cuda:BUILD",
{
"%{cuda_driver_lib}": _lib_name("cuda", cpu_value),
"%{cudart_static_lib}": _lib_name("cudart_static", cpu_value,
static=True),
"%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
@ -683,6 +689,7 @@ def _create_dummy_repository(repository_ctx):
repository_ctx.file("cuda/include/cublas.h", "")
repository_ctx.file("cuda/include/cudnn.h", "")
repository_ctx.file("cuda/extras/CUPTI/include/cupti.h", "")
repository_ctx.file("cuda/lib/%s" % _lib_name("cuda", cpu_value))
repository_ctx.file("cuda/lib/%s" % _lib_name("cudart", cpu_value))
repository_ctx.file("cuda/lib/%s" % _lib_name("cudart_static", cpu_value))
repository_ctx.file("cuda/lib/%s" % _lib_name("cublas", cpu_value))
@ -756,6 +763,7 @@ def _create_cuda_repository(repository_ctx):
})
_tpl(repository_ctx, "cuda:BUILD",
{
"%{cuda_driver_lib}": cuda_libs["cuda"].file_name,
"%{cudart_static_lib}": cuda_libs["cudart_static"].file_name,
"%{cudart_static_linkopt}": _cudart_static_linkopt(
cuda_config.cpu_value),