adjust gather ops launch config. for NCF model, this means ~20% gain. (due to grid size from 80->160 on volta).

PiperOrigin-RevId: 312373706
Change-Id: I2413d301ec170e6e90eeae025e4bb17fccd5abbb
This commit is contained in:
A. Unique TensorFlower 2020-05-19 16:08:28 -07:00 committed by TensorFlower Gardener
parent 119aa03c76
commit f0eb6dff6f

View File

@ -92,13 +92,18 @@ struct GatherFunctor<GPUDevice, T, Index> {
const int64 indices_size = indices.size();
const int64 slice_size = params.dimension(2);
GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
if (is_axis_zero) {
GpuLaunchConfig config = GetGpuLaunchConfig(
out_size, d, &GatherOpKernel<T, Index, true>,
/*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
TF_CHECK_OK(GpuLaunchKernel(
GatherOpKernel<T, Index, true>, config.block_count,
config.thread_per_block, 0, d.stream(), params.data(), indices.data(),
out.data(), gather_dim_size, indices_size, slice_size, out_size));
} else {
GpuLaunchConfig config = GetGpuLaunchConfig(
out_size, d, &GatherOpKernel<T, Index, false>,
/*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
TF_CHECK_OK(GpuLaunchKernel(
GatherOpKernel<T, Index, false>, config.block_count,
config.thread_per_block, 0, d.stream(), params.data(), indices.data(),