Add decomposition for ResourceApplyCenteredRMSProp
PiperOrigin-RevId: 316582464 Change-Id: Id61fb84cf6ce26398fcd870f331b51e0af8b9c3d
This commit is contained in:
parent
c34265b348
commit
2f45ee867d
|
@ -7280,6 +7280,49 @@ $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilo
|
|||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyCenteredRMSPropOp : TF_Op<"ResourceApplyCenteredRMSProp", []> {
|
||||
let summary = "Update '*var' according to the centered RMSProp algorithm.";
|
||||
|
||||
let description = [{
|
||||
The centered RMSProp algorithm uses an estimate of the centered second moment
|
||||
(i.e., the variance) for normalization, as opposed to regular RMSProp, which
|
||||
uses the (uncentered) second moment. This often helps with training, but is
|
||||
slightly more expensive in terms of computation and memory.
|
||||
|
||||
Note that in dense implementation of this algorithm, mg, ms, and mom will
|
||||
update even if the grad is zero, but in this sparse implementation, mg, ms,
|
||||
and mom will not update in iterations during which the grad is zero.
|
||||
|
||||
mean_square = decay * mean_square + (1-decay) * gradient ** 2
|
||||
mean_grad = decay * mean_grad + (1-decay) * gradient
|
||||
|
||||
Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2)
|
||||
|
||||
mg <- rho * mg_{t-1} + (1-rho) * grad
|
||||
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
|
||||
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
|
||||
var <- var - mom
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$var,
|
||||
TF_ResourceTensor:$mg,
|
||||
TF_ResourceTensor:$ms,
|
||||
TF_ResourceTensor:$mom,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_locking
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<4>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> {
|
||||
let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it.";
|
||||
|
||||
|
|
|
@ -368,6 +368,56 @@ func @decompose_resource_gather_op(%indices : tensor<5xi32>) -> tensor<2x5x16xi3
|
|||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceApplyCenteredRMSProp operation is decomposed.
|
||||
|
||||
// CHECK-LABEL: func @decompose_resource_apply_centered_RMS_prop
|
||||
// CHECK-SAME: [[VAR:%.*]]: tensor<f32>, [[MG:%.*]]: tensor<f32>, [[MS:%.*]]: tensor<f32>, [[MOM:%.*]]: tensor<f32>, [[LR:%.*]]: tensor<f32>, [[RHO:%.*]]: tensor<f32>, [[MOMENTUM:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>
|
||||
func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<f32>) -> () {
|
||||
// CHECK: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
|
||||
// CHECK: [[MG_HANDLE:%.*]] = "tf.VarHandleOp"
|
||||
// CHECK: [[MS_HANDLE:%.*]] = "tf.VarHandleOp"
|
||||
// CHECK: [[MOM_HANDLE:%.*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: [[GRADSQ:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]])
|
||||
// CHECK: [[SB:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
|
||||
// CHECK: [[GRAD_SUB:%.*]] = "tf.Mul"([[GRADSQ]], [[SB]])
|
||||
// CHECK: [[MS:%.*]] = "tf.ReadVariableOp"([[MS_HANDLE]])
|
||||
// CHECK: [[MS_RHO:%.*]] = "tf.Mul"([[MS]], [[RHO]])
|
||||
// CHECK: [[MS_NEW:%.*]] = "tf.Add"([[GRAD_SUB]], [[MS_RHO]])
|
||||
// CHECK: "tf.AssignVariableOp"([[MS_HANDLE]], [[MS_NEW]])
|
||||
|
||||
// CHECK: [[SUB_RHO:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
|
||||
// CHECK: [[SUB_GRAD:%.*]] = "tf.Mul"([[GRAD]], [[SUB_RHO]])
|
||||
// CHECK: [[MG:%.*]] = "tf.ReadVariableOp"([[MG_HANDLE]])
|
||||
// CHECK: [[MG_RHO:%.*]] = "tf.Mul"([[MG]], [[RHO]])
|
||||
// CHECK: [[MG_NEW:%.*]] = "tf.Add"([[SUB_GRAD]], [[MG_RHO]])
|
||||
// CHECK: "tf.AssignVariableOp"([[MG_HANDLE]], [[MG_NEW]])
|
||||
|
||||
// CHECK: [[MOM:%.*]] = "tf.ReadVariableOp"([[MOM_HANDLE]])
|
||||
// CHECK: [[MOM_MOM:%.*]] = "tf.Mul"([[MOMENTUM]], [[MOM]])
|
||||
// CHECK: [[LR_GRAD:%.*]] = "tf.Mul"([[LR]], [[GRAD]])
|
||||
|
||||
// CHECK: [[MG_MG:%.*]] = "tf.Mul"([[MG_NEW]], [[MG_NEW]])
|
||||
// CHECK: [[MG_NEW:%.*]] = "tf.Add"([[MG_MG]], [[EPSILON]])
|
||||
// CHECK: [[MG_SUB:%.*]] = "tf.Sub"([[MS_NEW]], [[MG_NEW]])
|
||||
// CHECK: [[MG_SQRT:%.*]] = "tf.Sqrt"([[MG_SUB]])
|
||||
// CHECK: [[MOM_DIV:%.*]] = "tf.Div"([[LR_GRAD]], [[MG_SQRT]])
|
||||
// CHECK: [[MOM_NEW:%.*]] = "tf.Add"([[MOM_MOM]], [[MOM_DIV]])
|
||||
|
||||
// CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]])
|
||||
// CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[MOM_NEW]])
|
||||
|
||||
"tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceScatterUpdate operation is decomposed.
|
||||
|
||||
// CHECK-LABEL: @decompose_resource_scatter_update_op
|
||||
|
|
|
@ -327,3 +327,69 @@ def DecomposeVariableShape : Pat<
|
|||
(TF_VariableShapeOp:$src_op $resource),
|
||||
(TF_ShapeOp (CreateTFReadVariableOpFromResourceHandle $src_op, $resource)),
|
||||
[(CheckHasResourceSubtype $resource)]>;
|
||||
|
||||
// This decomposition is only correct inside XLA as it ignores use_locking
|
||||
// attribute.
|
||||
// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
|
||||
// mg = grad * (one - rho) + mg * rho;
|
||||
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
|
||||
//
|
||||
def DecomposeResourceApplyCenteredRMSProp :
|
||||
Pattern<
|
||||
(TF_ResourceApplyCenteredRMSPropOp:$src_op
|
||||
$var_resource, $mg_resource, $ms_resource, $mom_resource, $lr, $rho, $momentum, $epsilon,
|
||||
$grad, ConstBoolAttrFalse:$use_locking
|
||||
),
|
||||
[(TF_ConstOp:$one (GetScalarOfType<1> $grad)),
|
||||
(CreateTFReadVariableOp $src_op, $grad, $ms_resource),
|
||||
(TF_AddOp:$ms_new
|
||||
(TF_MulOp
|
||||
(TF_MulOp $grad, $grad),
|
||||
(TF_SubOp $one, $rho)
|
||||
),
|
||||
(TF_MulOp
|
||||
(CreateTFReadVariableOp $src_op, $grad, $ms_resource),
|
||||
$rho
|
||||
)
|
||||
),
|
||||
(TF_AssignVariableOp $ms_resource, $ms_new),
|
||||
// mg = grad * (one - rho) + mg * rho;
|
||||
(TF_AddOp:$mg_new
|
||||
(TF_MulOp
|
||||
$grad,
|
||||
(TF_SubOp $one, $rho)
|
||||
),
|
||||
(TF_MulOp
|
||||
(CreateTFReadVariableOp $src_op, $grad, $mg_resource),
|
||||
$rho
|
||||
)
|
||||
),
|
||||
(TF_AssignVariableOp $mg_resource, $mg_new),
|
||||
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
|
||||
(TF_AddOp:$mom_new
|
||||
(TF_MulOp $momentum,
|
||||
(CreateTFReadVariableOp $src_op, $grad, $mom_resource)),
|
||||
(TF_DivOp
|
||||
(TF_MulOp $lr, $grad),
|
||||
(TF_SqrtOp
|
||||
(TF_SubOp
|
||||
$ms_new,
|
||||
(TF_AddOp
|
||||
(TF_MulOp
|
||||
$mg_new,
|
||||
$mg_new
|
||||
),
|
||||
$epsilon
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
(TF_AssignVariableOp $mom_resource, $mom_new),
|
||||
// var <- var - mom
|
||||
(TF_AssignSubVariableOp $var_resource,
|
||||
(TF_SubOp (CreateTFReadVariableOp $src_op, $grad, $var_resource),
|
||||
$mom_new)
|
||||
)
|
||||
]
|
||||
>;
|
||||
|
|
Loading…
Reference in New Issue