Verify resource subtype is supplied for tf.VarHandleOp.

If no resource subtype is supplied, we can't derive the dtype attribute.
Verify that a resource subtype is supplied so we get a verification failure
instead of an assert failure if you try to access the dtype attribute for a
tf.VarHandleOp of type tensor<*x!tf.resource>.

PiperOrigin-RevId: 350824057
Change-Id: Ia340af12ab5d6f21b89a469c48aab606c41e2705
This commit is contained in:
A. Unique TensorFlower 2021-01-08 13:19:55 -08:00 committed by TensorFlower Gardener
parent 9fa88362d4
commit b03f57f650
10 changed files with 123 additions and 139 deletions

View File

@ -568,6 +568,26 @@ def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
"return (*getOperation()->result_type_begin()).cast<ShapedType>();",
[{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
// A derived attribute that returns the element type of the tensor held by a
// named resource-type operand or result.
class TF_DerivedOperandOrResultHandleTypeAttr<string name> : DerivedTypeAttr<
"auto resource_type =\n"
" mlir::getElementTypeOrSelf(this->" # name # "())\n"
" .cast<TF::ResourceType>();\n"
"assert(!resource_type.getSubtypes().empty() && \"unknown type\");\n"
"return mlir::getElementTypeOrSelf(*resource_type.getSubtypes().begin());">;
// A derived attribute that returns the shape of the tensor held by a named
// resource-type operand or result.
class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr<
"ShapedType",
"auto resource_type =\n"
" mlir::getElementTypeOrSelf(this->" # name # "())\n"
" .cast<TF::ResourceType>();\n"
"assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n"
"return resource_type.getSubtypes().begin()->cast<ShapedType>();",
[{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
let returnType = "Type";
}

View File

@ -865,35 +865,16 @@ Example:
Res<TF_ResourceTensor, "", [TF_VariableAlloc]>:$resource
);
let verifier = [{
// VarHandleOp requires the resource handle supply a single subtype from
// which to derive the dtype and shape attributes.
if (resource_type().getSubtypes().size() != 1) {
return emitOpError(
"must have exactly one subtype in the result resource type");
}
return success();
}];
DerivedTypeAttr dtype = DerivedTypeAttr<
"return getElementTypeOrSelf(resource_subtype());">;
DerivedAttr shape = DerivedAttr<
"ShapedType",
"return resource_subtype().cast<ShapedType>();",
[{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
TF_DerivedOperandOrResultHandleTypeAttr dtype =
TF_DerivedOperandOrResultHandleTypeAttr<"resource">;
TF_DerivedOperandOrResultHandleShapeAttr shape =
TF_DerivedOperandOrResultHandleShapeAttr<"resource">;
let extraClassDeclaration = [{
// TF_ResourceHandleAllocatorInterface:
ResourceHandleValueAndId GetResourceHandleValueAndId(
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
int64_t &next_id);
TensorType resource_subtype() { return resource_type().getSubtypes()[0]; }
ResourceType resource_type() {
return getElementTypeOrSelf(resource()).cast<TF::ResourceType>();
}
}];
}

View File

@ -28,7 +28,7 @@ func @decompose_use_subtype() {
// CHECK-LABEL: func @decompose_assign_add_variable_op
func @decompose_assign_add_variable_op() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<i32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
@ -36,7 +36,7 @@ func @decompose_assign_add_variable_op() -> () {
// CHECK: "tf.AssignVariableOp"
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
"tf.AssignAddVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignAddVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
return
}
@ -49,7 +49,7 @@ func @decompose_assign_add_variable_op() -> () {
// CHECK-LABEL: func @decompose_assign_sub_variable_op
func @decompose_assign_sub_variable_op() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<i32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
@ -57,7 +57,7 @@ func @decompose_assign_sub_variable_op() -> () {
// CHECK: "tf.AssignVariableOp"
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
"tf.AssignSubVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignSubVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
return
}
@ -70,7 +70,7 @@ func @decompose_assign_sub_variable_op() -> () {
// CHECK-SAME: (%[[DELTA:.*]]: tensor<f32>)
func @decompose_resource_apply_gradient_descent(%arg0: tensor<f32>) -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[ALPHA:[0-9]*]] = "tf.Const"
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
@ -80,7 +80,7 @@ func @decompose_resource_apply_gradient_descent(%arg0: tensor<f32>) -> () {
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[SUB]])
%1 = "tf.Const"() {T = f32, value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
"tf.ResourceApplyGradientDescent"(%0, %1, %arg0) {use_locking = false} : (tensor<!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>) -> ()
"tf.ResourceApplyGradientDescent"(%0, %1, %arg0) {use_locking = false} : (tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
return
}
@ -96,8 +96,8 @@ func @decompose_resource_apply_momentum_non_nesterov(%arg0: tensor<f32>, %arg1:
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[ACCUM_HANDLE:%.*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: [[ACCUM:%.*]] = "tf.ReadVariableOp"([[ACCUM_HANDLE]])
// CHECK: [[ACCUM_MOMENTUM:%.*]] = "tf.Mul"([[ACCUM]], [[MOMENTUM]])
@ -107,7 +107,7 @@ func @decompose_resource_apply_momentum_non_nesterov(%arg0: tensor<f32>, %arg1:
// CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]])
// CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[ACCUM_NEW_LR]])
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
"tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
"tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@ -122,8 +122,8 @@ func @decompose_resource_apply_momentum_nesterov(%arg0: tensor<f32>, %arg1: tens
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[ACCUM_HANDLE:%.*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: [[ACCUM:%.*]] = "tf.ReadVariableOp"([[ACCUM_HANDLE]])
// CHECK: [[ACCUM_MOMENTUM:%.*]] = "tf.Mul"([[ACCUM]], [[MOMENTUM]])
@ -136,7 +136,7 @@ func @decompose_resource_apply_momentum_nesterov(%arg0: tensor<f32>, %arg1: tens
// CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]])
// CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[DELTA]])
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
"tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
"tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@ -151,10 +151,10 @@ func @decompose_resource_apply_keras_momentum_non_nesterov(%arg0: tensor<f32>, %
// CHECK: %[[VAR_HANDLE:[0-9]*]] = "tf.VarHandleOp"
// CHECK: %[[ACCUM_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: %[[ACCUM_MOMENTUM:[0-9]*]] = "tf.Mul"(%[[ACCUM]], %[[MOMENTUM]])
// CHECK: %[[GRAD_LR:[0-9]*]] = "tf.Mul"(%[[GRAD]], %[[LR]])
// CHECK: %[[NEW_ACCUM:[0-9]*]] = "tf.Sub"(%[[ACCUM_MOMENTUM]], %[[GRAD_LR]])
@ -164,7 +164,7 @@ func @decompose_resource_apply_keras_momentum_non_nesterov(%arg0: tensor<f32>, %
// CHECK: %[[NEW_VAR:[0-9]*]] = "tf.AddV2"(%[[VAR]], %[[NEW_ACCUM]])
// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[NEW_VAR]])
"tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
"tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@ -180,10 +180,10 @@ func @decompose_resource_apply_keras_momentum_nesterov(%arg0: tensor<f32>, %arg1
// CHECK: %[[VAR_HANDLE:[0-9]*]] = "tf.VarHandleOp"
// CHECK: %[[ACCUM_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: %[[ACCUM_MOMENTUM:[0-9]*]] = "tf.Mul"(%[[ACCUM]], %[[MOMENTUM]])
// CHECK: %[[GRAD_LR:[0-9]*]] = "tf.Mul"(%[[GRAD]], %[[LR]])
// CHECK: %[[NEW_ACCUM:[0-9]*]] = "tf.Sub"(%[[ACCUM_MOMENTUM]], %[[GRAD_LR]])
@ -195,7 +195,7 @@ func @decompose_resource_apply_keras_momentum_nesterov(%arg0: tensor<f32>, %arg1
// CHECK: %[[NEW_VAR:[0-9]*]] = "tf.AddV2"(%[[VAR]], %[[NEW_DELTA]])
// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[NEW_VAR]])
"tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
"tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@ -212,21 +212,21 @@ func @decompose_resource_apply_adagradv2(%arg0: tensor<f32>, %arg1: tensor<f32>,
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"()
// CHECK: [[ACC_HANDLE:%.*]] = "tf.VarHandleOp"()
// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[NEW_ACC:%.*]] = "tf.AddV2"([[OLD_ACC]], [[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: [[LR_MULTIPLY:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: [[SQRT:%.*]] = "tf.Sqrt"([[NEW_ACC]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[DIVISOR:%.*]] = "tf.AddV2"([[SQRT]], [[EPSILON]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[LR_MULTIPLY]], [[DIVISOR]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"(%9, %8) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
"tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
"tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@ -236,22 +236,22 @@ func @decompose_resource_apply_adagradv2(%arg0: tensor<f32>, %arg1: tensor<f32>,
// CHECK-SAME: (%[[LR:.*]]: tensor<f32>, %[[GRAD:.*]]: tensor<f32>)
func @decompose_resource_apply_adagrad(%arg0: tensor<f32>, %arg1: tensor<f32>) -> () {
// CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
// CHECK: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
// CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[GRAD_SQUARE:.*]] = "tf.Mul"(%[[GRAD]], %[[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: %[[LR_MULTIPLY:.*]] = "tf.Mul"(%[[LR]], %[[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[SQRT:.*]] = "tf.Sqrt"(%[[ACCUM_NEW]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: %[[DIV:.*]] = "tf.Div"(%[[LR_MULTIPLY]], %[[SQRT]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: %[[VAR_NEW:.*]] = "tf.Sub"(%[[VAR]], %[[DIV]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
"tf.ResourceApplyAdagrad"(%0, %1, %arg0, %arg1) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>) -> ()
"tf.ResourceApplyAdagrad"(%0, %1, %arg0, %arg1) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
return
}
@ -274,12 +274,12 @@ func @decompose_resource_apply_adam_non_nesterov(%arg0: tensor<f32>, %arg1: tens
// CHECK: [[ONE_MINUS_BETA1_POWER:%.*]] = "tf.Sub"([[ONE]], [[BETA1_POWER]])
// CHECK: [[ALPHA_NO_LR:%.*]] = "tf.Div"([[SQRT_ONE_MINUS_BETA2_POWER]], [[ONE_MINUS_BETA1_POWER]])
// CHECK: [[ALPHA:%.*]] = "tf.Mul"([[LR]], [[ALPHA_NO_LR]])
// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[BETA1_OLD_M:%.*]] = "tf.Mul"([[BETA1]], [[OLD_M]])
// CHECK: [[ONE_MINUS_BETA1:%.*]] = "tf.Sub"([[ONE]], [[BETA1]])
// CHECK: [[ONE_MINUS_BETA1_GRAD:%.*]] = "tf.Mul"([[ONE_MINUS_BETA1]], [[GRAD]])
// CHECK: [[NEW_M:%.*]] = "tf.AddV2"([[BETA1_OLD_M]], [[ONE_MINUS_BETA1_GRAD]])
// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[BETA2_OLD_V:%.*]] = "tf.Mul"([[BETA2]], [[OLD_V]])
// CHECK: [[ONE_MINUS_BETA2:%.*]] = "tf.Sub"([[ONE]], [[BETA2]])
// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Square"([[GRAD]])
@ -289,17 +289,17 @@ func @decompose_resource_apply_adam_non_nesterov(%arg0: tensor<f32>, %arg1: tens
// CHECK: [[SQRT_NEW_V:%.*]] = "tf.Sqrt"([[NEW_V]])
// CHECK: [[SQRT_NEW_V_EPSILON:%.*]] = "tf.AddV2"([[SQRT_NEW_V]], [[EPSILON]])
// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[ALPHA_NEW_M]], [[SQRT_NEW_V_EPSILON]])
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"([[OLD_VAR]], [[VAR_DELTA]])
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]])
// CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]])
// CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]])
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
"tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
"tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@ -322,12 +322,12 @@ func @decompose_resource_apply_adam_nesterov(%arg0: tensor<f32>, %arg1: tensor<f
// CHECK: [[VAL_84:%.*]] = "tf.Sub"([[ONE]], [[BETA1_POWER]])
// CHECK: [[VAL_85:%.*]] = "tf.Div"([[VAL_83]], [[VAL_84]])
// CHECK: [[VAL_86:%.*]] = "tf.Mul"([[LR]], [[VAL_85]])
// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[VAL_88:%.*]] = "tf.Mul"([[BETA1]], [[OLD_M]])
// CHECK: [[VAL_89:%.*]] = "tf.Sub"([[ONE]], [[BETA1]])
// CHECK: [[VAL_90:%.*]] = "tf.Mul"([[VAL_89]], [[GRAD]])
// CHECK: [[NEW_M:%.*]] = "tf.AddV2"([[VAL_88]], [[VAL_90]])
// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[VAL_93:%.*]] = "tf.Mul"([[BETA2]], [[OLD_V]])
// CHECK: [[VAL_94:%.*]] = "tf.Sub"([[ONE]], [[BETA2]])
// CHECK: [[VAL_95:%.*]] = "tf.Square"([[GRAD]])
@ -341,17 +341,17 @@ func @decompose_resource_apply_adam_nesterov(%arg0: tensor<f32>, %arg1: tensor<f
// CHECK: [[VAL_103:%.*]] = "tf.Sqrt"([[NEW_V]])
// CHECK: [[VAL_104:%.*]] = "tf.AddV2"([[VAL_103]], [[EPSILON]])
// CHECK: [[VAL_105:%.*]] = "tf.Div"([[VAL_102]], [[VAL_104]])
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"([[OLD_VAR]], [[VAL_105]])
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
// CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
"tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
"tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@ -366,12 +366,12 @@ func @decompose_resource_gather_op(%indices : tensor<?xi32>) -> tensor<*xi32> {
// CHECK: [[ZERO:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
// CHECK: [[VAR:%.+]] = "tf.VarHandleOp"
%resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: [[READVAR:%.+]] = "tf.ReadVariableOp"([[VAR]])
// CHECK: [[GATHER:%.+]] = "tf.GatherV2"([[READVAR]], [[INDEX]], [[ZERO]]) {batch_dims = 0 : i64} : (tensor<*xi32>, tensor<?xi32>, tensor<i64>) -> tensor<*xi32>
// CHECK: return [[GATHER]]
%0 = "tf.ResourceGather"(%resource, %indices) : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<?xi32>) -> (tensor<*xi32>)
%0 = "tf.ResourceGather"(%resource, %indices) : (tensor<*x!tf.resource>, tensor<?xi32>) -> (tensor<*xi32>)
return %0: tensor<*xi32>
}
@ -403,10 +403,10 @@ func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor<f32>, %arg1: tens
// CHECK: [[MG_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[MS_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[MOM_HANDLE:%.*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: [[GRADSQ:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]])
// CHECK: [[SB:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
@ -438,7 +438,7 @@ func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor<f32>, %arg1: tens
// CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[MOM_NEW]])
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
"tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
"tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
// -----
@ -477,12 +477,12 @@ func @decompose_resource_apply_RMS_prop(%arg0: tensor<*x!tf.resource>, %arg1: te
// CHECK-SAME: ([[INDEX:%.+]]: tensor<2x?xi32>, [[UPDATE:%.+]]: tensor<?x?x?xi32>)
func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates: tensor<?x?x?xi32>) {
// CHECK: [[VAR:%.+]] = "tf.VarHandleOp"
%resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: [[READ:%.+]] = "tf.ReadVariableOp"([[VAR]])
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterUpdate"([[READ]], [[INDEX]], [[UPDATE]]) : (tensor<*xi32>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
// CHECK: "tf.AssignVariableOp"([[VAR]], [[TENSOR]])
"tf.ResourceScatterUpdate"(%resource, %indices, %updates) : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> ()
"tf.ResourceScatterUpdate"(%resource, %indices, %updates) : (tensor<*x!tf.resource>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> ()
return
}

View File

@ -340,7 +340,7 @@ func @main(%arg0: tensor<!tf.resource>) {
// Tests main function with invalid VarHandleOp resource subtype.
func @main() {
// expected-error @+1 {{must have exactly one subtype in the result resource type}}
// expected-error@+1 {{expects resource type to have one subtype, got '!tf.resource'}}
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
return
}

View File

@ -12,18 +12,18 @@ func @main() {
// -----
// CHECK-LABEL: func @no_args
// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
// CHECK-SAME: (%arg0: tensor<!tf.resource> {tf.resource_name = "x"})
// CHECK-NOT: "tf.VarHandleOp"
func @no_args() {
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
return
}
// CHECK-LABEL: func @some_args
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource> {tf.resource_name = "x"})
// CHECK-NOT: "tf.VarHandleOp"
func @some_args(%arg0: tensor<i1>) {
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
return
}

View File

@ -6,7 +6,7 @@
func @only_resource_load() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
// CHECK: "tf_device.cluster"
@ -16,7 +16,7 @@ func @only_resource_load() -> tensor<*xi32> {
// CHECK-SAME: () -> tensor<*xi32>
%1 = "tf_device.cluster"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
@ -32,7 +32,7 @@ func @only_resource_load() -> tensor<*xi32> {
func @only_resource_store() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"()
@ -43,7 +43,7 @@ func @only_resource_store() -> tensor<*xi32> {
%1 = "tf_device.cluster"() ( {
%2 = "tf.SomeComputation"() : () -> (tensor<*xi32>)
"tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
"tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
tf_device.return %2 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
@ -59,7 +59,7 @@ func @only_resource_store() -> tensor<*xi32> {
func @same_resource_load_and_store() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
// CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
@ -70,9 +70,9 @@ func @same_resource_load_and_store() -> tensor<*xi32> {
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1)
%1 = "tf_device.cluster"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
@ -89,7 +89,7 @@ func @same_resource_load_and_store() -> tensor<*xi32> {
func @same_resource_load_and_store_cast() -> tensor<1xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
// CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
@ -101,10 +101,10 @@ func @same_resource_load_and_store_cast() -> tensor<1xi32> {
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1)
%1 = "tf_device.cluster"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<1xi32>
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<1xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<1xi32>) -> (tensor<*xi32>)
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
%4 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<1xi32>
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
%4 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<1xi32>
tf_device.return %4 : tensor<1xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<1xi32>
@ -123,16 +123,16 @@ func @internal_resource() -> tensor<*xi32> {
%0 = "tf_device.cluster"() ( {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
%2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
%2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[COMPUTE_RES]])
"tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
"tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
// CHECK: tf_device.return %[[COMPUTE_RES]]
tf_device.return %3 : tensor<*xi32>
@ -1006,10 +1006,10 @@ func @test_unsupported_resource_op() -> tensor<*xi32> {
// CHECK: tf_device.return
// CHECK: {cluster_attr = "cluster_attr"}
// CHECK: return
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf_device.cluster"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
"tf.SomeResourceOperation"(%0) : (tensor<*x!tf.resource<tensor<*xi32>>>) -> ()
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
"tf.SomeResourceOperation"(%0) : (tensor<*x!tf.resource>) -> ()
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
@ -1032,12 +1032,12 @@ func @test_unsupported_resource_op_in_if(%arg0: tensor<i1>) -> tensor<*xi32> {
// CHECK-SAME: else_branch = @else_fn, is_stateless = true, then_branch = @then_fn
// CHECK: tf_device.return
// CHECK: return
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%1 = "tf.VarHandleOp"() {container = "d", shared_name = "w"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "d", shared_name = "w"} : () -> tensor<*x!tf.resource>
%2 = "tf_device.cluster"() ( {
%3 = "tf.If"(%arg0, %0, %1)
{ else_branch = @else_fn, then_branch = @then_fn, is_stateless = true}
: (tensor<i1>, tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
: (tensor<i1>, tensor<*x!tf.resource>, tensor<*x!tf.resource>) -> tensor<*xi32>
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
return %2 : tensor<*xi32>

View File

@ -4135,27 +4135,3 @@ func @testInvalidTPUExecuteAndUpdateVariables(%arg0: tensor<!tf.resource<tensor<
"tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1) {device_var_reads_indices = [0], device_var_updates_indices = [-2]} : (tensor<!tf.resource<tensor<i32>>>, tensor<3x!tf.string>) -> ()
return
}
// -----
// Valid VarHandleOp operation.
// CHECK-LABEL: func @testVarHandleOp
func @testVarHandleOp() -> tensor<!tf.resource<tensor<*xf32>>> {
%0 = "tf.VarHandleOp"() {
container = "",
shared_name = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
} : () -> tensor<!tf.resource<tensor<*xf32>>>
return %0 : tensor<!tf.resource<tensor<*xf32>>>
}
// -----
// VarHandleOp operation missing the required resource subtype.
func @testVarHandleOp() -> tensor<*x!tf.resource> {
// expected-error @+1 {{must have exactly one subtype in the result resource type}}
%0 = "tf.VarHandleOp"() {
container = "",
shared_name = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
} : () -> tensor<*x!tf.resource>
return %0 : tensor<*x!tf.resource>
}

View File

@ -184,11 +184,11 @@ func @var_handle_on_tpu_iter_on_cpu() -> tensor<i32> {
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
%var = "tf.VarHandleOp"() {container = "c", shared_name = "v", device = "/device:TPU:0"} : () -> tensor<!tf.resource<tensor<3x3x1x32xf32>>>
%var = "tf.VarHandleOp"() {container = "c", shared_name = "v", device = "/device:TPU:0"} : () -> tensor<*x!tf.resource>
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2:2 = "tf.IteratorGetNext"(%var) {device = "/device:CPU:0"}
: (tensor<!tf.resource<tensor<3x3x1x32xf32>>>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return

View File

@ -103,9 +103,16 @@ llvm::SmallSet<llvm::StringRef, 1> GetCompositeResourceUserNames(
return composite_users;
}
// Checks that the only users of `tf.VarHandleOp` are
// `tf.ReadVariableOp` and `tf.AssignVariableOp`.
// Checks if `tf.VarHandleOp` has a valid resource subtype and its users are of
// `tf.ReadVariableOp` and `tf.AssignVariableOp` only.
mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) {
auto resource_type =
getElementTypeOrSelf(var_handle_op.getType()).cast<TF::ResourceType>();
if (resource_type.getSubtypes().size() != 1)
return var_handle_op.emitOpError()
<< "expects resource type to have one subtype, got "
<< resource_type;
auto composite_ops = GetCompositeResourceUserNames(var_handle_op);
if (!composite_ops.empty())
return var_handle_op.emitOpError()

View File

@ -45,7 +45,7 @@ Status CompositeOpExpansion::Run(EagerOperation* orig_op,
// isn't a composite op. The following ops are explicitly skipped here because
// their "no-op" expansion is known to cause problems in some cases.
static const char* kOpsToSkip[] = {"IdentityOp", "NoOp", "OptionalHasValue",
"OptionalGetValue"};
"OptionalGetValue", "VarHandleOp"};
for (const char* skip : kOpsToSkip) {
if (absl::StartsWith(orig_op->op_name(), skip)) return Status::OK();
}