From 4c268edb3ce8894302a1357b2e78cd0f1062ff89 Mon Sep 17 00:00:00 2001 From: Russell Power Date: Fri, 26 Jun 2020 19:01:15 -0700 Subject: [PATCH] Add custom op registration mechanism for the TF dialect. PiperOrigin-RevId: 318582669 Change-Id: I9be81ec073cc26b43f3d0f78576d546f48e9d1d1 --- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 8 ++++++ .../compiler/mlir/tensorflow/ir/tf_ops.h | 26 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 97903170fba..eb831ab9d1a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -4321,6 +4321,10 @@ struct TFInlinerInterface : public DialectInlinerInterface { #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc" +std::vector + *TensorFlowDialect::additional_operation_hooks_ = + new std::vector(); + TensorFlowDialect::TensorFlowDialect(MLIRContext *context) : Dialect(/*name=*/"tf", context) { addOperations< @@ -4338,6 +4342,10 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context) // Support unknown operations because not all TensorFlow operations are // registered. allowUnknownOperations(); + + for (auto hook : *TensorFlowDialect::additional_operation_hooks_) { + hook(*this); + } } namespace { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index dbc14485cdb..f37b71575f6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -84,6 +84,32 @@ class TensorFlowDialect : public Dialect { // value with the desired resultant type. Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) override; + + typedef std::function AdditionalOpFunction; + + // Register an op registration hook which is invoked during construction. + // + // A hook may use the public addOperations() method to add additional + // operations to the dialect. Hooks will only apply to subsequent + // instantations of the Dialect/MLIRContext. + static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) { + additional_operation_hooks_->push_back(std::move(fn)); + } + + // Re-define publicly the protected addOperations() method from the Dialect + // class, usually used in a Dialect constructor. This allows hook + // functions to register operations on the TensorFlow dialect using the + // same interface. + template + void addOperations() { + (void)std::initializer_list{ + 0, (addOperation(AbstractOperation::get(*this)), 0)...}; + } + + private: + // Hook functions which may add additional operations to the dialect. + // These are invoked at construction time. + static std::vector *additional_operation_hooks_; }; // TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose