adjust gather ops launch config. for NCF model, this means ~20% gain. (due to grid size from 80->160 on volta).
PiperOrigin-RevId: 312373706 Change-Id: I2413d301ec170e6e90eeae025e4bb17fccd5abbb
This commit is contained in:
parent
119aa03c76
commit
f0eb6dff6f
@ -92,13 +92,18 @@ struct GatherFunctor<GPUDevice, T, Index> {
|
||||
const int64 indices_size = indices.size();
|
||||
const int64 slice_size = params.dimension(2);
|
||||
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
|
||||
if (is_axis_zero) {
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(
|
||||
out_size, d, &GatherOpKernel<T, Index, true>,
|
||||
/*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
GatherOpKernel<T, Index, true>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), params.data(), indices.data(),
|
||||
out.data(), gather_dim_size, indices_size, slice_size, out_size));
|
||||
} else {
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(
|
||||
out_size, d, &GatherOpKernel<T, Index, false>,
|
||||
/*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
GatherOpKernel<T, Index, false>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), params.data(), indices.data(),
|
||||
|
Loading…
Reference in New Issue
Block a user