Canonicalize Not(...) more aggressively in deadness analysis
This change makes deadness analysis canonicalize Not expressions more aggressively using DeMorgan's law. It uses this added power to more aggressively detect And and Or expressions whose operands include a predicate and its negation. The main motivation for doing this is to simplify ~#true to #false and ~#false to #true, but I can't test that directly before another related change lands. PiperOrigin-RevId: 230356857
This commit is contained in:
parent
80f47caee4
commit
5d99fe9b65
@ -333,16 +333,19 @@ class PredicateFactory {
|
||||
}
|
||||
|
||||
Predicate* MakeNotPredicate(Predicate* pred) {
|
||||
SignatureForNot signature = pred;
|
||||
auto it = interned_not_instances_.find(signature);
|
||||
if (it == interned_not_instances_.end()) {
|
||||
std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
|
||||
Predicate* new_pred_ptr = new_pred.get();
|
||||
interned_not_instances_.emplace(signature, std::move(new_pred));
|
||||
return new_pred_ptr;
|
||||
} else {
|
||||
return it->second.get();
|
||||
auto it = make_not_predicate_cache_.find(pred);
|
||||
if (it != make_not_predicate_cache_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
Predicate* result = MakeNotPredicateImpl(pred);
|
||||
|
||||
bool insert_successful =
|
||||
make_not_predicate_cache_.insert({pred, result}).second;
|
||||
(void)insert_successful;
|
||||
DCHECK(insert_successful);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step) {
|
||||
@ -378,7 +381,52 @@ class PredicateFactory {
|
||||
Predicate* MakeTrue() { return MakeAndPredicate({}); }
|
||||
Predicate* MakeFalse() { return MakeOrPredicate({}); }
|
||||
|
||||
~PredicateFactory() {
|
||||
DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?";
|
||||
}
|
||||
|
||||
private:
|
||||
Predicate* MakeNotPredicateImpl(Predicate* pred) {
|
||||
IncrementStackDepth stack_frame(this);
|
||||
if (!stack_frame.HasOverflowed()) {
|
||||
if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) {
|
||||
return simplified;
|
||||
}
|
||||
|
||||
// ~~A => A
|
||||
if (auto* not_pred = dynamic_cast<NotPredicate*>(pred)) {
|
||||
return not_pred->operand();
|
||||
}
|
||||
}
|
||||
|
||||
SignatureForNot signature = pred;
|
||||
auto it = interned_not_instances_.find(signature);
|
||||
if (it == interned_not_instances_.end()) {
|
||||
std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
|
||||
Predicate* new_pred_ptr = new_pred.get();
|
||||
interned_not_instances_.emplace(signature, std::move(new_pred));
|
||||
return new_pred_ptr;
|
||||
} else {
|
||||
return it->second.get();
|
||||
}
|
||||
}
|
||||
|
||||
Predicate* SimplifyUsingDeMorgan(Predicate* pred) {
|
||||
// ~(A & B & C & ...) => ~A | ~B | ~C | ~...
|
||||
// ~(A | B | C | ...) -> ~A & ~B & ~C & ~...
|
||||
Predicate::Kind kind = pred->kind();
|
||||
|
||||
if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) {
|
||||
std::vector<Predicate*> new_operands;
|
||||
absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands),
|
||||
[&](Predicate* p) { return MakeNotPredicate(p); });
|
||||
return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands)
|
||||
: MakeOrPredicate(new_operands);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename PredicateT, typename... Args>
|
||||
std::unique_ptr<Predicate> Make(Args&&... args) {
|
||||
return std::unique_ptr<PredicateT>(
|
||||
@ -422,6 +470,36 @@ class PredicateFactory {
|
||||
}
|
||||
};
|
||||
|
||||
// Used to limit recursion to avoid blowing up the stack and cap compile time.
|
||||
class IncrementStackDepth {
|
||||
public:
|
||||
explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) {
|
||||
parent_->stack_depth_++;
|
||||
}
|
||||
|
||||
bool HasOverflowed() const {
|
||||
const int kMaxStackDepth = 8;
|
||||
return parent_->stack_depth_ >= kMaxStackDepth;
|
||||
}
|
||||
|
||||
~IncrementStackDepth() { parent_->stack_depth_--; }
|
||||
|
||||
private:
|
||||
PredicateFactory* parent_;
|
||||
};
|
||||
|
||||
// A cache for the MakeNotPredicate function.
|
||||
//
|
||||
// NB! This is *not* the same as `interned_not_instances_`.
|
||||
// `interned_not_instances_` maps ensures pointer identity for `NotPredicate`
|
||||
// instances, i.e., it ensures there at most one instance of Not(predicate)
|
||||
// for any given predicate whereas `make_not_predicate_cache_` simply caches
|
||||
// the result of the `MakeNotPredicate` function. The values in
|
||||
// `interned_not_instances_` are always instance of `NotPredicate` whereas the
|
||||
// values in `make_not_predicate_cache_` may not be (for instance it will map
|
||||
// Not(Not(A)) to A).
|
||||
absl::flat_hash_map<Predicate*, Predicate*> make_not_predicate_cache_;
|
||||
|
||||
absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
|
||||
HashSignatureForAndOr>
|
||||
interned_and_or_instances_;
|
||||
@ -432,6 +510,7 @@ class PredicateFactory {
|
||||
absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
|
||||
HashSignatureForSymbol>
|
||||
interned_symbol_instances_;
|
||||
int stack_depth_ = 0;
|
||||
};
|
||||
|
||||
Predicate* PredicateFactory::MakeInternedAndOr(
|
||||
@ -466,6 +545,13 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
absl::Span<Predicate* const> operands, bool is_and) {
|
||||
Predicate::Kind pred_kind =
|
||||
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
|
||||
|
||||
IncrementStackDepth stack_frame(this);
|
||||
if (stack_frame.HasOverflowed()) {
|
||||
return MakeInternedAndOr(
|
||||
std::vector<Predicate*>(operands.begin(), operands.end()), pred_kind);
|
||||
}
|
||||
|
||||
Predicate::Kind other_pred_kind =
|
||||
is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
|
||||
absl::flat_hash_set<Predicate*> simplified_ops_set;
|
||||
@ -494,16 +580,31 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
|
||||
// Simplify "A&~A=>False" and "A|~A=>True".
|
||||
absl::flat_hash_set<Predicate*> negated_ops;
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (op->kind() == Predicate::Kind::kNot) {
|
||||
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
|
||||
}
|
||||
}
|
||||
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (negated_ops.count(op)) {
|
||||
// Simple case:
|
||||
//
|
||||
// A & ~A & ... == False
|
||||
// A | ~A | ... == True
|
||||
return is_and ? MakeFalse() : MakeTrue();
|
||||
}
|
||||
|
||||
Predicate* negated_op = MakeNotPredicate(op);
|
||||
if (negated_op->kind() == pred_kind) {
|
||||
// Slightly more complicated case:
|
||||
//
|
||||
// (~A | ~B | ~C) & A & B & C & ... ==
|
||||
// ~(A & B & C) & (A & B & C) & ... == False
|
||||
//
|
||||
// (~A & ~B & ~C) | A | B | C | ... ==
|
||||
// ~(A | B | C) | (A | B | C) | ... == True
|
||||
if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) {
|
||||
return simplified_ops_set.contains(p);
|
||||
})) {
|
||||
return is_and ? MakeFalse() : MakeTrue();
|
||||
}
|
||||
}
|
||||
negated_ops.insert(negated_op);
|
||||
}
|
||||
|
||||
// If all ops contain the same subop, then factor it out thanks to the
|
||||
|
@ -818,5 +818,44 @@ TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
|
||||
EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)");
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, DeMorgan) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL);
|
||||
Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL);
|
||||
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
|
||||
|
||||
ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0);
|
||||
ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1);
|
||||
|
||||
Output and_0_1 =
|
||||
ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true);
|
||||
|
||||
Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"),
|
||||
{sw_0.output_false, sw_1.output_false})
|
||||
.output;
|
||||
|
||||
// Predicate(should_always_be_dead) =
|
||||
// (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False
|
||||
Output should_always_be_dead =
|
||||
ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1);
|
||||
|
||||
// Predicate(should_always_be_dead) =
|
||||
// (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True
|
||||
Output should_always_be_alive =
|
||||
ops::Merge(root.WithOpName("should_always_be_alive"),
|
||||
{and_0_1, or_not0_not1})
|
||||
.output;
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user