Add support for Dialect interfaces.
Dialect interfaces are virtual apis registered to a specific dialect instance. Dialect interfaces are generally useful for transformation passes, or analyses, that want to opaquely operate on operations within a given dialect. These interfaces generally involve wide coverage over the entire dialect. A dialect interface can be defined by inheriting from the CRTP base class DialectInterfaceBase::Base. This class provides the necessary utilities for registering an interface with the dialect so that it can be looked up later. Dialects overriding an interface may register an instance via 'Dialect::addInterfaces'. This API works very similarly to the respective addOperations/addTypes/etc. This will allow for a transformation/utility to later query the interface from an opaque dialect instance via 'getInterface<T>'. A utility class 'DialectInterfaceCollection' is also provided that will collect all of the dialects that implement a specific interface within a given module. This allows for simplifying the API of interface lookups. PiperOrigin-RevId: 263489015
This commit is contained in:
parent
dcdca11bcb
commit
3f1deea3cb
third_party/mlir
1
third_party/mlir/BUILD
vendored
1
third_party/mlir/BUILD
vendored
@ -77,6 +77,7 @@ cc_library(
|
||||
"include/mlir/IR/Diagnostics.h",
|
||||
"include/mlir/IR/Dialect.h",
|
||||
"include/mlir/IR/DialectHooks.h",
|
||||
"include/mlir/IR/DialectInterface.h",
|
||||
"include/mlir/IR/DialectSymbolRegistry.def",
|
||||
"include/mlir/IR/Function.h",
|
||||
"include/mlir/IR/FunctionSupport.h",
|
||||
|
31
third_party/mlir/include/mlir/IR/Dialect.h
vendored
31
third_party/mlir/include/mlir/IR/Dialect.h
vendored
@ -25,6 +25,7 @@
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
|
||||
namespace mlir {
|
||||
class DialectInterface;
|
||||
class OpBuilder;
|
||||
class Type;
|
||||
|
||||
@ -167,6 +168,21 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Interfaces
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Lookup an interface for the given ID if one is registered, otherwise
|
||||
/// nullptr.
|
||||
const DialectInterface *getRegisteredInterface(ClassID *interfaceID) {
|
||||
auto it = registeredInterfaces.find(interfaceID);
|
||||
return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
|
||||
}
|
||||
template <typename InterfaceT> const InterfaceT *getRegisteredInterface() {
|
||||
return static_cast<const InterfaceT *>(
|
||||
getRegisteredInterface(InterfaceT::getInterfaceID()));
|
||||
}
|
||||
|
||||
protected:
|
||||
/// The constructor takes a unique namespace for this dialect as well as the
|
||||
/// context to bind to.
|
||||
@ -237,6 +253,18 @@ protected:
|
||||
/// Enable support for unregistered types.
|
||||
void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
|
||||
|
||||
/// Register a dialect interface with this dialect instance.
|
||||
void addInterface(std::unique_ptr<DialectInterface> interface);
|
||||
|
||||
/// Register a set of dialect interfaces with this dialect instance.
|
||||
template <typename T, typename T2, typename... Tys> void addInterfaces() {
|
||||
addInterfaces<T>();
|
||||
addInterfaces<T2, Tys...>();
|
||||
}
|
||||
template <typename T> void addInterfaces() {
|
||||
addInterface(llvm::make_unique<T>(this));
|
||||
}
|
||||
|
||||
private:
|
||||
// Register a symbol(e.g. type) with its given unique class identifier.
|
||||
void addSymbol(const ClassID *const classID);
|
||||
@ -263,6 +291,9 @@ private:
|
||||
/// types prefixed with the dialect namespace but not registered with addType.
|
||||
/// These types are represented with OpaqueType.
|
||||
bool unknownTypesAllowed = false;
|
||||
|
||||
/// A collection of registered dialect interfaces.
|
||||
DenseMap<ClassID *, std::unique_ptr<DialectInterface>> registeredInterfaces;
|
||||
};
|
||||
|
||||
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
|
||||
|
129
third_party/mlir/include/mlir/IR/DialectInterface.h
vendored
Normal file
129
third_party/mlir/include/mlir/IR/DialectInterface.h
vendored
Normal file
@ -0,0 +1,129 @@
|
||||
//===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef MLIR_IR_DIALECTINTERFACE_H
|
||||
#define MLIR_IR_DIALECTINTERFACE_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
class Dialect;
|
||||
class MLIRContext;
|
||||
class Operation;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
namespace detail {
|
||||
/// The base class used for all derived interface types. This class provides
|
||||
/// utilities necessary for registration.
|
||||
template <typename ConcreteType, typename BaseT>
|
||||
class DialectInterfaceBase : public BaseT {
|
||||
public:
|
||||
using Base = DialectInterfaceBase<ConcreteType, BaseT>;
|
||||
|
||||
/// Get a unique id for the derived interface type.
|
||||
static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
|
||||
|
||||
protected:
|
||||
DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
|
||||
};
|
||||
} // end namespace detail
|
||||
|
||||
/// This class represents an interface overridden for a single dialect.
|
||||
class DialectInterface {
|
||||
public:
|
||||
virtual ~DialectInterface();
|
||||
|
||||
/// The base class used for all derived interface types. This class provides
|
||||
/// utilities necessary for registration.
|
||||
template <typename ConcreteType>
|
||||
using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>;
|
||||
|
||||
/// Return the dialect that this interface represents.
|
||||
Dialect *getDialect() const { return dialect; }
|
||||
|
||||
/// Return the derived interface id.
|
||||
ClassID *getID() const { return interfaceID; }
|
||||
|
||||
protected:
|
||||
DialectInterface(Dialect *dialect, ClassID *id)
|
||||
: dialect(dialect), interfaceID(id) {}
|
||||
|
||||
private:
|
||||
/// The dialect that represents this interface.
|
||||
Dialect *dialect;
|
||||
|
||||
/// The unique identifier for the derived interface type.
|
||||
ClassID *interfaceID;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectInterfaceCollection
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace detail {
|
||||
/// This class is the base class for a collection of instances for a specific
|
||||
/// interface kind.
|
||||
class DialectInterfaceCollectionBase {
|
||||
public:
|
||||
DialectInterfaceCollectionBase(MLIRContext *ctx, ClassID *interfaceKind);
|
||||
virtual ~DialectInterfaceCollectionBase();
|
||||
|
||||
protected:
|
||||
/// Get the interface for the dialect of given operation, or null if one
|
||||
/// is not registered.
|
||||
const DialectInterface *getInterfaceFor(Operation *op) const;
|
||||
|
||||
/// Get the interface for the given dialect.
|
||||
const DialectInterface *getInterfaceFor(Dialect *dialect) const {
|
||||
return interfaces.lookup(dialect);
|
||||
}
|
||||
|
||||
private:
|
||||
/// A map of registered dialect interface instances.
|
||||
DenseMap<Dialect *, const DialectInterface *> interfaces;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// A collection of dialect interfaces within a context, for a given concrete
|
||||
/// interface type.
|
||||
template <typename InterfaceType>
|
||||
class DialectInterfaceCollection
|
||||
: public detail::DialectInterfaceCollectionBase {
|
||||
public:
|
||||
using Base = DialectInterfaceCollection<InterfaceType>;
|
||||
|
||||
/// Collect the registered dialect interfaces within the provided context.
|
||||
DialectInterfaceCollection(MLIRContext *ctx)
|
||||
: detail::DialectInterfaceCollectionBase(
|
||||
ctx, InterfaceType::getInterfaceID()) {}
|
||||
|
||||
/// Get the interface for a given object, or null if one is not registered.
|
||||
/// The object may be a dialect or an operation instance.
|
||||
template <typename Object>
|
||||
const InterfaceType *getInterfaceFor(Object *obj) const {
|
||||
return static_cast<const InterfaceType *>(
|
||||
detail::DialectInterfaceCollectionBase::getInterfaceFor(obj));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
43
third_party/mlir/lib/IR/Dialect.cpp
vendored
43
third_party/mlir/lib/IR/Dialect.cpp
vendored
@ -18,12 +18,19 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/DialectHooks.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace detail;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Registry for all dialect allocation functions.
|
||||
static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
|
||||
@ -61,6 +68,10 @@ void mlir::registerAllDialects(MLIRContext *context) {
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Dialect::Dialect(StringRef name, MLIRContext *context)
|
||||
: name(name), context(context) {
|
||||
assert(isValidNamespace(name) && "invalid dialect namespace");
|
||||
@ -107,3 +118,33 @@ bool Dialect::isValidNamespace(StringRef str) {
|
||||
llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
|
||||
return dialectNameRegex.match(str);
|
||||
}
|
||||
|
||||
/// Register a set of dialect interfaces with this dialect instance.
|
||||
void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
|
||||
auto it = registeredInterfaces.try_emplace(interface->getID(),
|
||||
std::move(interface));
|
||||
(void)it;
|
||||
assert(it.second && "interface kind has already been registered");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect Interface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DialectInterface::~DialectInterface() {}
|
||||
|
||||
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
|
||||
MLIRContext *ctx, ClassID *interfaceKind) {
|
||||
for (auto *dialect : ctx->getRegisteredDialects())
|
||||
if (auto *interface = dialect->getRegisteredInterface(interfaceKind))
|
||||
interfaces.try_emplace(dialect, interface);
|
||||
}
|
||||
|
||||
DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
|
||||
|
||||
/// Get the interface for the dialect of given operation, or null if one
|
||||
/// is not registered.
|
||||
const DialectInterface *
|
||||
DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
|
||||
return interfaces.lookup(op->getDialect());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user