From f330a1c8925a4a33bd0ea451656cfd80772979c3 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 19 Jul 2018 12:11:20 -0700 Subject: [PATCH] Intern predicate pointers This is a performance optimization. PiperOrigin-RevId: 205280010 --- tensorflow/compiler/jit/deadness_analysis.cc | 158 +++++++++++-------- 1 file changed, 89 insertions(+), 69 deletions(-) diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index b2d119029a4..d81e5fe9008 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -44,10 +44,6 @@ class Predicate { enum class Kind { kAnd, kOr, kNot, kSymbol }; virtual string ToString() const = 0; - virtual bool operator==(const Predicate& other) const = 0; - virtual bool operator!=(const Predicate& other) const { - return !(*this == other); - } int64 hash() const { return hash_; } virtual Kind kind() const = 0; @@ -58,6 +54,8 @@ class Predicate { private: const int64 hash_; + + TF_DISALLOW_COPY_AND_ASSIGN(Predicate); }; int64 HashPredicateSequence(Predicate::Kind kind, @@ -69,19 +67,6 @@ int64 HashPredicateSequence(Predicate::Kind kind, return hash; } -bool PredicateSequenceEqual(gtl::ArraySlice lhs, - gtl::ArraySlice rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - for (int64 i = 0; i < lhs.size(); i++) { - if (*lhs[i] != *rhs[i]) { - return false; - } - } - return true; -} - // Represents a logical conjunction of a set of predicates. class AndPredicate : public Predicate { public: @@ -102,17 +87,9 @@ class AndPredicate : public Predicate { return strings::StrCat("(", str_util::Join(operands_str, " & "), ")"); } - bool operator==(const Predicate& other) const override { - return other.kind() == Kind::kAnd && - PredicateSequenceEqual( - dynamic_cast(other).operands(), operands()); - } - Kind kind() const override { return Kind::kAnd; } - const tensorflow::gtl::ArraySlice operands() const { - return operands_; - } + const gtl::ArraySlice operands() const { return operands_; } private: std::vector operands_; @@ -138,16 +115,8 @@ class OrPredicate : public Predicate { return strings::StrCat("(", str_util::Join(operands_str, " | "), ")"); } - bool operator==(const Predicate& other) const override { - return other.kind() == Kind::kOr && - PredicateSequenceEqual( - dynamic_cast(other).operands(), operands()); - } - Kind kind() const override { return Kind::kOr; } - const tensorflow::gtl::ArraySlice operands() const { - return operands_; - } + const gtl::ArraySlice operands() const { return operands_; } private: std::vector operands_; @@ -164,11 +133,6 @@ class NotPredicate : public Predicate { return strings::StrCat("~", operand()->ToString()); } - bool operator==(const Predicate& other) const override { - return other.kind() == Kind::kNot && - *dynamic_cast(other).operand() == *operand(); - } - Kind kind() const override { return Kind::kNot; } Predicate* operand() const { return operand_; } @@ -188,14 +152,6 @@ class SymbolPredicate : public Predicate { must_be_true_(must_be_true) {} string ToString() const override { return tensor_id_.ToString(); } - bool operator==(const Predicate& other) const override { - return other.kind() == Kind::kSymbol && - must_be_true() == - dynamic_cast(other).must_be_true() && - dynamic_cast(other).tensor_id() == - tensor_id(); - } - Kind kind() const override { return Kind::kSymbol; } // If `must_be_true()` is true this SymbolPredicate represents the proposition @@ -225,16 +181,37 @@ class PredicateFactory { Predicate* MakeAndPredicate(gtl::ArraySlice operands) { return MakeAndOrImpl(operands, /*is_and=*/true); } + Predicate* MakeOrPredicate(gtl::ArraySlice operands) { return MakeAndOrImpl(operands, /*is_and=*/false); } Predicate* MakeNotPredicate(Predicate* pred) { - return Make(pred); + SignatureForNot signature = pred; + auto it = interned_not_instances_.find(signature); + if (it == interned_not_instances_.end()) { + std::unique_ptr new_pred = Make(pred); + Predicate* new_pred_ptr = new_pred.get(); + interned_not_instances_.emplace(signature, std::move(new_pred)); + return new_pred_ptr; + } else { + return it->second.get(); + } } Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { - return Make(tensor_id, must_be_true); + SignatureForSymbol signature = {tensor_id, must_be_true}; + auto it = interned_symbol_instances_.find(signature); + if (it == interned_symbol_instances_.end()) { + std::unique_ptr new_pred = + Make(tensor_id, must_be_true); + Predicate* new_pred_ptr = new_pred.get(); + interned_symbol_instances_.emplace(std::move(signature), + std::move(new_pred)); + return new_pred_ptr; + } else { + return it->second.get(); + } } Predicate* MakeTrue() { return MakeAndPredicate({}); } @@ -242,29 +219,53 @@ class PredicateFactory { private: template - Predicate* Make(Args... args) { - std::unique_ptr pred( + std::unique_ptr Make(Args&&... args) { + return std::unique_ptr( new PredicateT(std::forward(args)...)); - predicate_storage_.emplace_back(std::move(pred)); - return predicate_storage_.back().get(); } Predicate* MakeAndOrImpl(gtl::ArraySlice operands, bool is_and); - struct PredicatePtrHash { - size_t operator()(const Predicate* pred) const { return pred->hash(); } - }; + // Predicate instances are interned, meaning that there is only a single + // instance of a Predicate object with a given content. This makes checking + // for structural equality super-cheap -- we can just compare pointers. + // + // We intern predicates by maintaining a map from the content of a Predicate + // to the only instance of said predicate we allow to exist in the + // interned_and_or_instances_, interned_not_instances_ and + // interned_symbol_instances_ fields. These maps also double up as storage + // for the owning pointers to predicate instances. - struct PredicatePtrEq { - size_t operator()(const Predicate* a, const Predicate* b) const { - return *a == *b; + using SignatureForAndOr = + std::pair>; + using SignatureForNot = Predicate*; + using SignatureForSymbol = std::pair; + + struct HashSignatureForAndOr { + size_t operator()(const SignatureForAndOr& signature) const { + size_t hash = ::tensorflow::hash()(signature.first); + for (Predicate* p : signature.second) { + hash = Hash64Combine(hash, ::tensorflow::hash()(p)); + } + return hash; } }; - using PredicateSet = - gtl::FlatSet; + struct HashSignatureForSymbol { + size_t operator()(const SignatureForSymbol& signature) const { + return Hash64Combine(SafeTensorId::Hasher()(signature.first), + ::tensorflow::hash()(signature.second)); + } + }; - std::vector> predicate_storage_; + gtl::FlatMap, + HashSignatureForAndOr> + interned_and_or_instances_; + gtl::FlatMap> + interned_not_instances_; + gtl::FlatMap, + HashSignatureForSymbol> + interned_symbol_instances_; }; // Common code to create AndPredicate or OrPredicate instances. @@ -272,7 +273,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; - PredicateSet simplified_ops_set; + gtl::FlatSet simplified_ops_set; std::vector simplified_ops; for (Predicate* op : operands) { // Simplify A&A => A and A|A => A. @@ -300,7 +301,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, } // Simplify "A&~A=>False" and "A|~A=>True". - PredicateSet negated_ops; + gtl::FlatSet negated_ops; for (Predicate* op : simplified_ops) { if (op->kind() == Predicate::Kind::kNot) { negated_ops.insert(dynamic_cast(*op).operand()); @@ -317,8 +318,26 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, simplified_ops.begin(), simplified_ops.end(), [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); - return is_and ? Make(std::move(simplified_ops)) - : Make(std::move(simplified_ops)); + auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); + if (it == interned_and_or_instances_.end()) { + simplified_ops.shrink_to_fit(); + // NB! Because we'll use a non-owning reference to simplified_ops in the + // key for interned_and_or_instances_ we need to be careful to std::move() + // it all the way through. + gtl::ArraySlice operands_slice = simplified_ops; + std::unique_ptr new_pred = + is_and ? Make(std::move(simplified_ops)) + : Make(std::move(simplified_ops)); + + Predicate* new_pred_ptr = new_pred.get(); + CHECK(interned_and_or_instances_ + .emplace(SignatureForAndOr(pred_kind, operands_slice), + std::move(new_pred)) + .second); + return new_pred_ptr; + } else { + return it->second.get(); + } } class DeadnessAnalysisImpl : public DeadnessAnalysis { @@ -491,8 +510,9 @@ bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) { // Today we just compare the predicates for equality (with some // canonicalization/simplification happening before) but we could be more - // sophisticated here if need be. - if (pred != nullptr && *pred != *it->second) { + // sophisticated here if need be. Comparing pointers is sufficient because + // we intern Predicate instances by their content. + if (pred != nullptr && pred != it->second) { if (vlog_) { VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ") -> true";