Simplify broadcast plus compare

PiperOrigin-RevId: 316927924
Change-Id: If7f3ce209cbff3720c19a60d3e713167e1c6e8c6
This commit is contained in:
Yunxing Dai 2020-06-17 11:25:14 -07:00 committed by TensorFlower Gardener
parent 5e7fc9584a
commit c6a3ab159f

View File

@ -2815,6 +2815,28 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
HloInstruction* lhs;
HloInstruction* rhs;
CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
{
// compare(broadcast(a) + x, broadcast(b)) ==>
// compare(x, broadcast(b-a))
HloInstruction *x, *a, *b;
if (Match(compare,
m::Compare(
m::AddAnyOrder(m::Op(&x), m::Broadcast(m::Op(&a).WithShape(
m::Shape().IsScalar()))),
m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
if (ShapeUtil::ElementIsSigned(x->shape())) {
HloInstruction* sub =
computation_->AddInstruction(HloInstruction::CreateBinary(
b->shape(), HloOpcode::kSubtract, b, a));
HloInstruction* broadcast = computation_->AddInstruction(
HloInstruction::CreateBroadcast(x->shape(), sub, {}));
HloInstruction* new_compare = computation_->AddInstruction(
HloInstruction::CreateCompare(compare->shape(), x, broadcast,
compare->comparison_direction()));
return ReplaceInstruction(compare, new_compare);
}
}
}
if (compare->comparison_direction() == ComparisonDirection::kLt &&
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {