Add ResourceApplyAdagradV2 op and decomposition

PiperOrigin-RevId: 304511700
Change-Id: I26b4f10269416ae68d40a0bb8e0ea463d42d2765
This commit is contained in:
HyoukJoong Lee 2020-04-02 17:31:55 -07:00 committed by TensorFlower Gardener
parent 0bde6a68f2
commit 7d4df951af
3 changed files with 82 additions and 1 deletions

View File

@ -6006,6 +6006,30 @@ Resize `images` to `size` using nearest neighbor interpolation.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ResourceApplyAdagradV2Op : TF_Op<"ResourceApplyAdagradV2", []> {
let summary = "Update '*var' according to the adagrad scheme.";
let description = [{
accum += grad * grad
var -= lr * grad * (1 / (sqrt(accum) + epsilon))
}];
let arguments = (ins
TF_ResourceTensor:$var,
TF_ResourceTensor:$accum,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
DefaultValuedAttr<BoolAttr, "false">:$use_locking,
DefaultValuedAttr<BoolAttr, "true">:$update_slots
);
let results = (outs);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
}
def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> {
let summary = "Update '*var' according to the Adam algorithm.";

View File

@ -147,6 +147,37 @@ func @decompose_resource_apply_keras_momentum_nesterov(%arg0: tensor<f32>, %arg1
// -----
// Tests that composite tf.ResourceApplyAdagradV2 operation is decomposed.
// CHECK-LABEL: func @decompose_resource_apply_adagradv2
// CHECK-SAME: ([[LR:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>)
func @decompose_resource_apply_adagradv2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> () {
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"()
// CHECK: [[ACC_HANDLE:%.*]] = "tf.VarHandleOp"()
// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[NEW_ACC:%.*]] = "tf.AddV2"([[OLD_ACC]], [[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: [[LR_MULTIPLY:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: [[SQRT:%.*]] = "tf.Sqrt"([[NEW_ACC]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[DIVISOR:%.*]] = "tf.AddV2"([[SQRT]], [[EPSILON]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[LR_MULTIPLY]], [[DIVISOR]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"(%9, %8) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
"tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
// -----
// Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is
// decomposed.

View File

@ -22,7 +22,7 @@ class GetScalarOfType<int value> : NativeCodeCall<
"GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
// Creates a tf.ReadVariable op that reads a resource `$2` that has the same
// element type as `$1`. The op created will use location of `$1`.
// element type as `$1`. The op created will use location of `$0`.
def CreateTFReadVariableOp: NativeCodeCall<
"$_builder.create<TF::ReadVariableOp>("
" $0.getLoc(),"
@ -118,6 +118,32 @@ def DecomposeResourceApplyKerasMomentumOpNesterov :
]
>;
// Pattern to Decompose ResourceApplyAdagrad.
// This decomposition is only correct inside XLA as it ignores use_locking
// attribute.
// accum <- accum + grad * grad
// variable <- variable - lr * grad / (sqrt(accum) + epsilon)
def DecomposeResourceApplyAdagradV2 :
Pattern<
(TF_ResourceApplyAdagradV2Op:$src_op
$var_resource, $accum_resource, $lr, $epsilon, $grad, BoolAttr:$_,
ConstBoolAttrTrue:$update_slots),
[
(TF_AddV2Op:$new_accum
(CreateTFReadVariableOp $src_op, $grad, $accum_resource),
(TF_MulOp $grad, $grad)
),
(TF_AssignSubVariableOp
$var_resource,
(TF_DivOp
(TF_MulOp $lr, $grad),
(TF_AddV2Op (TF_SqrtOp $new_accum), $epsilon)
)
),
(TF_AssignVariableOp $accum_resource, $new_accum),
]
>;
// Pattern to Decompose ResourceApplyAdam without Nesterov momentum.
// This decomposition is only correct inside XLA as it ignores use_locking
// attribute.