Roll PR #39577 (CUDNN v8 support) forward with fix:
Handle '64_' prefix in the cuDNN version for Windows. PiperOrigin-RevId: 314385732 Change-Id: I9b36497dbe460e2e7cad3d755cf9ce8d41647ce3
This commit is contained in:
		
							parent
							
								
									e7b9f9149e
								
							
						
					
					
						commit
						255f590ab6
					
				
							
								
								
									
										85
									
								
								tensorflow/stream_executor/cuda/cuda_dnn.cc
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										85
									
								
								tensorflow/stream_executor/cuda/cuda_dnn.cc
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							@ -1278,6 +1278,18 @@ port::Status CheckAndFetchProjectionWeights(
 | 
			
		||||
  cudnnRNNMode_t mode;
 | 
			
		||||
  cudnnRNNAlgo_t algo;
 | 
			
		||||
  cudnnDataType_t data_type;
 | 
			
		||||
#if CUDNN_VERSION >= 8000
 | 
			
		||||
  RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v6(
 | 
			
		||||
      /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
 | 
			
		||||
      /*hiddenSize=*/&hidden_size_v,
 | 
			
		||||
      /*numLayers=*/&num_layers_v,
 | 
			
		||||
      /*dropoutDesc=*/&dropout_desc,
 | 
			
		||||
      /*inputMode=*/&input_mode,
 | 
			
		||||
      /*direction=*/&direction,
 | 
			
		||||
      /*mode=*/&mode,
 | 
			
		||||
      /*algo=*/&algo,
 | 
			
		||||
      /*mathPrec=*/&data_type));
 | 
			
		||||
#else
 | 
			
		||||
  RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
 | 
			
		||||
      /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
 | 
			
		||||
      /*hiddenSize=*/&hidden_size_v,
 | 
			
		||||
@ -1287,7 +1299,8 @@ port::Status CheckAndFetchProjectionWeights(
 | 
			
		||||
      /*direction=*/&direction,
 | 
			
		||||
      /*mode=*/&mode,
 | 
			
		||||
      /*algo=*/&algo,
 | 
			
		||||
      /*dataType=*/&data_type));
 | 
			
		||||
      /*mathPrec=*/&data_type));
 | 
			
		||||
#endif
 | 
			
		||||
  int rec_proj_size_v;
 | 
			
		||||
  int out_proj_size_v;
 | 
			
		||||
  RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
 | 
			
		||||
