[XLA] Convert Abs(a)*Abs(a) to a*a and add an option to allow for numerically unsafe algebraic simplifications
PiperOrigin-RevId: 330288395 Change-Id: Iece65fa1cc28a9eb5bcebc2faf8d34235c47e56a
This commit is contained in:
parent
bb49eafc08
commit
e9f0135b7e
@ -913,7 +913,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
|
||||
(Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
|
||||
Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
|
||||
(ShapeUtil::ElementIsIntegral(add->shape()) ||
|
||||
IsAllFpConstantPowerOf2(c))) {
|
||||
options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
|
||||
return ReplaceWithNewInstruction(
|
||||
add, HloInstruction::CreateBinary(
|
||||
add->shape(), HloOpcode::kMultiply,
|
||||
@ -2710,6 +2710,17 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
{
|
||||
HloInstruction* abs_operand;
|
||||
if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
|
||||
!ShapeUtil::ElementIsComplex(abs_operand->shape())) {
|
||||
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
|
||||
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
HloInstruction *convert_operand, *operand;
|
||||
// Mul(Convert(Pred), operand) => select(pred, operand, 0)
|
||||
|
||||
@ -97,6 +97,14 @@ class AlgebraicSimplifierOptions {
|
||||
return enable_scalar_multiply_reduction_;
|
||||
}
|
||||
|
||||
// Also the algebraic simplifer to treat floating point values like real
|
||||
// numbers.
|
||||
void set_enable_floats_are_real(bool enable_floats_are_real) {
|
||||
enable_floats_are_real_ = enable_floats_are_real;
|
||||
}
|
||||
|
||||
bool enable_floats_are_real() const { return enable_floats_are_real_; }
|
||||
|
||||
// If enable_window_reduce_replacement is true, the kReduceWindow instruction
|
||||
// can be optimized by replacement with simpler operations.
|
||||
void set_enable_window_reduce_to_reduce_replacement(
|
||||
@ -158,6 +166,7 @@ class AlgebraicSimplifierOptions {
|
||||
bool enable_conv_simplification_{true};
|
||||
bool enable_conv_operand_swap_{true};
|
||||
bool enable_scalar_multiply_reduction_{false};
|
||||
bool enable_floats_are_real_{false};
|
||||
bool enable_window_reduce_to_reduce_replacement_{true};
|
||||
bool enable_reduce_of_reshape_{true};
|
||||
bool replace_transpose_with_bitcast_{true};
|
||||
|
||||
@ -117,6 +117,22 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
|
||||
m::ConstantScalar(0.125))));
|
||||
}
|
||||
|
||||
// (Abs(A)) * (Abs(A)) => (A*A)
|
||||
TEST_F(AlgebraicSimplifierTest, SquareOfAbs) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p = f32[] parameter(0)
|
||||
a = f32[] abs(p)
|
||||
ROOT z = f32[] multiply(a, a)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
|
||||
}
|
||||
|
||||
// (A*C1) * (B*C2) => (A*B)*(C1*C2)
|
||||
TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
|
||||
const char* kModuleStr = R"(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user