Add/fix comments in SparseApplyFtrl kernel

This commit is contained in:
Ben Barsdell 2020-12-08 12:37:45 +11:00
parent e71259ab83
commit e34b6cef37
2 changed files with 9 additions and 5 deletions

View File

@ -2704,7 +2704,6 @@ REGISTER_KERNELS(GPU, double);
#undef REGISTER_CPU_KERNELS
#undef REGISTER_KERNELS
// Note, this op works on cpu only.
template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage>
class SparseApplyFtrlOp : public OpKernel {
public:

View File

@ -658,10 +658,15 @@ struct SparseApplyFtrl<GPUDevice, T, Tindex, has_l2_shrinkage> {
GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
return GpuLaunchKernel(
SparseApplyFtrlKernel<T, Tindex, has_l2_shrinkage>, config.block_count,
config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
linear.data(), lr.data(), l1.data(), l2.data(), l2_shrinkage.data(),
lr_power.data(), grad.data(), indices.data(), first_dim_size, grad_size,
indices_size, multiply_linear_by_lr);
config.thread_per_block, 0, d.stream(), /*var=*/var.data(),
/*accum=*/accum.data(),
/*linear=*/linear.data(), /*lr=*/lr.data(), /*l1=*/l1.data(),
/*l2=*/l2.data(), /*l2_shrinkage=*/l2_shrinkage.data(),
/*lr_power=*/lr_power.data(), /*grad=*/grad.data(),
/*indices=*/indices.data(), /*param_rows=*/first_dim_size,
/*updates_size=*/grad_size,
/*indices_size=*/indices_size,
/*multiply_linear_by_lr=*/multiply_linear_by_lr);
}
};