Merge pull request #46798 from WindQAQ:fix-region-control-flow-to-functional-incompatible-cast
PiperOrigin-RevId: 355878955 Change-Id: I60e43fd0204bc513df25836853ed313a56aecc1e
This commit is contained in:
commit
0b2f0c4064
tensorflow/compiler/mlir/tensorflow
@ -168,6 +168,50 @@ func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Do not skip extern incompatible cast for trivial transform.
|
||||
|
||||
func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
|
||||
func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
|
||||
func @testIfExternIncompatibleCastTrivialTransform(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<2xf32> {
|
||||
// CHECK: %[[CAST:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
|
||||
// CHECK: "tf.If"(%arg0, %[[CAST]]) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
|
||||
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
|
||||
%0 = "tf.IfRegion"(%arg0) ( {
|
||||
%2 = call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}, {
|
||||
%2 = call @testIf1Else(%1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}) {is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Do not skip incompatible cast for trivial transform.
|
||||
|
||||
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<2xi64>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: "tf.Cast"
|
||||
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<2xi64>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: "tf.Cast"
|
||||
func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
|
||||
func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
|
||||
func @testIfIncompatibleCastTrivialTransform(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<2xf32> {
|
||||
// CHECK: "tf.If"(%arg0, %arg1) {else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then}
|
||||
%0 = "tf.IfRegion"(%arg0) ( {
|
||||
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
|
||||
%2 = call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}, {
|
||||
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
|
||||
%2 = call @testIf1Else(%1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}) {is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// No inputs, some outputs for IfRegion
|
||||
// CHECK: func private @tf.IfRegion_else() -> tensor<2xf32>
|
||||
// CHECK-NEXT: constant dense<1.000000e+00>
|
||||
@ -558,6 +602,37 @@ func @testWhileRegionTrivialMultipleCasts(%arg0 : tensor<*xf32>, %arg1 : tensor<
|
||||
|
||||
// -----
|
||||
|
||||
// Almost trivially transformable with incompatible cast
|
||||
// CHECK: func private @tf.WhileRegion_body
|
||||
// CHECK-NEXT: "tf.Cast"
|
||||
// CHECK: func private @tf.WhileRegion_cond
|
||||
// CHECK-NEXT: "tf.Cast"
|
||||
// CHECK-LABEL: testWhileRegionIncompatibleCast
|
||||
func private @while_cond(%arg0 : tensor<4xf32>, %arg1 : tensor<i32>) -> tensor<i1>
|
||||
func private @while_body(%arg0 : tensor<4xf32>, %arg1 : tensor<i32>) -> (tensor<4xi64>, tensor<i32>)
|
||||
func @testWhileRegionIncompatibleCast(%arg0 : tensor<*xi64>, %arg1 : tensor<i32>) -> tensor<*xi64> {
|
||||
// CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
|
||||
%0:2 = "tf.WhileRegion"(%arg0, %arg1) (
|
||||
{
|
||||
^bb0(%carg0: tensor<*xi64>, %carg1: tensor<i32>):
|
||||
%cond_cast = "tf.Cast"(%carg0) : (tensor<*xi64>) -> tensor<4xf32>
|
||||
%cond = call @while_cond(%cond_cast, %carg1) : (tensor<4xf32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%cond) : (tensor<i1>) -> ()
|
||||
},
|
||||
{
|
||||
// loop body
|
||||
^bb0(%barg0: tensor<*xi64>, %barg1: tensor<i32>):
|
||||
%bdy_cast = "tf.Cast"(%barg0) : (tensor<*xi64>) -> tensor<4xf32>
|
||||
%bdy:2 = call @while_body(%bdy_cast, %barg1) : (tensor<4xf32>, tensor<i32>) -> (tensor<4xi64>, tensor<i32>)
|
||||
"tf.Yield"(%bdy#0, %bdy#1) : (tensor<4xi64>, tensor<i32>) -> ()
|
||||
}
|
||||
) { is_stateless = false } : (tensor<*xi64>, tensor<i32>) -> (tensor<*xi64>, tensor<i32>)
|
||||
// CHECK: return [[Result]]#0
|
||||
return %0#0 : tensor<*xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Almost trivially transformable with extern values
|
||||
// CHECK: func private @tf.WhileRegion_body
|
||||
// CHECK: call @while_body
|
||||
|
@ -217,8 +217,16 @@ bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
|
||||
for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
|
||||
// Get the defining Op, skipping over casts.
|
||||
auto get_defining_op = [](Value value) {
|
||||
while (llvm::isa_and_nonnull<CastOp>(value.getDefiningOp()))
|
||||
value = cast<CastOp>(value.getDefiningOp()).getOperand();
|
||||
while (auto cast_op =
|
||||
llvm::dyn_cast_or_null<CastOp>(value.getDefiningOp())) {
|
||||
// Consider cast compatibility in case
|
||||
// %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32>
|
||||
// is skipped.
|
||||
if (cast_op.SrcT() != cast_op.DstT()) {
|
||||
break;
|
||||
}
|
||||
value = cast_op.getOperand();
|
||||
}
|
||||
return value;
|
||||
};
|
||||
Value first_arg = get_defining_op(std::get<0>(it));
|
||||
|
Loading…
Reference in New Issue
Block a user