diff --git a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir index c6c15de3ee0..6a5d702accd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir @@ -168,6 +168,50 @@ func @testIf2Result(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // ----- +// Do not skip extern incompatible cast for trivial transform. + +func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32> +func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32> +func @testIfExternIncompatibleCastTrivialTransform(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<2xf32> { + // CHECK: %[[CAST:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32> + // CHECK: "tf.If"(%arg0, %[[CAST]]) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then} + %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32> + %0 = "tf.IfRegion"(%arg0) ( { + %2 = call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%2) : (tensor<*xf32>) -> () + }, { + %2 = call @testIf1Else(%1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%2) : (tensor<*xf32>) -> () + }) {is_stateless = false} : (tensor) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +// Do not skip incompatible cast for trivial transform. + +// CHECK: func private @tf.IfRegion_else(%arg0: tensor<2xi64>) -> tensor<*xf32> +// CHECK-NEXT: "tf.Cast" +// CHECK: func private @tf.IfRegion_then(%arg0: tensor<2xi64>) -> tensor<*xf32> +// CHECK-NEXT: "tf.Cast" +func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32> +func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32> +func @testIfIncompatibleCastTrivialTransform(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<2xf32> { + // CHECK: "tf.If"(%arg0, %arg1) {else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then} + %0 = "tf.IfRegion"(%arg0) ( { + %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32> + %2 = call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%2) : (tensor<*xf32>) -> () + }, { + %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32> + %2 = call @testIf1Else(%1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%2) : (tensor<*xf32>) -> () + }) {is_stateless = false} : (tensor) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + // No inputs, some outputs for IfRegion // CHECK: func private @tf.IfRegion_else() -> tensor<2xf32> // CHECK-NEXT: constant dense<1.000000e+00> @@ -558,6 +602,37 @@ func @testWhileRegionTrivialMultipleCasts(%arg0 : tensor<*xf32>, %arg1 : tensor< // ----- +// Almost trivially transformable with incompatible cast +// CHECK: func private @tf.WhileRegion_body +// CHECK-NEXT: "tf.Cast" +// CHECK: func private @tf.WhileRegion_cond +// CHECK-NEXT: "tf.Cast" +// CHECK-LABEL: testWhileRegionIncompatibleCast +func private @while_cond(%arg0 : tensor<4xf32>, %arg1 : tensor) -> tensor +func private @while_body(%arg0 : tensor<4xf32>, %arg1 : tensor) -> (tensor<4xi64>, tensor) +func @testWhileRegionIncompatibleCast(%arg0 : tensor<*xi64>, %arg1 : tensor) -> tensor<*xi64> { + // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xi64>, %carg1: tensor): + %cond_cast = "tf.Cast"(%carg0) : (tensor<*xi64>) -> tensor<4xf32> + %cond = call @while_cond(%cond_cast, %carg1) : (tensor<4xf32>, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xi64>, %barg1: tensor): + %bdy_cast = "tf.Cast"(%barg0) : (tensor<*xi64>) -> tensor<4xf32> + %bdy:2 = call @while_body(%bdy_cast, %barg1) : (tensor<4xf32>, tensor) -> (tensor<4xi64>, tensor) + "tf.Yield"(%bdy#0, %bdy#1) : (tensor<4xi64>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xi64>, tensor) -> (tensor<*xi64>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xi64> +} + +// ----- + // Almost trivially transformable with extern values // CHECK: func private @tf.WhileRegion_body // CHECK: call @while_body diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index 31839ac6386..90ba1e48d40 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -217,8 +217,16 @@ bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) { for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) { // Get the defining Op, skipping over casts. auto get_defining_op = [](Value value) { - while (llvm::isa_and_nonnull(value.getDefiningOp())) - value = cast(value.getDefiningOp()).getOperand(); + while (auto cast_op = + llvm::dyn_cast_or_null(value.getDefiningOp())) { + // Consider cast compatibility in case + // %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32> + // is skipped. + if (cast_op.SrcT() != cast_op.DstT()) { + break; + } + value = cast_op.getOperand(); + } return value; }; Value first_arg = get_defining_op(std::get<0>(it));