Add custom op registration mechanism for the TF dialect.
PiperOrigin-RevId: 318582669 Change-Id: I9be81ec073cc26b43f3d0f78576d546f48e9d1d1
This commit is contained in:
parent
15f7e2fca2
commit
4c268edb3c
@ -4321,6 +4321,10 @@ struct TFInlinerInterface : public DialectInlinerInterface {
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc"
|
||||
|
||||
std::vector<TensorFlowDialect::AdditionalOpFunction>
|
||||
*TensorFlowDialect::additional_operation_hooks_ =
|
||||
new std::vector<TensorFlowDialect::AdditionalOpFunction>();
|
||||
|
||||
TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"tf", context) {
|
||||
addOperations<
|
||||
@ -4338,6 +4342,10 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
||||
// Support unknown operations because not all TensorFlow operations are
|
||||
// registered.
|
||||
allowUnknownOperations();
|
||||
|
||||
for (auto hook : *TensorFlowDialect::additional_operation_hooks_) {
|
||||
hook(*this);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -84,6 +84,32 @@ class TensorFlowDialect : public Dialect {
|
||||
// value with the desired resultant type.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
Location loc) override;
|
||||
|
||||
typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction;
|
||||
|
||||
// Register an op registration hook which is invoked during construction.
|
||||
//
|
||||
// A hook may use the public addOperations() method to add additional
|
||||
// operations to the dialect. Hooks will only apply to subsequent
|
||||
// instantations of the Dialect/MLIRContext.
|
||||
static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) {
|
||||
additional_operation_hooks_->push_back(std::move(fn));
|
||||
}
|
||||
|
||||
// Re-define publicly the protected addOperations() method from the Dialect
|
||||
// class, usually used in a Dialect constructor. This allows hook
|
||||
// functions to register operations on the TensorFlow dialect using the
|
||||
// same interface.
|
||||
template <typename... Args>
|
||||
void addOperations() {
|
||||
(void)std::initializer_list<int>{
|
||||
0, (addOperation(AbstractOperation::get<Args>(*this)), 0)...};
|
||||
}
|
||||
|
||||
private:
|
||||
// Hook functions which may add additional operations to the dialect.
|
||||
// These are invoked at construction time.
|
||||
static std::vector<AdditionalOpFunction> *additional_operation_hooks_;
|
||||
};
|
||||
|
||||
// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose
|
||||
|
Loading…
Reference in New Issue
Block a user