Fix bug in FoldTransposeIntoMatMul arithmetic optimization.
Previously, the optimization would leave the node map in an inconsistent state: a non-folded input would continue to consider the pre-optimized node as its output. If the non-folded input was subsequently optimized in the same pass of the ArithmeticOptimizer, we could end up with an incorrect graph. To fix, we ensure that the non-folded input (if any) is rewired to the new node. PiperOrigin-RevId: 261222411
This commit is contained in:
parent
c3f5d9c263
commit
21488b7bca
@ -2248,6 +2248,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
|
|||||||
FlipBooleanAttr(attr_a, new_op);
|
FlipBooleanAttr(attr_a, new_op);
|
||||||
new_op->set_input(0, a->input(0));
|
new_op->set_input(0, a->input(0));
|
||||||
ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
|
ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
|
||||||
|
} else {
|
||||||
|
ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (b_is_foldable) {
|
if (b_is_foldable) {
|
||||||
@ -2256,6 +2258,8 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
|
|||||||
FlipBooleanAttr(attr_b, new_op);
|
FlipBooleanAttr(attr_b, new_op);
|
||||||
new_op->set_input(1, b->input(0));
|
new_op->set_input(1, b->input(0));
|
||||||
ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
|
ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
|
||||||
|
} else {
|
||||||
|
ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<const NodeDef*> deps_to_forward = {node};
|
std::vector<const NodeDef*> deps_to_forward = {node};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user