Add support of multiple ReadVariable ops after casting in TF canonicalize pass
For example, %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource> %1 = "tf.ReadVariableOp"(%1) {device = ""} : (tensor<*x!tf.resource>) -> tensor<f32> %1 = "tf.ReadVariableOp"(%1) {device = ""} : (tensor<*x!tf.resource>) -> tensor<f32> PiperOrigin-RevId: 305146713 Change-Id: I2e5db7778a69da6a103e0ec35ebe9ba518ba424c
This commit is contained in:
parent
8424ef8160
commit
47816940ea
@ -199,7 +199,6 @@ func @testAddV2OfNegLeft(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> te
|
||||
%0 = "tf.Neg"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
%1 = "tf.AddV2"(%0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %1: tensor<8x16xf32>
|
||||
|
||||
// CHECK: %0 = "tf.Sub"(%arg1, %arg0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
// CHECK: return %0
|
||||
}
|
||||
@ -419,8 +418,8 @@ func @ToBool_0DScalar(%arg0: tensor<i1>) -> tensor<i1> {
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testReadVariableOfOfCast
|
||||
func @testReadVariableOfOfCast(%arg0: tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<8x40xf32> {
|
||||
// CHECK-LABEL: testReadVariableOpOfCast
|
||||
func @testReadVariableOpOfCast(%arg0: tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<8x40xf32> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<*x!tf.resource>
|
||||
%1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor<8x40xf32>
|
||||
return %1: tensor<8x40xf32>
|
||||
@ -429,8 +428,8 @@ func @testReadVariableOfOfCast(%arg0: tensor<!tf.resource<tensor<8x40xf32>>>) ->
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testReadVariableOfOfCastWithTruncate
|
||||
func @testReadVariableOfOfCastWithTruncate(%arg0: tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<8x40xf32> {
|
||||
// CHECK-LABEL: testReadVariableOpOfCastWithTruncate
|
||||
func @testReadVariableOpOfCastWithTruncate(%arg0: tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<8x40xf32> {
|
||||
%0 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<*x!tf.resource>
|
||||
%1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor<8x40xf32>
|
||||
return %1: tensor<8x40xf32>
|
||||
@ -439,8 +438,8 @@ func @testReadVariableOfOfCastWithTruncate(%arg0: tensor<!tf.resource<tensor<8x4
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testReadVariableOfOfCastMultiUse
|
||||
func @testReadVariableOfOfCastMultiUse(%arg0: tensor<!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
// CHECK-LABEL: testReadVariableOpOfCastMultiUse
|
||||
func @testReadVariableOpOfCastMultiUse(%arg0: tensor<!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
%0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource>
|
||||
%1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%0, %1) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
@ -452,3 +451,14 @@ func @testReadVariableOfOfCastMultiUse(%arg0: tensor<!tf.resource<tensor<f32>>>)
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMultiReadVariableOpsOfCast
|
||||
func @testMultiReadVariableOpsOfCast(%arg0: tensor<!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
%0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource>
|
||||
%1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor<f32>
|
||||
%2 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor<f32>
|
||||
return %2: tensor<f32>
|
||||
|
||||
// CHECK: %0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
// CHECK: %1 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
@ -27,9 +27,10 @@ def SingleResultAndOperandHaveSameType : Constraint<
|
||||
|
||||
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;
|
||||
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
// Checks if all the users is ReadVariableOp.
|
||||
def HasOnlyReadVariableOpUsers : Constraint<
|
||||
CPred<"llvm::all_of($0.getUsers(), [](mlir::OpOperand op) { "
|
||||
"return llvm::isa<mlir::TF::ReadVariableOp>(op.getOwner()); })">>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Add op patterns.
|
||||
@ -185,5 +186,5 @@ def XdivyWithSqrtDivisor : Pat<(TF_XdivyOp $arg0, (TF_SqrtOp $arg1)),
|
||||
// Cast op followed by a ReadVariable op can be folded into the ReadVariable
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ReadVariableOfCast : Pat<(TF_ReadVariableOp (TF_CastOp:$output $x, BoolAttr:$Truncate)), (TF_ReadVariableOp $x), [(HasOneUse $output)]>;
|
||||
def ReadVariableOfCast : Pat<(TF_ReadVariableOp (TF_CastOp:$output $x, BoolAttr:$Truncate)), (TF_ReadVariableOp $x), [(HasOnlyReadVariableOpUsers $output)]>;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user