[XLA] Strength reduce cvt(pred) / bcast(f32) to bcast(1 / f32) * cvt(pred)
This allows us to reduce the number of redundant divides. PiperOrigin-RevId: 312407220 Change-Id: Id6ac5322d2eeecd1a40aee0e53b2c814220726d0
This commit is contained in:
parent
4a37d3fecd
commit
361470d24a
@ -1488,6 +1488,22 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||
return ReplaceInstruction(divide, new_divide);
|
||||
}
|
||||
|
||||
// If X is a convert from pred, then
|
||||
// X / broadcast(Y) => broadcast(1/Y) * X
|
||||
if (Match(divide,
|
||||
m::Divide(
|
||||
m::Convert(&a,
|
||||
m::Op().WithShape(m::Shape().WithElementType(PRED))),
|
||||
m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b));
|
||||
auto recip_bcast = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(divide->shape(), recip, {}));
|
||||
TF_ASSIGN_OR_RETURN(auto mul,
|
||||
MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a));
|
||||
return ReplaceInstruction(divide, mul);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -6481,5 +6481,25 @@ TEST_F(AlgebraicSimplifierTest, SwapConvOperands) {
|
||||
EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1);
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = pred[2] parameter(0)
|
||||
cvt = f32[2] convert(p0)
|
||||
p1 = f32[] parameter(1)
|
||||
bcast = f32[2] broadcast(p1), dimensions={}
|
||||
ROOT div = f32[2] divide(cvt, bcast)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::MultiplyAnyOrder(
|
||||
m::Convert(m::Parameter(0)),
|
||||
m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user