@ -2424,6 +2437,28 @@ port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
 | 
			
		||||
    const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
 | 
			
		||||
    const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
 | 
			
		||||
    size_t memory_limit_bytes) {
 | 
			
		||||
#if CUDNN_VERSION >= 8000
 | 
			
		||||
  const int num_requested_algos = 5;
 | 
			
		||||
  int num_returned_algos = 0;
 | 
			
		||||
  cudnnConvolutionFwdAlgoPerf_t perf_results[num_requested_algos];
 | 
			
		||||
 | 
			
		||||
  RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
 | 
			
		||||
      cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
 | 
			
		||||
      output_nd.handle(), num_requested_algos, &num_returned_algos,
 | 
			
		||||
      perf_results));
 | 
			
		||||
 | 
			
		||||
  size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
 | 
			
		||||
  for (int r = 0; r < num_returned_algos; r++) {
 | 
			
		||||
    if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
 | 
			
		||||
        perf_results[r].algo != CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
 | 
			
		||||
        perf_results[r].memory <= mem_limit) {
 | 
			
		||||
      return perf_results[r].algo;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return port::Status(port::error::INTERNAL,
 | 
			
		||||
                      "cudnnGetConvolutionForwardAlgorithm_v7 returned "
 | 
			
		||||
                      "no suitable algorithms. This could be a cudnn bug.");
 | 
			
		||||
#else
 | 
			
		||||
  cudnnConvolutionFwdPreference_t preference =
 | 
			
		||||
      specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
 | 
			
		||||
                              : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
 | 
			
		||||
@ -2432,6 +2467,7 @@ port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
 | 
			
		||||
      cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
 | 
			
		||||
      output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
 | 
			
		||||
  return algo_to_use;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
 | 
			
		||||
@ -2442,6 +2478,29 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
 | 
			
		||||
                                    const CudnnTensorDescriptor& output_nd,
 | 
			
		||||
                                    bool specify_workspace_limit,
 | 
			
		||||
                                    size_t memory_limit_bytes) {
 | 
			
		||||
#if CUDNN_VERSION >= 8000
 | 
			
		||||
  const int num_requested_algos = 5;
 | 
			
		||||
  int num_returned_algos = 0;
 | 
			
		||||
  cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_requested_algos];
 | 
			
		||||
 | 
			
		||||
  RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(
 | 
			
		||||
      cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
 | 
			
		||||
      input_nd.handle(), num_requested_algos, &num_returned_algos,
 | 
			
		||||
      perf_results));
 | 
			
		||||
 | 
			
		||||
  size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
 | 
			
		||||
  for (int r = 0; r < num_returned_algos; r++) {
 | 
			
		||||
    if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
 | 
			
		||||
        perf_results[r].algo !=
 | 
			
		||||
            CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED &&
 | 
			
		||||
        perf_results[r].memory <= mem_limit) {
 | 
			
		||||
      return perf_results[r].algo;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return port::Status(port::error::INTERNAL,
 | 
			
		||||
                      "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned "
 | 
			
		||||
                      "no suitable algorithms. This could be a cudnn bug.");
 | 
			
		||||
#else
 | 
			
		||||
  cudnnConvolutionBwdDataPreference_t preference =
 | 
			
		||||
      specify_workspace_limit
 | 
			
		||||
          ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
 | 
			
		||||
@ -2451,6 +2510,7 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
 | 
			
		||||
      cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
 | 
			
		||||
      input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
 | 
			
		||||
  return algo_to_use;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
 | 
			
		||||
@ -2461,6 +2521,28 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
 | 
			
		||||
                                      const CudnnTensorDescriptor& output_nd,
 | 
			
		||||
                                      bool specify_workspace_limit,
 | 
			
		||||
                                      size_t memory_limit_bytes) {
 | 
			
		||||
#if CUDNN_VERSION >= 8000
 | 
			
		||||
  const int num_requested_algos = 5;
 | 
			
		||||
  int num_returned_algos = 0;
 | 
			
		||||
  cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_requested_algos];
 | 
			
		||||
 | 
			
		||||
  RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
 | 
			
		||||
      cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
 | 
			
		||||
      filter.handle(), num_requested_algos, &num_returned_algos, perf_results));
 | 
			
		||||
 | 
			
		||||
  size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
 | 
			
		||||
  for (int r = 0; r < num_returned_algos; r++) {
 | 
			
		||||
    if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
 | 
			
		||||
        perf_results[r].algo !=
 | 
			
		||||
            CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED &&
 | 
			
		||||
        perf_results[r].memory <= mem_limit) {
 | 
			
		||||
      return perf_results[r].algo;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return port::Status(port::error::INTERNAL,
 | 
			
		||||
                      "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned "
 | 
			
		||||
                      "no suitable algorithms. This could be a cudnn bug.");
 | 
			
		||||
#else
 | 
			
		||||
  cudnnConvolutionBwdFilterPreference_t preference =
 | 
			
		||||
      specify_workspace_limit
 | 
			
		||||
          ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
 | 
			
		||||
@ -2470,6 +2552,7 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
 | 
			
		||||
      cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
 | 
			
		||||
      filter.handle(), preference, memory_limit_bytes, &algo_to_use));
 | 
			
		||||
  return algo_to_use;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3316
									
								
								tensorflow/stream_executor/cuda/cudnn_8_0.inc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3316
									
								
								tensorflow/stream_executor/cuda/cudnn_8_0.inc
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -51,15 +51,17 @@ cudnnStatus_t GetSymbolNotFoundError() { return CUDNN_STATUS_INTERNAL_ERROR; }
 | 
			
		||||
#error cuDNN version earlier than 6 is not supported.
 | 
			
		||||
#elif CUDNN_MAJOR < 7
 | 
			
		||||
#include "tensorflow/stream_executor/cuda/cudnn_6_0.inc"
 | 
			
		||||
#elif CUDNN_MINOR < 1
 | 
			
		||||
#elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 1
 | 
			
		||||
#include "tensorflow/stream_executor/cuda/cudnn_7_0.inc"
 | 
			
		||||
// 2 instead of 3: see https://github.com/tensorflow/tensorflow/issues/32350
 | 
			
		||||
#elif CUDNN_MINOR < 2
 | 
			
		||||
#elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 2
 | 
			
		||||
#include "tensorflow/stream_executor/cuda/cudnn_7_1.inc"
 | 
			
		||||
#elif CUDNN_MINOR < 4
 | 
			
		||||
#elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 4
 | 
			
		||||
#include "tensorflow/stream_executor/cuda/cudnn_7_3.inc"
 | 
			
		||||
#elif CUDNN_MINOR < 6
 | 
			
		||||
#elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 6
 | 
			
		||||
#include "tensorflow/stream_executor/cuda/cudnn_7_4.inc"
 | 
			
		||||
#else
 | 
			
		||||
#elif CUDNN_MAJOR == 7
 | 
			
		||||
#include "tensorflow/stream_executor/cuda/cudnn_7_6.inc"
 | 
			
		||||
#else
 | 
			
		||||
