diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index f1c4ea2df75..aa00860c5a5 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -209,12 +209,11 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } // Do not fuse a producer if the other operands of the fusion are // reachable from the producer, this would create a cycle. - if (std::any_of(consumer_operands.begin(), consumer_operands.end(), - [&](HloInstruction* operand) { - return producer != operand && - reachability()->IsReachable(producer, operand); - })) { - continue; + if (c_any_of(consumer_operands, [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { + break; } to_fuse.insert(producer); potential_fusion_list.emplace_back(producer, consumer); @@ -229,15 +228,10 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { for (auto& fusion_pair : potential_fusion_list) { HloInstruction* producer = fusion_pair.first; HloInstruction* consumer = fusion_pair.second; - bool fusable = true; - for (size_t i = 0; i < consumer->operand_count(); ++i) { - if (producer != consumer->operand(i) && - reachability()->IsReachable(producer, consumer->operand(i))) { - fusable = false; - break; - } - } - if (fusable) { + if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { UpdateReachability(producer, consumer, instrs_to_update_reachability); fusion_list.push_back(fusion_pair); }