Merge pull request #34437 from xinan-jiang:fix/grappler-arithmetic-update-consumers
PiperOrigin-RevId: 288380990 Change-Id: If564fdbf763820172c81e448e97b065c3cd92041
This commit is contained in:
commit
6ed5c4b7f7
@ -1762,7 +1762,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
|
||||
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
|
||||
for (NodeDef* consumer : consumers) {
|
||||
for (int i = 0; i < consumer->input_size(); ++i) {
|
||||
if (consumer->input(i) == node_name) {
|
||||
if (consumer->input(i) == node_name &&
|
||||
consumer->name() != NodeName(new_input)) {
|
||||
consumer->set_input(i, new_input);
|
||||
ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
|
||||
}
|
||||
@ -2907,7 +2908,8 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
|
||||
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
|
||||
for (NodeDef* consumer : consumers) {
|
||||
for (int i = 0; i < consumer->input_size(); ++i) {
|
||||
if (consumer->input(i) == node_name && consumer->name() != new_input) {
|
||||
if (consumer->input(i) == node_name &&
|
||||
consumer->name() != NodeName(new_input)) {
|
||||
consumer->set_input(i, new_input);
|
||||
ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user