Fix crash when trying to look through blockarg

PiperOrigin-RevId: 303178864
Change-Id: I1eb1560c3e236e5b566c5b69be98ebf74582c549
This commit is contained in:
Jacques Pienaar 2020-03-26 13:09:07 -07:00 committed by TensorFlower Gardener
parent 3dc8f81602
commit 6e920629c5
2 changed files with 22 additions and 3 deletions

View File

@ -38,6 +38,24 @@ func @test_single_branch_direct_t() -> tensor<i32> {
return %0 : tensor<i32>
}
// CHECK-LABEL: test_single_branch_direct_arg_f
// CHECK: Switch
// CHECK: tf.AddV2
func @test_single_branch_direct_arg_f(%pred : tensor<i1>) -> tensor<i32> {
%cst_0 = constant dense<10> : tensor<i32>
%cst_1 = constant dense<1> : tensor<i32>
%0 = tf_executor.graph {
%7:3 = tf_executor.Switch %cst_0, %pred : tensor<i32>
%8:2 = tf_executor.island {
%12 = "tf.AddV2"(%7#1, %cst_1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %12 : tensor<i32>
}
%11:3 = tf_executor.Merge %7#0, %8#0 : tensor<i32> {N = 2 : i64}
tf_executor.fetch %11#0 : tensor<i32>
}
return %0 : tensor<i32>
}
// pred ? x + 1 : x - 1
// CHECK-LABEL: ControlFlowTest.testCond_1f
// CHECK-NOT: Switch
@ -330,4 +348,4 @@ func @switch_with_send_recv() {
tf_executor.fetch
}
return
}
}

View File

@ -67,7 +67,7 @@ class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
// Returns the defining op for a value looking through islands.
static Operation* GetDefiningOp(Value val) {
Operation* op = val.getDefiningOp();
auto island_op = dyn_cast<tf_executor::IslandOp>(op);
auto island_op = dyn_cast_or_null<tf_executor::IslandOp>(op);
if (!island_op) return op;
auto yield_op = island_op.GetYield();
auto index = val.cast<mlir::OpResult>().getResultNumber();
@ -84,7 +84,8 @@ static Operation* GetDefiningOp(Value val) {
static Value LookThroughIdentityOp(Value pred_val) {
if (!pred_val) return pred_val;
auto op = GetDefiningOp(pred_val);
if (auto id_op = dyn_cast<TF::IdentityOp>(op)) pred_val = id_op.input();
if (auto id_op = dyn_cast_or_null<TF::IdentityOp>(op))
pred_val = id_op.input();
return pred_val;
}