Add/fix comments in SparseApplyFtrl kernel
This commit is contained in:
parent
e71259ab83
commit
e34b6cef37
tensorflow/core/kernels
@ -2704,7 +2704,6 @@ REGISTER_KERNELS(GPU, double);
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
// Note, this op works on cpu only.
|
||||
template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage>
|
||||
class SparseApplyFtrlOp : public OpKernel {
|
||||
public:
|
||||
|
@ -658,10 +658,15 @@ struct SparseApplyFtrl<GPUDevice, T, Tindex, has_l2_shrinkage> {
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
|
||||
return GpuLaunchKernel(
|
||||
SparseApplyFtrlKernel<T, Tindex, has_l2_shrinkage>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
|
||||
linear.data(), lr.data(), l1.data(), l2.data(), l2_shrinkage.data(),
|
||||
lr_power.data(), grad.data(), indices.data(), first_dim_size, grad_size,
|
||||
indices_size, multiply_linear_by_lr);
|
||||
config.thread_per_block, 0, d.stream(), /*var=*/var.data(),
|
||||
/*accum=*/accum.data(),
|
||||
/*linear=*/linear.data(), /*lr=*/lr.data(), /*l1=*/l1.data(),
|
||||
/*l2=*/l2.data(), /*l2_shrinkage=*/l2_shrinkage.data(),
|
||||
/*lr_power=*/lr_power.data(), /*grad=*/grad.data(),
|
||||
/*indices=*/indices.data(), /*param_rows=*/first_dim_size,
|
||||
/*updates_size=*/grad_size,
|
||||
/*indices_size=*/indices_size,
|
||||
/*multiply_linear_by_lr=*/multiply_linear_by_lr);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user