Simplify broadcast plus compare
PiperOrigin-RevId: 316927924 Change-Id: If7f3ce209cbff3720c19a60d3e713167e1c6e8c6
This commit is contained in:
parent
5e7fc9584a
commit
c6a3ab159f
@ -2815,6 +2815,28 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
|
|||||||
HloInstruction* lhs;
|
HloInstruction* lhs;
|
||||||
HloInstruction* rhs;
|
HloInstruction* rhs;
|
||||||
CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
|
CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
|
||||||
|
{
|
||||||
|
// compare(broadcast(a) + x, broadcast(b)) ==>
|
||||||
|
// compare(x, broadcast(b-a))
|
||||||
|
HloInstruction *x, *a, *b;
|
||||||
|
if (Match(compare,
|
||||||
|
m::Compare(
|
||||||
|
m::AddAnyOrder(m::Op(&x), m::Broadcast(m::Op(&a).WithShape(
|
||||||
|
m::Shape().IsScalar()))),
|
||||||
|
m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
|
||||||
|
if (ShapeUtil::ElementIsSigned(x->shape())) {
|
||||||
|
HloInstruction* sub =
|
||||||
|
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
b->shape(), HloOpcode::kSubtract, b, a));
|
||||||
|
HloInstruction* broadcast = computation_->AddInstruction(
|
||||||
|
HloInstruction::CreateBroadcast(x->shape(), sub, {}));
|
||||||
|
HloInstruction* new_compare = computation_->AddInstruction(
|
||||||
|
HloInstruction::CreateCompare(compare->shape(), x, broadcast,
|
||||||
|
compare->comparison_direction()));
|
||||||
|
return ReplaceInstruction(compare, new_compare);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (compare->comparison_direction() == ComparisonDirection::kLt &&
|
if (compare->comparison_direction() == ComparisonDirection::kLt &&
|
||||||
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
|
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user