Merge pull request #12502 from nluehr/cuda-9.0

Support for CUDA 9.0
This commit is contained in:
zheng-xq 2017-09-11 18:22:32 -07:00 committed by GitHub
commit 5541ef4fbb
9 changed files with 119 additions and 51 deletions

View File

@ -999,7 +999,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2)
// partial_result = Reducer(
// partial_result,
// __shfl_down(partial_result, shuffle_distance));
// __shfl_down_sync(CUDA_WARP_ALL, partial_result, shuffle_distance));
// if (lane_id == 0)
// AtomicReducer(&output[y], partial_result);
// }

View File

@ -71,7 +71,17 @@ const int kDefaultInlineThreshold = 1100;
// Gets the libdevice filename for a particular compute capability. When
// presented with a GPU we don't recognize, we just return the libdevice from
// compute_20.
static string GetLibdeviceFilename(std::pair<int, int> compute_capability) {
static string GetLibdeviceFilename(const string& libdevice_dir_path,
std::pair<int, int> compute_capability) {
// Since CUDA 9.0, all GPU versions are included in a single file
const char* unified_libdevice_filename = "libdevice.10.bc";
std::vector<string> unified_libdevice_files;
tensorflow::Env::Default()->GetMatchingPaths(
tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename),
&unified_libdevice_files);
if( unified_libdevice_files.size() == 1 ) {
return unified_libdevice_filename;
}
// There are only four libdevice files: compute_{20,30,35,50}. Each GPU
// version gets mapped to one of these. Note in particular that sm_60 and
// sm_61 map to libdevice.compute_30.
@ -101,7 +111,7 @@ static string GetLibdeviceFilename(std::pair<int, int> compute_capability) {
}
// Gets the GPU name as it's known to LLVM for a given compute capability. If
// we see an unrecognized compute capability, we return "sm_20".
// we see an unrecognized compute capability, we return "sm_30".
static string GetSmName(std::pair<int, int> compute_capability) {
static auto* m = new std::map<std::pair<int, int>, int>({{{2, 0}, 20},
{{2, 1}, 21},
@ -114,8 +124,10 @@ static string GetSmName(std::pair<int, int> compute_capability) {
{{5, 3}, 53},
{{6, 0}, 60},
{{6, 1}, 61},
{{6, 2}, 62}});
int sm_version = 20;
{{6, 2}, 62},
// TODO: Change this to 70 once LLVM NVPTX supports it
{{7, 0}, 60}});
int sm_version = 30;
auto it = m->find(compute_capability);
if (it != m->end()) {
sm_version = it->second;
@ -306,7 +318,8 @@ tensorflow::Status LinkLibdeviceIfNecessary(
llvm::Linker linker(*module);
string libdevice_path = tensorflow::io::JoinPath(
libdevice_dir_path, GetLibdeviceFilename(compute_capability));
libdevice_dir_path, GetLibdeviceFilename(libdevice_dir_path,
compute_capability));
TF_RETURN_IF_ERROR(tensorflow::Env::Default()->FileExists(libdevice_path));
VLOG(1) << "Linking with libdevice from: " << libdevice_path;
std::unique_ptr<llvm::Module> libdevice_module =

View File

@ -2936,7 +2936,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
],
] + if_cuda(["@cub_archive//:cub"]),
)
tf_kernel_library(

View File

@ -173,15 +173,20 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
// Accumulate the results in the shared memory into the first element.
// No syncthreads is needed since this is only in the same warp.
int32 thread_index = threadIdx.x;
if (thread_index < 16) s_data[thread_index] += s_data[thread_index + 16];
if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
if (thread_index < 1) s_data[thread_index] += s_data[thread_index + 1];
// The first thread writes out the accumulated result to the global location.
if (thread_index == 0) {
CudaAtomicAdd(bias_backprop + bias_index, T(s_data[0]));
if (thread_index < 16) {
s_data[thread_index] += s_data[thread_index + 16];
__syncwarp(0xFFFF);
if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
__syncwarp(0xFF);
if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
__syncwarp(0xF);
if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
__syncwarp(0x3);
if (thread_index == 0) {
T val = T(s_data[0] + s_data[1]);
// The first thread writes out the accumulated result to global location.
CudaAtomicAdd(bias_backprop + bias_index, val);
}
}
}

View File

@ -139,7 +139,7 @@ struct scalar_left : private Binary {
typedef Tout result_type;
const Tin* left;
EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default;
inline scalar_left(const scalar_left& other) = default;
template <typename... Args>
EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args)
@ -169,7 +169,7 @@ struct scalar_right : private Binary {
typedef Tout result_type;
const Tin* right;
EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default;
inline scalar_right(const scalar_right& other) = default;
template <typename... Args>
EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args)

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/tensor_format.h"
#include "external/cub_archive/cub/util_ptx.cuh"
#if !defined(_MSC_VER)
#define UNROLL _Pragma("unroll")
@ -1015,6 +1016,21 @@ __global__ void __launch_bounds__(640, 2)
}
}
// Device function to compute sub-warp sum reduction for a power-of-two group of
// neighboring threads.
template<int kWidth, typename T>
__device__ __forceinline__ T WarpSumReduce(T val) {
// support only power-of-two widths.
assert(__popc(kWidth) == 1);
int sub_warp = cub::LaneId() / kWidth;
int zeros = sub_warp * kWidth;
unsigned mask = ((1U << kWidth) - 1) << zeros;
for (int delta = kWidth / 2; delta > 0; delta /= 2) {
val += CudaShuffleXor(mask, val, delta);
}
return val;
}
// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in
// NHWC format, tailored for small images up to 32x32. Stride and depth
// multiplier must be 1. Padding must be 'SAME'. Only use this kernel if
@ -1127,6 +1143,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range);
if (depth_in_range) {
const T* const out_ptr = inout_offset + output;
@ -1140,7 +1157,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16; delta >= kBlockSlices; delta /= 2) {
val += CudaShuffleDown(val, delta);
val += CudaShuffleDown(active_threads, val, delta);
}
if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) {
*accum_ptr = val;
@ -1164,9 +1181,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
if (filter_depth < in_depth) {
T val = accum_data[i];
// Warp-accumulate the pixels of the same depth from the accumulator.
for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
val += CudaShuffleDown(val, delta);
}
val = WarpSumReduce<kAccumPixels>(val);
if (!(thread_idx & kAccumPixels - 1)) {
CudaAtomicAdd(filter_offset + filter, val);
}
@ -1382,6 +1397,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range);
if (slice_in_range) {
const T* const out_ptr = inout_offset + output;
@ -1395,7 +1411,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) {
val += CudaShuffleDown(val, delta);
val += CudaShuffleDown(active_threads, val, delta);
}
if (!(thread_idx & 32 / kBlockSlices - 1)) {
*accum_ptr = val;
@ -1419,9 +1435,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
if (filter_depth < in_depth) {
T val = accum_data[i];
// Warp-accumulate pixels of the same depth from the accumulator.
for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
val += CudaShuffleDown(val, delta);
}
val = WarpSumReduce<kAccumPixels>(val);
if (!(thread_idx & kAccumPixels - 1)) {
CudaAtomicAdd(filter_offset + filter, val);
}

View File

@ -27,7 +27,7 @@ TEST(CudaLibdevicePathTest, LibdevicePath) {
VLOG(2) << "Libdevice root = " << LibdeviceRoot();
std::vector<string> libdevice_files;
TF_EXPECT_OK(Env::Default()->GetMatchingPaths(
io::JoinPath(LibdeviceRoot(), "libdevice.compute_*.bc"),
io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"),
&libdevice_files));
EXPECT_LT(0, libdevice_files.size());
}

View File

@ -25,6 +25,29 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
#include "cuda/include/cuda.h"
// Mask for all 32 threads in a warp.
#define CUDA_WARP_ALL 0xFFFFFFFF
#if defined(CUDA_VERSION) && CUDA_VERSION < 9000
// CUDA 9.0 introduces a new, light-weight barrier synchronization primitive
// that operates at the warp-scope. This is required to ensure visibility of
// reads/writes among threads that can make indepenent progress on Volta.
// For previous CUDA versions these synchronizations not necessary, and we
// define an empty function as a convenience for backward compatibility.
__device__ inline void __syncwarp(unsigned mask=CUDA_WARP_ALL) {}
// CUDA 9.0 deprecates the warp-intrinsic functions (shfl, ballot, etc.) in
// favor of synchronizing versions. These ensure that all warp lanes specified
// in mask execute the intrinsic in convergence. Here we provide legacy mappings
// to the less-verbose routines provided in previous versions of CUDA.
#define __ballot_sync(mask, predicate) __ballot(predicate)
#define __shfl_sync(mask, val, srcLane, width) __shfl(val, srcLane, width)
#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width)
#define __shfl_up_sync(mask, val, delta, width) __shfl_up(val, delta, width)
#define __shfl_xor_sync(mask, val, laneMask, width) __shfl_xor(val, laneMask, width)
#endif
// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
// GetCuda3DLaunchConfig:
@ -603,82 +626,95 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) {
return x < y ? y : x;
}
__device__ EIGEN_ALWAYS_INLINE unsigned CudaBallot(unsigned mask,
int predicate) {
return __ballot_sync(mask, predicate);
}
template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(T value, int srcLane,
__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(unsigned mask, T value,
int srcLane,
int width = warpSize) {
return __shfl(value, srcLane, width);
return __shfl_sync(mask, value, srcLane, width);
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(double value, int srcLane,
__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(unsigned mask,
double value, int srcLane,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = __shfl(hi, srcLane, width);
lo = __shfl(lo, srcLane, width);
hi = __shfl_sync(mask, hi, srcLane, width);
lo = __shfl_sync(mask, lo, srcLane, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(T value, int delta,
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(unsigned mask,
T value, int delta,
int width = warpSize) {
return __shfl_up(value, delta, width);
return __shfl_up_sync(mask, value, delta, width);
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(double value, int delta,
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(unsigned mask,
double value, int delta,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = __shfl_up(hi, delta, width);
lo = __shfl_up(lo, delta, width);
hi = __shfl_up_sync(mask, hi, delta, width);
lo = __shfl_up_sync(mask, lo, delta, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(T value, int delta,
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask,
T value, int delta,
int width = warpSize) {
return __shfl_down(value, delta, width);
return __shfl_down_sync(mask, value, delta, width);
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(double value, int delta,
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(unsigned mask,
double value, int delta,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = __shfl_down(hi, delta, width);
lo = __shfl_down(lo, delta, width);
hi = __shfl_down_sync(mask, hi, delta, width);
lo = __shfl_down_sync(mask, lo, delta, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(T value, int laneMask,
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask,
T value, int laneMask,
int width = warpSize) {
return __shfl_xor(value, laneMask, width);
return __shfl_xor_sync(mask, value, laneMask, width);
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(double value, int laneMask,
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(unsigned mask,
double value, int laneMask,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = __shfl_xor(hi, laneMask, width);
lo = __shfl_xor(lo, laneMask, width);
hi = __shfl_xor_sync(mask, hi, laneMask, width);
lo = __shfl_xor_sync(mask, lo, laneMask, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}

View File

@ -627,11 +627,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
temp_workaround_http_archive(
name = "nccl_archive",
urls = [
"http://mirror.bazel.build/github.com/nvidia/nccl/archive/ccfc4567dc3e2a37fb42cfbc64d10eb526e7da7b.tar.gz",
"https://github.com/nvidia/nccl/archive/ccfc4567dc3e2a37fb42cfbc64d10eb526e7da7b.tar.gz",
"http://mirror.bazel.build/github.com/nvidia/nccl/archive/29a1a916dc14bb2c00feed3d4820d51fa85be1e6.tar.gz",
"https://github.com/nvidia/nccl/archive/29a1a916dc14bb2c00feed3d4820d51fa85be1e6.tar.gz",
],
sha256 = "6c34a0862d9f8ed4ad5984c6a8206b351957bb14cf6ad7822720f285f4aada04",
strip_prefix = "nccl-ccfc4567dc3e2a37fb42cfbc64d10eb526e7da7b",
sha256 = "6387030e37d14762f87eefbc86ee527293ec04745c66ccd820cf7fc0fdc23f92",
strip_prefix = "nccl-29a1a916dc14bb2c00feed3d4820d51fa85be1e6",
build_file = str(Label("//third_party:nccl.BUILD")),
repository = tf_repo_name,
)