Fix crash when trying to look through blockarg
PiperOrigin-RevId: 303178864 Change-Id: I1eb1560c3e236e5b566c5b69be98ebf74582c549
This commit is contained in:
parent
3dc8f81602
commit
6e920629c5
@ -38,6 +38,24 @@ func @test_single_branch_direct_t() -> tensor<i32> {
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_single_branch_direct_arg_f
|
||||
// CHECK: Switch
|
||||
// CHECK: tf.AddV2
|
||||
func @test_single_branch_direct_arg_f(%pred : tensor<i1>) -> tensor<i32> {
|
||||
%cst_0 = constant dense<10> : tensor<i32>
|
||||
%cst_1 = constant dense<1> : tensor<i32>
|
||||
%0 = tf_executor.graph {
|
||||
%7:3 = tf_executor.Switch %cst_0, %pred : tensor<i32>
|
||||
%8:2 = tf_executor.island {
|
||||
%12 = "tf.AddV2"(%7#1, %cst_1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_executor.yield %12 : tensor<i32>
|
||||
}
|
||||
%11:3 = tf_executor.Merge %7#0, %8#0 : tensor<i32> {N = 2 : i64}
|
||||
tf_executor.fetch %11#0 : tensor<i32>
|
||||
}
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// pred ? x + 1 : x - 1
|
||||
// CHECK-LABEL: ControlFlowTest.testCond_1f
|
||||
// CHECK-NOT: Switch
|
||||
@ -330,4 +348,4 @@ func @switch_with_send_recv() {
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
|
||||
// Returns the defining op for a value looking through islands.
|
||||
static Operation* GetDefiningOp(Value val) {
|
||||
Operation* op = val.getDefiningOp();
|
||||
auto island_op = dyn_cast<tf_executor::IslandOp>(op);
|
||||
auto island_op = dyn_cast_or_null<tf_executor::IslandOp>(op);
|
||||
if (!island_op) return op;
|
||||
auto yield_op = island_op.GetYield();
|
||||
auto index = val.cast<mlir::OpResult>().getResultNumber();
|
||||
@ -84,7 +84,8 @@ static Operation* GetDefiningOp(Value val) {
|
||||
static Value LookThroughIdentityOp(Value pred_val) {
|
||||
if (!pred_val) return pred_val;
|
||||
auto op = GetDefiningOp(pred_val);
|
||||
if (auto id_op = dyn_cast<TF::IdentityOp>(op)) pred_val = id_op.input();
|
||||
if (auto id_op = dyn_cast_or_null<TF::IdentityOp>(op))
|
||||
pred_val = id_op.input();
|
||||
return pred_val;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user