From 21488b7bca8313abe12fbdd4fc7cf4387a26d7bc Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 1 Aug 2019 16:33:59 -0700 Subject: [PATCH] Fix bug in FoldTransposeIntoMatMul arithmetic optimization. Previously, the optimization would leave the node map in an inconsistent state: a non-folded input would continue to consider the pre-optimized node as its output. If the non-folded input was subsequently optimized in the same pass of the ArithmeticOptimizer, we could end up with an incorrect graph. To fix, we ensure that the non-folded input (if any) is rewired to the new node. PiperOrigin-RevId: 261222411 --- tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index a37b0812259..3bbd988f76e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2248,6 +2248,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { FlipBooleanAttr(attr_a, new_op); new_op->set_input(0, a->input(0)); ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0)); + } else { + ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name()); } if (b_is_foldable) { @@ -2256,6 +2258,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { FlipBooleanAttr(attr_b, new_op); new_op->set_input(1, b->input(0)); ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0)); + } else { + ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name()); } std::vector deps_to_forward = {node};