Replacing GetCudaLaunchConfig and CudaLaunchKernel with their Gpu equivalent.
PiperOrigin-RevId: 256377258
This commit is contained in:
parent
4bc0fe3cf4
commit
6f0584298b
@ -81,17 +81,17 @@ void BiasGPU<T>::compute(const GPUDevice& d, const T* input, const T* bias,
|
||||
if (total_count == 0) {
|
||||
return;
|
||||
}
|
||||
GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d);
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d);
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
TF_CHECK_OK(CudaLaunchKernel(BiasNHWCKernel<T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(),
|
||||
config.virtual_thread_count, input, bias,
|
||||
output, bias_size));
|
||||
TF_CHECK_OK(GpuLaunchKernel(BiasNHWCKernel<T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(),
|
||||
config.virtual_thread_count, input, bias,
|
||||
output, bias_size));
|
||||
} else {
|
||||
TF_CHECK_OK(CudaLaunchKernel(BiasNCHWKernel<T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(),
|
||||
config.virtual_thread_count, input, bias,
|
||||
output, bias_size, image_size));
|
||||
TF_CHECK_OK(GpuLaunchKernel(BiasNCHWKernel<T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(),
|
||||
config.virtual_thread_count, input, bias,
|
||||
output, bias_size, image_size));
|
||||
}
|
||||
}
|
||||
|
||||
@ -204,7 +204,7 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
|
||||
return;
|
||||
}
|
||||
static constexpr int32 kWarpSize = 32;
|
||||
GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d);
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d);
|
||||
|
||||
const int max_shared_memory_size = d.sharedMemPerBlock() / 2;
|
||||
int32 shared_memory_size = 0;
|
||||
@ -214,10 +214,10 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
|
||||
// Check if we have enough shared memory.
|
||||
if (shared_memory_size <= max_shared_memory_size) {
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
TF_CHECK_OK(CudaLaunchKernel(BiasGradNHWC_SharedAtomics<T>,
|
||||
config.block_count, config.thread_per_block,
|
||||
shared_memory_size, d.stream(), total_count,
|
||||
output_backprop, bias_backprop, bias_size));
|
||||
TF_CHECK_OK(GpuLaunchKernel(BiasGradNHWC_SharedAtomics<T>,
|
||||
config.block_count, config.thread_per_block,
|
||||
shared_memory_size, d.stream(), total_count,
|
||||
output_backprop, bias_backprop, bias_size));
|
||||
} else {
|
||||
// Round up the block count to multiple of bias_size.
|
||||
int group_size = (config.block_count + bias_size - 1) / bias_size;
|
||||
@ -225,24 +225,24 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
|
||||
if (config.thread_per_block < kWarpSize) {
|
||||
config.thread_per_block = kWarpSize;
|
||||
}
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
BiasGradNCHW_SharedAtomics<T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), output_backprop,
|
||||
bias_backprop, batch, bias_size, image_size, group_size));
|
||||
TF_CHECK_OK(GpuLaunchKernel(BiasGradNCHW_SharedAtomics<T>,
|
||||
config.block_count, config.thread_per_block,
|
||||
0, d.stream(), output_backprop, bias_backprop,
|
||||
batch, bias_size, image_size, group_size));
|
||||
}
|
||||
} else {
|
||||
// Note that even if we don't have enough shared memory to fit the entire
|
||||
// output block, it is possible to process one group of elements at a time.
|
||||
// But for now, we simply fall back to the naive implementation.
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
BiasGradNHWC_Naive<T>, config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), total_count, output_backprop, bias_backprop, bias_size));
|
||||
} else {
|
||||
TF_CHECK_OK(CudaLaunchKernel(BiasGradNCHW_Naive<T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(),
|
||||
total_count, output_backprop, bias_backprop,
|
||||
bias_size, image_size));
|
||||
TF_CHECK_OK(GpuLaunchKernel(BiasGradNCHW_Naive<T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(),
|
||||
total_count, output_backprop, bias_backprop,
|
||||
bias_size, image_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -656,12 +656,12 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
|
||||
kBlockDepth * (tile_pixels + filter_pixels) * sizeof(S);
|
||||
const int num_outputs = args.out_rows * args.out_cols * block_count;
|
||||
auto device = ctx->eigen_gpu_device();
|
||||
GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
|
||||
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
|
||||
num_outputs, device, kernel, shared_memory_size,
|
||||
block_dim.x * block_dim.y * block_dim.z);
|
||||
TF_CHECK_OK(CudaLaunchKernel(kernel, config.block_count, block_dim,
|
||||
shared_memory_size, device.stream(), args, input,
|
||||
filter, output));
|
||||
TF_CHECK_OK(GpuLaunchKernel(kernel, config.block_count, block_dim,
|
||||
shared_memory_size, device.stream(), args, input,
|
||||
filter, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -751,10 +751,10 @@ Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
kKnownDepthMultiplier < 0
|
||||
? std::numeric_limits<int>::max()
|
||||
: device.getNumGpuMultiProcessors();
|
||||
TF_CHECK_OK(CudaLaunchKernel(kernel,
|
||||
std::min(max_block_count, config.block_count),
|
||||
config.thread_per_block, 0, device.stream(),
|
||||
args, input, filter, output, num_outputs));
|
||||
TF_CHECK_OK(GpuLaunchKernel(kernel,
|
||||
std::min(max_block_count, config.block_count),
|
||||
config.thread_per_block, 0, device.stream(), args,
|
||||
input, filter, output, num_outputs));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -969,7 +969,7 @@ Status LaunchDepthwiseConv2dBackpropInputGPU(OpKernelContext* ctx,
|
||||
auto device = ctx->eigen_gpu_device();
|
||||
GpuLaunchConfig config =
|
||||
GetGpuLaunchConfig(num_in_backprop, device, kernel, 0, 0);
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
kernel, config.block_count, config.thread_per_block, 0, device.stream(),
|
||||
args, out_backprop, filter, in_backprop, num_in_backprop));
|
||||
return Status::OK();
|
||||
@ -1611,12 +1611,12 @@ Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
|
||||
" is not supported");
|
||||
}
|
||||
const int num_out_backprop = args.out_rows * args.out_cols * block_count;
|
||||
GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
|
||||
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
|
||||
num_out_backprop, device, kernel, shared_memory_size,
|
||||
block_dim.x * block_dim.y * block_dim.z);
|
||||
TF_CHECK_OK(CudaLaunchKernel(kernel, config.block_count, block_dim,
|
||||
shared_memory_size, device.stream(), args,
|
||||
out_backprop, input, filter_backprop));
|
||||
TF_CHECK_OK(GpuLaunchKernel(kernel, config.block_count, block_dim,
|
||||
shared_memory_size, device.stream(), args,
|
||||
out_backprop, input, filter_backprop));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1717,7 +1717,7 @@ Status LaunchDepthwiseConv2dBackpropFilterGPU(
|
||||
auto device = ctx->eigen_gpu_device();
|
||||
GpuLaunchConfig config =
|
||||
GetGpuLaunchConfig(num_out_backprop, device, kernel, 0, 0);
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
kernel, config.block_count, config.thread_per_block, 0, device.stream(),
|
||||
args, out_backprop, input, filter_backprop, num_out_backprop));
|
||||
return Status::OK();
|
||||
|
@ -128,9 +128,9 @@ struct DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
|
||||
int* info) {
|
||||
const int64 num_matrices = output.size();
|
||||
const int64 n = lu_factor.dimension(2);
|
||||
GpuLaunchConfig config = GetCudaLaunchConfig(num_matrices, device);
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(num_matrices, device);
|
||||
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/false>,
|
||||
config.block_count, config.thread_per_block, 0, device.stream(),
|
||||
config.virtual_thread_count, n, lu_factor.data(), pivots, nullptr,
|
||||
@ -151,8 +151,8 @@ struct LogDeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
|
||||
typename TTypes<Scalar, 1>::Tensor log_abs_det) {
|
||||
const int64 num_matrices = sign.size();
|
||||
const int64 n = lu_factor.dimension(2);
|
||||
GpuLaunchConfig config = GetCudaLaunchConfig(num_matrices, device);
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(num_matrices, device);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/true>,
|
||||
config.block_count, config.thread_per_block, 0, device.stream(),
|
||||
config.virtual_thread_count, n, lu_factor.data(), pivots, sign.data(),
|
||||
|
@ -118,13 +118,13 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
|
||||
const auto& d = context->eigen_device<GPUDevice>();
|
||||
|
||||
// Compute a mask for all predictions.
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(num_targets * num_classes, d);
|
||||
OP_REQUIRES_OK(context, CudaLaunchKernel(
|
||||
ComputePredictionMaskKernel<T, TargetT>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), predictions.data(), targets.data(),
|
||||
predictions_mask.flat<int64>().data(),
|
||||
num_targets, num_classes));
|
||||
CudaLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
|
||||
OP_REQUIRES_OK(
|
||||
context, GpuLaunchKernel(ComputePredictionMaskKernel<T, TargetT>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), predictions.data(), targets.data(),
|
||||
predictions_mask.flat<int64>().data(),
|
||||
num_targets, num_classes));
|
||||
|
||||
// Reduce prediction masks to number of predictions larger than the target
|
||||
// prediction, or to the negative value if we can't compute an answer.
|
||||
|
@ -222,8 +222,8 @@ class LuOpGpu : public AsyncOpKernel {
|
||||
int* pivots_ptr = pivots.flat<int>().data();
|
||||
Tidx* permutation_indices_ptr =
|
||||
permutation_indices->template flat<Tidx>().data();
|
||||
GpuLaunchConfig cfgPivots = GetCudaLaunchConfig(batch_size, device);
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
GpuLaunchConfig cfgPivots = GetGpuLaunchConfig(batch_size, device);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
ComputePermutationFromTranspositionsKernel<Tidx>, cfgPivots.block_count,
|
||||
cfgPivots.thread_per_block, 0, device.stream(), cfgPivots, num_rows,
|
||||
pivots_ptr, permutation_indices_ptr));
|
||||
|
@ -241,7 +241,7 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
||||
DataType::DT_INT32, TensorShape({max_nms_mask_size}), &d_nms_mask));
|
||||
// reset data sensitive tensors
|
||||
auto device = context->eigen_gpu_device();
|
||||
auto config = GetCudaLaunchConfig(d_nms_mask.NumElements(), device);
|
||||
auto config = GetGpuLaunchConfig(d_nms_mask.NumElements(), device);
|
||||
TF_CHECK_OK(GpuLaunchKernel(SetZero<int>, config.block_count,
|
||||
config.thread_per_block, 0, device.stream(),
|
||||
config.virtual_thread_count,
|
||||
@ -425,7 +425,7 @@ class NonMaxSuppressionV2GPUOp : public OpKernel {
|
||||
&d_sorted_boxes));
|
||||
|
||||
// this will return sorted scores and their indices
|
||||
auto config = GetCudaLaunchConfig(num_boxes, device);
|
||||
auto config = GetGpuLaunchConfig(num_boxes, device);
|
||||
// initialize box and score indices
|
||||
TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
|
||||
config.thread_per_block, 0, device.stream(),
|
||||
@ -472,7 +472,7 @@ class NonMaxSuppressionV2GPUOp : public OpKernel {
|
||||
context->allocate_output(0, TensorShape({num_outputs}),
|
||||
&output_indices));
|
||||
if (num_outputs == 0) return;
|
||||
config = GetCudaLaunchConfig(num_outputs, device);
|
||||
config = GetGpuLaunchConfig(num_outputs, device);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
|
||||
0, device.stream(), config.virtual_thread_count,
|
||||
|
@ -104,9 +104,9 @@ struct ReluGrad<Device, Eigen::half> {
|
||||
if (count == 0) return;
|
||||
int32 half2_count = Eigen::divup(count, 2);
|
||||
constexpr int32 kThreadInBlock = 512;
|
||||
GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
|
||||
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
|
||||
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
|
||||
}
|
||||
@ -133,9 +133,9 @@ struct Relu<Device, qint8> {
|
||||
|
||||
int32 vect_count = Eigen::divup(count, 4);
|
||||
constexpr int32 kThreadInBlock = 512;
|
||||
GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
|
||||
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
|
||||
vect_count, d, Relu_int8x4_kernel, 0, kThreadInBlock);
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
Relu_int8x4_kernel, config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), vect_count, reinterpret_cast<const int32*>(input.data()),
|
||||
reinterpret_cast<int32*>(output.data())));
|
||||
|
@ -71,7 +71,7 @@ struct Roll<GPUDevice, T> {
|
||||
d.memcpyHostToDevice(thres_buf, threshold.data(), thres_bytes);
|
||||
d.memcpyHostToDevice(range_buf, dim_range.data(), range_bytes);
|
||||
|
||||
CudaLaunchConfig cfg = GetCudaLaunchConfig(num_elements, d);
|
||||
CudaLaunchConfig cfg = GetGpuLaunchConfig(num_elements, d);
|
||||
|
||||
TF_CHECK_OK(GpuLaunchKernel(RollKernel<T>, cfg.block_count,
|
||||
cfg.thread_per_block, 0, d.stream(),
|
||||
|
@ -70,12 +70,12 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
||||
// maximize occupancy
|
||||
const int kGroupSize = Distribution::kResultElementCount;
|
||||
int work_element_count = (output_size + kGroupSize - 1) / kGroupSize;
|
||||
GpuLaunchConfig cfg = GetCudaLaunchConfig(work_element_count, d,
|
||||
FillKernel<Distribution>, 0, 0);
|
||||
GpuLaunchConfig cfg =
|
||||
GetGpuLaunchConfig(work_element_count, d, FillKernel<Distribution>, 0, 0);
|
||||
|
||||
int zero = 0;
|
||||
cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int));
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
FillKernel<Distribution>, cfg.block_count, cfg.thread_per_block, 0,
|
||||
d.stream(), dist, state_size, output_size, state_data, output_data));
|
||||
}
|
||||
|
@ -196,16 +196,16 @@ class SvdOpGpu : public AsyncOpKernel {
|
||||
const GPUDevice& d = context->eigen_device<GPUDevice>();
|
||||
d.memset(outputV_ptr, 0, batch_size * sizeof(Scalar));
|
||||
Gpu2DLaunchConfig cfg2D = GetCuda2DLaunchConfig(batch_size, m, d);
|
||||
TF_CHECK_OK(CudaLaunchKernel(ComputeValueOfVKernel<Scalar>,
|
||||
cfg2D.block_count, cfg2D.thread_per_block, 0,
|
||||
d.stream(), cfg2D, m, full_matrices_ ? m : p,
|
||||
input_copy.flat<Scalar>().data(),
|
||||
outputU_ptr, outputS_ptr, outputV_ptr));
|
||||
TF_CHECK_OK(GpuLaunchKernel(ComputeValueOfVKernel<Scalar>,
|
||||
cfg2D.block_count, cfg2D.thread_per_block, 0,
|
||||
d.stream(), cfg2D, m, full_matrices_ ? m : p,
|
||||
input_copy.flat<Scalar>().data(), outputU_ptr,
|
||||
outputS_ptr, outputV_ptr));
|
||||
// 2. clamp V to -1 or +1
|
||||
GpuLaunchConfig cfg1D = GetCudaLaunchConfig(batch_size, d);
|
||||
TF_CHECK_OK(CudaLaunchKernel(ExtractSignOfVKernel<Scalar>,
|
||||
cfg1D.block_count, cfg1D.thread_per_block, 0,
|
||||
d.stream(), cfg1D, outputV_ptr));
|
||||
GpuLaunchConfig cfg1D = GetGpuLaunchConfig(batch_size, d);
|
||||
TF_CHECK_OK(GpuLaunchKernel(ExtractSignOfVKernel<Scalar>,
|
||||
cfg1D.block_count, cfg1D.thread_per_block, 0,
|
||||
d.stream(), cfg1D, outputV_ptr));
|
||||
}
|
||||
|
||||
if (compute_uv_) {
|
||||
|
@ -402,9 +402,9 @@ cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards,
|
||||
// We are limited by the amount of shared memory we have per block.
|
||||
auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
|
||||
|
||||
TF_CHECK_OK(CudaLaunchKernel(TopKKernel<T>, batch_size, num_shards,
|
||||
shared_memory_size, stream, input, length, k,
|
||||
sorted, output, indices));
|
||||
TF_CHECK_OK(GpuLaunchKernel(TopKKernel<T>, batch_size, num_shards,
|
||||
shared_memory_size, stream, input, length, k,
|
||||
sorted, output, indices));
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
|
@ -77,7 +77,7 @@ class TridiagonalMatMulOpGpu : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output));
|
||||
|
||||
const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
|
||||
CudaLaunchConfig cfg = GetCudaLaunchConfig(1, device);
|
||||
CudaLaunchConfig cfg = GetGpuLaunchConfig(1, device);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
TridiagonalMatMulKernel<Scalar>, cfg.block_count, cfg.thread_per_block,
|
||||
0, device.stream(), batch_size, m, n, superdiag.flat<Scalar>().data(),
|
||||
|
@ -228,13 +228,13 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
|
||||
void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
|
||||
const Scalar* rhs, Scalar* output, int m, int k) {
|
||||
const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
|
||||
GpuLaunchConfig cfg = GetCudaLaunchConfig(1, device);
|
||||
GpuLaunchConfig cfg = GetGpuLaunchConfig(1, device);
|
||||
bool* not_invertible_dev;
|
||||
cudaMalloc(¬_invertible_dev, sizeof(bool));
|
||||
TF_CHECK_OK(CudaLaunchKernel(SolveForSizeOneOrTwoKernel<Scalar>,
|
||||
cfg.block_count, cfg.thread_per_block, 0,
|
||||
device.stream(), m, diagonals, rhs, k, output,
|
||||
not_invertible_dev));
|
||||
TF_CHECK_OK(GpuLaunchKernel(SolveForSizeOneOrTwoKernel<Scalar>,
|
||||
cfg.block_count, cfg.thread_per_block, 0,
|
||||
device.stream(), m, diagonals, rhs, k, output,
|
||||
not_invertible_dev));
|
||||
bool not_invertible_host;
|
||||
cudaMemcpy(¬_invertible_host, not_invertible_dev, sizeof(bool),
|
||||
cudaMemcpyDeviceToHost);
|
||||
|
Loading…
Reference in New Issue
Block a user