Merge pull request from WindQAQ:fix-region-control-flow-to-functional-incompatible-cast

PiperOrigin-RevId: 355878955
Change-Id: I60e43fd0204bc513df25836853ed313a56aecc1e
This commit is contained in:
TensorFlower Gardener 2021-02-05 10:35:17 -08:00
commit 0b2f0c4064
2 changed files with 85 additions and 2 deletions
tensorflow/compiler/mlir/tensorflow

View File

@ -168,6 +168,50 @@ func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// -----
// Do not skip extern incompatible cast for trivial transform.
func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
func @testIfExternIncompatibleCastTrivialTransform(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<2xf32> {
// CHECK: %[[CAST:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
// CHECK: "tf.If"(%arg0, %[[CAST]]) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
%0 = "tf.IfRegion"(%arg0) ( {
%2 = call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}, {
%2 = call @testIf1Else(%1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// Do not skip incompatible cast for trivial transform.
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<2xi64>) -> tensor<*xf32>
// CHECK-NEXT: "tf.Cast"
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<2xi64>) -> tensor<*xf32>
// CHECK-NEXT: "tf.Cast"
func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
func @testIfIncompatibleCastTrivialTransform(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<2xf32> {
// CHECK: "tf.If"(%arg0, %arg1) {else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then}
%0 = "tf.IfRegion"(%arg0) ( {
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
%2 = call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}, {
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
%2 = call @testIf1Else(%1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// No inputs, some outputs for IfRegion
// CHECK: func private @tf.IfRegion_else() -> tensor<2xf32>
// CHECK-NEXT: constant dense<1.000000e+00>
@ -558,6 +602,37 @@ func @testWhileRegionTrivialMultipleCasts(%arg0 : tensor<*xf32>, %arg1 : tensor<
// -----
// Almost trivially transformable with incompatible cast
// CHECK: func private @tf.WhileRegion_body
// CHECK-NEXT: "tf.Cast"
// CHECK: func private @tf.WhileRegion_cond
// CHECK-NEXT: "tf.Cast"
// CHECK-LABEL: testWhileRegionIncompatibleCast
func private @while_cond(%arg0 : tensor<4xf32>, %arg1 : tensor<i32>) -> tensor<i1>
func private @while_body(%arg0 : tensor<4xf32>, %arg1 : tensor<i32>) -> (tensor<4xi64>, tensor<i32>)
func @testWhileRegionIncompatibleCast(%arg0 : tensor<*xi64>, %arg1 : tensor<i32>) -> tensor<*xi64> {
// CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
%0:2 = "tf.WhileRegion"(%arg0, %arg1) (
{
^bb0(%carg0: tensor<*xi64>, %carg1: tensor<i32>):
%cond_cast = "tf.Cast"(%carg0) : (tensor<*xi64>) -> tensor<4xf32>
%cond = call @while_cond(%cond_cast, %carg1) : (tensor<4xf32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%cond) : (tensor<i1>) -> ()
},
{
// loop body
^bb0(%barg0: tensor<*xi64>, %barg1: tensor<i32>):
%bdy_cast = "tf.Cast"(%barg0) : (tensor<*xi64>) -> tensor<4xf32>
%bdy:2 = call @while_body(%bdy_cast, %barg1) : (tensor<4xf32>, tensor<i32>) -> (tensor<4xi64>, tensor<i32>)
"tf.Yield"(%bdy#0, %bdy#1) : (tensor<4xi64>, tensor<i32>) -> ()
}
) { is_stateless = false } : (tensor<*xi64>, tensor<i32>) -> (tensor<*xi64>, tensor<i32>)
// CHECK: return [[Result]]#0
return %0#0 : tensor<*xi64>
}
// -----
// Almost trivially transformable with extern values
// CHECK: func private @tf.WhileRegion_body
// CHECK: call @while_body

View File

@ -217,8 +217,16 @@ bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
// Get the defining Op, skipping over casts.
auto get_defining_op = [](Value value) {
while (llvm::isa_and_nonnull<CastOp>(value.getDefiningOp()))
value = cast<CastOp>(value.getDefiningOp()).getOperand();
while (auto cast_op =
llvm::dyn_cast_or_null<CastOp>(value.getDefiningOp())) {
// Consider cast compatibility in case
// %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32>
// is skipped.
if (cast_op.SrcT() != cast_op.DstT()) {
break;
}
value = cast_op.getOperand();
}
return value;
};
Value first_arg = get_defining_op(std::get<0>(it));