[XLA] Convert a broadcasted denominator of a divide into a broadcast of a reciprocal.
PiperOrigin-RevId: 248839769
This commit is contained in:
parent
564cc016d1
commit
fe270f8c0f
@ -1122,6 +1122,20 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
|||||||
return ReplaceInstruction(divide, new_divide);
|
return ReplaceInstruction(divide, new_divide);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A / Broaddcast(B) => A * Broadcast(1/B)
|
||||||
|
if (Match(divide, m::Divide(m::Op(&a), m::Broadcast(&c, m::Op(&b))))) {
|
||||||
|
auto one = MakeBroadcastHlo(
|
||||||
|
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::One(b->shape().element_type()))),
|
||||||
|
{}, b->shape().dimensions());
|
||||||
|
TF_ASSIGN_OR_RETURN(auto recip, MakeBinaryHlo(HloOpcode::kDivide, one, b));
|
||||||
|
auto recip_broadcast =
|
||||||
|
MakeBroadcastHlo(recip, c->dimensions(), c->shape().dimensions());
|
||||||
|
TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a,
|
||||||
|
recip_broadcast));
|
||||||
|
return ReplaceInstruction(divide, new_divide);
|
||||||
|
}
|
||||||
|
|
||||||
// (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
|
// (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
|
||||||
if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
|
if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
|
||||||
m::Divide(m::Op(&c), m::Op(&d))))) {
|
m::Divide(m::Op(&c), m::Op(&d))))) {
|
||||||
|
@ -5061,6 +5061,29 @@ TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
|
|||||||
EXPECT_THAT(root, GmockMatch(m::Multiply()));
|
EXPECT_THAT(root, GmockMatch(m::Multiply()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, DivOfBroadcast) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY test {
|
||||||
|
p0 = f32[10] parameter(0)
|
||||||
|
b = f32[30,10] broadcast(f32[10] p0), dimensions={1}
|
||||||
|
p1 = f32[30,10] parameter(1)
|
||||||
|
ROOT d = f32[30,10] divide(p1,b)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
|
||||||
|
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, GmockMatch(m::Multiply(
|
||||||
|
m::Parameter(1),
|
||||||
|
m::Broadcast(m::Divide(m::Broadcast(m::Constant()),
|
||||||
|
m::Parameter(0))))));
|
||||||
|
}
|
||||||
|
|
||||||
// Test that 1/sqrt(X) is simplified to rsqrt(X).
|
// Test that 1/sqrt(X) is simplified to rsqrt(X).
|
||||||
TEST_F(AlgebraicSimplifierTest, RecipSqrt) {
|
TEST_F(AlgebraicSimplifierTest, RecipSqrt) {
|
||||||
const char* kModuleStr = R"(
|
const char* kModuleStr = R"(
|
||||||
|
Loading…
Reference in New Issue
Block a user