diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir index 3302ec560ed..0813ee8db90 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir @@ -285,6 +285,26 @@ func @main(%arg0: tensor>>, %arg1: tensor) -> tensor<2xf32> { + // CHECK-NOT: tf.MlirLocalVarOp + // CHECK-NOT: tf.AssignVariableOp + %0 = "tf.MlirLocalVarOp"() : () -> tensor>> + %1 = "tf._SomeOp"() : () -> tensor<2xf32> + "tf.AssignVariableOp"(%0, %1) : (tensor>>, tensor<2xf32>) -> () + %2 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor>>) -> tensor<2xf32> + return %2 : tensor<2xf32> +} +func @callee(%arg0: tensor>>) -> tensor<2xf32> attributes {sym_visibility = "private"} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + + // ----- // Tests main function with multiple blocks. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 1562e9bb0a5..3cd316cf92d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -1340,9 +1340,15 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) { llvm::SmallDenseMap lifted_partitioned_call_callees; - return HoistForControlFlow(&function.front(), - cast(function.getParentOp()), - &lifted_partitioned_call_callees); + if (failed(HoistForControlFlow(&function.front(), + cast(function.getParentOp()), + &lifted_partitioned_call_callees))) + return failure(); + + // Clean up and canonicalize to remove dead local variables as some local + // variables might be dead after hoisting resource loads/stores from control + // flow ops. + return TF::CleanupAndCanonicalizeForResourceOpLifting(function); } } // namespace TF