[XLA] Strength reduce cvt(pred) / bcast(f32) to bcast(1 / f32) * cvt(pred)

This allows us to reduce the number of redundant divides.

PiperOrigin-RevId: 312407220
Change-Id: Id6ac5322d2eeecd1a40aee0e53b2c814220726d0
This commit is contained in:
David Majnemer 2020-05-19 20:15:55 -07:00 committed by TensorFlower Gardener
parent 4a37d3fecd
commit 361470d24a
2 changed files with 36 additions and 0 deletions

View File

@ -1488,6 +1488,22 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return ReplaceInstruction(divide, new_divide);
}
// If X is a convert from pred, then
// X / broadcast(Y) => broadcast(1/Y) * X
if (Match(divide,
m::Divide(
m::Convert(&a,
m::Op().WithShape(m::Shape().WithElementType(PRED))),
m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
TF_ASSIGN_OR_RETURN(
auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b));
auto recip_bcast = computation_->AddInstruction(
HloInstruction::CreateBroadcast(divide->shape(), recip, {}));
TF_ASSIGN_OR_RETURN(auto mul,
MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a));
return ReplaceInstruction(divide, mul);
}
return Status::OK();
}

View File

@ -6481,5 +6481,25 @@ TEST_F(AlgebraicSimplifierTest, SwapConvOperands) {
EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1);
}
TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = pred[2] parameter(0)
cvt = f32[2] convert(p0)
p1 = f32[] parameter(1)
bcast = f32[2] broadcast(p1), dimensions={}
ROOT div = f32[2] divide(cvt, bcast)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::MultiplyAnyOrder(
m::Convert(m::Parameter(0)),
m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
}
} // namespace
} // namespace xla