From ef71383cf2d8d8241d813ac11695a26c34ae3ceb Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Tue, 6 Jun 2017 10:14:50 -0700 Subject: [PATCH] Support for CUDA 9.0 Add explicit __syncwarp to bias_op - Makes warp-synchronous code safe on Volta Add sync mask to __shfl intrinsics Add libdevice bytecode paths for CUDA 9 - In CUDA 9, all supported architectures are merged into a single file Update code gating for CUDA 9 Add sm_70 to the lookup table used by XLA Change the default sm arch from 20 to 30. Fix for NVPTX not yet supporting sm_70 Remove unnecessary cuda decorators from defaulted constructors Use updated NCCL for CUDA 9 fp16 support --- .../xla/service/gpu/ir_emitter_unnested.cc | 2 +- .../gpu/llvm_gpu_backend/gpu_backend_lib.cc | 23 ++++-- tensorflow/core/kernels/BUILD | 2 +- tensorflow/core/kernels/bias_op_gpu.cu.cc | 23 +++--- tensorflow/core/kernels/cwise_ops.h | 4 +- .../core/kernels/depthwise_conv_op_gpu.cu.cc | 30 ++++++-- .../core/platform/cuda_libdevice_path_test.cc | 2 +- tensorflow/core/util/cuda_kernel_helper.h | 76 ++++++++++++++----- tensorflow/workspace.bzl | 8 +- 9 files changed, 119 insertions(+), 51 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 749badf3f23..6d800a976d1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -996,7 +996,7 @@ Status IrEmitterUnnested::EmitRowReduction( // for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2) // partial_result = Reducer( // partial_result, - // __shfl_down(partial_result, shuffle_distance)); + // __shfl_down_sync(CUDA_WARP_ALL, partial_result, shuffle_distance)); // if (lane_id == 0) // AtomicReducer(&output[y], partial_result); // } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 2e7765c4c61..5d650b872fa 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -71,7 +71,17 @@ const int kDefaultInlineThreshold = 1100; // Gets the libdevice filename for a particular compute capability. When // presented with a GPU we don't recognize, we just return the libdevice from // compute_20. -static string GetLibdeviceFilename(std::pair compute_capability) { +static string GetLibdeviceFilename(const string& libdevice_dir_path, + std::pair compute_capability) { + // Since CUDA 9.0, all GPU versions are included in a single file + const char* unified_libdevice_filename = "libdevice.10.bc"; + std::vector unified_libdevice_files; + tensorflow::Env::Default()->GetMatchingPaths( + tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), + &unified_libdevice_files); + if( unified_libdevice_files.size() == 1 ) { + return unified_libdevice_filename; + } // There are only four libdevice files: compute_{20,30,35,50}. Each GPU // version gets mapped to one of these. Note in particular that sm_60 and // sm_61 map to libdevice.compute_30. @@ -101,7 +111,7 @@ static string GetLibdeviceFilename(std::pair compute_capability) { } // Gets the GPU name as it's known to LLVM for a given compute capability. If -// we see an unrecognized compute capability, we return "sm_20". +// we see an unrecognized compute capability, we return "sm_30". static string GetSmName(std::pair compute_capability) { static auto* m = new std::map, int>({{{2, 0}, 20}, {{2, 1}, 21}, @@ -114,8 +124,10 @@ static string GetSmName(std::pair compute_capability) { {{5, 3}, 53}, {{6, 0}, 60}, {{6, 1}, 61}, - {{6, 2}, 62}}); - int sm_version = 20; + {{6, 2}, 62}, + // TODO: Change this to 70 once LLVM NVPTX supports it + {{7, 0}, 60}}); + int sm_version = 30; auto it = m->find(compute_capability); if (it != m->end()) { sm_version = it->second; @@ -306,7 +318,8 @@ tensorflow::Status LinkLibdeviceIfNecessary( llvm::Linker linker(*module); string libdevice_path = tensorflow::io::JoinPath( - libdevice_dir_path, GetLibdeviceFilename(compute_capability)); + libdevice_dir_path, GetLibdeviceFilename(libdevice_dir_path, + compute_capability)); TF_RETURN_IF_ERROR(tensorflow::Env::Default()->FileExists(libdevice_path)); VLOG(1) << "Linking with libdevice from: " << libdevice_path; std::unique_ptr libdevice_module = diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index efc5d7c553a..2110e6a6aaa 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2932,7 +2932,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:nn_ops_op_lib", - ], + ] + if_cuda(["@cub_archive//:cub"]), ) tf_kernel_library( diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc index ddc2d457b0e..42f3db1d79d 100644 --- a/tensorflow/core/kernels/bias_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -173,15 +173,20 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop, // Accumulate the results in the shared memory into the first element. // No syncthreads is needed since this is only in the same warp. int32 thread_index = threadIdx.x; - if (thread_index < 16) s_data[thread_index] += s_data[thread_index + 16]; - if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8]; - if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4]; - if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2]; - if (thread_index < 1) s_data[thread_index] += s_data[thread_index + 1]; - - // The first thread writes out the accumulated result to the global location. - if (thread_index == 0) { - CudaAtomicAdd(bias_backprop + bias_index, T(s_data[0])); + if (thread_index < 16) { + s_data[thread_index] += s_data[thread_index + 16]; + __syncwarp(0xFFFF); + if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8]; + __syncwarp(0xFF); + if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4]; + __syncwarp(0xF); + if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2]; + __syncwarp(0x3); + if (thread_index == 0) { + T val = T(s_data[0] + s_data[1]); + // The first thread writes out the accumulated result to global location. + CudaAtomicAdd(bias_backprop + bias_index, val); + } } } diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index d935331904d..ada39eae38f 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -139,7 +139,7 @@ struct scalar_left : private Binary { typedef Tout result_type; const Tin* left; - EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default; + inline scalar_left(const scalar_left& other) = default; template EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args) @@ -169,7 +169,7 @@ struct scalar_right : private Binary { typedef Tout result_type; const Tin* right; - EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default; + inline scalar_right(const scalar_right& other) = default; template EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args) diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index fcfcd188d2d..1de7d6a2c02 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/cuda_kernel_helper.h" #include "tensorflow/core/util/tensor_format.h" +#include "external/cub_archive/cub/util_ptx.cuh" #if !defined(_MSC_VER) #define UNROLL _Pragma("unroll") @@ -1015,6 +1016,21 @@ __global__ void __launch_bounds__(640, 2) } } +// Device function to compute sub-warp sum reduction for a power-of-two group of +// neighboring threads. +template +__device__ __forceinline__ T WarpSumReduce(T val) { + // support only power-of-two widths. + assert(__popc(kWidth) == 1); + int sub_warp = cub::LaneId() / kWidth; + int zeros = sub_warp * kWidth; + unsigned mask = ((1U << kWidth) - 1) << zeros; + for (int delta = kWidth / 2; delta > 0; delta /= 2) { + val += CudaShuffleXor(mask, val, delta); + } + return val; +} + // CUDA kernel to compute the depthwise convolution backward w.r.t. filter in // NHWC format, tailored for small images up to 32x32. Stride and depth // multiplier must be 1. Padding must be 'SAME'. Only use this kernel if @@ -1127,6 +1143,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); + unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range); if (depth_in_range) { const T* const out_ptr = inout_offset + output; @@ -1140,7 +1157,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; // Warp-accumulate pixels of the same depth and write to accumulator. for (int delta = 16; delta >= kBlockSlices; delta /= 2) { - val += CudaShuffleDown(val, delta); + val += CudaShuffleDown(active_threads, val, delta); } if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) { *accum_ptr = val; @@ -1164,9 +1181,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( if (filter_depth < in_depth) { T val = accum_data[i]; // Warp-accumulate the pixels of the same depth from the accumulator. - for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) { - val += CudaShuffleDown(val, delta); - } + val = WarpSumReduce(val); if (!(thread_idx & kAccumPixels - 1)) { CudaAtomicAdd(filter_offset + filter, val); } @@ -1382,6 +1397,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); + unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range); if (slice_in_range) { const T* const out_ptr = inout_offset + output; @@ -1395,7 +1411,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; // Warp-accumulate pixels of the same depth and write to accumulator. for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) { - val += CudaShuffleDown(val, delta); + val += CudaShuffleDown(active_threads, val, delta); } if (!(thread_idx & 32 / kBlockSlices - 1)) { *accum_ptr = val; @@ -1419,9 +1435,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( if (filter_depth < in_depth) { T val = accum_data[i]; // Warp-accumulate pixels of the same depth from the accumulator. - for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) { - val += CudaShuffleDown(val, delta); - } + val = WarpSumReduce(val); if (!(thread_idx & kAccumPixels - 1)) { CudaAtomicAdd(filter_offset + filter, val); } diff --git a/tensorflow/core/platform/cuda_libdevice_path_test.cc b/tensorflow/core/platform/cuda_libdevice_path_test.cc index 86295592a8b..639f6804ea2 100644 --- a/tensorflow/core/platform/cuda_libdevice_path_test.cc +++ b/tensorflow/core/platform/cuda_libdevice_path_test.cc @@ -27,7 +27,7 @@ TEST(CudaLibdevicePathTest, LibdevicePath) { VLOG(2) << "Libdevice root = " << LibdeviceRoot(); std::vector libdevice_files; TF_EXPECT_OK(Env::Default()->GetMatchingPaths( - io::JoinPath(LibdeviceRoot(), "libdevice.compute_*.bc"), + io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"), &libdevice_files)); EXPECT_LT(0, libdevice_files.size()); } diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index ee651944052..55d0523dbfa 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -25,6 +25,29 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" +#include "cuda/include/cuda.h" + +// Mask for all 32 threads in a warp. +#define CUDA_WARP_ALL 0xFFFFFFFF + +#if defined(CUDA_VERSION) && CUDA_VERSION < 9000 +// CUDA 9.0 introduces a new, light-weight barrier synchronization primitive +// that operates at the warp-scope. This is required to ensure visibility of +// reads/writes among threads that can make indepenent progress on Volta. +// For previous CUDA versions these synchronizations not necessary, and we +// define an empty function as a convenience for backward compatibility. +__device__ inline void __syncwarp(unsigned mask=CUDA_WARP_ALL) {} + +// CUDA 9.0 deprecates the warp-intrinsic functions (shfl, ballot, etc.) in +// favor of synchronizing versions. These ensure that all warp lanes specified +// in mask execute the intrinsic in convergence. Here we provide legacy mappings +// to the less-verbose routines provided in previous versions of CUDA. +#define __ballot_sync(mask, predicate) __ballot(predicate) +#define __shfl_sync(mask, val, srcLane, width) __shfl(val, srcLane, width) +#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width) +#define __shfl_up_sync(mask, val, delta, width) __shfl_up(val, delta, width) +#define __shfl_xor_sync(mask, val, laneMask, width) __shfl_xor(val, laneMask, width) +#endif // Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and // GetCuda3DLaunchConfig: @@ -603,82 +626,95 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) { return x < y ? y : x; } +__device__ EIGEN_ALWAYS_INLINE unsigned CudaBallot(unsigned mask, + int predicate) { + return __ballot_sync(mask, predicate); +} + template -__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(T value, int srcLane, +__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(unsigned mask, T value, + int srcLane, int width = warpSize) { - return __shfl(value, srcLane, width); + return __shfl_sync(mask, value, srcLane, width); } // Variant of the (undocumented) version from the CUDA SDK, but using unsigned // instead of float for lo and hi (which is incorrect with ftz, for example). // A bug has been filed with NVIDIA and will be fixed in the next CUDA release. // TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(double value, int srcLane, +__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(unsigned mask, + double value, int srcLane, int width = warpSize) { unsigned lo, hi; asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl(hi, srcLane, width); - lo = __shfl(lo, srcLane, width); + hi = __shfl_sync(mask, hi, srcLane, width); + lo = __shfl_sync(mask, lo, srcLane, width); asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); return value; } template -__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(T value, int delta, +__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(unsigned mask, + T value, int delta, int width = warpSize) { - return __shfl_up(value, delta, width); + return __shfl_up_sync(mask, value, delta, width); } // Variant of the (undocumented) version from the CUDA SDK, but using unsigned // instead of float for lo and hi (which is incorrect with ftz, for example). // A bug has been filed with NVIDIA and will be fixed in the next CUDA release. // TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(double value, int delta, +__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(unsigned mask, + double value, int delta, int width = warpSize) { unsigned lo, hi; asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_up(hi, delta, width); - lo = __shfl_up(lo, delta, width); + hi = __shfl_up_sync(mask, hi, delta, width); + lo = __shfl_up_sync(mask, lo, delta, width); asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); return value; } template -__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(T value, int delta, +__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask, + T value, int delta, int width = warpSize) { - return __shfl_down(value, delta, width); + return __shfl_down_sync(mask, value, delta, width); } // Variant of the (undocumented) version from the CUDA SDK, but using unsigned // instead of float for lo and hi (which is incorrect with ftz, for example). // A bug has been filed with NVIDIA and will be fixed in the next CUDA release. // TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(double value, int delta, +__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(unsigned mask, + double value, int delta, int width = warpSize) { unsigned lo, hi; asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_down(hi, delta, width); - lo = __shfl_down(lo, delta, width); + hi = __shfl_down_sync(mask, hi, delta, width); + lo = __shfl_down_sync(mask, lo, delta, width); asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); return value; } template -__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(T value, int laneMask, +__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask, + T value, int laneMask, int width = warpSize) { - return __shfl_xor(value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); } // Variant of the (undocumented) version from the CUDA SDK, but using unsigned // instead of float for lo and hi (which is incorrect with ftz, for example). // A bug has been filed with NVIDIA and will be fixed in the next CUDA release. // TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(double value, int laneMask, +__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(unsigned mask, + double value, int laneMask, int width = warpSize) { unsigned lo, hi; asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_xor(hi, laneMask, width); - lo = __shfl_xor(lo, laneMask, width); + hi = __shfl_xor_sync(mask, hi, laneMask, width); + lo = __shfl_xor_sync(mask, lo, laneMask, width); asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); return value; } diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 176719fabb4..483743daab9 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -623,11 +623,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "nccl_archive", urls = [ - "http://mirror.bazel.build/github.com/nvidia/nccl/archive/ccfc4567dc3e2a37fb42cfbc64d10eb526e7da7b.tar.gz", - "https://github.com/nvidia/nccl/archive/ccfc4567dc3e2a37fb42cfbc64d10eb526e7da7b.tar.gz", + "http://mirror.bazel.build/github.com/nvidia/nccl/archive/29a1a916dc14bb2c00feed3d4820d51fa85be1e6.tar.gz", + "https://github.com/nvidia/nccl/archive/29a1a916dc14bb2c00feed3d4820d51fa85be1e6.tar.gz", ], - sha256 = "6c34a0862d9f8ed4ad5984c6a8206b351957bb14cf6ad7822720f285f4aada04", - strip_prefix = "nccl-ccfc4567dc3e2a37fb42cfbc64d10eb526e7da7b", + sha256 = "6387030e37d14762f87eefbc86ee527293ec04745c66ccd820cf7fc0fdc23f92", + strip_prefix = "nccl-29a1a916dc14bb2c00feed3d4820d51fa85be1e6", build_file = str(Label("//third_party:nccl.BUILD")), repository = tf_repo_name, )