diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index a37b0812259..3bbd988f76e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2248,6 +2248,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { FlipBooleanAttr(attr_a, new_op); new_op->set_input(0, a->input(0)); ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0)); + } else { + ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name()); } if (b_is_foldable) { @@ -2256,6 +2258,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { FlipBooleanAttr(attr_b, new_op); new_op->set_input(1, b->input(0)); ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0)); + } else { + ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name()); } std::vector deps_to_forward = {node};