Add decomposition for ResourceApplyCenteredRMSProp

PiperOrigin-RevId: 316582464
Change-Id: Id61fb84cf6ce26398fcd870f331b51e0af8b9c3d
This commit is contained in:
A. Unique TensorFlower 2020-06-15 17:47:21 -07:00 committed by TensorFlower Gardener
parent c34265b348
commit 2f45ee867d
3 changed files with 159 additions and 0 deletions

View File

@ -7280,6 +7280,49 @@ $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilo
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
} }
def TF_ResourceApplyCenteredRMSPropOp : TF_Op<"ResourceApplyCenteredRMSProp", []> {
let summary = "Update '*var' according to the centered RMSProp algorithm.";
let description = [{
The centered RMSProp algorithm uses an estimate of the centered second moment
(i.e., the variance) for normalization, as opposed to regular RMSProp, which
uses the (uncentered) second moment. This often helps with training, but is
slightly more expensive in terms of computation and memory.
Note that in dense implementation of this algorithm, mg, ms, and mom will
update even if the grad is zero, but in this sparse implementation, mg, ms,
and mom will not update in iterations during which the grad is zero.
mean_square = decay * mean_square + (1-decay) * gradient ** 2
mean_grad = decay * mean_grad + (1-decay) * gradient
Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2)
mg <- rho * mg_{t-1} + (1-rho) * grad
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
var <- var - mom
}];
let arguments = (ins
TF_ResourceTensor:$var,
TF_ResourceTensor:$mg,
TF_ResourceTensor:$ms,
TF_ResourceTensor:$mom,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
DefaultValuedAttr<BoolAttr, "false">:$use_locking
);
let results = (outs);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<4>;
}
def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> { def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> {
let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it."; let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it.";

View File

@ -368,6 +368,56 @@ func @decompose_resource_gather_op(%indices : tensor<5xi32>) -> tensor<2x5x16xi3
// ----- // -----
// Tests that composite tf.ResourceApplyCenteredRMSProp operation is decomposed.
// CHECK-LABEL: func @decompose_resource_apply_centered_RMS_prop
// CHECK-SAME: [[VAR:%.*]]: tensor<f32>, [[MG:%.*]]: tensor<f32>, [[MS:%.*]]: tensor<f32>, [[MOM:%.*]]: tensor<f32>, [[LR:%.*]]: tensor<f32>, [[RHO:%.*]]: tensor<f32>, [[MOMENTUM:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>
func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<f32>) -> () {
// CHECK: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[MG_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[MS_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[MOM_HANDLE:%.*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: [[GRADSQ:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]])
// CHECK: [[SB:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
// CHECK: [[GRAD_SUB:%.*]] = "tf.Mul"([[GRADSQ]], [[SB]])
// CHECK: [[MS:%.*]] = "tf.ReadVariableOp"([[MS_HANDLE]])
// CHECK: [[MS_RHO:%.*]] = "tf.Mul"([[MS]], [[RHO]])
// CHECK: [[MS_NEW:%.*]] = "tf.Add"([[GRAD_SUB]], [[MS_RHO]])
// CHECK: "tf.AssignVariableOp"([[MS_HANDLE]], [[MS_NEW]])
// CHECK: [[SUB_RHO:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
// CHECK: [[SUB_GRAD:%.*]] = "tf.Mul"([[GRAD]], [[SUB_RHO]])
// CHECK: [[MG:%.*]] = "tf.ReadVariableOp"([[MG_HANDLE]])
// CHECK: [[MG_RHO:%.*]] = "tf.Mul"([[MG]], [[RHO]])
// CHECK: [[MG_NEW:%.*]] = "tf.Add"([[SUB_GRAD]], [[MG_RHO]])
// CHECK: "tf.AssignVariableOp"([[MG_HANDLE]], [[MG_NEW]])
// CHECK: [[MOM:%.*]] = "tf.ReadVariableOp"([[MOM_HANDLE]])
// CHECK: [[MOM_MOM:%.*]] = "tf.Mul"([[MOMENTUM]], [[MOM]])
// CHECK: [[LR_GRAD:%.*]] = "tf.Mul"([[LR]], [[GRAD]])
// CHECK: [[MG_MG:%.*]] = "tf.Mul"([[MG_NEW]], [[MG_NEW]])
// CHECK: [[MG_NEW:%.*]] = "tf.Add"([[MG_MG]], [[EPSILON]])
// CHECK: [[MG_SUB:%.*]] = "tf.Sub"([[MS_NEW]], [[MG_NEW]])
// CHECK: [[MG_SQRT:%.*]] = "tf.Sqrt"([[MG_SUB]])
// CHECK: [[MOM_DIV:%.*]] = "tf.Div"([[LR_GRAD]], [[MG_SQRT]])
// CHECK: [[MOM_NEW:%.*]] = "tf.Add"([[MOM_MOM]], [[MOM_DIV]])
// CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]])
// CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[MOM_NEW]])
"tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
// -----
// Tests that composite tf.ResourceScatterUpdate operation is decomposed. // Tests that composite tf.ResourceScatterUpdate operation is decomposed.
// CHECK-LABEL: @decompose_resource_scatter_update_op // CHECK-LABEL: @decompose_resource_scatter_update_op

View File

@ -327,3 +327,69 @@ def DecomposeVariableShape : Pat<
(TF_VariableShapeOp:$src_op $resource), (TF_VariableShapeOp:$src_op $resource),
(TF_ShapeOp (CreateTFReadVariableOpFromResourceHandle $src_op, $resource)), (TF_ShapeOp (CreateTFReadVariableOpFromResourceHandle $src_op, $resource)),
[(CheckHasResourceSubtype $resource)]>; [(CheckHasResourceSubtype $resource)]>;
// This decomposition is only correct inside XLA as it ignores use_locking
// attribute.
// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
// mg = grad * (one - rho) + mg * rho;
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
//
def DecomposeResourceApplyCenteredRMSProp :
Pattern<
(TF_ResourceApplyCenteredRMSPropOp:$src_op
$var_resource, $mg_resource, $ms_resource, $mom_resource, $lr, $rho, $momentum, $epsilon,
$grad, ConstBoolAttrFalse:$use_locking
),
[(TF_ConstOp:$one (GetScalarOfType<1> $grad)),
(CreateTFReadVariableOp $src_op, $grad, $ms_resource),
(TF_AddOp:$ms_new
(TF_MulOp
(TF_MulOp $grad, $grad),
(TF_SubOp $one, $rho)
),
(TF_MulOp
(CreateTFReadVariableOp $src_op, $grad, $ms_resource),
$rho
)
),
(TF_AssignVariableOp $ms_resource, $ms_new),
// mg = grad * (one - rho) + mg * rho;
(TF_AddOp:$mg_new
(TF_MulOp
$grad,
(TF_SubOp $one, $rho)
),
(TF_MulOp
(CreateTFReadVariableOp $src_op, $grad, $mg_resource),
$rho
)
),
(TF_AssignVariableOp $mg_resource, $mg_new),
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
(TF_AddOp:$mom_new
(TF_MulOp $momentum,
(CreateTFReadVariableOp $src_op, $grad, $mom_resource)),
(TF_DivOp
(TF_MulOp $lr, $grad),
(TF_SqrtOp
(TF_SubOp
$ms_new,
(TF_AddOp
(TF_MulOp
$mg_new,
$mg_new
),
$epsilon
)
)
)
)
),
(TF_AssignVariableOp $mom_resource, $mom_new),
// var <- var - mom
(TF_AssignSubVariableOp $var_resource,
(TF_SubOp (CreateTFReadVariableOp $src_op, $grad, $var_resource),
$mom_new)
)
]
>;