From fa1259ed252ceada5e82358682c98b4a9b212cd7 Mon Sep 17 00:00:00 2001 From: Xinan Jiang Date: Wed, 20 Nov 2019 15:27:10 +0800 Subject: [PATCH] [Grappler] Fix comparison between node name and input in function UpdateConsumers --- tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d2ff480c29d..07f264c6f21 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1740,7 +1740,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { const std::set consumers = ctx().node_map->GetOutputs(node_name); for (NodeDef* consumer : consumers) { for (int i = 0; i < consumer->input_size(); ++i) { - if (consumer->input(i) == node_name) { + if (consumer->input(i) == node_name && + consumer->name() != NodeName(new_input)) { consumer->set_input(i, new_input); ctx().node_map->UpdateInput(consumer->name(), node_name, new_input); } @@ -2876,7 +2877,8 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { const std::set consumers = ctx().node_map->GetOutputs(node_name); for (NodeDef* consumer : consumers) { for (int i = 0; i < consumer->input_size(); ++i) { - if (consumer->input(i) == node_name && consumer->name() != new_input) { + if (consumer->input(i) == node_name && + consumer->name() != NodeName(new_input)) { consumer->set_input(i, new_input); ctx().node_map->UpdateInput(consumer->name(), node_name, new_input); }