diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 42104b951ac..a2ecbca124c 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -352,9 +352,9 @@ class ResourceApplyRMSProp : public XlaOpKernel { b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); xla::ComputationDataHandle new_mom = b->Add(b->Mul(mom, momentum), - b->Div(b->Mul(grad, lr), + b->Mul(b->Mul(grad, lr), b->Pow(b->Add(new_ms, epsilon), - XlaHelpers::FloatLiteral(b, type, 0.5)))); + XlaHelpers::FloatLiteral(b, type, -0.5)))); xla::ComputationDataHandle new_var = b->Sub(var, new_mom); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var));