Automated g4 rollback of changelist 177799252
PiperOrigin-RevId: 177989542
This commit is contained in:
parent
33e3da538a
commit
21e831dc4a
@ -34,9 +34,9 @@ namespace functor {
|
||||
__global__ void ReduceSliceDeviceKernel##reduceop( \
|
||||
Cuda3DLaunchConfig config, Index indices_width, Index bound, \
|
||||
const T begin, const Index *indices, const T *input, T *out) { \
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \
|
||||
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { \
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { \
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { \
|
||||
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { \
|
||||
Index outidx = x * config.virtual_thread_count.y * \
|
||||
config.virtual_thread_count.z + \
|
||||
y * config.virtual_thread_count.z + z; \
|
||||
@ -68,9 +68,8 @@ namespace functor {
|
||||
if (sizex * sizey * sizez == 0) { \
|
||||
return; \
|
||||
} \
|
||||
Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \
|
||||
sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop<T, Index>, \
|
||||
0, 0); \
|
||||
Cuda3DLaunchConfig config = GetCuda3DLaunchConfig(sizex, sizey, sizez, d,\
|
||||
ReduceSliceDeviceKernel##reduceop<T, Index>, 0, 0); \
|
||||
\
|
||||
ReduceSliceDeviceKernel##reduceop<T, Index> \
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \
|
||||
|
@ -1847,13 +1847,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "cuda_device_functions",
|
||||
hdrs = ["util/cuda_device_functions.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":framework_lite"],
|
||||
)
|
||||
|
||||
# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"?
|
||||
cc_library(
|
||||
name = "protos_cc",
|
||||
|
@ -173,13 +173,19 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
|
||||
// Accumulate the results in the shared memory into the first element.
|
||||
// No syncthreads is needed since this is only in the same warp.
|
||||
int32 thread_index = threadIdx.x;
|
||||
if (thread_index < 32) {
|
||||
AccT data = s_data[thread_index];
|
||||
for (int32 offset = warpSize / 2; offset > 0; offset /= 2) {
|
||||
data += CudaShuffleDownSync(kCudaWarpAll, data, offset);
|
||||
}
|
||||
if (thread_index < 16) {
|
||||
s_data[thread_index] += s_data[thread_index + 16];
|
||||
__syncwarp(0xFFFF);
|
||||
if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
|
||||
__syncwarp(0xFF);
|
||||
if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
|
||||
__syncwarp(0xF);
|
||||
if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
|
||||
__syncwarp(0x3);
|
||||
if (thread_index == 0) {
|
||||
CudaAtomicAdd(bias_backprop + bias_index, T(data));
|
||||
T val = T(s_data[0] + s_data[1]);
|
||||
// The first thread writes out the accumulated result to global location.
|
||||
CudaAtomicAdd(bias_backprop + bias_index, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -34,7 +34,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
using Eigen::GpuDevice;
|
||||
|
||||
// Returns whether depthwise convolution forward or backward input pass can be
|
||||
@ -1029,7 +1028,7 @@ __device__ __forceinline__ T WarpSumReduce(T val) {
|
||||
int zeros = sub_warp * kWidth;
|
||||
unsigned mask = ((1UL << kWidth) - 1) << zeros;
|
||||
for (int delta = kWidth / 2; delta > 0; delta /= 2) {
|
||||
val += CudaShuffleXorSync(mask, val, delta);
|
||||
val += CudaShuffleXor(mask, val, delta);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
@ -1146,7 +1145,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
|
||||
|
||||
// Note: the condition to reach this is uniform across the entire block.
|
||||
__syncthreads();
|
||||
unsigned active_threads = CudaBallotSync(kCudaWarpAll, depth_in_range);
|
||||
unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range);
|
||||
|
||||
if (depth_in_range) {
|
||||
const T* const out_ptr = inout_offset + output;
|
||||
@ -1160,7 +1159,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
|
||||
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
|
||||
// Warp-accumulate pixels of the same depth and write to accumulator.
|
||||
for (int delta = 16; delta >= kBlockSlices; delta /= 2) {
|
||||
val += CudaShuffleDownSync(active_threads, val, delta);
|
||||
val += CudaShuffleDown(active_threads, val, delta);
|
||||
}
|
||||
if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) {
|
||||
*accum_ptr = val;
|
||||
@ -1400,7 +1399,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
|
||||
|
||||
// Note: the condition to reach this is uniform across the entire block.
|
||||
__syncthreads();
|
||||
unsigned active_threads = CudaBallotSync(kCudaWarpAll, slice_in_range);
|
||||
unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range);
|
||||
|
||||
if (slice_in_range) {
|
||||
const T* const out_ptr = inout_offset + output;
|
||||
@ -1414,7 +1413,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
|
||||
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
|
||||
// Warp-accumulate pixels of the same depth and write to accumulator.
|
||||
for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) {
|
||||
val += CudaShuffleDownSync(active_threads, val, delta);
|
||||
val += CudaShuffleDown(active_threads, val, delta);
|
||||
}
|
||||
if (!(thread_idx & 32 / kBlockSlices - 1)) {
|
||||
*accum_ptr = val;
|
||||
|
@ -55,27 +55,6 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
|
||||
}
|
||||
};
|
||||
|
||||
// Specializations for std::complex, updating real and imaginary part
|
||||
// individually. Even though this is not an atomic op anymore, it is safe
|
||||
// because there is only one type of op per kernel.
|
||||
template <typename T>
|
||||
struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
|
||||
std::complex<T>* out, const std::complex<T>& val) {
|
||||
T* ptr = reinterpret_cast<T*>(out);
|
||||
CudaAtomicAdd(ptr, val.real());
|
||||
CudaAtomicAdd(ptr, val.imag());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::SUB> {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
|
||||
std::complex<T>* out, const std::complex<T>& val) {
|
||||
LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD>()(out, -val);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
|
||||
|
@ -63,8 +63,8 @@ __global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m,
|
||||
int64 ldu, const Scalar* M,
|
||||
const Scalar* U, const Scalar* S,
|
||||
Scalar* V) {
|
||||
CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
|
||||
CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
|
||||
CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) {
|
||||
CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count, y) {
|
||||
Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
|
||||
CudaAtomicAdd(V + batch, v);
|
||||
}
|
||||
|
@ -1,418 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
|
||||
#define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
|
||||
|
||||
/**
|
||||
* Wrappers and helpers for CUDA device code.
|
||||
*
|
||||
* Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
|
||||
* backwards compatibility, see go/volta-porting for details.
|
||||
* Provides atomic operations on types that aren't natively supported.
|
||||
*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include "cuda/include/cuda.h"
|
||||
#include "cuda/include/device_functions.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#if __CUDACC_VER_MAJOR__ >= 9
|
||||
#include "cuda/include/cuda_fp16.h"
|
||||
#elif __CUDACC_VER__ >= 7050
|
||||
#include "cuda/include/cuda_fp16.h"
|
||||
#else
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper for range-based for loop using 'delta' increments.
|
||||
// Usage: see CudaGridRange?() functions below.
|
||||
template <typename T>
|
||||
class CudaGridRange {
|
||||
struct Iterator {
|
||||
__device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
|
||||
__device__ T operator*() const { return index_; }
|
||||
__device__ Iterator& operator++() {
|
||||
index_ += delta_;
|
||||
return *this;
|
||||
}
|
||||
__device__ bool operator!=(const Iterator& other) const {
|
||||
bool greater = index_ > other.index_;
|
||||
bool less = index_ < other.index_;
|
||||
// Anything past an end iterator (delta_ == 0) is equal.
|
||||
// In range-based for loops, this optimizes to 'return less'.
|
||||
if (!other.delta_) {
|
||||
return less;
|
||||
}
|
||||
if (!delta_) {
|
||||
return greater;
|
||||
}
|
||||
return less || greater;
|
||||
}
|
||||
|
||||
private:
|
||||
T index_;
|
||||
const T delta_;
|
||||
};
|
||||
|
||||
public:
|
||||
__device__ CudaGridRange(T begin, T delta, T end)
|
||||
: begin_(begin), delta_(delta), end_(end) {}
|
||||
|
||||
__device__ Iterator begin() const { return Iterator{begin_, delta_}; }
|
||||
__device__ Iterator end() const { return Iterator{end_, 0}; }
|
||||
|
||||
private:
|
||||
T begin_;
|
||||
T delta_;
|
||||
T end_;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Helper to visit indices in the range 0 <= i < count, using the x-coordinate
|
||||
// of the global thread index. That is, each index i is visited by all threads
|
||||
// with the same x-coordinate.
|
||||
// Usage: for(int i : CudaGridRangeX(count)) { visit(i); }
|
||||
template <typename T>
|
||||
__device__ detail::CudaGridRange<T> CudaGridRangeX(T count) {
|
||||
return detail::CudaGridRange<T>(blockIdx.x * blockDim.x + threadIdx.x,
|
||||
gridDim.x * blockDim.x, count);
|
||||
}
|
||||
|
||||
// Helper to visit indices in the range 0 <= i < count using the y-coordinate.
|
||||
// Usage: for(int i : CudaGridRangeY(count)) { visit(i); }
|
||||
template <typename T>
|
||||
__device__ detail::CudaGridRange<T> CudaGridRangeY(T count) {
|
||||
return detail::CudaGridRange<T>(blockIdx.y * blockDim.y + threadIdx.y,
|
||||
gridDim.y * blockDim.y, count);
|
||||
}
|
||||
|
||||
// Helper to visit indices in the range 0 <= i < count using the z-coordinate.
|
||||
// Usage: for(int i : CudaGridRangeZ(count)) { visit(i); }
|
||||
template <typename T>
|
||||
__device__ detail::CudaGridRange<T> CudaGridRangeZ(T count) {
|
||||
return detail::CudaGridRange<T>(blockIdx.z * blockDim.z + threadIdx.z,
|
||||
gridDim.z * blockDim.z, count);
|
||||
}
|
||||
|
||||
// Mask for all 32 threads in a warp.
|
||||
const unsigned kCudaWarpAll = 0xffffffff;
|
||||
|
||||
// On sm_6x and earlier, verifies that all bits in mask corresponding to active
|
||||
// threads of the warp are set. It does not verify the converse (bits of
|
||||
// inactive threads are not set), because all syncs are unblocked when a thread
|
||||
// exits the kernel, but the ballot of inactive (including exited) threads
|
||||
// returns 0.
|
||||
__device__ inline void CudaVerifySyncMask(unsigned mask) {
|
||||
#if __CUDA_ARCH__ < 700
|
||||
assert(0 == (__ballot(1) & ~mask)); // Active threads must have mask bit set.
|
||||
#endif
|
||||
}
|
||||
|
||||
// For all *_sync wrappers below, it is illegal to synchronize threads from
|
||||
// different program locations, because that is not supported before sm_70.
|
||||
// Code that requires sm_70 (and CUDA 9) may use the intrinsic directly.
|
||||
|
||||
// Wrapper for __syncwarp.
|
||||
__device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) {
|
||||
CudaVerifySyncMask(mask);
|
||||
#if CUDA_VERSION >= 9000
|
||||
__syncwarp(mask);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Wrapper for __ballot_sync.
|
||||
__device__ inline unsigned CudaBallotSync(unsigned mask, int pred) {
|
||||
CudaVerifySyncMask(mask);
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __ballot_sync(mask, pred);
|
||||
#else
|
||||
return __ballot(pred);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Wrapper for __any_sync.
|
||||
__device__ inline int CudaAnySync(unsigned mask, int pred) {
|
||||
CudaVerifySyncMask(mask);
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __any_sync(mask, pred);
|
||||
#else
|
||||
return __any(pred);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Wrapper for __all_sync.
|
||||
__device__ inline int CudaAllSync(unsigned mask, int pred) {
|
||||
CudaVerifySyncMask(mask);
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __all_sync(mask, pred);
|
||||
#else
|
||||
return __all(pred);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Wrapper for __shfl_sync.
|
||||
template <typename T>
|
||||
__device__ T CudaShuffleSync(unsigned mask, T value, int src_lane,
|
||||
int width = warpSize) {
|
||||
CudaVerifySyncMask(mask);
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_sync(mask, value, src_lane, width);
|
||||
#else
|
||||
return __shfl(value, src_lane, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
|
||||
// instead of float for lo and hi (which is incorrect with ftz, for example).
|
||||
// See b/69446944.
|
||||
__device__ inline double CudaShuffleSync(unsigned mask, double value,
|
||||
int src_lane, int width = warpSize) {
|
||||
unsigned lo, hi;
|
||||
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
|
||||
hi = CudaShuffleSync(mask, hi, src_lane, width);
|
||||
lo = CudaShuffleSync(mask, lo, src_lane, width);
|
||||
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
|
||||
return value;
|
||||
}
|
||||
|
||||
// Wrapper for __shfl_up_sync.
|
||||
template <typename T>
|
||||
__device__ inline T CudaShuffleUpSync(unsigned mask, T value, int delta,
|
||||
int width = warpSize) {
|
||||
CudaVerifySyncMask(mask);
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_up_sync(mask, value, delta, width);
|
||||
#else
|
||||
return __shfl_up(value, delta, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
|
||||
// instead of float for lo and hi (which is incorrect with ftz, for example).
|
||||
// See b/69446944.
|
||||
__device__ inline double CudaShuffleUpSync(unsigned mask, double value,
|
||||
int delta, int width = warpSize) {
|
||||
unsigned lo, hi;
|
||||
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
|
||||
hi = CudaShuffleUpSync(mask, hi, delta, width);
|
||||
lo = CudaShuffleUpSync(mask, lo, delta, width);
|
||||
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
|
||||
return value;
|
||||
}
|
||||
|
||||
// Wrapper for __shfl_down_sync.
|
||||
template <typename T>
|
||||
__device__ inline T CudaShuffleDownSync(unsigned mask, T value, int delta,
|
||||
int width = warpSize) {
|
||||
CudaVerifySyncMask(mask);
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_down_sync(mask, value, delta, width);
|
||||
#else
|
||||
return __shfl_down(value, delta, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
|
||||
// instead of float for lo and hi (which is incorrect with ftz, for example).
|
||||
// See b/69446944.
|
||||
__device__ inline double CudaShuffleDownSync(unsigned mask, double value,
|
||||
int delta, int width = warpSize) {
|
||||
unsigned lo, hi;
|
||||
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
|
||||
hi = CudaShuffleDownSync(mask, hi, delta, width);
|
||||
lo = CudaShuffleDownSync(mask, lo, delta, width);
|
||||
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
|
||||
return value;
|
||||
}
|
||||
|
||||
// Wrapper for __shfl_xor_sync.
|
||||
template <typename T>
|
||||
__device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask,
|
||||
int width = warpSize) {
|
||||
CudaVerifySyncMask(mask);
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, lane_mask, width);
|
||||
#else
|
||||
return __shfl_xor(value, lane_mask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
|
||||
// instead of float for lo and hi (which is incorrect with ftz, for example).
|
||||
// See b/69446944.
|
||||
__device__ inline double CudaShuffleXorSync(unsigned mask, double value,
|
||||
int lane_mask,
|
||||
int width = warpSize) {
|
||||
unsigned lo, hi;
|
||||
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
|
||||
hi = CudaShuffleXorSync(mask, hi, lane_mask, width);
|
||||
lo = CudaShuffleXorSync(mask, lo, lane_mask, width);
|
||||
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
|
||||
return value;
|
||||
}
|
||||
|
||||
// Wrapper for __ldg.
|
||||
template <typename T>
|
||||
__host__ __device__ T CudaLdg(const T* address) {
|
||||
#if __CUDA_ARCH__ >= 350
|
||||
return __ldg(address);
|
||||
#else
|
||||
return *address;
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool CudaLdg(const bool* address) {
|
||||
return CudaLdg(reinterpret_cast<const char*>(address)) != 0;
|
||||
}
|
||||
|
||||
__host__ __device__ inline std::complex<float> CudaLdg(
|
||||
const std::complex<float>* address) {
|
||||
#if __CUDA_ARCH__ >= 350
|
||||
float2 mem = __ldg(reinterpret_cast<const float2*>(address));
|
||||
return std::complex<float>(mem.x, mem.y);
|
||||
#else
|
||||
return *address;
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ inline std::complex<double> CudaLdg(
|
||||
const std::complex<double>* address) {
|
||||
#if __CUDA_ARCH__ >= 350
|
||||
double2 mem = __ldg(reinterpret_cast<const double2*>(address));
|
||||
return std::complex<double>(mem.x, mem.y);
|
||||
#else
|
||||
return *address;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Zeroes count elements starting at ptr using all threads of a 1-D grid.
|
||||
// Note: this function does not synchronize, and therefore the memory range is
|
||||
// not guaranteed to be zero until the next kernel launch.
|
||||
template <typename T>
|
||||
__global__ void SetZero(const int count, T* ptr) {
|
||||
// Check that the grid is one dimensional and index doesn't overflow.
|
||||
assert(blockDim.y == 1 && blockDim.z == 1);
|
||||
assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
|
||||
for (int i : CudaGridRangeX(count)) {
|
||||
ptr[i] = T(0);
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
// Helper function for atomic accumulation implemented as CAS.
|
||||
template <typename T, typename F>
|
||||
__device__ T CudaAtomicCasHelper(T* ptr, F accumulate) {
|
||||
T old = *ptr;
|
||||
T assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(ptr, assumed, accumulate(assumed));
|
||||
} while (assumed != old);
|
||||
return old;
|
||||
}
|
||||
|
||||
// Overload for floating point (using integer comparison to handle NaN
|
||||
// correctly).
|
||||
template <typename F>
|
||||
__device__ float CudaAtomicCasHelper(float* ptr, F accumulate) {
|
||||
return __float_as_int(
|
||||
CudaAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) {
|
||||
return __float_as_int(accumulate(__int_as_float(a)));
|
||||
}));
|
||||
}
|
||||
template <typename F>
|
||||
__device__ double CudaAtomicCasHelper(double* ptr, F accumulate) {
|
||||
return __longlong_as_double(CudaAtomicCasHelper(
|
||||
reinterpret_cast<tensorflow::uint64*>(ptr),
|
||||
[accumulate](tensorflow::uint64 a) {
|
||||
return __double_as_longlong(accumulate(__longlong_as_double(a)));
|
||||
}));
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// CUDA provides atomic ops, but not for all types. We provide wrappers
|
||||
// for some ops and provide implementation for all reasonable types.
|
||||
|
||||
template <typename T>
|
||||
__device__ T CudaAtomicAdd(T* ptr, T value) {
|
||||
return atomicAdd(ptr, value);
|
||||
}
|
||||
#if __CUDA_ARCH__ < 600
|
||||
__device__ inline double CudaAtomicAdd(double* ptr, double value) {
|
||||
return detail::CudaAtomicCasHelper(ptr,
|
||||
[value](double a) { return a + value; });
|
||||
}
|
||||
#elif __clang__
|
||||
// Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX.
|
||||
// see https://reviews.llvm.org/D39638
|
||||
__device__ inline double CudaAtomicAdd(double* ptr, double value) {
|
||||
double result;
|
||||
asm volatile("atom.add.f64 %0, [%1], %2;"
|
||||
: "=d"(result)
|
||||
: "l"(ptr), "d"(value)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ T CudaAtomicSub(T* ptr, T value) {
|
||||
return atomicSub(ptr, value);
|
||||
}
|
||||
// Specializations of substraction which add the negative value.
|
||||
__device__ inline float CudaAtomicSub(float* ptr, float value) {
|
||||
return CudaAtomicAdd(ptr, -value);
|
||||
}
|
||||
__device__ inline double CudaAtomicSub(double* ptr, double value) {
|
||||
return CudaAtomicAdd(ptr, -value);
|
||||
}
|
||||
__device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr,
|
||||
tensorflow::uint64 value) {
|
||||
return CudaAtomicAdd(ptr, -value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T CudaAtomicMax(T* ptr, T value) {
|
||||
return atomicMax(ptr, value);
|
||||
}
|
||||
#if __CUDA_ARCH__ < 320
|
||||
__device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr,
|
||||
tensorflow::uint64 value) {
|
||||
return detail::CudaAtomicCasHelper(
|
||||
ptr, [value](tensorflow::uint64 a) { return max(a, value); });
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T CudaAtomicMul(T* ptr, T value) {
|
||||
return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; });
|
||||
}
|
||||
template <typename T>
|
||||
__device__ inline T CudaAtomicDiv(T* ptr, T value) {
|
||||
return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; });
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
|
@ -18,125 +18,299 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/util/cuda_device_functions.h"
|
||||
#include "tensorflow/core/util/cuda_launch_config.h"
|
||||
#include <algorithm>
|
||||
|
||||
// Deprecated, use 'for(int i : CudaGridRangeX(n))' instead.
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i : ::tensorflow::CudaGridRangeX<int>(n))
|
||||
// Deprecated, use 'for(int i : CudaGridRange?(n))' instead.
|
||||
#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
|
||||
for (int i : ::tensorflow::CudaGridRange##axis<int>(n))
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "cuda/include/cuda.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
template <typename T>
|
||||
__host__ __device__ inline T ldg(const T* ptr) {
|
||||
return CudaLdg(ptr);
|
||||
}
|
||||
// Mask for all 32 threads in a warp.
|
||||
#define CUDA_WARP_ALL 0xFFFFFFFF
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ inline const T& tf_min(const T& x, const T& y) {
|
||||
return x < y ? x : y;
|
||||
}
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION < 9000
|
||||
// CUDA 9.0 introduces a new, light-weight barrier synchronization primitive
|
||||
// that operates at the warp-scope. This is required to ensure visibility of
|
||||
// reads/writes among threads that can make indepenent progress on Volta.
|
||||
// For previous CUDA versions these synchronizations not necessary, and we
|
||||
// define an empty function as a convenience for backward compatibility.
|
||||
__device__ inline void __syncwarp(unsigned mask = CUDA_WARP_ALL) {}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ inline const T& tf_max(const T& x, const T& y) {
|
||||
return x < y ? y : x;
|
||||
}
|
||||
// CUDA 9.0 deprecates the warp-intrinsic functions (shfl, ballot, etc.) in
|
||||
// favor of synchronizing versions. These ensure that all warp lanes specified
|
||||
// in mask execute the intrinsic in convergence. Here we provide legacy mappings
|
||||
// to the less-verbose routines provided in previous versions of CUDA.
|
||||
#define __ballot_sync(mask, predicate) __ballot(predicate)
|
||||
#define __shfl_sync(mask, val, srcLane, width) __shfl(val, srcLane, width)
|
||||
#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width)
|
||||
#define __shfl_up_sync(mask, val, delta, width) __shfl_up(val, delta, width)
|
||||
#define __shfl_xor_sync(mask, val, laneMask, width) \
|
||||
__shfl_xor(val, laneMask, width)
|
||||
#endif
|
||||
|
||||
// Overloads of the above functions for float and double.
|
||||
__host__ __device__ inline float tf_min(float x, float y) {
|
||||
return fminf(x, y);
|
||||
}
|
||||
__host__ __device__ inline double tf_min(double x, double y) {
|
||||
return fmin(x, y);
|
||||
}
|
||||
__host__ __device__ inline float tf_max(float x, float y) {
|
||||
return fmaxf(x, y);
|
||||
}
|
||||
__host__ __device__ inline double tf_max(double x, double y) {
|
||||
return fmax(x, y);
|
||||
}
|
||||
|
||||
__device__ inline Eigen::half CudaShuffleSync(unsigned mask, Eigen::half value,
|
||||
int src_lane,
|
||||
int width = warpSize) {
|
||||
return Eigen::half(
|
||||
CudaShuffleSync(mask, static_cast<uint16>(value), src_lane, width));
|
||||
}
|
||||
|
||||
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleUpSync(
|
||||
unsigned mask, Eigen::half value, int delta, int width = warpSize) {
|
||||
return Eigen::half(
|
||||
CudaShuffleUpSync(mask, static_cast<uint16>(value), delta, width));
|
||||
}
|
||||
|
||||
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDownSync(
|
||||
unsigned mask, Eigen::half value, int delta, int width = warpSize) {
|
||||
return Eigen::half(
|
||||
CudaShuffleDownSync(mask, static_cast<uint16>(value), delta, width));
|
||||
}
|
||||
|
||||
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync(
|
||||
unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) {
|
||||
return Eigen::half(
|
||||
CudaShuffleXorSync(mask, static_cast<uint16>(value), lane_mask, width));
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
// Overload of above function for half. Note that we don't have
|
||||
// atomicCAS() for anything less than 32 bits, so we need to include the
|
||||
// other 16 bits in the operation.
|
||||
// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
|
||||
// GetCuda3DLaunchConfig:
|
||||
//
|
||||
// This version is going to be very slow
|
||||
// under high concurrency, since most threads will be spinning on failing
|
||||
// their compare-and-swap tests. (The fact that we get false sharing on the
|
||||
// neighboring fp16 makes this even worse.) If you are doing a large reduction,
|
||||
// you are much better off with doing the intermediate steps in fp32 and then
|
||||
// switching to fp16 as late as you can in the calculations.
|
||||
// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
|
||||
// version uses heuristics without any knowledge of the device kernel, the other
|
||||
// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
|
||||
// launch parameters that maximize occupancy. Currently, only the maximum
|
||||
// occupancy version of GetCuda3DLaunchConfig is available.
|
||||
//
|
||||
// Note: Assumes little endian.
|
||||
template <typename F>
|
||||
__device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) {
|
||||
namespace half_impl = Eigen::half_impl;
|
||||
intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
|
||||
if (intptr & 0x3) {
|
||||
assert(!(intptr & 0x1));
|
||||
// The half is in the second part of the uint32 (upper 16 bits).
|
||||
uint32* address = reinterpret_cast<uint32*>(intptr - 2);
|
||||
uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 a) {
|
||||
Eigen::half acc = accumulate(
|
||||
half_impl::__half_raw{static_cast<unsigned short>(a >> 16)});
|
||||
uint32_t upper = static_cast<half_impl::__half_raw>(acc).x;
|
||||
return (upper << 16) | (a & 0xffff);
|
||||
});
|
||||
return half_impl::__half_raw{static_cast<uint16>(result >> 16)};
|
||||
} else {
|
||||
// The half is in the first part of the uint32 (lower 16 bits).
|
||||
uint32* address = reinterpret_cast<uint32*>(intptr);
|
||||
uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 a) {
|
||||
Eigen::half acc = accumulate(
|
||||
half_impl::__half_raw{static_cast<unsigned short>(a & 0xffff)});
|
||||
uint32_t lower = static_cast<half_impl::__half_raw>(acc).x;
|
||||
return (a & 0xffff0000) | lower;
|
||||
});
|
||||
return half_impl::__half_raw{static_cast<uint16>(result & 0xffff)};
|
||||
// For large number of work elements, the convention is that each kernel would
|
||||
// iterate through its assigned range. The return value of GetCudaLaunchConfig
|
||||
// is struct CudaLaunchConfig, which contains all the information needed for the
|
||||
// kernel launch, including: virtual number of threads, the number of threads
|
||||
// per block and number of threads per block used inside <<< >>> of a kernel
|
||||
// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
|
||||
// as CudaLaunchConfig. The only difference is the dimension. The macros
|
||||
// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
|
||||
//
|
||||
/* Sample code:
|
||||
|
||||
__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
|
||||
CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
|
||||
do_your_job_here;
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
__device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr,
|
||||
Eigen::half value) {
|
||||
return detail::CudaAtomicCasHelper(
|
||||
ptr, [value](Eigen::half a) { return a + value; });
|
||||
__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
|
||||
do_your_job_here;
|
||||
}
|
||||
}
|
||||
}
|
||||
__device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr,
|
||||
Eigen::half value) {
|
||||
return detail::CudaAtomicCasHelper(
|
||||
ptr, [value](Eigen::half a) { return a - value; });
|
||||
|
||||
__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
|
||||
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
|
||||
do_your_job_here;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MyDriverFunc(const GPUDevice &d) {
|
||||
// use heuristics
|
||||
CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
|
||||
MyKernel1D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
|
||||
Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
|
||||
MyKernel2D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
|
||||
Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
|
||||
MyKernel3D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
|
||||
|
||||
// maximize occupancy
|
||||
CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
|
||||
MyKernel1D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
|
||||
Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
|
||||
MyKernel1D, 0, 0);
|
||||
MyKernel2D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
|
||||
Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
|
||||
MyKernel1D, 0, 0);
|
||||
MyKernel3D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
|
||||
}
|
||||
|
||||
// See the test for this for more example:
|
||||
//
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
|
||||
|
||||
*/
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
|
||||
for (int i = blockIdx.axis * blockDim.axis + threadIdx.axis; i < n.axis; \
|
||||
i += blockDim.axis * gridDim.axis)
|
||||
|
||||
#define DIV_UP(a, b) (((a) + (b)-1) / (b))
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
struct CudaLaunchConfig {
|
||||
// Logical number of thread that works on the elements. If each logical
|
||||
// thread works on exactly a single element, this is the same as the working
|
||||
// element count.
|
||||
int virtual_thread_count = -1;
|
||||
// Number of threads per block.
|
||||
int thread_per_block = -1;
|
||||
// Number of blocks for Cuda kernel launch.
|
||||
int block_count = -1;
|
||||
};
|
||||
|
||||
// Calculate the Cuda launch config we should use for a kernel launch.
|
||||
// This is assuming the kernel is quite simple and will largely be
|
||||
// memory-limited.
|
||||
// REQUIRES: work_element_count > 0.
|
||||
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
||||
const GPUDevice& d) {
|
||||
CHECK_GT(work_element_count, 0);
|
||||
CudaLaunchConfig config;
|
||||
const int virtual_thread_count = work_element_count;
|
||||
const int physical_thread_count = std::min(
|
||||
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
|
||||
virtual_thread_count);
|
||||
const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
|
||||
const int block_count =
|
||||
std::min(DIV_UP(physical_thread_count, thread_per_block),
|
||||
d.getNumCudaMultiProcessors());
|
||||
|
||||
config.virtual_thread_count = virtual_thread_count;
|
||||
config.thread_per_block = thread_per_block;
|
||||
config.block_count = block_count;
|
||||
return config;
|
||||
}
|
||||
|
||||
// Calculate the Cuda launch config we should use for a kernel launch. This
|
||||
// variant takes the resource limits of func into account to maximize occupancy.
|
||||
// REQUIRES: work_element_count > 0.
|
||||
template <typename DeviceFunc>
|
||||
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
||||
const GPUDevice& d, DeviceFunc func,
|
||||
size_t dynamic_shared_memory_size,
|
||||
int block_size_limit) {
|
||||
CHECK_GT(work_element_count, 0);
|
||||
CudaLaunchConfig config;
|
||||
int block_count = 0;
|
||||
int thread_per_block = 0;
|
||||
|
||||
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
|
||||
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
|
||||
block_size_limit);
|
||||
CHECK_EQ(err, cudaSuccess);
|
||||
|
||||
block_count =
|
||||
std::min(block_count, DIV_UP(work_element_count, thread_per_block));
|
||||
|
||||
config.virtual_thread_count = work_element_count;
|
||||
config.thread_per_block = thread_per_block;
|
||||
config.block_count = block_count;
|
||||
return config;
|
||||
}
|
||||
|
||||
struct Cuda2DLaunchConfig {
|
||||
dim3 virtual_thread_count = dim3(0, 0, 0);
|
||||
dim3 thread_per_block = dim3(0, 0, 0);
|
||||
dim3 block_count = dim3(0, 0, 0);
|
||||
};
|
||||
|
||||
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
|
||||
const GPUDevice& d) {
|
||||
Cuda2DLaunchConfig config;
|
||||
|
||||
if (xdim <= 0 || ydim <= 0) {
|
||||
return config;
|
||||
}
|
||||
|
||||
const int kThreadsPerBlock = 256;
|
||||
int block_cols = std::min(xdim, kThreadsPerBlock);
|
||||
// ok to round down here and just do more loops in the kernel
|
||||
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
|
||||
|
||||
const int physical_thread_count =
|
||||
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
|
||||
|
||||
const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
|
||||
|
||||
config.virtual_thread_count = dim3(xdim, ydim, 1);
|
||||
config.thread_per_block = dim3(block_cols, block_rows, 1);
|
||||
|
||||
int grid_x = std::min(DIV_UP(xdim, block_cols), max_blocks);
|
||||
|
||||
config.block_count = dim3(
|
||||
grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
|
||||
return config;
|
||||
}
|
||||
|
||||
// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
|
||||
// This variant takes the resource limits of func into account to maximize
|
||||
// occupancy.
|
||||
using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
|
||||
|
||||
template <typename DeviceFunc>
|
||||
inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
|
||||
int xdim, int ydim, int zdim, const GPUDevice& d, DeviceFunc func,
|
||||
size_t dynamic_shared_memory_size, int block_size_limit) {
|
||||
Cuda3DLaunchConfig config;
|
||||
|
||||
if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
|
||||
return config;
|
||||
}
|
||||
|
||||
int dev;
|
||||
cudaGetDevice(&dev);
|
||||
cudaDeviceProp deviceProp;
|
||||
cudaGetDeviceProperties(&deviceProp, dev);
|
||||
int xthreadlimit = deviceProp.maxThreadsDim[0];
|
||||
int ythreadlimit = deviceProp.maxThreadsDim[1];
|
||||
int zthreadlimit = deviceProp.maxThreadsDim[2];
|
||||
int xgridlimit = deviceProp.maxGridSize[0];
|
||||
int ygridlimit = deviceProp.maxGridSize[1];
|
||||
int zgridlimit = deviceProp.maxGridSize[2];
|
||||
|
||||
int block_count = 0;
|
||||
int thread_per_block = 0;
|
||||
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
|
||||
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
|
||||
block_size_limit);
|
||||
CHECK_EQ(err, cudaSuccess);
|
||||
|
||||
#define MIN3(a, b, c) std::min((a), std::min((b), (c)))
|
||||
int threadsx = MIN3(xdim, thread_per_block, xthreadlimit);
|
||||
int threadsy =
|
||||
MIN3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
|
||||
int threadsz =
|
||||
MIN3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
|
||||
zthreadlimit);
|
||||
|
||||
int blocksx = MIN3(block_count, DIV_UP(xdim, threadsx), xgridlimit);
|
||||
int blocksy =
|
||||
MIN3(DIV_UP(block_count, blocksx), DIV_UP(ydim, threadsy), ygridlimit);
|
||||
int blocksz = MIN3(DIV_UP(block_count, (blocksx * blocksy)),
|
||||
DIV_UP(zdim, threadsz), zgridlimit);
|
||||
#undef MIN3
|
||||
|
||||
config.virtual_thread_count = dim3(xdim, ydim, zdim);
|
||||
config.thread_per_block = dim3(threadsx, threadsy, threadsz);
|
||||
config.block_count = dim3(blocksx, blocksy, blocksz);
|
||||
return config;
|
||||
}
|
||||
|
||||
template <typename DeviceFunc>
|
||||
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
|
||||
int xdim, int ydim, const GPUDevice& d, DeviceFunc func,
|
||||
size_t dynamic_shared_memory_size, int block_size_limit) {
|
||||
return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
|
||||
dynamic_shared_memory_size, block_size_limit);
|
||||
}
|
||||
|
||||
// Returns a raw reference to the current cuda stream. Required by a
|
||||
// number of kernel calls (for which StreamInterface* does not work), i.e.
|
||||
// CUB and certain cublas primitives.
|
||||
inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
|
||||
const cudaStream_t* ptr = CHECK_NOTNULL(
|
||||
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
|
||||
->stream()
|
||||
->implementation()
|
||||
->CudaStreamMemberHack()));
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
namespace cuda_helper {
|
||||
|
||||
template <typename IntType>
|
||||
__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
|
||||
IntType* orig = first;
|
||||
@ -156,8 +330,481 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
|
||||
|
||||
return first - orig;
|
||||
}
|
||||
|
||||
} // namespace cuda_helper
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ inline T ldg(const T* address) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
|
||||
return __ldg(address);
|
||||
#else
|
||||
return *address;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __host__ inline std::complex<float> ldg(
|
||||
const std::complex<float>* address) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
|
||||
float2 mem = __ldg(reinterpret_cast<const float2*>(address));
|
||||
return std::complex<float>(mem.x, mem.y);
|
||||
#else
|
||||
return *address;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __host__ inline std::complex<double> ldg(
|
||||
const std::complex<double>* address) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
|
||||
double2 mem = __ldg(reinterpret_cast<const double2*>(address));
|
||||
return std::complex<double>(mem.x, mem.y);
|
||||
#else
|
||||
return *address;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __host__ inline Eigen::half ldg(const Eigen::half* address) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
|
||||
return Eigen::half_impl::raw_uint16_to_half(
|
||||
__ldg(reinterpret_cast<const uint16_t*>(address)));
|
||||
#else
|
||||
return *address;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __host__ inline bool ldg(const bool* address) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
|
||||
return *reinterpret_cast<const bool*>(
|
||||
__ldg(reinterpret_cast<const char*>(address)));
|
||||
#else
|
||||
return *address;
|
||||
#endif
|
||||
}
|
||||
|
||||
// CUDA provides atomic ops, but not for all types. We provide wrappers
|
||||
// for some ops and provide implementation for all reasonable types.
|
||||
#define CUDA_ATOMIC_WRAPPER(op, T) \
|
||||
__device__ __forceinline__ T CudaAtomic##op(T* address, T val)
|
||||
|
||||
#define USE_CUDA_ATOMIC(op, T) \
|
||||
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
|
||||
|
||||
// For atomicAdd.
|
||||
USE_CUDA_ATOMIC(Add, int32);
|
||||
USE_CUDA_ATOMIC(Add, uint32);
|
||||
USE_CUDA_ATOMIC(Add, uint64);
|
||||
USE_CUDA_ATOMIC(Add, float);
|
||||
|
||||
// For atomicMax.
|
||||
USE_CUDA_ATOMIC(Max, int32);
|
||||
USE_CUDA_ATOMIC(Max, uint32);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
|
||||
USE_CUDA_ATOMIC(Max, uint64);
|
||||
#else
|
||||
// The uint64 overload of atomicMax() is only available for __CUDA_ARCH__ >=
|
||||
// 350. If not satisfied, we provide a custom implementation using atomicCAS().
|
||||
CUDA_ATOMIC_WRAPPER(Max, uint64) {
|
||||
uint64* address_as_ull = reinterpret_cast<uint64*>(address);
|
||||
uint64 old = *address_as_ull, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed, max(val, assumed));
|
||||
} while (assumed != old);
|
||||
|
||||
return old;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Custom implementation of atomicAdd for double.
|
||||
// This implementation is copied from CUDA manual.
|
||||
CUDA_ATOMIC_WRAPPER(Add, double) {
|
||||
uint64* address_as_ull = reinterpret_cast<uint64*>(address);
|
||||
uint64 old = *address_as_ull, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,
|
||||
__double_as_longlong(val + __longlong_as_double(assumed)));
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN
|
||||
} while (assumed != old);
|
||||
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
|
||||
// Custom implementation of atomicAdd for std::complex<float>.
|
||||
// This implementation performs to atomic additions on the components.
|
||||
CUDA_ATOMIC_WRAPPER(Add, std::complex<float>) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if __CUDA_ARCH__ >= 350
|
||||
float2* addr_as_float2 = reinterpret_cast<float2*>(address);
|
||||
float2* val_as_float2 = reinterpret_cast<float2*>(&val);
|
||||
CudaAtomicAdd(&(addr_as_float2->x), val_as_float2->x);
|
||||
CudaAtomicAdd(&(addr_as_float2->y), val_as_float2->y);
|
||||
#else
|
||||
static_assert(sizeof(std::complex<float>) == 2 * sizeof(float),
|
||||
"Unable to compile CudaAtomicAdd for complex64 because "
|
||||
"sizeof(complex64) != 2*sizeof(float32)");
|
||||
float* addr_as_float = reinterpret_cast<float*>(address);
|
||||
float* val_as_float = reinterpret_cast<float*>(&val);
|
||||
CudaAtomicAdd(addr_as_float, *val_as_float);
|
||||
CudaAtomicAdd(addr_as_float + 1, *(val_as_float + 1));
|
||||
#endif
|
||||
#endif
|
||||
return *address;
|
||||
}
|
||||
|
||||
// Custom implementation of atomicAdd for std::complex<double>.
|
||||
// This implementation performs to atomic additions on the components
|
||||
// using the double atomic wrapper above.
|
||||
CUDA_ATOMIC_WRAPPER(Add, complex128) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if __CUDA_ARCH__ >= 350
|
||||
double2* addr_as_double2 = reinterpret_cast<double2*>(address);
|
||||
double2* val_as_double2 = reinterpret_cast<double2*>(&val);
|
||||
CudaAtomicAdd(&(addr_as_double2->x), val_as_double2->x);
|
||||
CudaAtomicAdd(&(addr_as_double2->y), val_as_double2->y);
|
||||
#else
|
||||
static_assert(sizeof(std::complex<double>) == 2 * sizeof(double),
|
||||
"Unable to compile CudaAtomicAdd for complex128 because "
|
||||
"sizeof(complex128) != 2*sizeof(float64)");
|
||||
double* addr_as_double = reinterpret_cast<double*>(address);
|
||||
double* val_as_double = reinterpret_cast<double*>(&val);
|
||||
CudaAtomicAdd(addr_as_double, *val_as_double);
|
||||
CudaAtomicAdd(addr_as_double + 1, *(val_as_double + 1));
|
||||
#endif
|
||||
#endif
|
||||
return *address;
|
||||
}
|
||||
|
||||
// Helper functions for CudaAtomicAdd(half*, half), below.
|
||||
//
|
||||
// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2()
|
||||
// for a more efficient implementation, assuming that adding -0.0
|
||||
// will never harm the neighboring value. In this version, we take special
|
||||
// care to guarantee the bits of the untouched value are unchanged.
|
||||
inline __device__ uint32 add_to_low_half(uint32 val, float x) {
|
||||
Eigen::half low_half;
|
||||
low_half.x = static_cast<uint16>(val & 0xffffu);
|
||||
low_half = static_cast<Eigen::half>(static_cast<float>(low_half) + x);
|
||||
return (val & 0xffff0000u) | low_half.x;
|
||||
}
|
||||
|
||||
inline __device__ uint32 add_to_high_half(uint32 val, float x) {
|
||||
Eigen::half high_half;
|
||||
high_half.x = static_cast<uint16>(val >> 16);
|
||||
high_half = static_cast<Eigen::half>(static_cast<float>(high_half) + x);
|
||||
return (val & 0xffffu) | (high_half.x << 16);
|
||||
}
|
||||
|
||||
// Custom implementation of atomicAdd for half. Note that we don't have
|
||||
// atomicCAS() for anything less than 32 bits, so we need to include the
|
||||
// other 16 bits in the operation.
|
||||
//
|
||||
// Unlike the other atomic adds, this version is going to be very slow
|
||||
// under high concurrency, since most threads will be spinning on failing
|
||||
// their compare-and-swap tests. (The fact that we get false sharing on the
|
||||
// neighboring fp16 makes this even worse.) If you are doing a large reduction,
|
||||
// you are much better off with doing the intermediate steps in fp32 and then
|
||||
// switching to fp16 as late as you can in the calculations.
|
||||
//
|
||||
// Note: Assumes little endian.
|
||||
CUDA_ATOMIC_WRAPPER(Add, Eigen::half) {
|
||||
float val_as_float(val);
|
||||
intptr_t address_int = reinterpret_cast<intptr_t>(address);
|
||||
if ((address_int & 0x2) == 0) {
|
||||
// The half is in the first part of the uint32 (lower 16 bits).
|
||||
uint32* address_as_uint32 = reinterpret_cast<uint32*>(address);
|
||||
assert(((intptr_t)address_as_uint32 & 0x3) == 0);
|
||||
uint32 old = *address_as_uint32, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_uint32, assumed,
|
||||
add_to_low_half(assumed, val_as_float));
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN
|
||||
} while (assumed != old);
|
||||
|
||||
Eigen::half ret;
|
||||
ret.x = old & 0xffffu;
|
||||
return ret;
|
||||
} else {
|
||||
// The half is in the second part of the uint32 (upper 16 bits).
|
||||
uint32* address_as_uint32 = reinterpret_cast<uint32*>(address_int - 2);
|
||||
assert(((intptr_t)address_as_uint32 & 0x3) == 0);
|
||||
uint32 old = *address_as_uint32, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_uint32, assumed,
|
||||
add_to_high_half(assumed, val_as_float));
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN
|
||||
} while (assumed != old);
|
||||
|
||||
Eigen::half ret;
|
||||
ret.x = old >> 16;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SetZero(const int nthreads, T* bottom_diff) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); }
|
||||
}
|
||||
|
||||
// For atomicSub.
|
||||
|
||||
// Custom implementation for sub by just negating the value.
|
||||
#define WRAPPED_ATOMIC_SUB(T) \
|
||||
CUDA_ATOMIC_WRAPPER(Sub, T) { return CudaAtomicAdd(address, -val); }
|
||||
|
||||
WRAPPED_ATOMIC_SUB(uint64);
|
||||
WRAPPED_ATOMIC_SUB(int32);
|
||||
WRAPPED_ATOMIC_SUB(uint32);
|
||||
WRAPPED_ATOMIC_SUB(Eigen::half);
|
||||
WRAPPED_ATOMIC_SUB(float);
|
||||
WRAPPED_ATOMIC_SUB(double);
|
||||
|
||||
CUDA_ATOMIC_WRAPPER(Sub, complex64) {
|
||||
const std::complex<float> Tneg(-val.real(), -val.imag());
|
||||
return CudaAtomicAdd(address, Tneg);
|
||||
}
|
||||
|
||||
CUDA_ATOMIC_WRAPPER(Sub, complex128) {
|
||||
const std::complex<double> Tneg(-val.real(), -val.imag());
|
||||
return CudaAtomicAdd(address, Tneg);
|
||||
}
|
||||
|
||||
#undef WRAPPED_ATOMIC_SUB
|
||||
|
||||
// For atomicMul.
|
||||
CUDA_ATOMIC_WRAPPER(Mul, int32) {
|
||||
int32 old = *address, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address, assumed, val * assumed);
|
||||
} while (assumed != old);
|
||||
return old;
|
||||
}
|
||||
|
||||
CUDA_ATOMIC_WRAPPER(Mul, uint32) {
|
||||
uint32 old = *address, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address, assumed, val * assumed);
|
||||
} while (assumed != old);
|
||||
return old;
|
||||
}
|
||||
|
||||
CUDA_ATOMIC_WRAPPER(Mul, uint64) {
|
||||
uint64 old = *address, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address, assumed, val * assumed);
|
||||
} while (assumed != old);
|
||||
return old;
|
||||
}
|
||||
|
||||
CUDA_ATOMIC_WRAPPER(Mul, float) {
|
||||
int32* address_as_int = reinterpret_cast<int32*>(address);
|
||||
int32 old = *address_as_int, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_int, assumed,
|
||||
__float_as_int(val * __int_as_float(assumed)));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
|
||||
CUDA_ATOMIC_WRAPPER(Mul, double) {
|
||||
uint64* address_as_ull = reinterpret_cast<uint64*>(address);
|
||||
uint64 old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,
|
||||
__double_as_longlong(val * __longlong_as_double(assumed)));
|
||||
} while (assumed != old);
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
|
||||
// For atomicDiv.
|
||||
CUDA_ATOMIC_WRAPPER(Div, int32) {
|
||||
int32 old = *address, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address, assumed, assumed / val);
|
||||
} while (assumed != old);
|
||||
return old;
|
||||
}
|
||||
|
||||
CUDA_ATOMIC_WRAPPER(Div, uint32) {
|
||||
uint32 old = *address, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address, assumed, assumed / val);
|
||||
} while (assumed != old);
|
||||
return old;
|
||||
}
|
||||
|
||||
CUDA_ATOMIC_WRAPPER(Div, uint64) {
|
||||
uint64 old = *address, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address, assumed, assumed / val);
|
||||
} while (assumed != old);
|
||||
return old;
|
||||
}
|
||||
|
||||
CUDA_ATOMIC_WRAPPER(Div, float) {
|
||||
int32* address_as_int = reinterpret_cast<int32*>(address);
|
||||
int32 old = *address_as_int, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_int, assumed,
|
||||
__float_as_int(__int_as_float(assumed) / val));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
|
||||
CUDA_ATOMIC_WRAPPER(Div, double) {
|
||||
uint64* address_as_ull = reinterpret_cast<uint64*>(address);
|
||||
uint64 old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,
|
||||
__double_as_longlong(__longlong_as_double(assumed) / val));
|
||||
} while (assumed != old);
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
|
||||
#undef USE_CUDA_ATOMIC
|
||||
#undef CUDA_ATOMIC_WRAPPER
|
||||
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_min(const T& x, const T& y) {
|
||||
return x > y ? y : x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) {
|
||||
return x < y ? y : x;
|
||||
}
|
||||
|
||||
__device__ EIGEN_ALWAYS_INLINE unsigned CudaBallot(unsigned mask,
|
||||
int predicate) {
|
||||
return __ballot_sync(mask, predicate);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(unsigned mask, T value,
|
||||
int srcLane,
|
||||
int width = warpSize) {
|
||||
return __shfl_sync(mask, value, srcLane, width);
|
||||
}
|
||||
|
||||
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
|
||||
// instead of float for lo and hi (which is incorrect with ftz, for example).
|
||||
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
|
||||
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
|
||||
__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(unsigned mask, double value,
|
||||
int srcLane,
|
||||
int width = warpSize) {
|
||||
unsigned lo, hi;
|
||||
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
|
||||
hi = __shfl_sync(mask, hi, srcLane, width);
|
||||
lo = __shfl_sync(mask, lo, srcLane, width);
|
||||
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
|
||||
return value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(unsigned mask, T value,
|
||||
int delta,
|
||||
int width = warpSize) {
|
||||
return __shfl_up_sync(mask, value, delta, width);
|
||||
}
|
||||
|
||||
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
|
||||
// instead of float for lo and hi (which is incorrect with ftz, for example).
|
||||
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
|
||||
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
|
||||
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(unsigned mask, double value,
|
||||
int delta,
|
||||
int width = warpSize) {
|
||||
unsigned lo, hi;
|
||||
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
|
||||
hi = __shfl_up_sync(mask, hi, delta, width);
|
||||
lo = __shfl_up_sync(mask, lo, delta, width);
|
||||
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
|
||||
return value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask, T value,
|
||||
int delta,
|
||||
int width = warpSize) {
|
||||
return __shfl_down_sync(mask, value, delta, width);
|
||||
}
|
||||
|
||||
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDown(
|
||||
unsigned mask, Eigen::half value, int delta, int width = warpSize) {
|
||||
return Eigen::half(
|
||||
__shfl_down_sync(mask, static_cast<uint16>(value), delta, width));
|
||||
}
|
||||
|
||||
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
|
||||
// instead of float for lo and hi (which is incorrect with ftz, for example).
|
||||
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
|
||||
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
|
||||
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(unsigned mask,
|
||||
double value, int delta,
|
||||
int width = warpSize) {
|
||||
unsigned lo, hi;
|
||||
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
|
||||
hi = __shfl_down_sync(mask, hi, delta, width);
|
||||
lo = __shfl_down_sync(mask, lo, delta, width);
|
||||
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
|
||||
return value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask, T value,
|
||||
int laneMask,
|
||||
int width = warpSize) {
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
}
|
||||
|
||||
__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXor(
|
||||
unsigned mask, Eigen::half value, int laneMask, int width = warpSize) {
|
||||
return Eigen::half(
|
||||
__shfl_xor_sync(mask, static_cast<uint16>(value), laneMask, width));
|
||||
}
|
||||
|
||||
// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
|
||||
// instead of float for lo and hi (which is incorrect with ftz, for example).
|
||||
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
|
||||
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
|
||||
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(unsigned mask,
|
||||
double value, int laneMask,
|
||||
int width = warpSize) {
|
||||
unsigned lo, hi;
|
||||
asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
|
||||
hi = __shfl_xor_sync(mask, hi, laneMask, width);
|
||||
lo = __shfl_xor_sync(mask, lo, laneMask, width);
|
||||
asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
|
||||
return value;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#undef DIV_UP
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
|
||||
|
@ -52,11 +52,11 @@ __global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) {
|
||||
}
|
||||
}
|
||||
__global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
|
||||
if (x < 0) { // x might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
|
||||
if (y < 0) { // y might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
@ -66,15 +66,15 @@ __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
|
||||
}
|
||||
}
|
||||
__global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
|
||||
if (x < 0) { // x might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
|
||||
if (y < 0) { // y might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) {
|
||||
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
|
||||
if (z < 0) { // z might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
@ -94,7 +94,7 @@ class CudaLaunchConfigTest : public ::testing::Test {
|
||||
const int bufsize = 1024;
|
||||
int* outbuf = nullptr;
|
||||
Eigen::CudaStreamDevice stream;
|
||||
Eigen::GpuDevice d = Eigen::GpuDevice(&stream);
|
||||
GPUDevice d = GPUDevice(&stream);
|
||||
|
||||
virtual void SetUp() {
|
||||
cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize);
|
||||
|
@ -1,284 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
|
||||
#define TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "cuda/include/cuda.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
|
||||
// GetCuda3DLaunchConfig:
|
||||
//
|
||||
// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
|
||||
// version uses heuristics without any knowledge of the device kernel, the other
|
||||
// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
|
||||
// launch parameters that maximize occupancy. Currently, only the maximum
|
||||
// occupancy version of GetCuda3DLaunchConfig is available.
|
||||
//
|
||||
// For large number of work elements, the convention is that each kernel would
|
||||
// iterate through its assigned range. The return value of GetCudaLaunchConfig
|
||||
// is struct CudaLaunchConfig, which contains all the information needed for the
|
||||
// kernel launch, including: virtual number of threads, the number of threads
|
||||
// per block and number of threads per block used inside <<< >>> of a kernel
|
||||
// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
|
||||
// as CudaLaunchConfig. The only difference is the dimension. The macros
|
||||
// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
|
||||
//
|
||||
/* Sample code:
|
||||
|
||||
__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
|
||||
CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
|
||||
do_your_job_here;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
|
||||
do_your_job_here;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
|
||||
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
|
||||
do_your_job_here;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MyDriverFunc(const Eigen::GpuDevice &d) {
|
||||
// use heuristics
|
||||
CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
|
||||
MyKernel1D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
|
||||
Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
|
||||
MyKernel2D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
|
||||
Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
|
||||
MyKernel3D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
|
||||
|
||||
// maximize occupancy
|
||||
CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
|
||||
MyKernel1D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
|
||||
Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
|
||||
MyKernel1D, 0, 0);
|
||||
MyKernel2D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
|
||||
Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
|
||||
MyKernel1D, 0, 0);
|
||||
MyKernel3D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
|
||||
}
|
||||
|
||||
// See the test for this for more example:
|
||||
//
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
|
||||
|
||||
*/
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
inline int DivUp(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
struct CudaLaunchConfig {
|
||||
// Logical number of thread that works on the elements. If each logical
|
||||
// thread works on exactly a single element, this is the same as the working
|
||||
// element count.
|
||||
int virtual_thread_count = -1;
|
||||
// Number of threads per block.
|
||||
int thread_per_block = -1;
|
||||
// Number of blocks for Cuda kernel launch.
|
||||
int block_count = -1;
|
||||
};
|
||||
|
||||
// Calculate the Cuda launch config we should use for a kernel launch.
|
||||
// This is assuming the kernel is quite simple and will largely be
|
||||
// memory-limited.
|
||||
// REQUIRES: work_element_count > 0.
|
||||
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
||||
const Eigen::GpuDevice& d) {
|
||||
CHECK_GT(work_element_count, 0);
|
||||
CudaLaunchConfig config;
|
||||
const int virtual_thread_count = work_element_count;
|
||||
const int physical_thread_count = std::min(
|
||||
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
|
||||
virtual_thread_count);
|
||||
const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
|
||||
const int block_count =
|
||||
std::min(DivUp(physical_thread_count, thread_per_block),
|
||||
d.getNumCudaMultiProcessors());
|
||||
|
||||
config.virtual_thread_count = virtual_thread_count;
|
||||
config.thread_per_block = thread_per_block;
|
||||
config.block_count = block_count;
|
||||
return config;
|
||||
}
|
||||
|
||||
// Calculate the Cuda launch config we should use for a kernel launch. This
|
||||
// variant takes the resource limits of func into account to maximize occupancy.
|
||||
// REQUIRES: work_element_count > 0.
|
||||
template <typename DeviceFunc>
|
||||
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
||||
const Eigen::GpuDevice& d,
|
||||
DeviceFunc func,
|
||||
size_t dynamic_shared_memory_size,
|
||||
int block_size_limit) {
|
||||
CHECK_GT(work_element_count, 0);
|
||||
CudaLaunchConfig config;
|
||||
int block_count = 0;
|
||||
int thread_per_block = 0;
|
||||
|
||||
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
|
||||
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
|
||||
block_size_limit);
|
||||
CHECK_EQ(err, cudaSuccess);
|
||||
|
||||
block_count =
|
||||
std::min(block_count, DivUp(work_element_count, thread_per_block));
|
||||
|
||||
config.virtual_thread_count = work_element_count;
|
||||
config.thread_per_block = thread_per_block;
|
||||
config.block_count = block_count;
|
||||
return config;
|
||||
}
|
||||
|
||||
struct Cuda2DLaunchConfig {
|
||||
dim3 virtual_thread_count = dim3(0, 0, 0);
|
||||
dim3 thread_per_block = dim3(0, 0, 0);
|
||||
dim3 block_count = dim3(0, 0, 0);
|
||||
};
|
||||
|
||||
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
|
||||
const Eigen::GpuDevice& d) {
|
||||
Cuda2DLaunchConfig config;
|
||||
|
||||
if (xdim <= 0 || ydim <= 0) {
|
||||
return config;
|
||||
}
|
||||
|
||||
const int kThreadsPerBlock = 256;
|
||||
int block_cols = std::min(xdim, kThreadsPerBlock);
|
||||
// ok to round down here and just do more loops in the kernel
|
||||
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
|
||||
|
||||
const int physical_thread_count =
|
||||
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
|
||||
|
||||
const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
|
||||
|
||||
config.virtual_thread_count = dim3(xdim, ydim, 1);
|
||||
config.thread_per_block = dim3(block_cols, block_rows, 1);
|
||||
|
||||
int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
|
||||
|
||||
config.block_count = dim3(
|
||||
grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
|
||||
return config;
|
||||
}
|
||||
|
||||
// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
|
||||
// This variant takes the resource limits of func into account to maximize
|
||||
// occupancy.
|
||||
using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
|
||||
|
||||
template <typename DeviceFunc>
|
||||
inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
|
||||
int xdim, int ydim, int zdim, const Eigen::GpuDevice& d, DeviceFunc func,
|
||||
size_t dynamic_shared_memory_size, int block_size_limit) {
|
||||
Cuda3DLaunchConfig config;
|
||||
|
||||
if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
|
||||
return config;
|
||||
}
|
||||
|
||||
int dev;
|
||||
cudaGetDevice(&dev);
|
||||
cudaDeviceProp deviceProp;
|
||||
cudaGetDeviceProperties(&deviceProp, dev);
|
||||
int xthreadlimit = deviceProp.maxThreadsDim[0];
|
||||
int ythreadlimit = deviceProp.maxThreadsDim[1];
|
||||
int zthreadlimit = deviceProp.maxThreadsDim[2];
|
||||
int xgridlimit = deviceProp.maxGridSize[0];
|
||||
int ygridlimit = deviceProp.maxGridSize[1];
|
||||
int zgridlimit = deviceProp.maxGridSize[2];
|
||||
|
||||
int block_count = 0;
|
||||
int thread_per_block = 0;
|
||||
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
|
||||
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
|
||||
block_size_limit);
|
||||
CHECK_EQ(err, cudaSuccess);
|
||||
|
||||
auto min3 = [](int a, int b, int c) { return std::min(a, std::min(b, c)); };
|
||||
|
||||
int threadsx = min3(xdim, thread_per_block, xthreadlimit);
|
||||
int threadsy =
|
||||
min3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
|
||||
int threadsz =
|
||||
min3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
|
||||
zthreadlimit);
|
||||
|
||||
int blocksx = min3(block_count, DivUp(xdim, threadsx), xgridlimit);
|
||||
int blocksy =
|
||||
min3(DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit);
|
||||
int blocksz = min3(DivUp(block_count, (blocksx * blocksy)),
|
||||
DivUp(zdim, threadsz), zgridlimit);
|
||||
|
||||
config.virtual_thread_count = dim3(xdim, ydim, zdim);
|
||||
config.thread_per_block = dim3(threadsx, threadsy, threadsz);
|
||||
config.block_count = dim3(blocksx, blocksy, blocksz);
|
||||
return config;
|
||||
}
|
||||
|
||||
template <typename DeviceFunc>
|
||||
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
|
||||
int xdim, int ydim, const Eigen::GpuDevice& d, DeviceFunc func,
|
||||
size_t dynamic_shared_memory_size, int block_size_limit) {
|
||||
return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
|
||||
dynamic_shared_memory_size, block_size_limit);
|
||||
}
|
||||
|
||||
// Returns a raw reference to the current cuda stream. Required by a
|
||||
// number of kernel calls (for which StreamInterface* does not work), i.e.
|
||||
// CUB and certain cublas primitives.
|
||||
inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
|
||||
const cudaStream_t* ptr = CHECK_NOTNULL(
|
||||
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
|
||||
->stream()
|
||||
->implementation()
|
||||
->CudaStreamMemberHack()));
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
|
Loading…
x
Reference in New Issue
Block a user