diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index df8dccb2163..6131a729441 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -7280,6 +7280,49 @@ $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilo TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; } +def TF_ResourceApplyCenteredRMSPropOp : TF_Op<"ResourceApplyCenteredRMSProp", []> { + let summary = "Update '*var' according to the centered RMSProp algorithm."; + + let description = [{ +The centered RMSProp algorithm uses an estimate of the centered second moment +(i.e., the variance) for normalization, as opposed to regular RMSProp, which +uses the (uncentered) second moment. This often helps with training, but is +slightly more expensive in terms of computation and memory. + +Note that in dense implementation of this algorithm, mg, ms, and mom will +update even if the grad is zero, but in this sparse implementation, mg, ms, +and mom will not update in iterations during which the grad is zero. + +mean_square = decay * mean_square + (1-decay) * gradient ** 2 +mean_grad = decay * mean_grad + (1-decay) * gradient + +Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) + +mg <- rho * mg_{t-1} + (1-rho) * grad +ms <- rho * ms_{t-1} + (1-rho) * grad * grad +mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) +var <- var - mom + }]; + + let arguments = (ins + TF_ResourceTensor:$var, + TF_ResourceTensor:$mg, + TF_ResourceTensor:$ms, + TF_ResourceTensor:$mom, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<4>; +} + def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> { let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it."; diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir index 7a2e5173247..25dfda25358 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir @@ -368,6 +368,56 @@ func @decompose_resource_gather_op(%indices : tensor<5xi32>) -> tensor<2x5x16xi3 // ----- +// Tests that composite tf.ResourceApplyCenteredRMSProp operation is decomposed. + +// CHECK-LABEL: func @decompose_resource_apply_centered_RMS_prop +// CHECK-SAME: [[VAR:%.*]]: tensor, [[MG:%.*]]: tensor, [[MS:%.*]]: tensor, [[MOM:%.*]]: tensor, [[LR:%.*]]: tensor, [[RHO:%.*]]: tensor, [[MOMENTUM:%.*]]: tensor, [[EPSILON:%.*]]: tensor, [[GRAD:%.*]]: tensor +func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + // CHECK: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp" + // CHECK: [[MG_HANDLE:%.*]] = "tf.VarHandleOp" + // CHECK: [[MS_HANDLE:%.*]] = "tf.VarHandleOp" + // CHECK: [[MOM_HANDLE:%.*]] = "tf.VarHandleOp" + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + %3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + + // CHECK: [[GRADSQ:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]]) + // CHECK: [[SB:%.*]] = "tf.Sub"([[ONE]], [[RHO]]) + // CHECK: [[GRAD_SUB:%.*]] = "tf.Mul"([[GRADSQ]], [[SB]]) + // CHECK: [[MS:%.*]] = "tf.ReadVariableOp"([[MS_HANDLE]]) + // CHECK: [[MS_RHO:%.*]] = "tf.Mul"([[MS]], [[RHO]]) + // CHECK: [[MS_NEW:%.*]] = "tf.Add"([[GRAD_SUB]], [[MS_RHO]]) + // CHECK: "tf.AssignVariableOp"([[MS_HANDLE]], [[MS_NEW]]) + + // CHECK: [[SUB_RHO:%.*]] = "tf.Sub"([[ONE]], [[RHO]]) + // CHECK: [[SUB_GRAD:%.*]] = "tf.Mul"([[GRAD]], [[SUB_RHO]]) + // CHECK: [[MG:%.*]] = "tf.ReadVariableOp"([[MG_HANDLE]]) + // CHECK: [[MG_RHO:%.*]] = "tf.Mul"([[MG]], [[RHO]]) + // CHECK: [[MG_NEW:%.*]] = "tf.Add"([[SUB_GRAD]], [[MG_RHO]]) + // CHECK: "tf.AssignVariableOp"([[MG_HANDLE]], [[MG_NEW]]) + + // CHECK: [[MOM:%.*]] = "tf.ReadVariableOp"([[MOM_HANDLE]]) + // CHECK: [[MOM_MOM:%.*]] = "tf.Mul"([[MOMENTUM]], [[MOM]]) + // CHECK: [[LR_GRAD:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) + + // CHECK: [[MG_MG:%.*]] = "tf.Mul"([[MG_NEW]], [[MG_NEW]]) + // CHECK: [[MG_NEW:%.*]] = "tf.Add"([[MG_MG]], [[EPSILON]]) + // CHECK: [[MG_SUB:%.*]] = "tf.Sub"([[MS_NEW]], [[MG_NEW]]) + // CHECK: [[MG_SQRT:%.*]] = "tf.Sqrt"([[MG_SUB]]) + // CHECK: [[MOM_DIV:%.*]] = "tf.Div"([[LR_GRAD]], [[MG_SQRT]]) + // CHECK: [[MOM_NEW:%.*]] = "tf.Add"([[MOM_MOM]], [[MOM_DIV]]) + + // CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) + // CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[MOM_NEW]]) + + "tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor, tensor, tensor, tensor, tensor) -> () + return +} + +// ----- + // Tests that composite tf.ResourceScatterUpdate operation is decomposed. // CHECK-LABEL: @decompose_resource_scatter_update_op diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td index 3869a1a7fa3..0dd7d778e31 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -327,3 +327,69 @@ def DecomposeVariableShape : Pat< (TF_VariableShapeOp:$src_op $resource), (TF_ShapeOp (CreateTFReadVariableOpFromResourceHandle $src_op, $resource)), [(CheckHasResourceSubtype $resource)]>; + +// This decomposition is only correct inside XLA as it ignores use_locking +// attribute. +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mg = grad * (one - rho) + mg * rho; +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) +// +def DecomposeResourceApplyCenteredRMSProp : + Pattern< + (TF_ResourceApplyCenteredRMSPropOp:$src_op + $var_resource, $mg_resource, $ms_resource, $mom_resource, $lr, $rho, $momentum, $epsilon, + $grad, ConstBoolAttrFalse:$use_locking + ), + [(TF_ConstOp:$one (GetScalarOfType<1> $grad)), + (CreateTFReadVariableOp $src_op, $grad, $ms_resource), + (TF_AddOp:$ms_new + (TF_MulOp + (TF_MulOp $grad, $grad), + (TF_SubOp $one, $rho) + ), + (TF_MulOp + (CreateTFReadVariableOp $src_op, $grad, $ms_resource), + $rho + ) + ), + (TF_AssignVariableOp $ms_resource, $ms_new), + // mg = grad * (one - rho) + mg * rho; + (TF_AddOp:$mg_new + (TF_MulOp + $grad, + (TF_SubOp $one, $rho) + ), + (TF_MulOp + (CreateTFReadVariableOp $src_op, $grad, $mg_resource), + $rho + ) + ), + (TF_AssignVariableOp $mg_resource, $mg_new), + // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) + (TF_AddOp:$mom_new + (TF_MulOp $momentum, + (CreateTFReadVariableOp $src_op, $grad, $mom_resource)), + (TF_DivOp + (TF_MulOp $lr, $grad), + (TF_SqrtOp + (TF_SubOp + $ms_new, + (TF_AddOp + (TF_MulOp + $mg_new, + $mg_new + ), + $epsilon + ) + ) + ) + ) + ), + (TF_AssignVariableOp $mom_resource, $mom_new), + // var <- var - mom + (TF_AssignSubVariableOp $var_resource, + (TF_SubOp (CreateTFReadVariableOp $src_op, $grad, $var_resource), + $mom_new) + ) + ] + >;