#include "tensorflow/stream_executor/cuda/cudnn_8_0.inc"
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										25
									
								
								third_party/gpus/cuda_configure.bzl
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										25
									
								
								third_party/gpus/cuda_configure.bzl
									
									
									
									
										vendored
									
									
								
							@ -1069,11 +1069,32 @@ def _create_local_cuda_repository(repository_ctx):
 | 
			
		||||
        ],
 | 
			
		||||
    ))
 | 
			
		||||
 | 
			
		||||
    # Select the headers based on the cuDNN version (strip '64_' for Windows).
 | 
			
		||||
    if cuda_config.cudnn_version.rsplit("_", 1)[0] < "8":
 | 
			
		||||
        cudnn_headers = ["cudnn.h"]
 | 
			
		||||
    else:
 | 
			
		||||
        cudnn_headers = [
 | 
			
		||||
            "cudnn_adv_infer.h",
 | 
			
		||||
            "cudnn_adv_train.h",
 | 
			
		||||
            "cudnn_cnn_infer.h",
 | 
			
		||||
            "cudnn_cnn_train.h",
 | 
			
		||||
            "cudnn_ops_infer.h",
 | 
			
		||||
            "cudnn_ops_train.h",
 | 
			
		||||
            "cudnn.h",
 | 
			
		||||
            "cudnn_version.h",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    cudnn_srcs = []
 | 
			
		||||
    cudnn_outs = []
 | 
			
		||||
    for header in cudnn_headers:
 | 
			
		||||
        cudnn_srcs.append(cudnn_header_dir + "/" + header)
 | 
			
		||||
        cudnn_outs.append("cudnn/include/" + header)
 | 
			
		||||
 | 
			
		||||
    copy_rules.append(make_copy_files_rule(
 | 
			
		||||
        repository_ctx,
 | 
			
		||||
        name = "cudnn-include",
 | 
			
		||||
        srcs = [cudnn_header_dir + "/cudnn.h"],
 | 
			
		||||
        outs = ["cudnn/include/cudnn.h"],
 | 
			
		||||
        srcs = cudnn_srcs,
 | 
			
		||||
        outs = cudnn_outs,
 | 
			
		||||
    ))
 | 
			
		||||
 | 
			
		||||
    # Set up BUILD file for cuda/
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								third_party/gpus/find_cuda_config.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										24
									
								
								third_party/gpus/find_cuda_config.py
									
									
									
									
										vendored
									
									
								
							@ -219,17 +219,20 @@ def _find_library(base_paths, library_name, required_version):
 | 
			
		||||
  return _find_file(base_paths, _library_paths(), filepattern)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _find_versioned_file(base_paths, relative_paths, filepattern,
 | 
			
		||||
def _find_versioned_file(base_paths, relative_paths, filepatterns,
 | 
			
		||||
                         required_version, get_version):
 | 
			
		||||
  """Returns first valid path to a file that matches the requested version."""
 | 
			
		||||
  if type(filepatterns) not in [list, tuple]:
 | 
			
		||||
    filepatterns = [filepatterns]
 | 
			
		||||
  for path in _cartesian_product(base_paths, relative_paths):
 | 
			
		||||
    for file in glob.glob(os.path.join(path, filepattern)):
 | 
			
		||||
      actual_version = get_version(file)
 | 
			
		||||
      if _matches_version(actual_version, required_version):
 | 
			
		||||
        return file, actual_version
 | 
			
		||||
    for filepattern in filepatterns:
 | 
			
		||||
      for file in glob.glob(os.path.join(path, filepattern)):
 | 
			
		||||
        actual_version = get_version(file)
 | 
			
		||||
        if _matches_version(actual_version, required_version):
 | 
			
		||||
          return file, actual_version
 | 
			
		||||
  raise _not_found_error(
 | 
			
		||||
      base_paths, relative_paths,
 | 
			
		||||
      filepattern + " matching version '%s'" % required_version)
 | 
			
		||||
      ", ".join(filepatterns) + " matching version '%s'" % required_version)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _find_header(base_paths, header_name, required_version, get_version):
 | 
			
		||||
@ -426,12 +429,13 @@ def _find_cufft_config(base_paths, required_version, cuda_version):
 | 
			
		||||
def _find_cudnn_config(base_paths, required_version):
 | 
			
		||||
 | 
			
		||||
  def get_header_version(path):
 | 
			
		||||
    version = (
 | 
			
		||||
    version = [
 | 
			
		||||
        _get_header_version(path, name)
 | 
			
		||||
        for name in ("CUDNN_MAJOR", "CUDNN_MINOR", "CUDNN_PATCHLEVEL"))
 | 
			
		||||
    return ".".join(version)
 | 
			
		||||
        for name in ("CUDNN_MAJOR", "CUDNN_MINOR", "CUDNN_PATCHLEVEL")]
 | 
			
		||||
    return ".".join(version) if version[0] else None
 | 
			
		||||
 | 
			
		||||
  header_path, header_version = _find_header(base_paths, "cudnn.h",
 | 
			
		||||
  header_path, header_version = _find_header(base_paths,
 | 
			
		||||
                                             ("cudnn.h", "cudnn_version.h"),
 | 
			
		||||
                                             required_version,
 | 
			
		||||
                                             get_header_version)
 | 
			
		||||
  cudnn_version = header_version.split(".")[0]
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user