From e9bda5601f92c4b30c9eb30e920bc50cbe4e8243 Mon Sep 17 00:00:00 2001 From: River Riddle <riverriddle@google.com> Date: Tue, 20 Aug 2019 18:49:08 -0700 Subject: [PATCH] NFC: Use a DenseSet instead of a DenseMap for DialectInterfaceCollection. The interfaces are looked up by dialect, which can always be retrieved from an interface instance. PiperOrigin-RevId: 264516023 --- .../mlir/include/mlir/IR/DialectInterface.h | 29 +++++++++++++++---- third_party/mlir/lib/IR/Dialect.cpp | 4 +-- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/third_party/mlir/include/mlir/IR/DialectInterface.h b/third_party/mlir/include/mlir/IR/DialectInterface.h index f9151a5cc94..bb9138873e1 100644 --- a/third_party/mlir/include/mlir/IR/DialectInterface.h +++ b/third_party/mlir/include/mlir/IR/DialectInterface.h @@ -18,9 +18,8 @@ #ifndef MLIR_IR_DIALECTINTERFACE_H #define MLIR_IR_DIALECTINTERFACE_H -#include "mlir/IR/Dialect.h" #include "mlir/Support/STLExtras.h" -#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" namespace mlir { class Dialect; @@ -82,6 +81,25 @@ namespace detail { /// This class is the base class for a collection of instances for a specific /// interface kind. class DialectInterfaceCollectionBase { + /// DenseMap info for dialect interfaces that allows lookup by the dialect. + struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> { + using DenseMapInfo<const DialectInterface *>::isEqual; + + static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); } + static unsigned getHashValue(const DialectInterface *key) { + return getHashValue(key->getDialect()); + } + + static bool isEqual(Dialect *lhs, const DialectInterface *rhs) { + if (rhs == getEmptyKey() || rhs == getTombstoneKey()) + return false; + return lhs == rhs->getDialect(); + } + }; + + /// A set of registered dialect interface instances. + using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>; + public: DialectInterfaceCollectionBase(MLIRContext *ctx, ClassID *interfaceKind); virtual ~DialectInterfaceCollectionBase(); @@ -93,12 +111,13 @@ protected: /// Get the interface for the given dialect. const DialectInterface *getInterfaceFor(Dialect *dialect) const { - return interfaces.lookup(dialect); + auto it = interfaces.find_as(dialect); + return it == interfaces.end() ? nullptr : *it; } private: - /// A map of registered dialect interface instances. - DenseMap<Dialect *, const DialectInterface *> interfaces; + /// A set of registered dialect interface instances. + InterfaceSetT interfaces; }; } // namespace detail diff --git a/third_party/mlir/lib/IR/Dialect.cpp b/third_party/mlir/lib/IR/Dialect.cpp index 8af99e536ae..470940a6326 100644 --- a/third_party/mlir/lib/IR/Dialect.cpp +++ b/third_party/mlir/lib/IR/Dialect.cpp @@ -137,7 +137,7 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( MLIRContext *ctx, ClassID *interfaceKind) { for (auto *dialect : ctx->getRegisteredDialects()) if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) - interfaces.try_emplace(dialect, interface); + interfaces.insert(interface); } DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} @@ -146,5 +146,5 @@ DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} /// is not registered. const DialectInterface * DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { - return interfaces.lookup(op->getDialect()); + return getInterfaceFor(op->getDialect()); }