diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index b9502ffb45e..7a9462e88b7 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2178,6 +2178,8 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage { // Check that value preserving chain is the only consumer of the Mul output. TF_RETURN_IF_TRUE(!IsAnyMul(*source)); TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1); + // And that Mul is not in the preserve set. + TF_RETURN_IF_TRUE(IsInPreserveSet(*source)); const NodeDef* mul = source; int input_idx = 0;