Merge pull request #34437 from xinan-jiang:fix/grappler-arithmetic-update-consumers

PiperOrigin-RevId: 288380990
Change-Id: If564fdbf763820172c81e448e97b065c3cd92041
This commit is contained in:
TensorFlower Gardener 2020-01-06 14:53:33 -08:00
commit 6ed5c4b7f7

View File

@ -1762,7 +1762,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name); const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
for (NodeDef* consumer : consumers) { for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) { for (int i = 0; i < consumer->input_size(); ++i) {
if (consumer->input(i) == node_name) { if (consumer->input(i) == node_name &&
consumer->name() != NodeName(new_input)) {
consumer->set_input(i, new_input); consumer->set_input(i, new_input);
ctx().node_map->UpdateInput(consumer->name(), node_name, new_input); ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
} }
@ -2907,7 +2908,8 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name); const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
for (NodeDef* consumer : consumers) { for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) { for (int i = 0; i < consumer->input_size(); ++i) {
if (consumer->input(i) == node_name && consumer->name() != new_input) { if (consumer->input(i) == node_name &&
consumer->name() != NodeName(new_input)) {
consumer->set_input(i, new_input); consumer->set_input(i, new_input);
ctx().node_map->UpdateInput(consumer->name(), node_name, new_input); ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
} }