diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index c288d613e29..dad7408df3e 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -747,8 +747,8 @@ REGISTER_XLA_OP(Name("ResourceApplyCenteredRMSProp") .TypeConstraint("T", kFloatAndComplexTypes), ResourceApplyCenteredRMSProp); -void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, - bool has_l2_shrinkage) { +void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage, + bool multiply_linear_by_lr) { xla::XlaBuilder* b = ctx->builder(); TensorShape var_shape, accum_shape, linear_shape; @@ -840,9 +840,19 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, xla::XlaOp new_accum = accum + xla::Square(grad); xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power); xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power); - linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; - xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1); - xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2; + if (multiply_linear_by_lr) { + linear = + linear + grad_to_use * lr - (new_accum_lr_pow - accum_lr_pow) * var; + } else { + linear = + linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; + } + xla::XlaOp linear_clipped = + (multiply_linear_by_lr ? xla::Clamp(-l1 * lr, linear, l1 * lr) + : xla::Clamp(-l1, linear, l1)); + xla::XlaOp quadratic = + (multiply_linear_by_lr ? new_accum_lr_pow + two * l2 * lr + : new_accum_lr_pow / lr + two * l2); var = (linear_clipped - linear) / quadratic; accum = new_accum; @@ -855,14 +865,20 @@ class ResourceApplyFtrl : public XlaOpKernel { public: explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr("multiply_linear_by_lr", &multiply_linear_by_lr_)); } void Compile(XlaOpKernelContext* ctx) override { - CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/false); + CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/false, + /*multiply_linear_by_lr=*/multiply_linear_by_lr_); } private: DataType dtype_; + + // Whether to keep the "linear" slot variable multiplied by the learning rate. + bool multiply_linear_by_lr_; }; REGISTER_XLA_OP(Name("ResourceApplyFtrl").TypeConstraint("T", kFloatTypes), ResourceApplyFtrl); @@ -871,14 +887,20 @@ class ResourceApplyFtrlV2 : public XlaOpKernel { public: explicit ResourceApplyFtrlV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr("multiply_linear_by_lr", &multiply_linear_by_lr_)); } void Compile(XlaOpKernelContext* ctx) override { - CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/true); + CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/true, + /*multiply_linear_by_lr=*/multiply_linear_by_lr_); } private: DataType dtype_; + + // Whether to keep the "linear" slot variable multiplied by the learning rate. + bool multiply_linear_by_lr_; }; REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes), ResourceApplyFtrlV2); diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 68641b37733..e2418dda59a 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -248,6 +248,47 @@ struct ApplyFtrlV2 { } }; +template +struct ApplyFtrlV2MultiplyLinearByLr { + void operator()(const CPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat linear, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar l2_shrinkage, + typename TTypes::ConstScalar lr_power) { + auto grad_with_shrinkage = grad + static_cast(2) * l2_shrinkage() * var; + auto new_accum = accum + grad * grad; + // special case for which lr_power=-0.5. + if (lr_power() == static_cast(-0.5)) { + linear.device(d) += + grad_with_shrinkage * lr() - (new_accum.sqrt() - accum.sqrt()) * var; + } else { + linear.device(d) += + grad_with_shrinkage * lr() - + (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) * var; + } + auto x = (linear.constant(l1() * lr()) * linear.sign() - linear); + if (lr_power() == static_cast(-0.5)) { + auto y = + new_accum.sqrt() + linear.constant(static_cast(2) * l2() * lr()); + auto pre_shrink = x / y; + var.device(d) = (linear.abs() > linear.constant(l1() * lr())) + .select(pre_shrink, var.constant(static_cast(0))); + + } else { + auto y = new_accum.pow(-lr_power()) + + linear.constant(static_cast(2) * l2() * lr()); + auto pre_shrink = x / y; + var.device(d) = (linear.abs() > linear.constant(l1() * lr())) + .select(pre_shrink, var.constant(static_cast(0))); + } + accum.device(d) += grad * grad; + } +}; + template struct ApplyFtrl { void operator()(const CPUDevice& d, typename TTypes::Flat var, @@ -286,6 +327,44 @@ struct ApplyFtrl { } }; +template +struct ApplyFtrlMultiplyLinearByLr { + void operator()(const CPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat linear, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar lr_power) { + auto new_accum = accum + grad.square(); + // special case for which lr_power=-0.5. + if (lr_power() == static_cast(-0.5)) { + linear.device(d) += grad * lr() - (new_accum.sqrt() - accum.sqrt()) * var; + } else { + linear.device(d) += + grad * lr() - + (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) * var; + } + auto x = (linear.constant(l1()) * lr() * linear.sign() - linear); + if (lr_power() == static_cast(-0.5)) { + auto y = + new_accum.sqrt() + linear.constant(static_cast(2) * l2() * lr()); + auto pre_shrink = x / y; + var.device(d) = (linear.abs() > linear.constant(l1() * lr())) + .select(pre_shrink, var.constant(static_cast(0))); + + } else { + auto y = new_accum.pow(-lr_power()) + + linear.constant(static_cast(2) * l2() * lr()); + auto pre_shrink = x / y; + var.device(d) = (linear.abs() > linear.constant(l1() * lr())) + .select(pre_shrink, var.constant(static_cast(0))); + } + accum.device(d) += grad.square(); + } +}; + template struct ApplyMomentum { void operator()(const CPUDevice& d, typename TTypes::Flat var, @@ -1556,16 +1635,28 @@ namespace { template inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1, - const T& l2, const T& lr_power) { + const T& l2, const T& lr_power, + const bool multiply_linear_by_lr) { T quadratic; - if (lr_power == static_cast(-0.5)) { - quadratic = Eigen::numext::sqrt(accum) / lr + static_cast(2) * l2; + if (multiply_linear_by_lr) { + if (lr_power == static_cast(-0.5)) { + quadratic = Eigen::numext::sqrt(accum) + static_cast(2) * l2 * lr; + } else { + quadratic = + Eigen::numext::pow(accum, -lr_power) + static_cast(2) * l2 * lr; + } + auto l1_reg_adjust = std::max(std::min(linear, l1 * lr), -l1 * lr); + return (l1_reg_adjust - linear) / quadratic; } else { - quadratic = - Eigen::numext::pow(accum, -lr_power) / lr + static_cast(2) * l2; + if (lr_power == static_cast(-0.5)) { + quadratic = Eigen::numext::sqrt(accum) / lr + static_cast(2) * l2; + } else { + quadratic = + Eigen::numext::pow(accum, -lr_power) / lr + static_cast(2) * l2; + } + auto l1_reg_adjust = std::max(std::min(linear, l1), -l1); + return (l1_reg_adjust - linear) / quadratic; } - auto l1_reg_adjust = std::max(std::min(linear, l1), -l1); - return (l1_reg_adjust - linear) / quadratic; } } // namespace @@ -2392,6 +2483,8 @@ class ApplyFtrlOp : public OpKernel { public: explicit ApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr("multiply_linear_by_lr", &multiply_linear_by_lr_)); } void Compute(OpKernelContext* ctx) override { @@ -2466,10 +2559,22 @@ class ApplyFtrlOp : public OpKernel { errors::InvalidArgument("l2 shrinkage regularization strength " "is not a scalar: ", l2_shrinkage.shape().DebugString())); - functor::ApplyFtrlV2()( + if (multiply_linear_by_lr_) { + functor::ApplyFtrlV2()( + device, var.flat(), accum.flat(), linear.flat(), + grad.flat(), lr.scalar(), l1.scalar(), l2.scalar(), + l2_shrinkage.scalar(), lr_power.scalar()); + } else { + functor::ApplyFtrlV2MultiplyLinearByLr()( + device, var.flat(), accum.flat(), linear.flat(), + grad.flat(), lr.scalar(), l1.scalar(), l2.scalar(), + l2_shrinkage.scalar(), lr_power.scalar()); + } + } else if (multiply_linear_by_lr_) { + functor::ApplyFtrlMultiplyLinearByLr()( device, var.flat(), accum.flat(), linear.flat(), grad.flat(), lr.scalar(), l1.scalar(), l2.scalar(), - l2_shrinkage.scalar(), lr_power.scalar()); + lr_power.scalar()); } else { functor::ApplyFtrl()(device, var.flat(), accum.flat(), linear.flat(), grad.flat(), @@ -2482,6 +2587,7 @@ class ApplyFtrlOp : public OpKernel { private: bool use_exclusive_lock_; + bool multiply_linear_by_lr_; }; #define REGISTER_KERNELS(D, T) \ @@ -2559,7 +2665,16 @@ namespace functor { typename TTypes::ConstScalar l1, typename TTypes::ConstScalar l2, \ typename TTypes::ConstScalar l2_shrinkage, \ typename TTypes::ConstScalar lr_power); \ - extern template struct ApplyFtrlV2; + extern template struct ApplyFtrlV2; \ + template <> \ + void ApplyFtrlV2MultiplyLinearByLr::operator()( \ + const GPUDevice& d, typename TTypes::Flat var, \ + typename TTypes::Flat accum, typename TTypes::Flat linear, \ + typename TTypes::ConstFlat grad, typename TTypes::ConstScalar lr, \ + typename TTypes::ConstScalar l1, typename TTypes::ConstScalar l2, \ + typename TTypes::ConstScalar l2_shrinkage, \ + typename TTypes::ConstScalar lr_power); \ + extern template struct ApplyFtrlV2MultiplyLinearByLr; DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); @@ -2579,6 +2694,8 @@ class SparseApplyFtrlOp : public OpKernel { public: explicit SparseApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr("multiply_linear_by_lr", &multiply_linear_by_lr_)); } void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { @@ -2714,24 +2831,53 @@ class SparseApplyFtrlOp : public OpKernel { // eigen tensor library. #define COMPUTE_FTRL(grad, grad_maybe_with_shrinkage) \ auto new_accum = accum + grad.square(); \ - if (lr_power_scalar == static_cast(-0.5)) { \ - linear += grad_maybe_with_shrinkage - \ - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \ + if (multiply_linear_by_lr_) { \ + if (lr_power_scalar == static_cast(-0.5)) { \ + linear += grad_maybe_with_shrinkage * lr_scalar - \ + (new_accum.sqrt() - accum.sqrt()) * var; \ + } else { \ + linear += \ + grad_maybe_with_shrinkage * lr_scalar - \ + (new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) * \ + var; \ + } \ } else { \ - linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \ - accum.pow(-lr_power_scalar)) / \ - lr_scalar * var; \ + if (lr_power_scalar == static_cast(-0.5)) { \ + linear += grad_maybe_with_shrinkage - \ + (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \ + } else { \ + linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \ + accum.pow(-lr_power_scalar)) / \ + lr_scalar * var; \ + } \ } \ - auto l1_reg_adjust = linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar); \ + auto l1_reg_adjust = \ + (multiply_linear_by_lr_ \ + ? linear.cwiseMin(l1_scalar * lr_scalar) \ + .cwiseMax(-l1_scalar * lr_scalar) \ + : linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar)); \ auto x = l1_reg_adjust - linear; \ - if (lr_power_scalar == static_cast(-0.5)) { \ - auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) + \ - linear.constant(static_cast(2) * l2_scalar); \ - var = x / y; \ + if (multiply_linear_by_lr_) { \ + if (lr_power_scalar == static_cast(-0.5)) { \ + auto y = new_accum.sqrt() + \ + linear.constant(static_cast(2) * l2_scalar * lr_scalar); \ + var = x / y; \ + } else { \ + auto y = new_accum.pow(-lr_power_scalar) + \ + linear.constant(static_cast(2) * l2_scalar * lr_scalar); \ + var = x / y; \ + } \ } else { \ - auto y = new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) + \ - linear.constant(static_cast(2) * l2_scalar); \ - var = x / y; \ + if (lr_power_scalar == static_cast(-0.5)) { \ + auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) + \ + linear.constant(static_cast(2) * l2_scalar); \ + var = x / y; \ + } else { \ + auto y = \ + new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) + \ + linear.constant(static_cast(2) * l2_scalar); \ + var = x / y; \ + } \ } \ accum += grad.square(); @@ -2781,10 +2927,13 @@ class SparseApplyFtrlOp : public OpKernel { T updated_a = a + grad_flat(i) * grad_flat(i); using Eigen::numext::pow; T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar); - sigma /= lr_scalar; - T updated_l = l + g - sigma * v; + if (!multiply_linear_by_lr_) { + sigma /= lr_scalar; + } + T updated_l = (multiply_linear_by_lr_ ? l + g * lr_scalar - sigma * v + : l + g - sigma * v); v = FtrlCompute(updated_a, updated_l, lr_scalar, l1_scalar, l2_scalar, - lr_power_scalar); + lr_power_scalar, multiply_linear_by_lr_); a = updated_a; l = updated_l; } @@ -2796,6 +2945,7 @@ class SparseApplyFtrlOp : public OpKernel { private: bool use_exclusive_lock_; + bool multiply_linear_by_lr_; }; #define REGISTER_KERNELS(T, Tindices) \ diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h index 53109f5f920..ef44b5f9659 100644 --- a/tensorflow/core/kernels/training_ops.h +++ b/tensorflow/core/kernels/training_ops.h @@ -113,6 +113,18 @@ struct ApplyFtrl { typename TTypes::ConstScalar lr_power); }; +template +struct ApplyFtrlMultiplyLinearByLr { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat linear, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar lr_power); +}; + template struct ApplyFtrlV2 { void operator()(const Device& d, typename TTypes::Flat var, @@ -126,6 +138,19 @@ struct ApplyFtrlV2 { typename TTypes::ConstScalar lr_power); }; +template +struct ApplyFtrlV2MultiplyLinearByLr { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat linear, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar l2_shrinkage, + typename TTypes::ConstScalar lr_power); +}; + template struct ApplyMomentum { void operator()(const Device& d, typename TTypes::Flat var, diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index a50180ca1e1..1e53cfed777 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -215,6 +215,41 @@ struct ApplyFtrl { } }; +template +struct ApplyFtrlMultiplyLinearByLr { + void operator()(const GPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat linear, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar lr_power) { + Eigen::array::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + + auto lr_bcast = lr.reshape(single).broadcast(bcast); + auto l1_lr_bcast = (l1 * lr).reshape(single).broadcast(bcast); + auto l2_lr_bcast = (l2 * lr).reshape(single).broadcast(bcast); + auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast); + const auto two = static_cast(2.0); + + auto new_accum = accum + grad.square(); + auto accum_power = accum.binaryExpr(lr_power_bcast, + Eigen::internal::scalar_pow_op()); + auto new_accum_power = new_accum.binaryExpr( + lr_power_bcast, Eigen::internal::scalar_pow_op()); + linear.device(d) += grad * lr_bcast - (new_accum_power - accum_power) * var; + auto x = (l1_lr_bcast * linear.sign() - linear); + auto y = new_accum_power + linear.constant(two) * l2_lr_bcast; + auto pre_shrink = x / y; + var.device(d) = (linear.abs() > l1_lr_bcast) + .select(pre_shrink, var.constant(static_cast(0))); + accum.device(d) += grad.square(); + } +}; + template struct ApplyFtrlV2 { void operator()(const GPUDevice& d, typename TTypes::Flat var, @@ -255,6 +290,46 @@ struct ApplyFtrlV2 { } }; +template +struct ApplyFtrlV2MultiplyLinearByLr { + void operator()(const GPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat linear, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar l2_shrinkage, + typename TTypes::ConstScalar lr_power) { + Eigen::array::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + + auto l2_shrinkage_bcast = l2_shrinkage.reshape(single).broadcast(bcast); + auto lr_bcast = lr.reshape(single).broadcast(bcast); + auto l1_lr_bcast = (l1 * lr).reshape(single).broadcast(bcast); + auto l2_lr_bcast = (l2 * lr).reshape(single).broadcast(bcast); + auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast); + const auto two = static_cast(2.0); + + auto new_accum = accum + grad.square(); + auto accum_power = accum.binaryExpr(lr_power_bcast, + Eigen::internal::scalar_pow_op()); + auto new_accum_power = new_accum.binaryExpr( + lr_power_bcast, Eigen::internal::scalar_pow_op()); + auto grad_with_shrinkage = + grad + (var.constant(two) * l2_shrinkage_bcast * var); + linear.device(d) += + grad_with_shrinkage * lr_bcast - (new_accum_power - accum_power) * var; + auto x = (l1_lr_bcast * linear.sign() - linear); + auto y = new_accum_power + linear.constant(two) * l2_lr_bcast; + auto pre_shrink = x / y; + var.device(d) = (linear.abs() > l1_lr_bcast) + .select(pre_shrink, var.constant(static_cast(0))); + accum.device(d) += grad.square(); + } +}; + template struct ApplyMomentum { void operator()(const GPUDevice& d, typename TTypes::Flat var, @@ -565,10 +640,18 @@ template struct functor::ApplyFtrl; template struct functor::ApplyFtrl; template struct functor::ApplyFtrl; +template struct functor::ApplyFtrlMultiplyLinearByLr; +template struct functor::ApplyFtrlMultiplyLinearByLr; +template struct functor::ApplyFtrlMultiplyLinearByLr; + template struct functor::ApplyFtrlV2; template struct functor::ApplyFtrlV2; template struct functor::ApplyFtrlV2; +template struct functor::ApplyFtrlV2MultiplyLinearByLr; +template struct functor::ApplyFtrlV2MultiplyLinearByLr; +template struct functor::ApplyFtrlV2MultiplyLinearByLr; + template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index 620c9a2b49f..3cb6c71884a 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -559,6 +559,7 @@ REGISTER_OP("ApplyFtrl") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("multiply_linear_by_lr: bool = false") .SetShapeFn(ApplyFtrlShapeFn); REGISTER_OP("SparseApplyFtrl") @@ -575,6 +576,7 @@ REGISTER_OP("SparseApplyFtrl") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .Attr("multiply_linear_by_lr: bool = false") .SetShapeFn(ApplyFtrlShapeFn); REGISTER_OP("ResourceApplyFtrl") @@ -588,6 +590,7 @@ REGISTER_OP("ResourceApplyFtrl") .Input("lr_power: T") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("multiply_linear_by_lr: bool = false") .SetShapeFn(ApplyFtrlShapeFn); REGISTER_OP("ResourceSparseApplyFtrl") @@ -603,6 +606,7 @@ REGISTER_OP("ResourceSparseApplyFtrl") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .Attr("multiply_linear_by_lr: bool = false") .SetShapeFn(ApplyFtrlShapeFn); REGISTER_OP("ApplyFtrlV2") @@ -618,6 +622,7 @@ REGISTER_OP("ApplyFtrlV2") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("multiply_linear_by_lr: bool = false") .SetShapeFn(ApplyFtrlShapeFn); REGISTER_OP("SparseApplyFtrlV2") @@ -635,6 +640,7 @@ REGISTER_OP("SparseApplyFtrlV2") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .Attr("multiply_linear_by_lr: bool = false") .SetShapeFn(ApplyFtrlShapeFn); REGISTER_OP("ResourceApplyFtrlV2") @@ -649,6 +655,7 @@ REGISTER_OP("ResourceApplyFtrlV2") .Input("lr_power: T") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("multiply_linear_by_lr: bool = false") .SetShapeFn(ApplyFtrlShapeFn); REGISTER_OP("ResourceSparseApplyFtrlV2") @@ -665,6 +672,7 @@ REGISTER_OP("ResourceSparseApplyFtrlV2") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .Attr("multiply_linear_by_lr: bool = false") .SetShapeFn(ApplyFtrlShapeFn); template diff --git a/tensorflow/python/training/training_ops_test.py b/tensorflow/python/training/training_ops_test.py index 8ba6abdcf95..118636c551e 100644 --- a/tensorflow/python/training/training_ops_test.py +++ b/tensorflow/python/training/training_ops_test.py @@ -129,6 +129,61 @@ class TrainingOpsTest(TensorFlowTestCase): self.assertAllClose(linear_update, self.evaluate(linear)) self.assertAllClose(expected_out, out) + def _testTypesForFtrlMultiplyLinearByLr(self, + x, + y, + z, + lr, + grad, + use_gpu=None, + l1=0.0, + l2=0.0, + lr_power=-0.5): + self.setUp() + with self.session(use_gpu=use_gpu): + var = variables.VariableV1(x) + accum = variables.VariableV1(y) + linear = variables.VariableV1(z) + self.evaluate(variables.global_variables_initializer()) + + self.assertAllCloseAccordingToType(x, self.evaluate(var)) + apply_ftrl = ( + training_ops.apply_ftrl( + var, + accum, + linear, + grad, + lr, + l1, + l2, + lr_power, + multiply_linear_by_lr=True)) + out = self.evaluate(apply_ftrl) + self.assertShapeEqual(out, apply_ftrl) + accum_update = y + grad * grad + linear_update = z + grad * lr - (accum_update**(-lr_power) - y** + (-lr_power)) * x + quadratic = accum_update**(-lr_power) + 2 * l2 * lr + expected_out = np.array([ + (np.sign(linear_update[i]) * l1 * lr - linear_update[i]) / + (quadratic[i]) if np.abs(linear_update[i]) > l1 * lr else 0.0 + for i in range(linear_update.size) + ]) + self.assertAllCloseAccordingToType(accum_update, self.evaluate(accum)) + if x.dtype == np.float16: + # The calculations here really are not very precise in float16. + self.assertAllClose( + linear_update, self.evaluate(linear), rtol=2e-2, atol=2e-2) + self.assertAllClose(expected_out, out, rtol=2e-2, atol=2e-2) + elif x.dtype == np.float32: + # The calculations here not sufficiently precise in float32. + self.assertAllClose( + linear_update, self.evaluate(linear), rtol=1e-5, atol=1e-5) + self.assertAllClose(expected_out, out, rtol=1e-5, atol=1e-5) + else: + self.assertAllClose(linear_update, self.evaluate(linear)) + self.assertAllClose(expected_out, out) + @test_util.run_v1_only("b/120545219") def testApplyAdagrad(self): for (dtype, use_gpu) in itertools.product( @@ -151,6 +206,19 @@ class TrainingOpsTest(TensorFlowTestCase): grad = np.arange(100).astype(dtype) self._testTypesForFtrl(x, y, z, lr, grad, use_gpu=False, l1=l1, l2=l2) + @test_util.run_v1_only("b/120545219") + def testApplyFtrlMultiplyLinearByLr(self): + for dtype in [np.float16, np.float32, np.float64]: + x = np.arange(100).astype(dtype) + y = np.arange(1, 101).astype(dtype) + z = np.arange(102, 202).astype(dtype) + lr = np.array(2.0).astype(dtype) + l1 = np.array(3.0).astype(dtype) + l2 = np.array(4.0).astype(dtype) + grad = np.arange(100).astype(dtype) + self._testTypesForFtrlMultiplyLinearByLr( + x, y, z, lr, grad, use_gpu=False, l1=l1, l2=l2) + def _testTypesForSparseAdagrad(self, x, y, lr, grad, indices): self.setUp() with self.session(use_gpu=False): @@ -203,6 +271,47 @@ class TrainingOpsTest(TensorFlowTestCase): out = self.evaluate(sparse_apply_ftrl) self.assertShapeEqual(out, sparse_apply_ftrl) + for (i, index) in enumerate(indices): + self.assertAllCloseAccordingToType( + x[index] - lr * grad[i] * + (y[index] + grad[i] * grad[i])**(lr_power), + self.evaluate(var)[index]) + self.assertAllCloseAccordingToType(y[index] + grad[i] * grad[i], + self.evaluate(accum)[index]) + + def _testTypesForSparseFtrlMultiplyLinearByLr(self, + x, + y, + z, + lr, + grad, + indices, + l1=0.0, + l2=0.0, + lr_power=-0.5): + self.setUp() + with self.session(use_gpu=False): + var = variables.VariableV1(x) + accum = variables.VariableV1(y) + linear = variables.VariableV1(z) + self.evaluate(variables.global_variables_initializer()) + + self.assertAllCloseAccordingToType(x, self.evaluate(var)) + sparse_apply_ftrl = ( + training_ops.sparse_apply_ftrl( + var, + accum, + linear, + grad, + constant_op.constant(indices, self._toType(indices.dtype)), + lr, + l1, + l2, + lr_power=lr_power, + multiply_linear_by_lr=True)) + out = self.evaluate(sparse_apply_ftrl) + self.assertShapeEqual(out, sparse_apply_ftrl) + for (i, index) in enumerate(indices): self.assertAllCloseAccordingToType( x[index] - lr * grad[i] * (y[index] + grad[i] * grad[i])** @@ -255,6 +364,23 @@ class TrainingOpsTest(TensorFlowTestCase): indices = np.array([0, 2]).astype(index_type) self._testTypesForSparseFtrl(x, y, z, lr, grad, indices) + @test_util.run_v1_only("b/120545219") + def testSparseApplyFtrlMultiplyLinearByLrDim1(self): + for (dtype, + index_type) in itertools.product([np.float16, np.float32, np.float64], + [np.int32, np.int64]): + x_val = [[0.0], [0.0], [0.0]] + y_val = [[4.0], [5.0], [6.0]] + z_val = [[0.0], [0.0], [0.0]] + x = np.array(x_val).astype(dtype) + y = np.array(y_val).astype(dtype) + z = np.array(z_val).astype(dtype) + lr = np.array(2.0).astype(dtype) + grad_val = [[1.5], [2.5]] + grad = np.array(grad_val).astype(dtype) + indices = np.array([0, 2]).astype(index_type) + self._testTypesForSparseFtrlMultiplyLinearByLr(x, y, z, lr, grad, indices) + @test_util.run_v1_only("b/120545219") def testApplyAdam(self): for dtype, use_gpu in itertools.product( diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index af2a47fb3b9..3428796ea75 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -150,11 +150,11 @@ tf_module { } member_method { name: "ApplyFtrl" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ApplyFtrlV2" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ApplyGradientDescent" @@ -3414,11 +3414,11 @@ tf_module { } member_method { name: "ResourceApplyFtrl" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ResourceApplyFtrlV2" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ResourceApplyGradientDescent" @@ -3526,11 +3526,11 @@ tf_module { } member_method { name: "ResourceSparseApplyFtrl" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ResourceSparseApplyFtrlV2" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ResourceSparseApplyKerasMomentum" @@ -4030,11 +4030,11 @@ tf_module { } member_method { name: "SparseApplyFtrl" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "SparseApplyFtrlV2" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "SparseApplyMomentum" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index af2a47fb3b9..3428796ea75 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -150,11 +150,11 @@ tf_module { } member_method { name: "ApplyFtrl" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ApplyFtrlV2" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ApplyGradientDescent" @@ -3414,11 +3414,11 @@ tf_module { } member_method { name: "ResourceApplyFtrl" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ResourceApplyFtrlV2" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ResourceApplyGradientDescent" @@ -3526,11 +3526,11 @@ tf_module { } member_method { name: "ResourceSparseApplyFtrl" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ResourceSparseApplyFtrlV2" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "ResourceSparseApplyKerasMomentum" @@ -4030,11 +4030,11 @@ tf_module { } member_method { name: "SparseApplyFtrl" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "SparseApplyFtrlV2" - argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'var\', \'accum\', \'linear\', \'grad\', \'indices\', \'lr\', \'l1\', \'l2\', \'l2_shrinkage\', \'lr_power\', \'use_locking\', \'multiply_linear_by_lr\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } member_method { name: "SparseApplyMomentum"