[XLA] algsimplify: Cache scalar add computations per type
Otherwise we'd generate invalid HLO if there's a dot of different types being strength reduced in one run of algsimplify. PiperOrigin-RevId: 313060898 Change-Id: I6e0c3332654f4bfad7590297b66f839c3538115b
This commit is contained in:
parent
bdef91bcff
commit
b583e81bd4
@ -472,8 +472,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
|
||||
HloInstruction* dot);
|
||||
|
||||
HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) {
|
||||
if (scalar_add_computation_) {
|
||||
return scalar_add_computation_;
|
||||
HloComputation*& scalar_add_computation = scalar_add_computations_[type];
|
||||
if (scalar_add_computation) {
|
||||
return scalar_add_computation;
|
||||
}
|
||||
|
||||
HloComputation::Builder b("scalar_add_computation");
|
||||
@ -485,9 +486,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
|
||||
HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
|
||||
auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
|
||||
scalar_add_computation_ =
|
||||
scalar_add_computation =
|
||||
computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
|
||||
return scalar_add_computation_;
|
||||
return scalar_add_computation;
|
||||
}
|
||||
|
||||
// Tries to fold a kPad in the input or filter into the convolution
|
||||
@ -528,8 +529,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
|
||||
// Whether algebraic simplification has occurred.
|
||||
bool changed_ = false;
|
||||
|
||||
// Cached computation for adding two scalar F32.
|
||||
HloComputation* scalar_add_computation_ = nullptr;
|
||||
// Cached computation for adding two scalars of a given type.
|
||||
absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_;
|
||||
|
||||
AlgebraicSimplifier* simplifier_ = nullptr;
|
||||
};
|
||||
|
@ -6520,5 +6520,23 @@ TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
|
||||
m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) {
|
||||
constexpr char kModuleStr[] = R"(
|
||||
HloModule test
|
||||
ENTRY test {
|
||||
a = c64[2,2] parameter(0)
|
||||
b = c64[2] parameter(1)
|
||||
cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
c = f64[2,2] parameter(2)
|
||||
d = f64[2] parameter(3)
|
||||
dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
ROOT tuple = (c64[2], f64[2]) tuple(cd, dd)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_EQ(3, m->computation_count());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user