diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 9a088702e19..4a9d2907642 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1762,7 +1762,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { const std::set consumers = ctx().node_map->GetOutputs(node_name); for (NodeDef* consumer : consumers) { for (int i = 0; i < consumer->input_size(); ++i) { - if (consumer->input(i) == node_name) { + if (consumer->input(i) == node_name && + consumer->name() != NodeName(new_input)) { consumer->set_input(i, new_input); ctx().node_map->UpdateInput(consumer->name(), node_name, new_input); } @@ -2907,7 +2908,8 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { const std::set consumers = ctx().node_map->GetOutputs(node_name); for (NodeDef* consumer : consumers) { for (int i = 0; i < consumer->input_size(); ++i) { - if (consumer->input(i) == node_name && consumer->name() != new_input) { + if (consumer->input(i) == node_name && + consumer->name() != NodeName(new_input)) { consumer->set_input(i, new_input); ctx().node_map->UpdateInput(consumer->name(), node_name, new_input); }