diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index c3ef6b403e0..dc8e15f379b 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -77,6 +77,7 @@ cc_library( "include/mlir/IR/Diagnostics.h", "include/mlir/IR/Dialect.h", "include/mlir/IR/DialectHooks.h", + "include/mlir/IR/DialectInterface.h", "include/mlir/IR/DialectSymbolRegistry.def", "include/mlir/IR/Function.h", "include/mlir/IR/FunctionSupport.h", diff --git a/third_party/mlir/include/mlir/IR/Dialect.h b/third_party/mlir/include/mlir/IR/Dialect.h index eef77112a54..683701f3bc4 100644 --- a/third_party/mlir/include/mlir/IR/Dialect.h +++ b/third_party/mlir/include/mlir/IR/Dialect.h @@ -25,6 +25,7 @@ #include "mlir/IR/OperationSupport.h" namespace mlir { +class DialectInterface; class OpBuilder; class Type; @@ -167,6 +168,21 @@ public: return success(); } + //===--------------------------------------------------------------------===// + // Interfaces + //===--------------------------------------------------------------------===// + + /// Lookup an interface for the given ID if one is registered, otherwise + /// nullptr. + const DialectInterface *getRegisteredInterface(ClassID *interfaceID) { + auto it = registeredInterfaces.find(interfaceID); + return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; + } + template <typename InterfaceT> const InterfaceT *getRegisteredInterface() { + return static_cast<const InterfaceT *>( + getRegisteredInterface(InterfaceT::getInterfaceID())); + } + protected: /// The constructor takes a unique namespace for this dialect as well as the /// context to bind to. @@ -237,6 +253,18 @@ protected: /// Enable support for unregistered types. void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; } + /// Register a dialect interface with this dialect instance. + void addInterface(std::unique_ptr<DialectInterface> interface); + + /// Register a set of dialect interfaces with this dialect instance. + template <typename T, typename T2, typename... Tys> void addInterfaces() { + addInterfaces<T>(); + addInterfaces<T2, Tys...>(); + } + template <typename T> void addInterfaces() { + addInterface(llvm::make_unique<T>(this)); + } + private: // Register a symbol(e.g. type) with its given unique class identifier. void addSymbol(const ClassID *const classID); @@ -263,6 +291,9 @@ private: /// types prefixed with the dialect namespace but not registered with addType. /// These types are represented with OpaqueType. bool unknownTypesAllowed = false; + + /// A collection of registered dialect interfaces. + DenseMap<ClassID *, std::unique_ptr<DialectInterface>> registeredInterfaces; }; using DialectAllocatorFunction = std::function<void(MLIRContext *)>; diff --git a/third_party/mlir/include/mlir/IR/DialectInterface.h b/third_party/mlir/include/mlir/IR/DialectInterface.h new file mode 100644 index 00000000000..f9151a5cc94 --- /dev/null +++ b/third_party/mlir/include/mlir/IR/DialectInterface.h @@ -0,0 +1,129 @@ +//===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_IR_DIALECTINTERFACE_H +#define MLIR_IR_DIALECTINTERFACE_H + +#include "mlir/IR/Dialect.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/DenseMap.h" + +namespace mlir { +class Dialect; +class MLIRContext; +class Operation; + +//===----------------------------------------------------------------------===// +// DialectInterface +//===----------------------------------------------------------------------===// +namespace detail { +/// The base class used for all derived interface types. This class provides +/// utilities necessary for registration. +template <typename ConcreteType, typename BaseT> +class DialectInterfaceBase : public BaseT { +public: + using Base = DialectInterfaceBase<ConcreteType, BaseT>; + + /// Get a unique id for the derived interface type. + static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); } + +protected: + DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {} +}; +} // end namespace detail + +/// This class represents an interface overridden for a single dialect. +class DialectInterface { +public: + virtual ~DialectInterface(); + + /// The base class used for all derived interface types. This class provides + /// utilities necessary for registration. + template <typename ConcreteType> + using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>; + + /// Return the dialect that this interface represents. + Dialect *getDialect() const { return dialect; } + + /// Return the derived interface id. + ClassID *getID() const { return interfaceID; } + +protected: + DialectInterface(Dialect *dialect, ClassID *id) + : dialect(dialect), interfaceID(id) {} + +private: + /// The dialect that represents this interface. + Dialect *dialect; + + /// The unique identifier for the derived interface type. + ClassID *interfaceID; +}; + +//===----------------------------------------------------------------------===// +// DialectInterfaceCollection +//===----------------------------------------------------------------------===// + +namespace detail { +/// This class is the base class for a collection of instances for a specific +/// interface kind. +class DialectInterfaceCollectionBase { +public: + DialectInterfaceCollectionBase(MLIRContext *ctx, ClassID *interfaceKind); + virtual ~DialectInterfaceCollectionBase(); + +protected: + /// Get the interface for the dialect of given operation, or null if one + /// is not registered. + const DialectInterface *getInterfaceFor(Operation *op) const; + + /// Get the interface for the given dialect. + const DialectInterface *getInterfaceFor(Dialect *dialect) const { + return interfaces.lookup(dialect); + } + +private: + /// A map of registered dialect interface instances. + DenseMap<Dialect *, const DialectInterface *> interfaces; +}; +} // namespace detail + +/// A collection of dialect interfaces within a context, for a given concrete +/// interface type. +template <typename InterfaceType> +class DialectInterfaceCollection + : public detail::DialectInterfaceCollectionBase { +public: + using Base = DialectInterfaceCollection<InterfaceType>; + + /// Collect the registered dialect interfaces within the provided context. + DialectInterfaceCollection(MLIRContext *ctx) + : detail::DialectInterfaceCollectionBase( + ctx, InterfaceType::getInterfaceID()) {} + + /// Get the interface for a given object, or null if one is not registered. + /// The object may be a dialect or an operation instance. + template <typename Object> + const InterfaceType *getInterfaceFor(Object *obj) const { + return static_cast<const InterfaceType *>( + detail::DialectInterfaceCollectionBase::getInterfaceFor(obj)); + } +}; + +} // namespace mlir + +#endif diff --git a/third_party/mlir/lib/IR/Dialect.cpp b/third_party/mlir/lib/IR/Dialect.cpp index 1170e06b5a9..8af99e536ae 100644 --- a/third_party/mlir/lib/IR/Dialect.cpp +++ b/third_party/mlir/lib/IR/Dialect.cpp @@ -18,12 +18,19 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectHooks.h" -#include "mlir/IR/Function.h" +#include "mlir/IR/DialectInterface.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" + using namespace mlir; +using namespace detail; + +//===----------------------------------------------------------------------===// +// Dialect Registration +//===----------------------------------------------------------------------===// // Registry for all dialect allocation functions. static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>> @@ -61,6 +68,10 @@ void mlir::registerAllDialects(MLIRContext *context) { } } +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + Dialect::Dialect(StringRef name, MLIRContext *context) : name(name), context(context) { assert(isValidNamespace(name) && "invalid dialect namespace"); @@ -107,3 +118,33 @@ bool Dialect::isValidNamespace(StringRef str) { llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); return dialectNameRegex.match(str); } + +/// Register a set of dialect interfaces with this dialect instance. +void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { + auto it = registeredInterfaces.try_emplace(interface->getID(), + std::move(interface)); + (void)it; + assert(it.second && "interface kind has already been registered"); +} + +//===----------------------------------------------------------------------===// +// Dialect Interface +//===----------------------------------------------------------------------===// + +DialectInterface::~DialectInterface() {} + +DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( + MLIRContext *ctx, ClassID *interfaceKind) { + for (auto *dialect : ctx->getRegisteredDialects()) + if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) + interfaces.try_emplace(dialect, interface); +} + +DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} + +/// Get the interface for the dialect of given operation, or null if one +/// is not registered. +const DialectInterface * +DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { + return interfaces.lookup(op->getDialect()); +}