[XLA] Predicate Reduce(Dot(....)) under enable_dot_strength_reduction.

PiperOrigin-RevId: 295896500
Change-Id: I07fda5d17b160f8ea1492c71dee9b6d58204d50b
This commit is contained in:
Blake Hechtman 2020-02-18 22:05:49 -08:00 committed by TensorFlower Gardener
parent eebf50dd9e
commit 7c1bc443fa

View File

@ -3727,7 +3727,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
// Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
// batch dimensions of the dot. The transformation supports reducing other
// dimensions as well.
if (Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
if (options_.enable_dot_strength_reduction() &&
Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
Match(reduce->to_apply()->root_instruction(),
m::Add(m::Parameter(), m::Parameter())) &&
absl::c_any_of(reduce->dimensions(), [&](int64 dim) {