Add custom op registration mechanism for the TF dialect.

PiperOrigin-RevId: 318582669
Change-Id: I9be81ec073cc26b43f3d0f78576d546f48e9d1d1
This commit is contained in:
Russell Power 2020-06-26 19:01:15 -07:00 committed by TensorFlower Gardener
parent 15f7e2fca2
commit 4c268edb3c
2 changed files with 34 additions and 0 deletions

View File

@ -4321,6 +4321,10 @@ struct TFInlinerInterface : public DialectInlinerInterface {
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc"
std::vector<TensorFlowDialect::AdditionalOpFunction>
*TensorFlowDialect::additional_operation_hooks_ =
new std::vector<TensorFlowDialect::AdditionalOpFunction>();
TensorFlowDialect::TensorFlowDialect(MLIRContext *context) TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
: Dialect(/*name=*/"tf", context) { : Dialect(/*name=*/"tf", context) {
addOperations< addOperations<
@ -4338,6 +4342,10 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
// Support unknown operations because not all TensorFlow operations are // Support unknown operations because not all TensorFlow operations are
// registered. // registered.
allowUnknownOperations(); allowUnknownOperations();
for (auto hook : *TensorFlowDialect::additional_operation_hooks_) {
hook(*this);
}
} }
namespace { namespace {

View File

@ -84,6 +84,32 @@ class TensorFlowDialect : public Dialect {
// value with the desired resultant type. // value with the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override; Location loc) override;
typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction;
// Register an op registration hook which is invoked during construction.
//
// A hook may use the public addOperations() method to add additional
// operations to the dialect. Hooks will only apply to subsequent
// instantations of the Dialect/MLIRContext.
static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) {
additional_operation_hooks_->push_back(std::move(fn));
}
// Re-define publicly the protected addOperations() method from the Dialect
// class, usually used in a Dialect constructor. This allows hook
// functions to register operations on the TensorFlow dialect using the
// same interface.
template <typename... Args>
void addOperations() {
(void)std::initializer_list<int>{
0, (addOperation(AbstractOperation::get<Args>(*this)), 0)...};
}
private:
// Hook functions which may add additional operations to the dialect.
// These are invoked at construction time.
static std::vector<AdditionalOpFunction> *additional_operation_hooks_;
}; };
// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose // TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose