diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index 199bc787b83..bff68574e1a 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -144,6 +144,15 @@ StatusOr TryRemoveConditional(HloInstruction* conditional) { << conditional->ToShortString(); return false; } + + bool branch_empty = + ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) || + ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1)); + // Empty branch is faster to execute than select. + if (branch_empty) { + return false; + } + HloInstruction* true_call_op = create_call(0); HloInstruction* false_call_op = create_call(1); auto condition_broadcast = [&](const Shape& shape) { @@ -160,13 +169,6 @@ StatusOr TryRemoveConditional(HloInstruction* conditional) { hlo->shape().tuple_shapes(i), hlo, i)); }; - bool branch_empty = - ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) || - ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1)); - // Empty branch is faster to execute than select. - if (branch_empty) { - return false; - } std::function select = [&](HloInstruction* t, HloInstruction* f) { if (f->shape().IsToken()) {