diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 5f50c2b303b..cfbcb5a4fe2 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -3727,7 +3727,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were // batch dimensions of the dot. The transformation supports reducing other // dimensions as well. - if (Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) && + if (options_.enable_dot_strength_reduction() && + Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) && Match(reduce->to_apply()->root_instruction(), m::Add(m::Parameter(), m::Parameter())) && absl::c_any_of(reduce->dimensions(), [&](int64 dim) {