Merge pull request #34437 from xinan-jiang:fix/grappler-arithmetic-update-consumers
PiperOrigin-RevId: 288380990 Change-Id: If564fdbf763820172c81e448e97b065c3cd92041
This commit is contained in:
commit
6ed5c4b7f7
@ -1762,7 +1762,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
|
|||||||
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
|
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
|
||||||
for (NodeDef* consumer : consumers) {
|
for (NodeDef* consumer : consumers) {
|
||||||
for (int i = 0; i < consumer->input_size(); ++i) {
|
for (int i = 0; i < consumer->input_size(); ++i) {
|
||||||
if (consumer->input(i) == node_name) {
|
if (consumer->input(i) == node_name &&
|
||||||
|
consumer->name() != NodeName(new_input)) {
|
||||||
consumer->set_input(i, new_input);
|
consumer->set_input(i, new_input);
|
||||||
ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
|
ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
|
||||||
}
|
}
|
||||||
@ -2907,7 +2908,8 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
|
|||||||
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
|
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
|
||||||
for (NodeDef* consumer : consumers) {
|
for (NodeDef* consumer : consumers) {
|
||||||
for (int i = 0; i < consumer->input_size(); ++i) {
|
for (int i = 0; i < consumer->input_size(); ++i) {
|
||||||
if (consumer->input(i) == node_name && consumer->name() != new_input) {
|
if (consumer->input(i) == node_name &&
|
||||||
|
consumer->name() != NodeName(new_input)) {
|
||||||
consumer->set_input(i, new_input);
|
consumer->set_input(i, new_input);
|
||||||
ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
|
ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user