[XLA] Convert a broadcasted denominator of a divide into a broadcast of a reciprocal.
PiperOrigin-RevId: 248839769
This commit is contained in:
parent
564cc016d1
commit
fe270f8c0f
@ -1122,6 +1122,20 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||
return ReplaceInstruction(divide, new_divide);
|
||||
}
|
||||
|
||||
// A / Broaddcast(B) => A * Broadcast(1/B)
|
||||
if (Match(divide, m::Divide(m::Op(&a), m::Broadcast(&c, m::Op(&b))))) {
|
||||
auto one = MakeBroadcastHlo(
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::One(b->shape().element_type()))),
|
||||
{}, b->shape().dimensions());
|
||||
TF_ASSIGN_OR_RETURN(auto recip, MakeBinaryHlo(HloOpcode::kDivide, one, b));
|
||||
auto recip_broadcast =
|
||||
MakeBroadcastHlo(recip, c->dimensions(), c->shape().dimensions());
|
||||
TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a,
|
||||
recip_broadcast));
|
||||
return ReplaceInstruction(divide, new_divide);
|
||||
}
|
||||
|
||||
// (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
|
||||
if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
|
||||
m::Divide(m::Op(&c), m::Op(&d))))) {
|
||||
|
@ -5061,6 +5061,29 @@ TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
|
||||
EXPECT_THAT(root, GmockMatch(m::Multiply()));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, DivOfBroadcast) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10] parameter(0)
|
||||
b = f32[30,10] broadcast(f32[10] p0), dimensions={1}
|
||||
p1 = f32[30,10] parameter(1)
|
||||
ROOT d = f32[30,10] divide(p1,b)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Multiply(
|
||||
m::Parameter(1),
|
||||
m::Broadcast(m::Divide(m::Broadcast(m::Constant()),
|
||||
m::Parameter(0))))));
|
||||
}
|
||||
|
||||
// Test that 1/sqrt(X) is simplified to rsqrt(X).
|
||||
TEST_F(AlgebraicSimplifierTest, RecipSqrt) {
|
||||
const char* kModuleStr = R"(
|
||||
|
Loading…
Reference in New Issue
Block a user