[XLA] Convert Abs(a)*Abs(a) to a*a and add an option to allow for numerically unsafe algebraic simplifications

PiperOrigin-RevId: 330288395
Change-Id: Iece65fa1cc28a9eb5bcebc2faf8d34235c47e56a
This commit is contained in:
A. Unique TensorFlower 2020-09-06 12:44:09 -07:00 committed by TensorFlower Gardener
parent bb49eafc08
commit e9f0135b7e
3 changed files with 37 additions and 1 deletions

View File

@ -913,7 +913,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
(Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
(ShapeUtil::ElementIsIntegral(add->shape()) ||
IsAllFpConstantPowerOf2(c))) {
options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
return ReplaceWithNewInstruction(
add, HloInstruction::CreateBinary(
add->shape(), HloOpcode::kMultiply,
@ -2710,6 +2710,17 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
return Status::OK();
}
{
HloInstruction* abs_operand;
if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
!ShapeUtil::ElementIsComplex(abs_operand->shape())) {
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
changed_ = true;
return Status::OK();
}
}
{
HloInstruction *convert_operand, *operand;
// Mul(Convert(Pred), operand) => select(pred, operand, 0)

View File

@ -97,6 +97,14 @@ class AlgebraicSimplifierOptions {
return enable_scalar_multiply_reduction_;
}
// Also the algebraic simplifer to treat floating point values like real
// numbers.
void set_enable_floats_are_real(bool enable_floats_are_real) {
enable_floats_are_real_ = enable_floats_are_real;
}
bool enable_floats_are_real() const { return enable_floats_are_real_; }
// If enable_window_reduce_replacement is true, the kReduceWindow instruction
// can be optimized by replacement with simpler operations.
void set_enable_window_reduce_to_reduce_replacement(
@ -158,6 +166,7 @@ class AlgebraicSimplifierOptions {
bool enable_conv_simplification_{true};
bool enable_conv_operand_swap_{true};
bool enable_scalar_multiply_reduction_{false};
bool enable_floats_are_real_{false};
bool enable_window_reduce_to_reduce_replacement_{true};
bool enable_reduce_of_reshape_{true};
bool replace_transpose_with_bitcast_{true};

View File

@ -117,6 +117,22 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
m::ConstantScalar(0.125))));
}
// (Abs(A)) * (Abs(A)) => (A*A)
TEST_F(AlgebraicSimplifierTest, SquareOfAbs) {
const char* kModuleStr = R"(
HloModule m
test {
p = f32[] parameter(0)
a = f32[] abs(p)
ROOT z = f32[] multiply(a, a)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
}
// (A*C1) * (B*C2) => (A*B)*(C1*C2)
TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
const char* kModuleStr = R"(