Automated g4 rollback of changelist 177799252

PiperOrigin-RevId: 177989542
This commit is contained in:
A. Unique TensorFlower 2017-12-05 11:57:53 -08:00 committed by TensorFlower Gardener
parent 33e3da538a
commit 21e831dc4a
10 changed files with 782 additions and 861 deletions

View File

@ -34,9 +34,9 @@ namespace functor {
__global__ void ReduceSliceDeviceKernel##reduceop( \
Cuda3DLaunchConfig config, Index indices_width, Index bound, \
const T begin, const Index *indices, const T *input, T *out) { \
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { \
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { \
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { \
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { \
Index outidx = x * config.virtual_thread_count.y * \
config.virtual_thread_count.z + \
y * config.virtual_thread_count.z + z; \
@ -68,9 +68,8 @@ namespace functor {
if (sizex * sizey * sizez == 0) { \
return; \
} \
Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \
sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop<T, Index>, \
0, 0); \
Cuda3DLaunchConfig config = GetCuda3DLaunchConfig(sizex, sizey, sizez, d,\
ReduceSliceDeviceKernel##reduceop<T, Index>, 0, 0); \
\
ReduceSliceDeviceKernel##reduceop<T, Index> \
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \

View File

@ -1847,13 +1847,6 @@ cc_library(
],
)
tf_cuda_library(
name = "cuda_device_functions",
hdrs = ["util/cuda_device_functions.h"],
visibility = ["//visibility:public"],
deps = [":framework_lite"],
)
# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"?
cc_library(
name = "protos_cc",

View File

@ -173,13 +173,19 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
// Accumulate the results in the shared memory into the first element.
// No syncthreads is needed since this is only in the same warp.
int32 thread_index = threadIdx.x;
if (thread_index < 32) {
AccT data = s_data[thread_index];
for (int32 offset = warpSize / 2; offset > 0; offset /= 2) {
data += CudaShuffleDownSync(kCudaWarpAll, data, offset);
}
if (thread_index < 16) {
s_data[thread_index] += s_data[thread_index + 16];
__syncwarp(0xFFFF);
if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
__syncwarp(0xFF);
if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
__syncwarp(0xF);
if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
__syncwarp(0x3);
if (thread_index == 0) {
CudaAtomicAdd(bias_backprop + bias_index, T(data));
T val = T(s_data[0] + s_data[1]);
// The first thread writes out the accumulated result to global location.
CudaAtomicAdd(bias_backprop + bias_index, val);
}
}
}

View File

@ -34,7 +34,6 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
using Eigen::GpuDevice;
// Returns whether depthwise convolution forward or backward input pass can be
@ -1029,7 +1028,7 @@ __device__ __forceinline__ T WarpSumReduce(T val) {
int zeros = sub_warp * kWidth;
unsigned mask = ((1UL << kWidth) - 1) << zeros;
for (int delta = kWidth / 2; delta > 0; delta /= 2) {
val += CudaShuffleXorSync(mask, val, delta);
val += CudaShuffleXor(mask, val, delta);
}
return val;
}
@ -1146,7 +1145,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
unsigned active_threads = CudaBallotSync(kCudaWarpAll, depth_in_range);
unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range);
if (depth_in_range) {
const T* const out_ptr = inout_offset + output;
@ -1160,7 +1159,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16; delta >= kBlockSlices; delta /= 2) {
val += CudaShuffleDownSync(active_threads, val, delta);
val += CudaShuffleDown(active_threads, val, delta);
}
if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) {
*accum_ptr = val;
@ -1400,7 +1399,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
unsigned active_threads = CudaBallotSync(kCudaWarpAll, slice_in_range);
unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range);
if (slice_in_range) {
const T* const out_ptr = inout_offset + output;
@ -1414,7 +1413,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) {
val += CudaShuffleDownSync(active_threads, val, delta);
val += CudaShuffleDown(active_threads, val, delta);
}
if (!(thread_idx & 32 / kBlockSlices - 1)) {
*accum_ptr = val;

View File

@ -55,27 +55,6 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
}
};
// Specializations for std::complex, updating real and imaginary part
// individually. Even though this is not an atomic op anymore, it is safe
// because there is only one type of op per kernel.
template <typename T>
struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
std::complex<T>* out, const std::complex<T>& val) {
T* ptr = reinterpret_cast<T*>(out);
CudaAtomicAdd(ptr, val.real());
CudaAtomicAdd(ptr, val.imag());
}
};
template <typename T>
struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::SUB> {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
std::complex<T>* out, const std::complex<T>& val) {
LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD>()(out, -val);
}
};
} // namespace
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>

View File

@ -63,8 +63,8 @@ __global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m,
int64 ldu, const Scalar* M,
const Scalar* U, const Scalar* S,
Scalar* V) {
CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) {
CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count, y) {
Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
CudaAtomicAdd(V + batch, v);
}

View File

