Add ResourceApplyAdagradV2 op and decomposition
PiperOrigin-RevId: 304511700 Change-Id: I26b4f10269416ae68d40a0bb8e0ea463d42d2765
This commit is contained in:
parent
0bde6a68f2
commit
7d4df951af
@ -6006,6 +6006,30 @@ Resize `images` to `size` using nearest neighbor interpolation.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyAdagradV2Op : TF_Op<"ResourceApplyAdagradV2", []> {
|
||||
let summary = "Update '*var' according to the adagrad scheme.";
|
||||
|
||||
let description = [{
|
||||
accum += grad * grad
|
||||
var -= lr * grad * (1 / (sqrt(accum) + epsilon))
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$var,
|
||||
TF_ResourceTensor:$accum,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_locking,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$update_slots
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> {
|
||||
let summary = "Update '*var' according to the Adam algorithm.";
|
||||
|
||||
|
@ -147,6 +147,37 @@ func @decompose_resource_apply_keras_momentum_nesterov(%arg0: tensor<f32>, %arg1
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Tests that composite tf.ResourceApplyAdagradV2 operation is decomposed.
|
||||
|
||||
// CHECK-LABEL: func @decompose_resource_apply_adagradv2
|
||||
// CHECK-SAME: ([[LR:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>)
|
||||
func @decompose_resource_apply_adagradv2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> () {
|
||||
|
||||
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"()
|
||||
// CHECK: [[ACC_HANDLE:%.*]] = "tf.VarHandleOp"()
|
||||
// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: [[NEW_ACC:%.*]] = "tf.AddV2"([[OLD_ACC]], [[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: [[LR_MULTIPLY:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: [[SQRT:%.*]] = "tf.Sqrt"([[NEW_ACC]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: [[DIVISOR:%.*]] = "tf.AddV2"([[SQRT]], [[EPSILON]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[LR_MULTIPLY]], [[DIVISOR]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"(%9, %8) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
|
||||
// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
|
||||
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
"tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is
|
||||
// decomposed.
|
||||
|
||||
|
@ -22,7 +22,7 @@ class GetScalarOfType<int value> : NativeCodeCall<
|
||||
"GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||
|
||||
// Creates a tf.ReadVariable op that reads a resource `$2` that has the same
|
||||
// element type as `$1`. The op created will use location of `$1`.
|
||||
// element type as `$1`. The op created will use location of `$0`.
|
||||
def CreateTFReadVariableOp: NativeCodeCall<
|
||||
"$_builder.create<TF::ReadVariableOp>("
|
||||
" $0.getLoc(),"
|
||||
@ -118,6 +118,32 @@ def DecomposeResourceApplyKerasMomentumOpNesterov :
|
||||
]
|
||||
>;
|
||||
|
||||
// Pattern to Decompose ResourceApplyAdagrad.
|
||||
// This decomposition is only correct inside XLA as it ignores use_locking
|
||||
// attribute.
|
||||
// accum <- accum + grad * grad
|
||||
// variable <- variable - lr * grad / (sqrt(accum) + epsilon)
|
||||
def DecomposeResourceApplyAdagradV2 :
|
||||
Pattern<
|
||||
(TF_ResourceApplyAdagradV2Op:$src_op
|
||||
$var_resource, $accum_resource, $lr, $epsilon, $grad, BoolAttr:$_,
|
||||
ConstBoolAttrTrue:$update_slots),
|
||||
[
|
||||
(TF_AddV2Op:$new_accum
|
||||
(CreateTFReadVariableOp $src_op, $grad, $accum_resource),
|
||||
(TF_MulOp $grad, $grad)
|
||||
),
|
||||
(TF_AssignSubVariableOp
|
||||
$var_resource,
|
||||
(TF_DivOp
|
||||
(TF_MulOp $lr, $grad),
|
||||
(TF_AddV2Op (TF_SqrtOp $new_accum), $epsilon)
|
||||
)
|
||||
),
|
||||
(TF_AssignVariableOp $accum_resource, $new_accum),
|
||||
]
|
||||
>;
|
||||
|
||||
// Pattern to Decompose ResourceApplyAdam without Nesterov momentum.
|
||||
// This decomposition is only correct inside XLA as it ignores use_locking
|
||||
// attribute.
|
||||
|
Loading…
Reference in New Issue
Block a user