Simplify broadcast plus compare
PiperOrigin-RevId: 316927924 Change-Id: If7f3ce209cbff3720c19a60d3e713167e1c6e8c6
This commit is contained in:
parent
5e7fc9584a
commit
c6a3ab159f
@ -2815,6 +2815,28 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
|
||||
HloInstruction* lhs;
|
||||
HloInstruction* rhs;
|
||||
CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
|
||||
{
|
||||
// compare(broadcast(a) + x, broadcast(b)) ==>
|
||||
// compare(x, broadcast(b-a))
|
||||
HloInstruction *x, *a, *b;
|
||||
if (Match(compare,
|
||||
m::Compare(
|
||||
m::AddAnyOrder(m::Op(&x), m::Broadcast(m::Op(&a).WithShape(
|
||||
m::Shape().IsScalar()))),
|
||||
m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
|
||||
if (ShapeUtil::ElementIsSigned(x->shape())) {
|
||||
HloInstruction* sub =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
b->shape(), HloOpcode::kSubtract, b, a));
|
||||
HloInstruction* broadcast = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(x->shape(), sub, {}));
|
||||
HloInstruction* new_compare = computation_->AddInstruction(
|
||||
HloInstruction::CreateCompare(compare->shape(), x, broadcast,
|
||||
compare->comparison_direction()));
|
||||
return ReplaceInstruction(compare, new_compare);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (compare->comparison_direction() == ComparisonDirection::kLt &&
|
||||
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
|
||||
|
Loading…
Reference in New Issue
Block a user