From e9bda5601f92c4b30c9eb30e920bc50cbe4e8243 Mon Sep 17 00:00:00 2001
From: River Riddle <riverriddle@google.com>
Date: Tue, 20 Aug 2019 18:49:08 -0700
Subject: [PATCH] NFC: Use a DenseSet instead of a DenseMap for
 DialectInterfaceCollection.

The interfaces are looked up by dialect, which can always be retrieved from an interface instance.

PiperOrigin-RevId: 264516023
---
 .../mlir/include/mlir/IR/DialectInterface.h   | 29 +++++++++++++++----
 third_party/mlir/lib/IR/Dialect.cpp           |  4 +--
 2 files changed, 26 insertions(+), 7 deletions(-)

diff --git a/third_party/mlir/include/mlir/IR/DialectInterface.h b/third_party/mlir/include/mlir/IR/DialectInterface.h
index f9151a5cc94..bb9138873e1 100644
--- a/third_party/mlir/include/mlir/IR/DialectInterface.h
+++ b/third_party/mlir/include/mlir/IR/DialectInterface.h
@@ -18,9 +18,8 @@
 #ifndef MLIR_IR_DIALECTINTERFACE_H
 #define MLIR_IR_DIALECTINTERFACE_H
 
-#include "mlir/IR/Dialect.h"
 #include "mlir/Support/STLExtras.h"
-#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
 
 namespace mlir {
 class Dialect;
@@ -82,6 +81,25 @@ namespace detail {
 /// This class is the base class for a collection of instances for a specific
 /// interface kind.
 class DialectInterfaceCollectionBase {
+  /// DenseMap info for dialect interfaces that allows lookup by the dialect.
+  struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
+    using DenseMapInfo<const DialectInterface *>::isEqual;
+
+    static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); }
+    static unsigned getHashValue(const DialectInterface *key) {
+      return getHashValue(key->getDialect());
+    }
+
+    static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
+      if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+        return false;
+      return lhs == rhs->getDialect();
+    }
+  };
+
+  /// A set of registered dialect interface instances.
+  using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>;
+
 public:
   DialectInterfaceCollectionBase(MLIRContext *ctx, ClassID *interfaceKind);
   virtual ~DialectInterfaceCollectionBase();
@@ -93,12 +111,13 @@ protected:
 
   /// Get the interface for the given dialect.
   const DialectInterface *getInterfaceFor(Dialect *dialect) const {
-    return interfaces.lookup(dialect);
+    auto it = interfaces.find_as(dialect);
+    return it == interfaces.end() ? nullptr : *it;
   }
 
 private:
-  /// A map of registered dialect interface instances.
-  DenseMap<Dialect *, const DialectInterface *> interfaces;
+  /// A set of registered dialect interface instances.
+  InterfaceSetT interfaces;
 };
 } // namespace detail
 
diff --git a/third_party/mlir/lib/IR/Dialect.cpp b/third_party/mlir/lib/IR/Dialect.cpp
index 8af99e536ae..470940a6326 100644
--- a/third_party/mlir/lib/IR/Dialect.cpp
+++ b/third_party/mlir/lib/IR/Dialect.cpp
@@ -137,7 +137,7 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
     MLIRContext *ctx, ClassID *interfaceKind) {
   for (auto *dialect : ctx->getRegisteredDialects())
     if (auto *interface = dialect->getRegisteredInterface(interfaceKind))
-      interfaces.try_emplace(dialect, interface);
+      interfaces.insert(interface);
 }
 
 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
@@ -146,5 +146,5 @@ DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
 /// is not registered.
 const DialectInterface *
 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
-  return interfaces.lookup(op->getDialect());
+  return getInterfaceFor(op->getDialect());
 }