Decompose ResourceApplyProximalAdagrad op
Also, remove common subexpression in Select op's branches from tf2xla kernel which is easier to express in MLIR rewriter pattern. PiperOrigin-RevId: 345904245 Change-Id: I06d4dcb52acfa2184a5220478d669262858bc3e1
This commit is contained in:
parent
a5d5a36e4c
commit
c99adee9ec
@ -517,3 +517,44 @@ func @decompose_variable_shape_no_subtype(%input: tensor<!tf.resource>) -> tenso
|
||||
// CHECK-NOT: "tf.Shape"
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that resource subtype is correctly propagated when decomposing tf.ResourceGather.
|
||||
|
||||
// CHECK-LABEL: @decompose_resource_apply_proximal_adagrad_op
|
||||
// CHECK-SAME: (%[[LR:.*]]: tensor<f32>, %[[L1:.*]]: tensor<f32>, %[[L2:.*]]: tensor<f32>, %[[GRAD:.*]]: tensor<4xf32>)
|
||||
func @decompose_resource_apply_proximal_adagrad_op(%lr: tensor<f32>, %l1: tensor<f32>, %l2: tensor<f32>, %grad: tensor<4xf32>) -> () {
|
||||
%var = "tf.VarHandleOp"() {container = "c", shared_name = "var"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
%accum = "tf.VarHandleOp"() {container = "c", shared_name = "accum"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
|
||||
// CHECK-DAG: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "var"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
// CHECK-DAG: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "accum"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
// CHECK-DAG: %[[GRAD_SQRT:.*]] = "tf.Sqrt"(%[[GRAD]]) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQRT]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[RSQRT_ACCUM:.*]] = "tf.Rsqrt"(%[[ACCUM_NEW]]) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[ADAGRAD_LR:.*]] = "tf.Mul"(%[[LR]], %[[RSQRT_ACCUM]]) : (tensor<f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[DELTA:.*]] = "tf.Mul"(%[[GRAD]], %[[ADAGRAD_LR]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[PROX:.*]] = "tf.Sub"(%[[VAR]], %[[DELTA]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[SIGN:.*]] = "tf.Sign"(%[[PROX]]) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[ABS:.*]] = "tf.Abs"(%[[PROX]]) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[SCALED_L1:.*]] = "tf.Mul"(%[[ADAGRAD_LR]], %[[L1]]) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[PROX_NEW:.*]] = "tf.Sub"(%[[ABS]], %[[SCALED_L1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[MAX:.*]] = "tf.Maximum"(%[[PROX_NEW]], %[[ZERO]]) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[SIGNED:.*]] = "tf.Mul"(%[[SIGN]], %[[MAX]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[GT:.*]] = "tf.Greater"(%[[L1]], %[[ZERO]]) : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK-DAG: %[[NUMERATOR:.*]] = "tf.SelectV2"(%[[GT]], %[[SIGNED:.*]], %[[PROX]]) : (tensor<i1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[SCALED_L2:.*]] = "tf.Mul"(%[[ADAGRAD_LR]], %[[L2]]) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[DENOMINATOR:.*]] = "tf.Add"(%[[ONE]], %[[SCALED_L2]]) : (tensor<f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[VAR_NEW:.*]] = "tf.Div"(%[[NUMERATOR]], %[[DENOMINATOR]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
|
||||
// CHECK-DAG: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
|
||||
|
||||
"tf.ResourceApplyProximalAdagrad"(%var, %accum, %lr, %l1, %l2, %grad) {use_locking = false} : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<4xf32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -447,3 +447,30 @@ def DecomposeResourceApplyRMSProp :
|
||||
(TF_AssignSubVariableOp $var_resource, $mom_new)
|
||||
]
|
||||
>;
|
||||
|
||||
def DecomposeResourceApplyProximalAdagrad :
|
||||
Pattern<
|
||||
(TF_ResourceApplyProximalAdagradOp:$src_op
|
||||
$var_resource, $accum_resource, $lr, $l1, $l2, $grad,
|
||||
ConstBoolAttrFalse:$use_locking
|
||||
),
|
||||
[(TF_ConstOp:$one (GetScalarOfType<1> $grad)),
|
||||
(TF_ConstOp:$zero (GetScalarOfType<0> $grad)),
|
||||
(TF_AddV2Op:$accum_new
|
||||
(CreateTFReadVariableOp $src_op, $grad, $accum_resource),
|
||||
(TF_SqrtOp $grad)),
|
||||
(TF_MulOp:$adagrad_lr $lr, (TF_RsqrtOp $accum_new)),
|
||||
(TF_SubOp:$prox_var
|
||||
(CreateTFReadVariableOp $src_op, $grad, $var_resource),
|
||||
(TF_MulOp $grad, $adagrad_lr)),
|
||||
(TF_MulOp:$l1_gt_zero (TF_SignOp $prox_var),
|
||||
(TF_MaximumOp
|
||||
(TF_SubOp (TF_AbsOp $prox_var), (TF_MulOp $adagrad_lr, $l1)), $zero)),
|
||||
(TF_SelectV2Op:$var_numerator (TF_GreaterOp $l1, $zero),
|
||||
$l1_gt_zero, $prox_var),
|
||||
(TF_DivOp:$var_new
|
||||
$var_numerator, (TF_AddOp $one, (TF_MulOp $adagrad_lr, $l2))),
|
||||
(TF_AssignVariableOp $var_resource, $var_new),
|
||||
(TF_AssignVariableOp $accum_resource, $accum_new)
|
||||
]
|
||||
>;
|
||||
|
||||
@ -62,11 +62,11 @@ xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr,
|
||||
xla::XlaOp one = xla::ScalarLike(lr, 1.0);
|
||||
xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
|
||||
xla::XlaOp prox_var = var - grad * lr;
|
||||
xla::XlaOp l1_gt_zero = xla::Sign(prox_var) *
|
||||
xla::Max(xla::Abs(prox_var) - lr * l1, zero) /
|
||||
(one + lr * l2);
|
||||
xla::XlaOp l1_le_zero = prox_var / (one + lr * l2);
|
||||
return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
|
||||
xla::XlaOp l1_gt_zero =
|
||||
xla::Sign(prox_var) * xla::Max(xla::Abs(prox_var) - lr * l1, zero);
|
||||
xla::XlaOp l1_le_zero = prox_var;
|
||||
return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero) /
|
||||
(one + lr * l2);
|
||||
}
|
||||
|
||||
class ResourceApplyProximalGradientDescent : public XlaOpKernel {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user