diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 43cf8486b60..213ca402f56 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -112,26 +112,6 @@ func @internal_resource() -> tensor<*xi32> { // ----- -// Tests that pass fails when there are remaining resource operationss that can -// not be lifted. - -func @lifting_failure() -> tensor<*xi32> { - - %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - - // expected-error @+1 {{has remaining resource inputs that can not be lifted}} - %1 = "tf_device.cluster"() ( { - %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> - %3 = "tf.SomeResourceOp"(%0, %2) : (tensor<*x!tf.resource>, tensor<*xi32>) -> tensor<*xi32> - "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () - tf_device.return %3 : tensor<*xi32> - }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> - - return %1 : tensor<*xi32> -} - -// ----- - // Tests that pass lifts resource reads/writes from a loop, and removed unused // resources. @@ -347,30 +327,6 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // ----- -// Tests that pass reports error on unsupported ops in loop body. - -func @cluster_with_loop() -> () { - %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> - "tf_device.cluster"() ( { - %1 = "tf.While"(%0) { - body = @while_body, cond = @while_cond, device = "", is_stateless = false} - : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) - tf_device.return - }) {cluster_attr = "cluster_attr"} : () -> () - return -} -func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { - // expected-error @+1 {{found unsupported operations on resource.}} - "tf._UnknownOp"(%arg0) : (tensor<*x!tf.resource>>) -> () - return %arg0 : tensor<*x!tf.resource>> -} -func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { - %read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor - return %read : tensor -} - -// ----- - // Tests that pass reports error on unsupported ops in loop cond. func @cluster_with_loop() -> () { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 702455d156d..b5d4d94b7dc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -330,15 +330,6 @@ LogicalResult HoistResourceOpsFromCluster(tf_device::ClusterOp cluster, getUsedValuesDefinedAbove(new_cluster.body(), new_cluster.body(), captured_values); - for (Value v : captured_values) { - auto tensor_type = v.getType().dyn_cast(); - if (!tensor_type) continue; - if (!tensor_type.getElementType().isa()) continue; - - return new_cluster.emitOpError() - << "has remaining resource inputs that can not be lifted"; - } - return success(); } @@ -361,29 +352,23 @@ LogicalResult FindResourceArgUseInfo( ResourceArgUseInfo info; info.used = false; info.updated = false; - bool do_not_touch = false; + bool read_or_assigned = false; for (auto user : arg.getUsers()) { if (user == return_op) continue; + info.used = true; if (auto read = llvm::dyn_cast(user)) { - info.used = true; + read_or_assigned = true; info.data_type = read.getType(); continue; } if (auto assign = llvm::dyn_cast(user)) { - info.used = true; + read_or_assigned = true; info.updated = true; info.data_type = assign.value().getType(); continue; } - if (isa(user)) { - // Stacks will be handled by a separate pass. - do_not_touch = true; - break; - } - user->emitOpError("found unsupported operations on resource."); - return failure(); } - if (!do_not_touch) (*result)[arg.getArgNumber()] = info; + if (!info.used || read_or_assigned) (*result)[arg.getArgNumber()] = info; } return success(); }