[ROCm] Adding support to depthwise_conv_op
This commit is contained in:
parent
4bb0f6e87f
commit
65985751a9
@ -38,10 +38,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -246,7 +250,7 @@ extern template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
|
||||
extern template struct LaunchConv2DOp<CPUDevice, float>;
|
||||
extern template struct LaunchConv2DOp<CPUDevice, double>;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
// Extern template instantiated in conv_ops.cc.
|
||||
extern template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
|
||||
@ -461,7 +465,7 @@ TF_CALL_float(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_double(REGISTER_CPU_KERNEL);
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -494,6 +498,6 @@ TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
|
||||
TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
|
||||
TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
|
||||
#endif // CUDNN_VERSION
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -80,7 +80,7 @@ struct LaunchDepthwiseConvBackpropFilterOp {
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
template <typename T>
|
||||
struct LaunchDepthwiseConvOp<Eigen::GpuDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
|
@ -16,11 +16,10 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_GPU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_GPU_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "third_party/cub/util_ptx.cuh"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/kernels/depthwise_conv_op.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -79,7 +78,7 @@ inline EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(
|
||||
// convolution depending on a template argument of this enum.
|
||||
enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD };
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution forward pass
|
||||
// A Gpu kernel to compute the depthwise convolution forward pass
|
||||
// in NHWC format.
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
@ -103,7 +102,7 @@ __global__ void __launch_bounds__(1024, 2)
|
||||
const int out_width = args.out_cols;
|
||||
const int out_depth = args.out_depth;
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(thread_id, num_outputs) {
|
||||
GPU_1D_KERNEL_LOOP(thread_id, num_outputs) {
|
||||
// Compute the indexes of this thread in the output.
|
||||
const int out_channel = thread_id % out_depth;
|
||||
const int out_col = (thread_id / out_depth) % out_width;
|
||||
@ -192,8 +191,10 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
|
||||
typedef typename detail::PseudoHalfType<T>::Type S;
|
||||
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
|
||||
// Holds block plus halo and filter data for blockDim.x depths.
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(8, unsigned char, shared_memory);
|
||||
static_assert(sizeof(S) <= 8, "Insufficient alignment detected");
|
||||
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(8, unsigned char, shared_memory);
|
||||
|
||||
S* const shared_data = reinterpret_cast<S*>(shared_memory);
|
||||
|
||||
const int num_batches = args.batch;
|
||||
@ -323,7 +324,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
|
||||
}
|
||||
}
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution forward pass
|
||||
// A Gpu kernel to compute the depthwise convolution forward pass
|
||||
// in NCHW format.
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
@ -347,7 +348,7 @@ __global__ void __launch_bounds__(1024, 2)
|
||||
const int out_width = args.out_cols;
|
||||
const int out_depth = args.out_depth;
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(thread_id, num_outputs) {
|
||||
GPU_1D_KERNEL_LOOP(thread_id, num_outputs) {
|
||||
// Compute the indexes of this thread in the output.
|
||||
//
|
||||
// We want coalesced reads so we make sure that each warp reads
|
||||
@ -480,8 +481,10 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
|
||||
typedef typename detail::PseudoHalfType<T>::Type S;
|
||||
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
|
||||
// Holds block plus halo and filter data for blockDim.z depths.
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(8, unsigned char, shared_memory);
|
||||
static_assert(sizeof(S) <= 8, "Insufficient alignment detected");
|
||||
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(8, unsigned char, shared_memory);
|
||||
|
||||
S* const shared_data = reinterpret_cast<S*>(shared_memory);
|
||||
|
||||
const int num_batches = args.batch;
|
||||
@ -779,7 +782,7 @@ Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
}
|
||||
}
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
// A simple launch pad to launch the Gpu kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
void LaunchDepthwiseConvOp<GpuDevice, T>::operator()(OpKernelContext* ctx,
|
||||
const DepthwiseArgs& args,
|
||||
@ -795,7 +798,7 @@ void LaunchDepthwiseConvOp<GpuDevice, T>::operator()(OpKernelContext* ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. input.
|
||||
// A GPU kernel to compute the depthwise convolution backprop w.r.t. input.
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
__global__ void __launch_bounds__(640, 2)
|
||||
@ -819,7 +822,7 @@ __global__ void __launch_bounds__(640, 2)
|
||||
const int out_width = args.out_cols;
|
||||
const int out_depth = args.out_depth;
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(thread_id, num_in_backprop) {
|
||||
GPU_1D_KERNEL_LOOP(thread_id, num_in_backprop) {
|
||||
// Compute the indexes of this thread in the output.
|
||||
const int in_channel = thread_id % in_depth;
|
||||
const int in_col = (thread_id / in_depth) % in_width;
|
||||
@ -891,7 +894,7 @@ __global__ void __launch_bounds__(640, 2)
|
||||
|
||||
// TODO(vrv): Consider assigning threads to output and using
|
||||
// atomics for accumulation, similar to the filter case.
|
||||
CUDA_1D_KERNEL_LOOP(thread_id, num_in_backprop) {
|
||||
GPU_1D_KERNEL_LOOP(thread_id, num_in_backprop) {
|
||||
// Compute the indexes of this thread in the input.
|
||||
const int in_col = thread_id % in_width;
|
||||
const int in_row = (thread_id / in_width) % in_height;
|
||||
@ -998,7 +1001,7 @@ Status LaunchDepthwiseConv2dBackpropInputGPU(OpKernelContext* ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
// A simple launch pad to launch the Gpu kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
void LaunchDepthwiseConvBackpropInputOp<GpuDevice, T>::operator()(
|
||||
OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
|
||||
@ -1014,7 +1017,7 @@ void LaunchDepthwiseConvBackpropInputOp<GpuDevice, T>::operator()(
|
||||
}
|
||||
}
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
|
||||
// A GPU kernel to compute the depthwise convolution backprop w.r.t. filter.
|
||||
// TODO: Add fp32 accumulation to half calls of this function. This addition
|
||||
// is non-trivial as the partial sums are added directly to the output
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
@ -1041,7 +1044,7 @@ __global__ void __launch_bounds__(640, 2)
|
||||
const int out_width = args.out_cols;
|
||||
const int out_depth = args.out_depth;
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(thread_id, num_out_backprop) {
|
||||
GPU_1D_KERNEL_LOOP(thread_id, num_out_backprop) {
|
||||
// Compute the indexes of this thread in the output.
|
||||
const int out_channel = thread_id % out_depth;
|
||||
const int out_col = (thread_id / out_depth) % out_width;
|
||||
@ -1081,7 +1084,7 @@ __global__ void __launch_bounds__(640, 2)
|
||||
(dm + depth_multiplier *
|
||||
(in_channel +
|
||||
in_depth * (filter_col + filter_width * filter_row)));
|
||||
CudaAtomicAdd(addr, partial_sum);
|
||||
GpuAtomicAdd(addr, partial_sum);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -1112,7 +1115,7 @@ __global__ void __launch_bounds__(640, 2)
|
||||
// contention on the destination; 2. Have each thread compute one
|
||||
// gradient for an element in the filters. This should work well
|
||||
// when the input depth is big and filter size is not too small.
|
||||
CudaAtomicAdd(addr, partial_sum);
|
||||
GpuAtomicAdd(addr, partial_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1123,14 +1126,18 @@ __global__ void __launch_bounds__(640, 2)
|
||||
// Device function to compute sub-warp sum reduction for a power-of-two group of
|
||||
// neighboring threads.
|
||||
template <int kWidth, typename T>
|
||||
#if GOOGLE_CUDA
|
||||
__device__ __forceinline__ T WarpSumReduce(T val) {
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
__device__ inline T WarpSumReduce(T val) {
|
||||
#endif
|
||||
// support only power-of-two widths.
|
||||
assert(__popc(kWidth) == 1);
|
||||
int sub_warp = cub::LaneId() / kWidth;
|
||||
int sub_warp = GpuLaneId() / kWidth;
|
||||
int zeros = sub_warp * kWidth;
|
||||
unsigned mask = ((1UL << kWidth) - 1) << zeros;
|
||||
for (int delta = kWidth / 2; delta > 0; delta /= 2) {
|
||||
val += CudaShuffleXorSync(mask, val, delta);
|
||||
val += GpuShuffleXorSync(mask, val, delta);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
@ -1158,8 +1165,10 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
|
||||
typedef typename detail::PseudoHalfType<T>::Type S;
|
||||
assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z));
|
||||
// Holds block plus halo and filter data for blockDim.x depths.
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(8, unsigned char, shared_memory);
|
||||
static_assert(sizeof(S) <= 8, "Insufficient alignment detected");
|
||||
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(8, unsigned char, shared_memory);
|
||||
|
||||
S* const shared_data = reinterpret_cast<S*>(shared_memory);
|
||||
|
||||
const int num_batches = args.batch;
|
||||
@ -1253,7 +1262,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
|
||||
|
||||
// Note: the condition to reach this is uniform across the entire block.
|
||||
__syncthreads();
|
||||
unsigned active_threads = CudaBallotSync(kCudaWarpAll, channel_in_range);
|
||||
unsigned active_threads = GpuBallotSync(kCudaWarpAll, channel_in_range);
|
||||
|
||||
if (channel_in_range) {
|
||||
const T* const out_ptr = inout_offset + output;
|
||||
@ -1268,7 +1277,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
|
||||
S val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
|
||||
// Warp-accumulate pixels of the same depth and write to accumulator.
|
||||
for (int delta = 16; delta >= kBlockDepth; delta /= 2) {
|
||||
val += CudaShuffleXorSync(active_threads, val, delta);
|
||||
val += GpuShuffleXorSync(active_threads, val, delta);
|
||||
}
|
||||
if (!(thread_idx & 32 - kBlockDepth) /* lane_idx < kBlockDepth */) {
|
||||
*accum_ptr = val;
|
||||
@ -1294,14 +1303,14 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
|
||||
// Warp-accumulate the pixels of the same depth from the accumulator.
|
||||
val = WarpSumReduce<kAccumPixels>(val);
|
||||
if (!(thread_idx & kAccumPixels - 1)) {
|
||||
CudaAtomicAdd(filter_offset + filter, static_cast<T>(val));
|
||||
GpuAtomicAdd(filter_offset + filter, static_cast<T>(val));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
|
||||
// A Gpu kernel to compute the depthwise convolution backprop w.r.t. filter.
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
int kKnownDepthMultiplier>
|
||||
__global__ void __launch_bounds__(640, 2)
|
||||
@ -1326,7 +1335,7 @@ __global__ void __launch_bounds__(640, 2)
|
||||
const int out_width = args.out_cols;
|
||||
const int out_depth = args.out_depth;
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(thread_id, num_out_backprop) {
|
||||
GPU_1D_KERNEL_LOOP(thread_id, num_out_backprop) {
|
||||
// Compute the indexes of this thread in the output.
|
||||
const int out_col = thread_id % out_width;
|
||||
const int out_row = (thread_id / out_width) % out_height;
|
||||
@ -1370,7 +1379,7 @@ __global__ void __launch_bounds__(640, 2)
|
||||
(dm + depth_multiplier *
|
||||
(in_channel +
|
||||
in_depth * (filter_col + filter_width * filter_row)));
|
||||
CudaAtomicAdd(addr, partial_sum);
|
||||
GpuAtomicAdd(addr, partial_sum);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -1402,7 +1411,7 @@ __global__ void __launch_bounds__(640, 2)
|
||||
// contention on the destination; 2. Have each thread compute one
|
||||
// gradient for an element in the filters. This should work well
|
||||
// when the input depth is big and filter size is not too small.
|
||||
CudaAtomicAdd(addr, partial_sum);
|
||||
GpuAtomicAdd(addr, partial_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1521,7 +1530,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
|
||||
|
||||
// Note: the condition to reach this is uniform across the entire block.
|
||||
__syncthreads();
|
||||
unsigned active_threads = CudaBallotSync(kCudaWarpAll, channel_in_range);
|
||||
unsigned active_threads = GpuBallotSync(kCudaWarpAll, channel_in_range);
|
||||
|
||||
if (channel_in_range) {
|
||||
const T* const out_ptr = inout_offset + output;
|
||||
@ -1536,7 +1545,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
|
||||
S val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
|
||||
// Warp-accumulate pixels of the same depth and write to accumulator.
|
||||
for (int delta = 16 / kBlockDepth; delta > 0; delta /= 2) {
|
||||
val += CudaShuffleXorSync(active_threads, val, delta);
|
||||
val += GpuShuffleXorSync(active_threads, val, delta);
|
||||
}
|
||||
if (!(thread_idx & 32 / kBlockDepth - 1)) {
|
||||
*accum_ptr = val; // kBlockDepth threads per warp.
|
||||
@ -1563,7 +1572,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
|
||||
// Warp-accumulate pixels of the same depth from the accumulator.
|
||||
val = WarpSumReduce<kAccumPixels>(val);
|
||||
if (!(thread_idx & kAccumPixels - 1)) {
|
||||
CudaAtomicAdd(filter_offset + filter, static_cast<T>(val));
|
||||
GpuAtomicAdd(filter_offset + filter, static_cast<T>(val));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1745,7 +1754,7 @@ Status LaunchDepthwiseConv2dBackpropFilterGPU(
|
||||
}
|
||||
}
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
// A simple launch pad to launch the Gpu kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
void LaunchDepthwiseConvBackpropFilterOp<GpuDevice, T>::operator()(
|
||||
OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
|
||||
@ -1769,6 +1778,6 @@ void LaunchDepthwiseConvBackpropFilterOp<GpuDevice, T>::operator()(
|
||||
}
|
||||
}
|
||||
} // namespace tensorflow
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_GPU_H_
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/depthwise_conv_op.h"
|
||||
@ -27,4 +27,4 @@ template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, double>;
|
||||
template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, double>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/depthwise_conv_op.h"
|
||||
@ -27,4 +27,4 @@ template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, float>;
|
||||
template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, float>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/depthwise_conv_op.h"
|
||||
@ -27,4 +27,4 @@ template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, Eigen::half>;
|
||||
template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, Eigen::half>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
Loading…
Reference in New Issue
Block a user