[XLA] algsimplify: Cache scalar add computations per type

Otherwise we'd generate invalid HLO if there's a dot of different types being
strength reduced in one run of algsimplify.

PiperOrigin-RevId: 313060898
Change-Id: I6e0c3332654f4bfad7590297b66f839c3538115b
This commit is contained in:
Benjamin Kramer 2020-05-25 05:07:13 -07:00 committed by TensorFlower Gardener
parent bdef91bcff
commit b583e81bd4
2 changed files with 25 additions and 6 deletions

View File

@ -472,8 +472,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
HloInstruction* dot);
HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) {
if (scalar_add_computation_) {
return scalar_add_computation_;
HloComputation*& scalar_add_computation = scalar_add_computations_[type];
if (scalar_add_computation) {
return scalar_add_computation;
}
HloComputation::Builder b("scalar_add_computation");
@ -485,9 +486,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
scalar_add_computation_ =
scalar_add_computation =
computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
return scalar_add_computation_;
return scalar_add_computation;
}
// Tries to fold a kPad in the input or filter into the convolution
@ -528,8 +529,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
// Whether algebraic simplification has occurred.
bool changed_ = false;
// Cached computation for adding two scalar F32.
HloComputation* scalar_add_computation_ = nullptr;
// Cached computation for adding two scalars of a given type.
absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_;
AlgebraicSimplifier* simplifier_ = nullptr;
};

View File

@ -6520,5 +6520,23 @@ TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
}
TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) {
constexpr char kModuleStr[] = R"(
HloModule test
ENTRY test {
a = c64[2,2] parameter(0)
b = c64[2] parameter(1)
cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
c = f64[2,2] parameter(2)
d = f64[2] parameter(3)
dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT tuple = (c64[2], f64[2]) tuple(cd, dd)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_EQ(3, m->computation_count());
}
} // namespace
} // namespace xla