Fix bug in FoldTransposeIntoMatMul arithmetic optimization.

Previously, the optimization would leave the node map in an inconsistent state:
a non-folded input would continue to consider the pre-optimized node as its
output. If the non-folded input was subsequently optimized in the same pass of
the ArithmeticOptimizer, we could end up with an incorrect graph. To fix, we
ensure that the non-folded input (if any) is rewired to the new node.

PiperOrigin-RevId: 261222411
This commit is contained in:
Derek Murray 2019-08-01 16:33:59 -07:00 committed by TensorFlower Gardener
parent c3f5d9c263
commit 21488b7bca

View File

@ -2248,6 +2248,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
FlipBooleanAttr(attr_a, new_op);
new_op->set_input(0, a->input(0));
ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
} else {
ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name());
}
if (b_is_foldable) {
@ -2256,6 +2258,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
FlipBooleanAttr(attr_b, new_op);
new_op->set_input(1, b->input(0));
ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
} else {
ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name());
}
std::vector<const NodeDef*> deps_to_forward = {node};