[XLA] Strength reduce cvt(pred) / bcast(f32) to bcast(1 / f32) * cvt(pred)
This allows us to reduce the number of redundant divides. PiperOrigin-RevId: 312407220 Change-Id: Id6ac5322d2eeecd1a40aee0e53b2c814220726d0
This commit is contained in:
parent
4a37d3fecd
commit
361470d24a
@ -1488,6 +1488,22 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
|||||||
return ReplaceInstruction(divide, new_divide);
|
return ReplaceInstruction(divide, new_divide);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If X is a convert from pred, then
|
||||||
|
// X / broadcast(Y) => broadcast(1/Y) * X
|
||||||
|
if (Match(divide,
|
||||||
|
m::Divide(
|
||||||
|
m::Convert(&a,
|
||||||
|
m::Op().WithShape(m::Shape().WithElementType(PRED))),
|
||||||
|
m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b));
|
||||||
|
auto recip_bcast = computation_->AddInstruction(
|
||||||
|
HloInstruction::CreateBroadcast(divide->shape(), recip, {}));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto mul,
|
||||||
|
MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a));
|
||||||
|
return ReplaceInstruction(divide, mul);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6481,5 +6481,25 @@ TEST_F(AlgebraicSimplifierTest, SwapConvOperands) {
|
|||||||
EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1);
|
EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
p0 = pred[2] parameter(0)
|
||||||
|
cvt = f32[2] convert(p0)
|
||||||
|
p1 = f32[] parameter(1)
|
||||||
|
bcast = f32[2] broadcast(p1), dimensions={}
|
||||||
|
ROOT div = f32[2] divide(cvt, bcast)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(
|
||||||
|
m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::MultiplyAnyOrder(
|
||||||
|
m::Convert(m::Parameter(0)),
|
||||||
|
m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user