@ -1,418 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
#define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
/**
* Wrappers and helpers for CUDA device code.
*
* Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
* backwards compatibility, see go/volta-porting for details.
* Provides atomic operations on types that aren't natively supported.
*/
#if GOOGLE_CUDA
#include <algorithm>
#include <complex>
#include "cuda/include/cuda.h"
#include "cuda/include/device_functions.h"
#include "tensorflow/core/platform/types.h"
#if __CUDACC_VER_MAJOR__ >= 9
#include "cuda/include/cuda_fp16.h"
#elif __CUDACC_VER__ >= 7050
#include "cuda/include/cuda_fp16.h"
#else
#endif
namespace tensorflow {
namespace detail {
// Helper for range-based for loop using 'delta' increments.
// Usage: see CudaGridRange?() functions below.
template <typename T>
class CudaGridRange {
struct Iterator {
__device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
__device__ T operator*() const { return index_; }
__device__ Iterator& operator++() {
index_ += delta_;
return *this;
}
__device__ bool operator!=(const Iterator& other) const {
bool greater = index_ > other.index_;
bool less = index_ < other.index_;
// Anything past an end iterator (delta_ == 0) is equal.
// In range-based for loops, this optimizes to 'return less'.
if (!other.delta_) {
return less;
}
if (!delta_) {
return greater;
}
return less || greater;
}
private:
T index_;
const T delta_;
};
public:
__device__ CudaGridRange(T begin, T delta, T end)
: begin_(begin), delta_(delta), end_(end) {}
__device__ Iterator begin() const { return Iterator{begin_, delta_}; }
__device__ Iterator end() const { return Iterator{end_, 0}; }
private:
T begin_;
T delta_;
T end_;
};
} // namespace detail
// Helper to visit indices in the range 0 <= i < count, using the x-coordinate
// of the global thread index. That is, each index i is visited by all threads
// with the same x-coordinate.
// Usage: for(int i : CudaGridRangeX(count)) { visit(i); }
template <typename T>
__device__ detail::CudaGridRange<T> CudaGridRangeX(T count) {
return detail::CudaGridRange<T>(blockIdx.x * blockDim.x + threadIdx.x,
gridDim.x * blockDim.x, count);
}
// Helper to visit indices in the range 0 <= i < count using the y-coordinate.
// Usage: for(int i : CudaGridRangeY(count)) { visit(i); }
template <typename T>
__device__ detail::CudaGridRange<T> CudaGridRangeY(T count) {
return detail::CudaGridRange<T>(blockIdx.y * blockDim.y + threadIdx.y,
gridDim.y * blockDim.y, count);
}
// Helper to visit indices in the range 0 <= i < count using the z-coordinate.
// Usage: for(int i : CudaGridRangeZ(count)) { visit(i); }
template <typename T>
__device__ detail::CudaGridRange<T> CudaGridRangeZ(T count) {
return detail::CudaGridRange<T>(blockIdx.z * blockDim.z + threadIdx.z,
gridDim.z * blockDim.z, count);
}
// Mask for all 32 threads in a warp.
const unsigned kCudaWarpAll = 0xffffffff;
// On sm_6x and earlier, verifies that all bits in mask corresponding to active
// threads of the warp are set. It does not verify the converse (bits of
// inactive threads are not set), because all syncs are unblocked when a thread
// exits the kernel, but the ballot of inactive (including exited) threads
// returns 0.
__device__ inline void CudaVerifySyncMask(unsigned mask) {
#if __CUDA_ARCH__ < 700
assert(0 == (__ballot(1) & ~mask)); // Active threads must have mask bit set.
#endif
}
// For all *_sync wrappers below, it is illegal to synchronize threads from
// different program locations, because that is not supported before sm_70.
// Code that requires sm_70 (and CUDA 9) may use the intrinsic directly.
// Wrapper for __syncwarp.
__device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) {
CudaVerifySyncMask(mask);
#if CUDA_VERSION >= 9000
__syncwarp(mask);
#endif
}
// Wrapper for __ballot_sync.
__device__ inline unsigned CudaBallotSync(unsigned mask, int pred) {
CudaVerifySyncMask(mask);
#if CUDA_VERSION >= 9000
return __ballot_sync(mask, pred);
#else
return __ballot(pred);
#endif
}
// Wrapper for __any_sync.
__device__ inline int CudaAnySync(unsigned mask, int pred) {
CudaVerifySyncMask(mask);
#if CUDA_VERSION >= 9000
return __any_sync(mask, pred);
#else
return __any(pred);
#endif
}
// Wrapper for __all_sync.
__device__ inline int CudaAllSync(unsigned mask, int pred) {
CudaVerifySyncMask(mask);
#if CUDA_VERSION >= 9000
return __all_sync(mask, pred);
#else
return __all(pred);
#endif
}
// Wrapper for __shfl_sync.
template <typename T>
__device__ T CudaShuffleSync(unsigned mask, T value, int src_lane,
int width = warpSize) {
CudaVerifySyncMask(mask);
#if CUDA_VERSION >= 9000
return __shfl_sync(mask, value, src_lane, width);
#else
return __shfl(value, src_lane, width);
#endif
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// See b/69446944.
__device__ inline double CudaShuffleSync(unsigned mask, double value,
int src_lane, int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = CudaShuffleSync(mask, hi, src_lane, width);
lo = CudaShuffleSync(mask, lo, src_lane, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
// Wrapper for __shfl_up_sync.
template <typename T>
__device__ inline T CudaShuffleUpSync(unsigned mask, T value, int delta,
int width = warpSize) {
CudaVerifySyncMask(mask);
#if CUDA_VERSION >= 9000
return __shfl_up_sync(mask, value, delta, width);
#else
return __shfl_up(value, delta, width);
#endif
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// See b/69446944.
__device__ inline double CudaShuffleUpSync(unsigned mask, double value,
int delta, int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = CudaShuffleUpSync(mask, hi, delta, width);
lo = CudaShuffleUpSync(mask, lo, delta, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
// Wrapper for __shfl_down_sync.
template <typename T>
__device__ inline T CudaShuffleDownSync(unsigned mask, T value, int delta,
int width = warpSize) {
CudaVerifySyncMask(mask);
#if CUDA_VERSION >= 9000
return __shfl_down_sync(mask, value, delta, width);
#else
return __shfl_down(value, delta, width);
#endif
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// See b/69446944.
__device__ inline double CudaShuffleDownSync(unsigned mask, double value,
int delta, int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = CudaShuffleDownSync(mask, hi, delta, width);
lo = CudaShuffleDownSync(mask, lo, delta, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
// Wrapper for __shfl_xor_sync.
template <typename T>
__device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask,
int width = warpSize) {
CudaVerifySyncMask(mask);
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, lane_mask, width);
#else
return __shfl_xor(value, lane_mask, width);
#endif
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// See b/69446944.
__device__ inline double CudaShuffleXorSync(unsigned mask, double value,
int lane_mask,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = CudaShuffleXorSync(mask, hi, lane_mask, width);
lo = CudaShuffleXorSync(mask, lo, lane_mask, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
// Wrapper for __ldg.
template <typename T>
__host__ __device__ T CudaLdg(const T* address) {
#if __CUDA_ARCH__ >= 350
return __ldg(address);
#else
return *address;
#endif
}
__host__ __device__ inline bool CudaLdg(const bool* address) {
return CudaLdg(reinterpret_cast<const char*>(address)) != 0;
}
__host__ __device__ inline std::complex<float> CudaLdg(
const std::complex<float>* address) {
#if __CUDA_ARCH__ >= 350
float2 mem = __ldg(reinterpret_cast<const float2*>(address));
return std::complex<float>(mem.x, mem.y);
#else
return *address;
#endif
}
__host__ __device__ inline std::complex<double> CudaLdg(
const std::complex<double>* address) {
#if __CUDA_ARCH__ >= 350
double2 mem = __ldg(reinterpret_cast<const double2*>(address));
return std::complex<double>(mem.x, mem.y);
#else
return *address;
#endif
}
// Zeroes count elements starting at ptr using all threads of a 1-D grid.
// Note: this function does not synchronize, and therefore the memory range is
// not guaranteed to be zero until the next kernel launch.
template <typename T>
__global__ void SetZero(const int count, T* ptr) {
// Check that the grid is one dimensional and index doesn't overflow.
assert(blockDim.y == 1 && blockDim.z == 1);
assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
for (int i : CudaGridRangeX(count)) {
ptr[i] = T(0);
}
}
namespace detail {
// Helper function for atomic accumulation implemented as CAS.
template <typename T, typename F>
__device__ T CudaAtomicCasHelper(T* ptr, F accumulate) {
T old = *ptr;
T assumed;
do {
assumed = old;
old = atomicCAS(ptr, assumed, accumulate(assumed));
} while (assumed != old);
return old;
}
// Overload for floating point (using integer comparison to handle NaN
// correctly).
template <typename F>
__device__ float CudaAtomicCasHelper(float* ptr, F accumulate) {
return __float_as_int(
CudaAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) {
return __float_as_int(accumulate(__int_as_float(a)));
}));
}
template <typename F>
__device__ double CudaAtomicCasHelper(double* ptr, F accumulate) {
return __longlong_as_double(CudaAtomicCasHelper(
reinterpret_cast<tensorflow::uint64*>(ptr),
[accumulate](tensorflow::uint64 a) {
return __double_as_longlong(accumulate(__longlong_as_double(a)));
}));
}
} // namespace detail
// CUDA provides atomic ops, but not for all types. We provide wrappers
// for some ops and provide implementation for all reasonable types.
template <typename T>
__device__ T CudaAtomicAdd(T* ptr, T value) {
return atomicAdd(ptr, value);
}
#if __CUDA_ARCH__ < 600
__device__ inline double CudaAtomicAdd(double* ptr, double value) {
return detail::CudaAtomicCasHelper(ptr,
[value](double a) { return a + value; });
}
#elif __clang__
// Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX.
// see https://reviews.llvm.org/D39638
__device__ inline double CudaAtomicAdd(double* ptr, double value) {
double result;
asm volatile("atom.add.f64 %0, [%1], %2;"
: "=d"(result)
: "l"(ptr), "d"(value)
: "memory");
return result;
}
#endif
template <typename T>
__device__ T CudaAtomicSub(T* ptr, T value) {
return atomicSub(ptr, value);
}
// Specializations of substraction which add the negative value.
__device__ inline float CudaAtomicSub(float* ptr, float value) {
return CudaAtomicAdd(ptr, -value);
}
__device__ inline double CudaAtomicSub(double* ptr, double value) {
return CudaAtomicAdd(ptr, -value);
}
__device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr,
tensorflow::uint64 value) {
return CudaAtomicAdd(ptr, -value);
}
template <typename T>
__device__ T CudaAtomicMax(T* ptr, T value) {
return atomicMax(ptr, value);
}
#if __CUDA_ARCH__ < 320
__device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr,
tensorflow::uint64 value) {
return detail::CudaAtomicCasHelper(
ptr, [value](tensorflow::uint64 a) { return max(a, value); });
}
#endif
template <typename T>
__device__ inline T CudaAtomicMul(T* ptr, T value) {
return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; });
}
template <typename T>
__device__ inline T CudaAtomicDiv(T* ptr, T value) {
return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; });
}
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_

View File

@ -18,125 +18,299 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/util/cuda_device_functions.h"
#include "tensorflow/core/util/cuda_launch_config.h"
#include <algorithm>
// Deprecated, use 'for(int i : CudaGridRangeX(n))' instead.
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i : ::tensorflow::CudaGridRangeX<int>(n))
// Deprecated, use 'for(int i : CudaGridRange?(n))' instead.
#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
for (int i : ::tensorflow::CudaGridRange##axis<int>(n))
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "cuda/include/cuda.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
template <typename T>
__host__ __device__ inline T ldg(const T* ptr) {
return CudaLdg(ptr);
}
// Mask for all 32 threads in a warp.
#define CUDA_WARP_ALL 0xFFFFFFFF
template <typename T>
__host__ __device__ inline const T& tf_min(const T& x, const T& y) {
return x < y ? x : y;
}
#if defined(CUDA_VERSION) && CUDA_VERSION < 9000
// CUDA 9.0 introduces a new, light-weight barrier synchronization primitive
// that operates at the warp-scope. This is required to ensure visibility of
// reads/writes among threads that can make indepenent progress on Volta.
// For previous CUDA versions these synchronizations not necessary, and we
// define an empty function as a convenience for backward compatibility.
__device__ inline void __syncwarp(unsigned mask = CUDA_WARP_ALL) {}
template <typename T>
__host__ __device__ inline const T& tf_max(const T& x, const T& y) {
return x < y ? y : x;
}
// CUDA 9.0 deprecates the warp-intrinsic functions (shfl, ballot, etc.) in
// favor of synchronizing versions. These ensure that all warp lanes specified
// in mask execute the intrinsic in convergence. Here we provide legacy mappings
// to the less-verbose routines provided in previous versions of CUDA.
#define __ballot_sync(mask, predicate) __ballot(predicate)
#define __shfl_sync(mask, val, srcLane, width) __shfl(val, srcLane, width)
#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width)
#define __shfl_up_sync(mask, val, delta, width) __shfl_up(val, delta, width)
#define __shfl_xor_sync(mask, val, laneMask, width) \
__shfl_xor(val, laneMask, width)
#endif
// Overloads of the above functions for float and double.
__host__ __device__ inline float tf_min(float x, float y) {
return fminf(x, y);
}
__host__ __device__ inline double tf_min(double x, double y) {
return fmin(x, y);
}
__host__ __device__ inline float tf_max(float x, float y) {
return fmaxf(x, y);
}
__host__ __device__ inline double tf_max(double x, double y) {
return fmax(x, y);
}
__device__ inline Eigen::half CudaShuffleSync(unsigned mask, Eigen::half value,
int src_lane,
int width = warpSize) {
return Eigen::half(
CudaShuffleSync(mask, static_cast<uint16>(value), src_lane, width));
}
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleUpSync(
unsigned mask, Eigen::half value, int delta, int width = warpSize) {
return Eigen::half(
CudaShuffleUpSync(mask, static_cast<uint16>(value), delta, width));
}
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDownSync(
unsigned mask, Eigen::half value, int delta, int width = warpSize) {
return Eigen::half(
CudaShuffleDownSync(mask, static_cast<uint16>(value), delta, width));
}
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync(
unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) {
return Eigen::half(
CudaShuffleXorSync(mask, static_cast<uint16>(value), lane_mask, width));
}
namespace detail {
// Overload of above function for half. Note that we don't have
// atomicCAS() for anything less than 32 bits, so we need to include the
// other 16 bits in the operation.
// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
// GetCuda3DLaunchConfig:
//
// This version is going to be very slow
// under high concurrency, since most threads will be spinning on failing
// their compare-and-swap tests. (The fact that we get false sharing on the
// neighboring fp16 makes this even worse.) If you are doing a large reduction,
// you are much better off with doing the intermediate steps in fp32 and then
// switching to fp16 as late as you can in the calculations.
// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
// version uses heuristics without any knowledge of the device kernel, the other
// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
// launch parameters that maximize occupancy. Currently, only the maximum
// occupancy version of GetCuda3DLaunchConfig is available.
//
// Note: Assumes little endian.
template <typename F>
__device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) {
namespace half_impl = Eigen::half_impl;
intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
if (intptr & 0x3) {
assert(!(intptr & 0x1));
// The half is in the second part of the uint32 (upper 16 bits).
uint32* address = reinterpret_cast<uint32*>(intptr - 2);
uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 a) {
Eigen::half acc = accumulate(
half_impl::__half_raw{static_cast<unsigned short>(a >> 16)});
uint32_t upper = static_cast<half_impl::__half_raw>(acc).x;
return (upper << 16) | (a & 0xffff);
});
return half_impl::__half_raw{static_cast<uint16>(result >> 16)};
} else {
// The half is in the first part of the uint32 (lower 16 bits).
uint32* address = reinterpret_cast<uint32*>(intptr);
uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 a) {
Eigen::half acc = accumulate(
half_impl::__half_raw{static_cast<unsigned short>(a & 0xffff)});
uint32_t lower = static_cast<half_impl::__half_raw>(acc).x;
return (a & 0xffff0000) | lower;
});
return half_impl::__half_raw{static_cast<uint16>(result & 0xffff)};
// For large number of work elements, the convention is that each kernel would
// iterate through its assigned range. The return value of GetCudaLaunchConfig
// is struct CudaLaunchConfig, which contains all the information needed for the
// kernel launch, including: virtual number of threads, the number of threads
// per block and number of threads per block used inside <<< >>> of a kernel
// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
// as CudaLaunchConfig. The only difference is the dimension. The macros
// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
//
/* Sample code:
__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
do_your_job_here;
}
}
} // namespace detail
__device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr,
Eigen::half value) {
return detail::CudaAtomicCasHelper(
ptr, [value](Eigen::half a) { return a + value; });
__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
do_your_job_here;
}
}
}
__device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr,
Eigen::half value) {
return detail::CudaAtomicCasHelper(
ptr, [value](Eigen::half a) { return a - value; });
__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
do_your_job_here;
}
}
}
}
void MyDriverFunc(const GPUDevice &d) {
// use heuristics
CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
MyKernel1D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
MyKernel2D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
MyKernel3D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
// maximize occupancy
CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
MyKernel1D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
MyKernel1D, 0, 0);
MyKernel2D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
MyKernel1D, 0, 0);
MyKernel3D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
}
// See the test for this for more example:
//
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
*/
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
for (int i = blockIdx.axis * blockDim.axis + threadIdx.axis; i < n.axis; \
i += blockDim.axis * gridDim.axis)
#define DIV_UP(a, b) (((a) + (b)-1) / (b))
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
struct CudaLaunchConfig {
// Logical number of thread that works on the elements. If each logical
// thread works on exactly a single element, this is the same as the working
// element count.
int virtual_thread_count = -1;
// Number of threads per block.
int thread_per_block = -1;
// Number of blocks for Cuda kernel launch.
int block_count = -1;
};
// Calculate the Cuda launch config we should use for a kernel launch.
// This is assuming the kernel is quite simple and will largely be
// memory-limited.
// REQUIRES: work_element_count > 0.
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
const GPUDevice& d) {
CHECK_GT(work_element_count, 0);
CudaLaunchConfig config;
const int virtual_thread_count = work_element_count;
const int physical_thread_count = std::min(
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
virtual_thread_count);
const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
const int block_count =
std::min(DIV_UP(physical_thread_count, thread_per_block),
d.getNumCudaMultiProcessors());
config.virtual_thread_count = virtual_thread_count;
config.thread_per_block = thread_per_block;
config.block_count = block_count;
return config;
}
// Calculate the Cuda launch config we should use for a kernel launch. This
// variant takes the resource limits of func into account to maximize occupancy.
// REQUIRES: work_element_count > 0.
template <typename DeviceFunc>
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
const GPUDevice& d, DeviceFunc func,
size_t dynamic_shared_memory_size,
int block_size_limit) {
CHECK_GT(work_element_count, 0);
CudaLaunchConfig config;
int block_count = 0;
int thread_per_block = 0;
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
block_size_limit);
CHECK_EQ(err, cudaSuccess);
block_count =
std::min(block_count, DIV_UP(work_element_count, thread_per_block));
config.virtual_thread_count = work_element_count;
config.thread_per_block = thread_per_block;
config.block_count = block_count;
return config;
}
struct Cuda2DLaunchConfig {
dim3 virtual_thread_count = dim3(0, 0, 0);
dim3 thread_per_block = dim3(0, 0, 0);
dim3 block_count = dim3(0, 0, 0);
};
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
const GPUDevice& d) {
Cuda2DLaunchConfig config;
if (xdim <= 0 || ydim <= 0) {
return config;
}
const int kThreadsPerBlock = 256;
int block_cols = std::min(xdim, kThreadsPerBlock);
// ok to round down here and just do more loops in the kernel
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
const int physical_thread_count =
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
config.virtual_thread_count = dim3(xdim, ydim, 1);
config.thread_per_block = dim3(block_cols, block_rows, 1);
int grid_x = std::min(DIV_UP(xdim, block_cols), max_blocks);
config.block_count = dim3(
grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
return config;
}
// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
// This variant takes the resource limits of func into account to maximize
// occupancy.
using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
template <typename DeviceFunc>
inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
int xdim, int ydim, int zdim, const GPUDevice& d, DeviceFunc func,
size_t dynamic_shared_memory_size, int block_size_limit) {
Cuda3DLaunchConfig config;
if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
return config;
}
int dev;
cudaGetDevice(&dev);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, dev);
int xthreadlimit = deviceProp.maxThreadsDim[0];
int ythreadlimit = deviceProp.maxThreadsDim[1];
int zthreadlimit = deviceProp.maxThreadsDim[2];
int xgridlimit = deviceProp.maxGridSize[0];
int ygridlimit = deviceProp.maxGridSize[1];
int zgridlimit = deviceProp.maxGridSize[2];
int block_count = 0;
int thread_per_block = 0;
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
block_size_limit);
CHECK_EQ(err, cudaSuccess);
#define MIN3(a, b, c) std::min((a), std::min((b), (c)))
int threadsx = MIN3(xdim, thread_per_block, xthreadlimit);
int threadsy =
MIN3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
int threadsz =
MIN3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
zthreadlimit);
int blocksx = MIN3(block_count, DIV_UP(xdim, threadsx), xgridlimit);
int blocksy =
MIN3(DIV_UP(block_count, blocksx), DIV_UP(ydim, threadsy), ygridlimit);
int blocksz = MIN3(DIV_UP(block_count, (blocksx * blocksy)),
DIV_UP(zdim, threadsz), zgridlimit);
#undef MIN3
config.virtual_thread_count = dim3(xdim, ydim, zdim);
config.thread_per_block = dim3(threadsx, threadsy, threadsz);
config.block_count = dim3(blocksx, blocksy, blocksz);
return config;
}
template <typename DeviceFunc>
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
int xdim, int ydim, const GPUDevice& d, DeviceFunc func,
size_t dynamic_shared_memory_size, int block_size_limit) {
return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
dynamic_shared_memory_size, block_size_limit);
}
// Returns a raw reference to the current cuda stream. Required by a
// number of kernel calls (for which StreamInterface* does not work), i.e.
// CUB and certain cublas primitives.
inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
const cudaStream_t* ptr = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
->stream()
->implementation()
->CudaStreamMemberHack()));
return *ptr;
}
namespace cuda_helper {
template <typename IntType>
__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
IntType* orig = first;
@ -156,8 +330,481 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
return first - orig;
}
} // namespace cuda_helper
template <typename T>
__device__ __host__ inline T ldg(const T* address) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
return __ldg(address);
#else
return *address;
#endif
}
template <>
__device__ __host__ inline std::complex<float> ldg(
const std::complex<float>* address) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
float2 mem = __ldg(reinterpret_cast<const float2*>(address));
return std::complex<float>(mem.x, mem.y);
#else
return *address;
#endif
}
template <>
__device__ __host__ inline std::complex<double> ldg(
const std::complex<double>* address) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
double2 mem = __ldg(reinterpret_cast<const double2*>(address));
return std::complex<double>(mem.x, mem.y);
#else
return *address;
#endif
}
template <>
__device__ __host__ inline Eigen::half ldg(const Eigen::half* address) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
return Eigen::half_impl::raw_uint16_to_half(
__ldg(reinterpret_cast<const uint16_t*>(address)));
#else
return *address;
#endif
}
template <>
__device__ __host__ inline bool ldg(const bool* address) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
return *reinterpret_cast<const bool*>(
__ldg(reinterpret_cast<const char*>(address)));
#else
return *address;
#endif
}
// CUDA provides atomic ops, but not for all types. We provide wrappers
// for some ops and provide implementation for all reasonable types.
#define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T* address, T val)
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
// For atomicAdd.
USE_CUDA_ATOMIC(Add, int32);
USE_CUDA_ATOMIC(Add, uint32);
USE_CUDA_ATOMIC(Add, uint64);
USE_CUDA_ATOMIC(Add, float);
// For atomicMax.
USE_CUDA_ATOMIC(Max, int32);
USE_CUDA_ATOMIC(Max, uint32);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
USE_CUDA_ATOMIC(Max, uint64);
#else
// The uint64 overload of atomicMax() is only available for __CUDA_ARCH__ >=
// 350. If not satisfied, we provide a custom implementation using atomicCAS().
CUDA_ATOMIC_WRAPPER(Max, uint64) {
uint64* address_as_ull = reinterpret_cast<uint64*>(address);
uint64 old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, max(val, assumed));
} while (assumed != old);
return old;
}
#endif
// Custom implementation of atomicAdd for double.
// This implementation is copied from CUDA manual.
CUDA_ATOMIC_WRAPPER(Add, double) {
uint64* address_as_ull = reinterpret_cast<uint64*>(address);
uint64 old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN
} while (assumed != old);
return __longlong_as_double(old);
}
// Custom implementation of atomicAdd for std::complex<float>.
// This implementation performs to atomic additions on the components.
CUDA_ATOMIC_WRAPPER(Add, std::complex<float>) {
#if defined(__CUDA_ARCH__)
#if __CUDA_ARCH__ >= 350
float2* addr_as_float2 = reinterpret_cast<float2*>(address);
float2* val_as_float2 = reinterpret_cast<float2*>(&val);
CudaAtomicAdd(&(addr_as_float2->x), val_as_float2->x);
CudaAtomicAdd(&(addr_as_float2->y), val_as_float2->y);
#else
static_assert(sizeof(std::complex<float>) == 2 * sizeof(float),
"Unable to compile CudaAtomicAdd for complex64 because "
"sizeof(complex64) != 2*sizeof(float32)");
float* addr_as_float = reinterpret_cast<float*>(address);
float* val_as_float = reinterpret_cast<float*>(&val);
CudaAtomicAdd(addr_as_float, *val_as_float);
CudaAtomicAdd(addr_as_float + 1, *(val_as_float + 1));
#endif
#endif
return *address;
}
// Custom implementation of atomicAdd for std::complex<double>.
// This implementation performs to atomic additions on the components
// using the double atomic wrapper above.
CUDA_ATOMIC_WRAPPER(Add, complex128) {
#if defined(__CUDA_ARCH__)
#if __CUDA_ARCH__ >= 350
double2* addr_as_double2 = reinterpret_cast<double2*>(address);
double2* val_as_double2 = reinterpret_cast<double2*>(&val);
CudaAtomicAdd(&(addr_as_double2->x), val_as_double2->x);
CudaAtomicAdd(&(addr_as_double2->y), val_as_double2->y);
#else
static_assert(sizeof(std::complex<double>) == 2 * sizeof(double),
"Unable to compile CudaAtomicAdd for complex128 because "
"sizeof(complex128) != 2*sizeof(float64)");
double* addr_as_double = reinterpret_cast<double*>(address);
double* val_as_double = reinterpret_cast<double*>(&val);
CudaAtomicAdd(addr_as_double, *val_as_double);
CudaAtomicAdd(addr_as_double + 1, *(val_as_double + 1));
#endif
#endif
return *address;
}
// Helper functions for CudaAtomicAdd(half*, half), below.
//
// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2()
// for a more efficient implementation, assuming that adding -0.0
// will never harm the neighboring value. In this version, we take special
// care to guarantee the bits of the untouched value are unchanged.
inline __device__ uint32 add_to_low_half(uint32 val, float x) {
Eigen::half low_half;
low_half.x = static_cast<uint16>(val & 0xffffu);
low_half = static_cast<Eigen::half>(static_cast<float>(low_half) + x);
return (val & 0xffff0000u) | low_half.x;
}
inline __device__ uint32 add_to_high_half(uint32 val, float x) {
Eigen::half high_half;
high_half.x = static_cast<uint16>(val >> 16);
high_half = static_cast<Eigen::half>(static_cast<float>(high_half) + x);
return (val & 0xffffu) | (high_half.x << 16);
}
// Custom implementation of atomicAdd for half. Note that we don't have
// atomicCAS() for anything less than 32 bits, so we need to include the
// other 16 bits in the operation.
//
// Unlike the other atomic adds, this version is going to be very slow
// under high concurrency, since most threads will be spinning on failing
// their compare-and-swap tests. (The fact that we get false sharing on the
// neighboring fp16 makes this even worse.) If you are doing a large reduction,
// you are much better off with doing the intermediate steps in fp32 and then
// switching to fp16 as late as you can in the calculations.
//
// Note: Assumes little endian.
CUDA_ATOMIC_WRAPPER(Add, Eigen::half) {
float val_as_float(val);
intptr_t address_int = reinterpret_cast<intptr_t>(address);
if ((address_int & 0x2) == 0) {
// The half is in the first part of the uint32 (lower 16 bits).
uint32* address_as_uint32 = reinterpret_cast<uint32*>(address);
assert(((intptr_t)address_as_uint32 & 0x3) == 0);
uint32 old = *address_as_uint32, assumed;
do {
assumed = old;
old = atomicCAS(address_as_uint32, assumed,
add_to_low_half(assumed, val_as_float));
// Note: uses integer comparison to avoid hang in case of NaN
} while (assumed != old);
Eigen::half ret;
ret.x = old & 0xffffu;
return ret;
} else {
// The half is in the second part of the uint32 (upper 16 bits).
uint32* address_as_uint32 = reinterpret_cast<uint32*>(address_int - 2);
assert(((intptr_t)address_as_uint32 & 0x3) == 0);
uint32 old = *address_as_uint32, assumed;
do {
assumed = old;
old = atomicCAS(address_as_uint32, assumed,
add_to_high_half(assumed, val_as_float));
// Note: uses integer comparison to avoid hang in case of NaN
} while (assumed != old);
Eigen::half ret;
ret.x = old >> 16;
return ret;
}
}
template <typename T>
__global__ void SetZero(const int nthreads, T* bottom_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); }
}
// For atomicSub.
// Custom implementation for sub by just negating the value.
#define WRAPPED_ATOMIC_SUB(T) \
CUDA_ATOMIC_WRAPPER(Sub, T) { return CudaAtomicAdd(address, -val); }
WRAPPED_ATOMIC_SUB(uint64);
WRAPPED_ATOMIC_SUB(int32);
WRAPPED_ATOMIC_SUB(uint32);
WRAPPED_ATOMIC_SUB(Eigen::half);
WRAPPED_ATOMIC_SUB(float);
WRAPPED_ATOMIC_SUB(double);
CUDA_ATOMIC_WRAPPER(Sub, complex64) {
const std::complex<float> Tneg(-val.real(), -val.imag());
return CudaAtomicAdd(address, Tneg);
}
CUDA_ATOMIC_WRAPPER(Sub, complex128) {
const std::complex<double> Tneg(-val.real(), -val.imag());
return CudaAtomicAdd(address, Tneg);
}
#undef WRAPPED_ATOMIC_SUB
// For atomicMul.
CUDA_ATOMIC_WRAPPER(Mul, int32) {
int32 old = *address, assumed;
do {
assumed = old;
old = atomicCAS(address, assumed, val * assumed);
} while (assumed != old);
return old;
}
CUDA_ATOMIC_WRAPPER(Mul, uint32) {
uint32 old = *address, assumed;
do {
assumed = old;
old = atomicCAS(address, assumed, val * assumed);
} while (assumed != old);
return old;
}
CUDA_ATOMIC_WRAPPER(Mul, uint64) {
uint64 old = *address, assumed;
do {
assumed = old;
old = atomicCAS(address, assumed, val * assumed);
} while (assumed != old);
return old;
}
CUDA_ATOMIC_WRAPPER(Mul, float) {
int32* address_as_int = reinterpret_cast<int32*>(address);
int32 old = *address_as_int, assumed;
do {
assumed = old;
old = atomicCAS(address_as_int, assumed,
__float_as_int(val * __int_as_float(assumed)));
} while (assumed != old);
return __int_as_float(old);
}
CUDA_ATOMIC_WRAPPER(Mul, double) {
uint64* address_as_ull = reinterpret_cast<uint64*>(address);
uint64 old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val * __longlong_as_double(assumed)));
} while (assumed != old);
return __longlong_as_double(old);
}
// For atomicDiv.
CUDA_ATOMIC_WRAPPER(Div, int32) {
int32 old = *address, assumed;
do {
assumed = old;
old = atomicCAS(address, assumed, assumed / val);
} while (assumed != old);
return old;
}
CUDA_ATOMIC_WRAPPER(Div, uint32) {
uint32 old = *address, assumed;
do {
assumed = old;
old = atomicCAS(address, assumed, assumed / val);
} while (assumed != old);
return old;
}
CUDA_ATOMIC_WRAPPER(Div, uint64) {
uint64 old = *address, assumed;
do {
assumed = old;
old = atomicCAS(address, assumed, assumed / val);
} while (assumed != old);
return old;
}
CUDA_ATOMIC_WRAPPER(Div, float) {
int32* address_as_int = reinterpret_cast<int32*>(address);
int32 old = *address_as_int, assumed;
do {
assumed = old;
old = atomicCAS(address_as_int, assumed,
__float_as_int(__int_as_float(assumed) / val));
} while (assumed != old);
return __int_as_float(old);
}
CUDA_ATOMIC_WRAPPER(Div, double) {
uint64* address_as_ull = reinterpret_cast<uint64*>(address);
uint64 old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(__longlong_as_double(assumed) / val));
} while (assumed != old);
return __longlong_as_double(old);
}
#undef USE_CUDA_ATOMIC
#undef CUDA_ATOMIC_WRAPPER
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_min(const T& x, const T& y) {
return x > y ? y : x;
}
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) {
return x < y ? y : x;
}
__device__ EIGEN_ALWAYS_INLINE unsigned CudaBallot(unsigned mask,
int predicate) {
return __ballot_sync(mask, predicate);
}
template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(unsigned mask, T value,
int srcLane,
int width = warpSize) {
return __shfl_sync(mask, value, srcLane, width);
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(unsigned mask, double value,
int srcLane,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = __shfl_sync(mask, hi, srcLane, width);
lo = __shfl_sync(mask, lo, srcLane, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(unsigned mask, T value,
int delta,
int width = warpSize) {
return __shfl_up_sync(mask, value, delta, width);
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(unsigned mask, double value,
int delta,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = __shfl_up_sync(mask, hi, delta, width);
lo = __shfl_up_sync(mask, lo, delta, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask, T value,
int delta,
int width = warpSize) {
return __shfl_down_sync(mask, value, delta, width);
}
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDown(
unsigned mask, Eigen::half value, int delta, int width = warpSize) {
return Eigen::half(
__shfl_down_sync(mask, static_cast<uint16>(value), delta, width));
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(unsigned mask,
double value, int delta,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = __shfl_down_sync(mask, hi, delta, width);
lo = __shfl_down_sync(mask, lo, delta, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask, T value,
int laneMask,
int width = warpSize) {
return __shfl_xor_sync(mask, value, laneMask, width);
}
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXor(
unsigned mask, Eigen::half value, int laneMask, int width = warpSize) {
return Eigen::half(
__shfl_xor_sync(mask, static_cast<uint16>(value), laneMask, width));
}
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(unsigned mask,
double value, int laneMask,
int width = warpSize) {
unsigned lo, hi;
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
hi = __shfl_xor_sync(mask, hi, laneMask, width);
lo = __shfl_xor_sync(mask, lo, laneMask, width);
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
return value;
}
} // namespace tensorflow
#undef DIV_UP
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_

View File

@ -52,11 +52,11 @@ __global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) {
}
}
__global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
if (y < 0) { // y might overflow when testing extreme case
break;
}
@ -66,15 +66,15 @@ __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
}
}
__global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
if (y < 0) { // y might overflow when testing extreme case
break;
}
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) {
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
if (z < 0) { // z might overflow when testing extreme case
break;
}
@ -94,7 +94,7 @@ class CudaLaunchConfigTest : public ::testing::Test {
const int bufsize = 1024;
int* outbuf = nullptr;
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice d = Eigen::GpuDevice(&stream);
GPUDevice d = GPUDevice(&stream);
virtual void SetUp() {
cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize);

View File

@ -1,284 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
#define TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
#if GOOGLE_CUDA
#include <algorithm>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "cuda/include/cuda.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
// GetCuda3DLaunchConfig:
//
// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
// version uses heuristics without any knowledge of the device kernel, the other
// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
// launch parameters that maximize occupancy. Currently, only the maximum
// occupancy version of GetCuda3DLaunchConfig is available.
//
// For large number of work elements, the convention is that each kernel would
// iterate through its assigned range. The return value of GetCudaLaunchConfig
// is struct CudaLaunchConfig, which contains all the information needed for the
// kernel launch, including: virtual number of threads, the number of threads
// per block and number of threads per block used inside <<< >>> of a kernel
// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
// as CudaLaunchConfig. The only difference is the dimension. The macros
// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
//
/* Sample code:
__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
do_your_job_here;
}
}
__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
do_your_job_here;
}
}
}
__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
do_your_job_here;
}
}
}
}
void MyDriverFunc(const Eigen::GpuDevice &d) {
// use heuristics
CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
MyKernel1D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
MyKernel2D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
MyKernel3D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
// maximize occupancy
CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
MyKernel1D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
MyKernel1D, 0, 0);
MyKernel2D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
MyKernel1D, 0, 0);
MyKernel3D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
}
// See the test for this for more example:
//
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
*/
namespace tensorflow {
inline int DivUp(int a, int b) { return (a + b - 1) / b; }
struct CudaLaunchConfig {
// Logical number of thread that works on the elements. If each logical
// thread works on exactly a single element, this is the same as the working
// element count.
int virtual_thread_count = -1;
// Number of threads per block.
int thread_per_block = -1;
// Number of blocks for Cuda kernel launch.
int block_count = -1;
};
// Calculate the Cuda launch config we should use for a kernel launch.
// This is assuming the kernel is quite simple and will largely be
// memory-limited.
// REQUIRES: work_element_count > 0.
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
const Eigen::GpuDevice& d) {
CHECK_GT(work_element_count, 0);
CudaLaunchConfig config;
const int virtual_thread_count = work_element_count;
const int physical_thread_count = std::min(
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
virtual_thread_count);
const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
const int block_count =
std::min(DivUp(physical_thread_count, thread_per_block),
d.getNumCudaMultiProcessors());
config.virtual_thread_count = virtual_thread_count;
config.thread_per_block = thread_per_block;
config.block_count = block_count;
return config;
}
// Calculate the Cuda launch config we should use for a kernel launch. This
// variant takes the resource limits of func into account to maximize occupancy.
// REQUIRES: work_element_count > 0.
template <typename DeviceFunc>
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
const Eigen::GpuDevice& d,
DeviceFunc func,
size_t dynamic_shared_memory_size,
int block_size_limit) {
CHECK_GT(work_element_count, 0);
CudaLaunchConfig config;
int block_count = 0;
int thread_per_block = 0;
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
block_size_limit);
CHECK_EQ(err, cudaSuccess);
block_count =
std::min(block_count, DivUp(work_element_count, thread_per_block));
config.virtual_thread_count = work_element_count;
config.thread_per_block = thread_per_block;
config.block_count = block_count;
return config;
}
struct Cuda2DLaunchConfig {
dim3 virtual_thread_count = dim3(0, 0, 0);
dim3 thread_per_block = dim3(0, 0, 0);
dim3 block_count = dim3(0, 0, 0);
};
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
const Eigen::GpuDevice& d) {
Cuda2DLaunchConfig config;
if (xdim <= 0 || ydim <= 0) {
return config;
}
const int kThreadsPerBlock = 256;
int block_cols = std::min(xdim, kThreadsPerBlock);
// ok to round down here and just do more loops in the kernel
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
const int physical_thread_count =
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
config.virtual_thread_count = dim3(xdim, ydim, 1);
config.thread_per_block = dim3(block_cols, block_rows, 1);
int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
config.block_count = dim3(
grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
return config;
}
// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
// This variant takes the resource limits of func into account to maximize
// occupancy.
using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
template <typename DeviceFunc>
inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
int xdim, int ydim, int zdim, const Eigen::GpuDevice& d, DeviceFunc func,
size_t dynamic_shared_memory_size, int block_size_limit) {
Cuda3DLaunchConfig config;
if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
return config;
}
int dev;
cudaGetDevice(&dev);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, dev);
int xthreadlimit = deviceProp.maxThreadsDim[0];
int ythreadlimit = deviceProp.maxThreadsDim[1];
int zthreadlimit = deviceProp.maxThreadsDim[2];
int xgridlimit = deviceProp.maxGridSize[0];
int ygridlimit = deviceProp.maxGridSize[1];
int zgridlimit = deviceProp.maxGridSize[2];
int block_count = 0;
int thread_per_block = 0;
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
block_size_limit);
CHECK_EQ(err, cudaSuccess);
auto min3 = [](int a, int b, int c) { return std::min(a, std::min(b, c)); };
int threadsx = min3(xdim, thread_per_block, xthreadlimit);
int threadsy =
min3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
int threadsz =
min3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
zthreadlimit);
int blocksx = min3(block_count, DivUp(xdim, threadsx), xgridlimit);
int blocksy =
min3(DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit);
int blocksz = min3(DivUp(block_count, (blocksx * blocksy)),
DivUp(zdim, threadsz), zgridlimit);
config.virtual_thread_count = dim3(xdim, ydim, zdim);
config.thread_per_block = dim3(threadsx, threadsy, threadsz);
config.block_count = dim3(blocksx, blocksy, blocksz);
return config;
}
template <typename DeviceFunc>
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
int xdim, int ydim, const Eigen::GpuDevice& d, DeviceFunc func,
size_t dynamic_shared_memory_size, int block_size_limit) {
return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
dynamic_shared_memory_size, block_size_limit);
}
// Returns a raw reference to the current cuda stream. Required by a
// number of kernel calls (for which StreamInterface* does not work), i.e.
// CUB and certain cublas primitives.
inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
const cudaStream_t* ptr = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
->stream()
->implementation()
->CudaStreamMemberHack()));
return *ptr;
}
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_