diff --git a/third_party/mlir/include/mlir/IR/DialectInterface.h b/third_party/mlir/include/mlir/IR/DialectInterface.h index f9151a5cc94..bb9138873e1 100644 --- a/third_party/mlir/include/mlir/IR/DialectInterface.h +++ b/third_party/mlir/include/mlir/IR/DialectInterface.h @@ -18,9 +18,8 @@ #ifndef MLIR_IR_DIALECTINTERFACE_H #define MLIR_IR_DIALECTINTERFACE_H -#include "mlir/IR/Dialect.h" #include "mlir/Support/STLExtras.h" -#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" namespace mlir { class Dialect; @@ -82,6 +81,25 @@ namespace detail { /// This class is the base class for a collection of instances for a specific /// interface kind. class DialectInterfaceCollectionBase { + /// DenseMap info for dialect interfaces that allows lookup by the dialect. + struct InterfaceKeyInfo : public DenseMapInfo { + using DenseMapInfo::isEqual; + + static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); } + static unsigned getHashValue(const DialectInterface *key) { + return getHashValue(key->getDialect()); + } + + static bool isEqual(Dialect *lhs, const DialectInterface *rhs) { + if (rhs == getEmptyKey() || rhs == getTombstoneKey()) + return false; + return lhs == rhs->getDialect(); + } + }; + + /// A set of registered dialect interface instances. + using InterfaceSetT = DenseSet; + public: DialectInterfaceCollectionBase(MLIRContext *ctx, ClassID *interfaceKind); virtual ~DialectInterfaceCollectionBase(); @@ -93,12 +111,13 @@ protected: /// Get the interface for the given dialect. const DialectInterface *getInterfaceFor(Dialect *dialect) const { - return interfaces.lookup(dialect); + auto it = interfaces.find_as(dialect); + return it == interfaces.end() ? nullptr : *it; } private: - /// A map of registered dialect interface instances. - DenseMap interfaces; + /// A set of registered dialect interface instances. + InterfaceSetT interfaces; }; } // namespace detail diff --git a/third_party/mlir/lib/IR/Dialect.cpp b/third_party/mlir/lib/IR/Dialect.cpp index 8af99e536ae..470940a6326 100644 --- a/third_party/mlir/lib/IR/Dialect.cpp +++ b/third_party/mlir/lib/IR/Dialect.cpp @@ -137,7 +137,7 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( MLIRContext *ctx, ClassID *interfaceKind) { for (auto *dialect : ctx->getRegisteredDialects()) if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) - interfaces.try_emplace(dialect, interface); + interfaces.insert(interface); } DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} @@ -146,5 +146,5 @@ DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} /// is not registered. const DialectInterface * DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { - return interfaces.lookup(op->getDialect()); + return getInterfaceFor(op->getDialect()); }