[XLA] Convert a broadcasted denominator of a divide into a broadcast of a reciprocal.

PiperOrigin-RevId: 248839769
This commit is contained in:
Blake Hechtman 2019-05-17 23:21:27 -07:00 committed by TensorFlower Gardener
parent 564cc016d1
commit fe270f8c0f
2 changed files with 37 additions and 0 deletions

View File

@ -1122,6 +1122,20 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return ReplaceInstruction(divide, new_divide);
}
// A / Broaddcast(B) => A * Broadcast(1/B)
if (Match(divide, m::Divide(m::Op(&a), m::Broadcast(&c, m::Op(&b))))) {
auto one = MakeBroadcastHlo(
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::One(b->shape().element_type()))),
{}, b->shape().dimensions());
TF_ASSIGN_OR_RETURN(auto recip, MakeBinaryHlo(HloOpcode::kDivide, one, b));
auto recip_broadcast =
MakeBroadcastHlo(recip, c->dimensions(), c->shape().dimensions());
TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a,
recip_broadcast));
return ReplaceInstruction(divide, new_divide);
}
// (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
m::Divide(m::Op(&c), m::Op(&d))))) {

View File

@ -5061,6 +5061,29 @@ TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
EXPECT_THAT(root, GmockMatch(m::Multiply()));
}
TEST_F(AlgebraicSimplifierTest, DivOfBroadcast) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
p0 = f32[10] parameter(0)
b = f32[30,10] broadcast(f32[10] p0), dimensions={1}
p1 = f32[30,10] parameter(1)
ROOT d = f32[30,10] divide(p1,b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Multiply(
m::Parameter(1),
m::Broadcast(m::Divide(m::Broadcast(m::Constant()),
m::Parameter(0))))));
}
// Test that 1/sqrt(X) is simplified to rsqrt(X).
TEST_F(AlgebraicSimplifierTest, RecipSqrt) {
const char* kModuleStr = R"(