Add custom op registration mechanism for the TF dialect.
PiperOrigin-RevId: 318582669 Change-Id: I9be81ec073cc26b43f3d0f78576d546f48e9d1d1
This commit is contained in:
parent
15f7e2fca2
commit
4c268edb3c
@ -4321,6 +4321,10 @@ struct TFInlinerInterface : public DialectInlinerInterface {
|
|||||||
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc"
|
||||||
|
|
||||||
|
std::vector<TensorFlowDialect::AdditionalOpFunction>
|
||||||
|
*TensorFlowDialect::additional_operation_hooks_ =
|
||||||
|
new std::vector<TensorFlowDialect::AdditionalOpFunction>();
|
||||||
|
|
||||||
TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
||||||
: Dialect(/*name=*/"tf", context) {
|
: Dialect(/*name=*/"tf", context) {
|
||||||
addOperations<
|
addOperations<
|
||||||
@ -4338,6 +4342,10 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
|||||||
// Support unknown operations because not all TensorFlow operations are
|
// Support unknown operations because not all TensorFlow operations are
|
||||||
// registered.
|
// registered.
|
||||||
allowUnknownOperations();
|
allowUnknownOperations();
|
||||||
|
|
||||||
|
for (auto hook : *TensorFlowDialect::additional_operation_hooks_) {
|
||||||
|
hook(*this);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -84,6 +84,32 @@ class TensorFlowDialect : public Dialect {
|
|||||||
// value with the desired resultant type.
|
// value with the desired resultant type.
|
||||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||||
Location loc) override;
|
Location loc) override;
|
||||||
|
|
||||||
|
typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction;
|
||||||
|
|
||||||
|
// Register an op registration hook which is invoked during construction.
|
||||||
|
//
|
||||||
|
// A hook may use the public addOperations() method to add additional
|
||||||
|
// operations to the dialect. Hooks will only apply to subsequent
|
||||||
|
// instantations of the Dialect/MLIRContext.
|
||||||
|
static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) {
|
||||||
|
additional_operation_hooks_->push_back(std::move(fn));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-define publicly the protected addOperations() method from the Dialect
|
||||||
|
// class, usually used in a Dialect constructor. This allows hook
|
||||||
|
// functions to register operations on the TensorFlow dialect using the
|
||||||
|
// same interface.
|
||||||
|
template <typename... Args>
|
||||||
|
void addOperations() {
|
||||||
|
(void)std::initializer_list<int>{
|
||||||
|
0, (addOperation(AbstractOperation::get<Args>(*this)), 0)...};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Hook functions which may add additional operations to the dialect.
|
||||||
|
// These are invoked at construction time.
|
||||||
|
static std::vector<AdditionalOpFunction> *additional_operation_hooks_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose
|
// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose
|
||||||
|
Loading…
Reference in New Issue
Block a user