[XLA] algsimplify: Cache scalar add computations per type
Otherwise we'd generate invalid HLO if there's a dot of different types being strength reduced in one run of algsimplify. PiperOrigin-RevId: 313060898 Change-Id: I6e0c3332654f4bfad7590297b66f839c3538115b
This commit is contained in:
parent
bdef91bcff
commit
b583e81bd4
tensorflow/compiler/xla/service
@ -472,8 +472,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
|
|||||||
HloInstruction* dot);
|
HloInstruction* dot);
|
||||||
|
|
||||||
HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) {
|
HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) {
|
||||||
if (scalar_add_computation_) {
|
HloComputation*& scalar_add_computation = scalar_add_computations_[type];
|
||||||
return scalar_add_computation_;
|
if (scalar_add_computation) {
|
||||||
|
return scalar_add_computation;
|
||||||
}
|
}
|
||||||
|
|
||||||
HloComputation::Builder b("scalar_add_computation");
|
HloComputation::Builder b("scalar_add_computation");
|
||||||
@ -485,9 +486,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
|
|||||||
HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
|
HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
|
||||||
auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
|
auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
|
||||||
shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
|
shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
|
||||||
scalar_add_computation_ =
|
scalar_add_computation =
|
||||||
computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
|
computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
|
||||||
return scalar_add_computation_;
|
return scalar_add_computation;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tries to fold a kPad in the input or filter into the convolution
|
// Tries to fold a kPad in the input or filter into the convolution
|
||||||
@ -528,8 +529,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
|
|||||||
// Whether algebraic simplification has occurred.
|
// Whether algebraic simplification has occurred.
|
||||||
bool changed_ = false;
|
bool changed_ = false;
|
||||||
|
|
||||||
// Cached computation for adding two scalar F32.
|
// Cached computation for adding two scalars of a given type.
|
||||||
HloComputation* scalar_add_computation_ = nullptr;
|
absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_;
|
||||||
|
|
||||||
AlgebraicSimplifier* simplifier_ = nullptr;
|
AlgebraicSimplifier* simplifier_ = nullptr;
|
||||||
};
|
};
|
||||||
|
@ -6520,5 +6520,23 @@ TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
|
|||||||
m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
|
m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) {
|
||||||
|
constexpr char kModuleStr[] = R"(
|
||||||
|
HloModule test
|
||||||
|
ENTRY test {
|
||||||
|
a = c64[2,2] parameter(0)
|
||||||
|
b = c64[2] parameter(1)
|
||||||
|
cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||||
|
c = f64[2,2] parameter(2)
|
||||||
|
d = f64[2] parameter(3)
|
||||||
|
dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||||
|
ROOT tuple = (c64[2], f64[2]) tuple(cd, dd)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_EQ(3, m->computation_count());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user