Add/fix comments in SparseApplyFtrl kernel
This commit is contained in:
parent
e71259ab83
commit
e34b6cef37
@ -2704,7 +2704,6 @@ REGISTER_KERNELS(GPU, double);
|
|||||||
#undef REGISTER_CPU_KERNELS
|
#undef REGISTER_CPU_KERNELS
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
// Note, this op works on cpu only.
|
|
||||||
template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage>
|
template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage>
|
||||||
class SparseApplyFtrlOp : public OpKernel {
|
class SparseApplyFtrlOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
|
@ -658,10 +658,15 @@ struct SparseApplyFtrl<GPUDevice, T, Tindex, has_l2_shrinkage> {
|
|||||||
GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
|
GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
|
||||||
return GpuLaunchKernel(
|
return GpuLaunchKernel(
|
||||||
SparseApplyFtrlKernel<T, Tindex, has_l2_shrinkage>, config.block_count,
|
SparseApplyFtrlKernel<T, Tindex, has_l2_shrinkage>, config.block_count,
|
||||||
config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
|
config.thread_per_block, 0, d.stream(), /*var=*/var.data(),
|
||||||
linear.data(), lr.data(), l1.data(), l2.data(), l2_shrinkage.data(),
|
/*accum=*/accum.data(),
|
||||||
lr_power.data(), grad.data(), indices.data(), first_dim_size, grad_size,
|
/*linear=*/linear.data(), /*lr=*/lr.data(), /*l1=*/l1.data(),
|
||||||
indices_size, multiply_linear_by_lr);
|
/*l2=*/l2.data(), /*l2_shrinkage=*/l2_shrinkage.data(),
|
||||||
|
/*lr_power=*/lr_power.data(), /*grad=*/grad.data(),
|
||||||
|
/*indices=*/indices.data(), /*param_rows=*/first_dim_size,
|
||||||
|
/*updates_size=*/grad_size,
|
||||||
|
/*indices_size=*/indices_size,
|
||||||
|
/*multiply_linear_by_lr=*/multiply_linear_by_lr);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user