Run canonicalization again after resource op lifting to remove dead local variables.
PiperOrigin-RevId: 335725813 Change-Id: Ic1d60a695e560d47f3ddfaddd47190133ef37d4a
This commit is contained in:
parent
9fd6313f37
commit
11b687edfe
@ -285,6 +285,26 @@ func @main(%arg0: tensor<!tf.resource<tensor<f32>>>, %arg1: tensor<!tf.resource<
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Tests if local variables that are dead after resource op lifting are removed.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @main
|
||||||
|
func @main(%arg0: tensor<i32>) -> tensor<2xf32> {
|
||||||
|
// CHECK-NOT: tf.MlirLocalVarOp
|
||||||
|
// CHECK-NOT: tf.AssignVariableOp
|
||||||
|
%0 = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<2xf32>>>
|
||||||
|
%1 = "tf._SomeOp"() : () -> tensor<2xf32>
|
||||||
|
"tf.AssignVariableOp"(%0, %1) : (tensor<!tf.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
|
||||||
|
%2 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32>
|
||||||
|
return %2 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
func @callee(%arg0: tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32> attributes {sym_visibility = "private"} {
|
||||||
|
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32>
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Tests main function with multiple blocks.
|
// Tests main function with multiple blocks.
|
||||||
|
@ -1340,9 +1340,15 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
|
|||||||
|
|
||||||
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
|
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
|
||||||
lifted_partitioned_call_callees;
|
lifted_partitioned_call_callees;
|
||||||
return HoistForControlFlow(&function.front(),
|
if (failed(HoistForControlFlow(&function.front(),
|
||||||
cast<ModuleOp>(function.getParentOp()),
|
cast<ModuleOp>(function.getParentOp()),
|
||||||
&lifted_partitioned_call_callees);
|
&lifted_partitioned_call_callees)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Clean up and canonicalize to remove dead local variables as some local
|
||||||
|
// variables might be dead after hoisting resource loads/stores from control
|
||||||
|
// flow ops.
|
||||||
|
return TF::CleanupAndCanonicalizeForResourceOpLifting(function);
|
||||||
}
|
}
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user