Fix bug in FoldTransposeIntoMatMul arithmetic optimization.
Previously, the optimization would leave the node map in an inconsistent state: a non-folded input would continue to consider the pre-optimized node as its output. If the non-folded input was subsequently optimized in the same pass of the ArithmeticOptimizer, we could end up with an incorrect graph. To fix, we ensure that the non-folded input (if any) is rewired to the new node. PiperOrigin-RevId: 261222411
This commit is contained in:
parent
c3f5d9c263
commit
21488b7bca
@ -2248,6 +2248,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
|
||||
FlipBooleanAttr(attr_a, new_op);
|
||||
new_op->set_input(0, a->input(0));
|
||||
ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
|
||||
} else {
|
||||
ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name());
|
||||
}
|
||||
|
||||
if (b_is_foldable) {
|
||||
@ -2256,6 +2258,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
|
||||
FlipBooleanAttr(attr_b, new_op);
|
||||
new_op->set_input(1, b->input(0));
|
||||
ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
|
||||
} else {
|
||||
ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name());
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> deps_to_forward = {node};
|
||||
|
Loading…
Reference in New Issue
Block a user