diff --git a/tensorflow/compiler/mlir/tensorflow/tests/fold-switch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/fold-switch.mlir index 73ae30c7831..0b9e995b386 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/fold-switch.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/fold-switch.mlir @@ -38,6 +38,24 @@ func @test_single_branch_direct_t() -> tensor { return %0 : tensor } +// CHECK-LABEL: test_single_branch_direct_arg_f +// CHECK: Switch +// CHECK: tf.AddV2 +func @test_single_branch_direct_arg_f(%pred : tensor) -> tensor { + %cst_0 = constant dense<10> : tensor + %cst_1 = constant dense<1> : tensor + %0 = tf_executor.graph { + %7:3 = tf_executor.Switch %cst_0, %pred : tensor + %8:2 = tf_executor.island { + %12 = "tf.AddV2"(%7#1, %cst_1) : (tensor, tensor) -> tensor + tf_executor.yield %12 : tensor + } + %11:3 = tf_executor.Merge %7#0, %8#0 : tensor {N = 2 : i64} + tf_executor.fetch %11#0 : tensor + } + return %0 : tensor +} + // pred ? x + 1 : x - 1 // CHECK-LABEL: ControlFlowTest.testCond_1f // CHECK-NOT: Switch @@ -330,4 +348,4 @@ func @switch_with_send_recv() { tf_executor.fetch } return -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index 7d0e7e20e5d..30444b88677 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -67,7 +67,7 @@ class SwitchFoldPass : public mlir::FunctionPass { // Returns the defining op for a value looking through islands. static Operation* GetDefiningOp(Value val) { Operation* op = val.getDefiningOp(); - auto island_op = dyn_cast(op); + auto island_op = dyn_cast_or_null(op); if (!island_op) return op; auto yield_op = island_op.GetYield(); auto index = val.cast().getResultNumber(); @@ -84,7 +84,8 @@ static Operation* GetDefiningOp(Value val) { static Value LookThroughIdentityOp(Value pred_val) { if (!pred_val) return pred_val; auto op = GetDefiningOp(pred_val); - if (auto id_op = dyn_cast(op)) pred_val = id_op.input(); + if (auto id_op = dyn_cast_or_null(op)) + pred_val = id_op.input(); return pred_val; }