Decompose ResourceApplyProximalAdagrad op

Also, remove common subexpression in Select op's branches from tf2xla kernel which is easier to express in MLIR rewriter pattern.

PiperOrigin-RevId: 345904245
Change-Id: I06d4dcb52acfa2184a5220478d669262858bc3e1
This commit is contained in:
Smit Hinsu 2020-12-05 19:29:35 -08:00 committed by TensorFlower Gardener
parent a5d5a36e4c
commit c99adee9ec
3 changed files with 73 additions and 5 deletions

View File

@ -517,3 +517,44 @@ func @decompose_variable_shape_no_subtype(%input: tensor<!tf.resource>) -> tenso
// CHECK-NOT: "tf.Shape"
return %0 : tensor<3xi32>
}
// -----
// Tests that resource subtype is correctly propagated when decomposing tf.ResourceGather.
// CHECK-LABEL: @decompose_resource_apply_proximal_adagrad_op
// CHECK-SAME: (%[[LR:.*]]: tensor<f32>, %[[L1:.*]]: tensor<f32>, %[[L2:.*]]: tensor<f32>, %[[GRAD:.*]]: tensor<4xf32>)
func @decompose_resource_apply_proximal_adagrad_op(%lr: tensor<f32>, %l1: tensor<f32>, %l2: tensor<f32>, %grad: tensor<4xf32>) -> () {
%var = "tf.VarHandleOp"() {container = "c", shared_name = "var"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
%accum = "tf.VarHandleOp"() {container = "c", shared_name = "accum"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
// CHECK-DAG: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "var"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
// CHECK-DAG: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "accum"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
// CHECK-DAG: %[[GRAD_SQRT:.*]] = "tf.Sqrt"(%[[GRAD]]) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
// CHECK-DAG: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQRT]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[RSQRT_ACCUM:.*]] = "tf.Rsqrt"(%[[ACCUM_NEW]]) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[ADAGRAD_LR:.*]] = "tf.Mul"(%[[LR]], %[[RSQRT_ACCUM]]) : (tensor<f32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[DELTA:.*]] = "tf.Mul"(%[[GRAD]], %[[ADAGRAD_LR]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
// CHECK-DAG: %[[PROX:.*]] = "tf.Sub"(%[[VAR]], %[[DELTA]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[SIGN:.*]] = "tf.Sign"(%[[PROX]]) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[ABS:.*]] = "tf.Abs"(%[[PROX]]) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[SCALED_L1:.*]] = "tf.Mul"(%[[ADAGRAD_LR]], %[[L1]]) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[PROX_NEW:.*]] = "tf.Sub"(%[[ABS]], %[[SCALED_L1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[MAX:.*]] = "tf.Maximum"(%[[PROX_NEW]], %[[ZERO]]) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[SIGNED:.*]] = "tf.Mul"(%[[SIGN]], %[[MAX]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[GT:.*]] = "tf.Greater"(%[[L1]], %[[ZERO]]) : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK-DAG: %[[NUMERATOR:.*]] = "tf.SelectV2"(%[[GT]], %[[SIGNED:.*]], %[[PROX]]) : (tensor<i1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[SCALED_L2:.*]] = "tf.Mul"(%[[ADAGRAD_LR]], %[[L2]]) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[DENOMINATOR:.*]] = "tf.Add"(%[[ONE]], %[[SCALED_L2]]) : (tensor<f32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: %[[VAR_NEW:.*]] = "tf.Div"(%[[NUMERATOR]], %[[DENOMINATOR]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
// CHECK-DAG: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
"tf.ResourceApplyProximalAdagrad"(%var, %accum, %lr, %l1, %l2, %grad) {use_locking = false} : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<4xf32>) -> ()
return
}

View File

@ -447,3 +447,30 @@ def DecomposeResourceApplyRMSProp :
(TF_AssignSubVariableOp $var_resource, $mom_new)
]
>;
def DecomposeResourceApplyProximalAdagrad :
Pattern<
(TF_ResourceApplyProximalAdagradOp:$src_op
$var_resource, $accum_resource, $lr, $l1, $l2, $grad,
ConstBoolAttrFalse:$use_locking
),
[(TF_ConstOp:$one (GetScalarOfType<1> $grad)),
(TF_ConstOp:$zero (GetScalarOfType<0> $grad)),
(TF_AddV2Op:$accum_new
(CreateTFReadVariableOp $src_op, $grad, $accum_resource),
(TF_SqrtOp $grad)),
(TF_MulOp:$adagrad_lr $lr, (TF_RsqrtOp $accum_new)),
(TF_SubOp:$prox_var
(CreateTFReadVariableOp $src_op, $grad, $var_resource),
(TF_MulOp $grad, $adagrad_lr)),
(TF_MulOp:$l1_gt_zero (TF_SignOp $prox_var),
(TF_MaximumOp
(TF_SubOp (TF_AbsOp $prox_var), (TF_MulOp $adagrad_lr, $l1)), $zero)),
(TF_SelectV2Op:$var_numerator (TF_GreaterOp $l1, $zero),
$l1_gt_zero, $prox_var),
(TF_DivOp:$var_new
$var_numerator, (TF_AddOp $one, (TF_MulOp $adagrad_lr, $l2))),
(TF_AssignVariableOp $var_resource, $var_new),
(TF_AssignVariableOp $accum_resource, $accum_new)
]
>;

View File

@ -62,11 +62,11 @@ xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr,
xla::XlaOp one = xla::ScalarLike(lr, 1.0);
xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
xla::XlaOp prox_var = var - grad * lr;
xla::XlaOp l1_gt_zero = xla::Sign(prox_var) *
xla::Max(xla::Abs(prox_var) - lr * l1, zero) /
(one + lr * l2);
xla::XlaOp l1_le_zero = prox_var / (one + lr * l2);
return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
xla::XlaOp l1_gt_zero =
xla::Sign(prox_var) * xla::Max(xla::Abs(prox_var) - lr * l1, zero);
xla::XlaOp l1_le_zero = prox_var;
return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero) /
(one + lr * l2);
}
class ResourceApplyProximalGradientDescent : public XlaOpKernel {