Add ResourceApplyAdagradV2 op and decomposition
PiperOrigin-RevId: 304511700 Change-Id: I26b4f10269416ae68d40a0bb8e0ea463d42d2765
This commit is contained in:
parent
0bde6a68f2
commit
7d4df951af
@ -6006,6 +6006,30 @@ Resize `images` to `size` using nearest neighbor interpolation.
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_ResourceApplyAdagradV2Op : TF_Op<"ResourceApplyAdagradV2", []> {
|
||||||
|
let summary = "Update '*var' according to the adagrad scheme.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
accum += grad * grad
|
||||||
|
var -= lr * grad * (1 / (sqrt(accum) + epsilon))
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_ResourceTensor:$var,
|
||||||
|
TF_ResourceTensor:$accum,
|
||||||
|
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
|
||||||
|
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
|
||||||
|
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
|
||||||
|
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$use_locking,
|
||||||
|
DefaultValuedAttr<BoolAttr, "true">:$update_slots
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> {
|
def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> {
|
||||||
let summary = "Update '*var' according to the Adam algorithm.";
|
let summary = "Update '*var' according to the Adam algorithm.";
|
||||||
|
|
||||||
|
@ -147,6 +147,37 @@ func @decompose_resource_apply_keras_momentum_nesterov(%arg0: tensor<f32>, %arg1
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
||||||
|
// Tests that composite tf.ResourceApplyAdagradV2 operation is decomposed.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @decompose_resource_apply_adagradv2
|
||||||
|
// CHECK-SAME: ([[LR:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>)
|
||||||
|
func @decompose_resource_apply_adagradv2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> () {
|
||||||
|
|
||||||
|
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"()
|
||||||
|
// CHECK: [[ACC_HANDLE:%.*]] = "tf.VarHandleOp"()
|
||||||
|
// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[NEW_ACC:%.*]] = "tf.AddV2"([[OLD_ACC]], [[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[LR_MULTIPLY:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: [[SQRT:%.*]] = "tf.Sqrt"([[NEW_ACC]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[DIVISOR:%.*]] = "tf.AddV2"([[SQRT]], [[EPSILON]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[LR_MULTIPLY]], [[DIVISOR]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"(%9, %8) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
|
||||||
|
// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||||
|
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||||
|
|
||||||
|
"tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is
|
// Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is
|
||||||
// decomposed.
|
// decomposed.
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class GetScalarOfType<int value> : NativeCodeCall<
|
|||||||
"GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
"GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||||
|
|
||||||
// Creates a tf.ReadVariable op that reads a resource `$2` that has the same
|
// Creates a tf.ReadVariable op that reads a resource `$2` that has the same
|
||||||
// element type as `$1`. The op created will use location of `$1`.
|
// element type as `$1`. The op created will use location of `$0`.
|
||||||
def CreateTFReadVariableOp: NativeCodeCall<
|
def CreateTFReadVariableOp: NativeCodeCall<
|
||||||
"$_builder.create<TF::ReadVariableOp>("
|
"$_builder.create<TF::ReadVariableOp>("
|
||||||
" $0.getLoc(),"
|
" $0.getLoc(),"
|
||||||
@ -118,6 +118,32 @@ def DecomposeResourceApplyKerasMomentumOpNesterov :
|
|||||||
]
|
]
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
// Pattern to Decompose ResourceApplyAdagrad.
|
||||||
|
// This decomposition is only correct inside XLA as it ignores use_locking
|
||||||
|
// attribute.
|
||||||
|
// accum <- accum + grad * grad
|
||||||
|
// variable <- variable - lr * grad / (sqrt(accum) + epsilon)
|
||||||
|
def DecomposeResourceApplyAdagradV2 :
|
||||||
|
Pattern<
|
||||||
|
(TF_ResourceApplyAdagradV2Op:$src_op
|
||||||
|
$var_resource, $accum_resource, $lr, $epsilon, $grad, BoolAttr:$_,
|
||||||
|
ConstBoolAttrTrue:$update_slots),
|
||||||
|
[
|
||||||
|
(TF_AddV2Op:$new_accum
|
||||||
|
(CreateTFReadVariableOp $src_op, $grad, $accum_resource),
|
||||||
|
(TF_MulOp $grad, $grad)
|
||||||
|
),
|
||||||
|
(TF_AssignSubVariableOp
|
||||||
|
$var_resource,
|
||||||
|
(TF_DivOp
|
||||||
|
(TF_MulOp $lr, $grad),
|
||||||
|
(TF_AddV2Op (TF_SqrtOp $new_accum), $epsilon)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(TF_AssignVariableOp $accum_resource, $new_accum),
|
||||||
|
]
|
||||||
|
>;
|
||||||
|
|
||||||
// Pattern to Decompose ResourceApplyAdam without Nesterov momentum.
|
// Pattern to Decompose ResourceApplyAdam without Nesterov momentum.
|
||||||
// This decomposition is only correct inside XLA as it ignores use_locking
|
// This decomposition is only correct inside XLA as it ignores use_locking
|
||||||
// attribute.
|
// attribute.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user