Intern predicate pointers

This is a performance optimization.

PiperOrigin-RevId: 205280010
This commit is contained in:
Sanjoy Das 2018-07-19 12:11:20 -07:00 committed by TensorFlower Gardener
parent 34db47821a
commit f330a1c892

View File

@ -44,10 +44,6 @@ class Predicate {
enum class Kind { kAnd, kOr, kNot, kSymbol };
virtual string ToString() const = 0;
virtual bool operator==(const Predicate& other) const = 0;
virtual bool operator!=(const Predicate& other) const {
return !(*this == other);
}
int64 hash() const { return hash_; }
virtual Kind kind() const = 0;
@ -58,6 +54,8 @@ class Predicate {
private:
const int64 hash_;
TF_DISALLOW_COPY_AND_ASSIGN(Predicate);
};
int64 HashPredicateSequence(Predicate::Kind kind,
@ -69,19 +67,6 @@ int64 HashPredicateSequence(Predicate::Kind kind,
return hash;
}
bool PredicateSequenceEqual(gtl::ArraySlice<Predicate*> lhs,
gtl::ArraySlice<Predicate*> rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (int64 i = 0; i < lhs.size(); i++) {
if (*lhs[i] != *rhs[i]) {
return false;
}
}
return true;
}
// Represents a logical conjunction of a set of predicates.
class AndPredicate : public Predicate {
public:
@ -102,17 +87,9 @@ class AndPredicate : public Predicate {
return strings::StrCat("(", str_util::Join(operands_str, " & "), ")");
}
bool operator==(const Predicate& other) const override {
return other.kind() == Kind::kAnd &&
PredicateSequenceEqual(
dynamic_cast<const AndPredicate&>(other).operands(), operands());
}
Kind kind() const override { return Kind::kAnd; }
const tensorflow::gtl::ArraySlice<Predicate*> operands() const {
return operands_;
}
const gtl::ArraySlice<Predicate*> operands() const { return operands_; }
private:
std::vector<Predicate*> operands_;
@ -138,16 +115,8 @@ class OrPredicate : public Predicate {
return strings::StrCat("(", str_util::Join(operands_str, " | "), ")");
}
bool operator==(const Predicate& other) const override {
return other.kind() == Kind::kOr &&
PredicateSequenceEqual(
dynamic_cast<const OrPredicate&>(other).operands(), operands());
}
Kind kind() const override { return Kind::kOr; }
const tensorflow::gtl::ArraySlice<Predicate*> operands() const {
return operands_;
}
const gtl::ArraySlice<Predicate*> operands() const { return operands_; }
private:
std::vector<Predicate*> operands_;
@ -164,11 +133,6 @@ class NotPredicate : public Predicate {
return strings::StrCat("~", operand()->ToString());
}
bool operator==(const Predicate& other) const override {
return other.kind() == Kind::kNot &&
*dynamic_cast<const NotPredicate&>(other).operand() == *operand();
}
Kind kind() const override { return Kind::kNot; }
Predicate* operand() const { return operand_; }
@ -188,14 +152,6 @@ class SymbolPredicate : public Predicate {
must_be_true_(must_be_true) {}
string ToString() const override { return tensor_id_.ToString(); }
bool operator==(const Predicate& other) const override {
return other.kind() == Kind::kSymbol &&
must_be_true() ==
dynamic_cast<const SymbolPredicate&>(other).must_be_true() &&
dynamic_cast<const SymbolPredicate&>(other).tensor_id() ==
tensor_id();
}
Kind kind() const override { return Kind::kSymbol; }
// If `must_be_true()` is true this SymbolPredicate represents the proposition
@ -225,16 +181,37 @@ class PredicateFactory {
Predicate* MakeAndPredicate(gtl::ArraySlice<Predicate*> operands) {
return MakeAndOrImpl(operands, /*is_and=*/true);
}
Predicate* MakeOrPredicate(gtl::ArraySlice<Predicate*> operands) {
return MakeAndOrImpl(operands, /*is_and=*/false);
}
Predicate* MakeNotPredicate(Predicate* pred) {
return Make<NotPredicate>(pred);
SignatureForNot signature = pred;
auto it = interned_not_instances_.find(signature);
if (it == interned_not_instances_.end()) {
std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
Predicate* new_pred_ptr = new_pred.get();
interned_not_instances_.emplace(signature, std::move(new_pred));
return new_pred_ptr;
} else {
return it->second.get();
}
}
Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) {
return Make<SymbolPredicate>(tensor_id, must_be_true);
SignatureForSymbol signature = {tensor_id, must_be_true};
auto it = interned_symbol_instances_.find(signature);
if (it == interned_symbol_instances_.end()) {
std::unique_ptr<Predicate> new_pred =
Make<SymbolPredicate>(tensor_id, must_be_true);
Predicate* new_pred_ptr = new_pred.get();
interned_symbol_instances_.emplace(std::move(signature),
std::move(new_pred));
return new_pred_ptr;
} else {
return it->second.get();
}
}
Predicate* MakeTrue() { return MakeAndPredicate({}); }
@ -242,29 +219,53 @@ class PredicateFactory {
private:
template <typename PredicateT, typename... Args>
Predicate* Make(Args... args) {
std::unique_ptr<PredicateT> pred(
std::unique_ptr<Predicate> Make(Args&&... args) {
return std::unique_ptr<PredicateT>(
new PredicateT(std::forward<Args>(args)...));
predicate_storage_.emplace_back(std::move(pred));
return predicate_storage_.back().get();
}
Predicate* MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, bool is_and);
struct PredicatePtrHash {
size_t operator()(const Predicate* pred) const { return pred->hash(); }
};
// Predicate instances are interned, meaning that there is only a single
// instance of a Predicate object with a given content. This makes checking
// for structural equality super-cheap -- we can just compare pointers.
//
// We intern predicates by maintaining a map from the content of a Predicate
// to the only instance of said predicate we allow to exist in the
// interned_and_or_instances_, interned_not_instances_ and
// interned_symbol_instances_ fields. These maps also double up as storage
// for the owning pointers to predicate instances.
struct PredicatePtrEq {
size_t operator()(const Predicate* a, const Predicate* b) const {
return *a == *b;
using SignatureForAndOr =
std::pair<Predicate::Kind, gtl::ArraySlice<Predicate*>>;
using SignatureForNot = Predicate*;
using SignatureForSymbol = std::pair<SafeTensorId, bool>;
struct HashSignatureForAndOr {
size_t operator()(const SignatureForAndOr& signature) const {
size_t hash = ::tensorflow::hash<Predicate::Kind>()(signature.first);
for (Predicate* p : signature.second) {
hash = Hash64Combine(hash, ::tensorflow::hash<Predicate*>()(p));
}
return hash;
}
};
using PredicateSet =
gtl::FlatSet<Predicate*, PredicatePtrHash, PredicatePtrEq>;
struct HashSignatureForSymbol {
size_t operator()(const SignatureForSymbol& signature) const {
return Hash64Combine(SafeTensorId::Hasher()(signature.first),
::tensorflow::hash<bool>()(signature.second));
}
};
std::vector<std::unique_ptr<Predicate>> predicate_storage_;
gtl::FlatMap<SignatureForAndOr, std::unique_ptr<Predicate>,
HashSignatureForAndOr>
interned_and_or_instances_;
gtl::FlatMap<SignatureForNot, std::unique_ptr<Predicate>>
interned_not_instances_;
gtl::FlatMap<SignatureForSymbol, std::unique_ptr<Predicate>,
HashSignatureForSymbol>
interned_symbol_instances_;
};
// Common code to create AndPredicate or OrPredicate instances.
@ -272,7 +273,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
bool is_and) {
Predicate::Kind pred_kind =
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
PredicateSet simplified_ops_set;
gtl::FlatSet<Predicate*> simplified_ops_set;
std::vector<Predicate*> simplified_ops;
for (Predicate* op : operands) {
// Simplify A&A => A and A|A => A.
@ -300,7 +301,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
}
// Simplify "A&~A=>False" and "A|~A=>True".
PredicateSet negated_ops;
gtl::FlatSet<Predicate*> negated_ops;
for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kNot) {
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
@ -317,8 +318,26 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
simplified_ops.begin(), simplified_ops.end(),
[](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
return is_and ? Make<AndPredicate>(std::move(simplified_ops))
: Make<OrPredicate>(std::move(simplified_ops));
auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
if (it == interned_and_or_instances_.end()) {
simplified_ops.shrink_to_fit();
// NB! Because we'll use a non-owning reference to simplified_ops in the
// key for interned_and_or_instances_ we need to be careful to std::move()
// it all the way through.
gtl::ArraySlice<Predicate*> operands_slice = simplified_ops;
std::unique_ptr<Predicate> new_pred =
is_and ? Make<AndPredicate>(std::move(simplified_ops))
: Make<OrPredicate>(std::move(simplified_ops));
Predicate* new_pred_ptr = new_pred.get();
CHECK(interned_and_or_instances_
.emplace(SignatureForAndOr(pred_kind, operands_slice),
std::move(new_pred))
.second);
return new_pred_ptr;
} else {
return it->second.get();
}
}
class DeadnessAnalysisImpl : public DeadnessAnalysis {
@ -491,8 +510,9 @@ bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
// Today we just compare the predicates for equality (with some
// canonicalization/simplification happening before) but we could be more
// sophisticated here if need be.
if (pred != nullptr && *pred != *it->second) {
// sophisticated here if need be. Comparing pointers is sufficient because
// we intern Predicate instances by their content.
if (pred != nullptr && pred != it->second) {
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
<< ") -> true";