Remove dependency on the MLIR Global Dialect registry from third_party/tensorflow/compiler/mlir/tensorflow/... (NFC)
PiperOrigin-RevId: 328256230 Change-Id: I180650c53c9bbb790bead9d47ae546a3938387d1
This commit is contained in:
parent
9f069e0255
commit
e2efac4eca
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
@ -74,15 +75,9 @@ using tensorflow::tracing::TracingTensorHandle;
|
||||
|
||||
namespace {
|
||||
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
void RegisterDialects(mlir::MLIRContext& ctx) {
|
||||
mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
|
||||
ctx.getDialectRegistry().loadAll(&ctx);
|
||||
}
|
||||
|
||||
Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
|
||||
@ -239,6 +234,7 @@ class MlirFunctionContext : public TracingContext {
|
||||
: TracingContext(kMlir),
|
||||
context_(std::make_unique<MLIRContext>()),
|
||||
builder_(context_.get()) {
|
||||
RegisterDialects(*context_);
|
||||
// TODO(aminim) figure out the location story here
|
||||
module_ = ModuleOp::create(builder_.getUnknownLoc());
|
||||
func_ = FuncOp::create(builder_.getUnknownLoc(), name,
|
||||
@ -666,7 +662,6 @@ Status MlirFunctionContext::Finalize(OutputList* outputs,
|
||||
|
||||
extern "C" {
|
||||
TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
|
||||
RegisterDialects();
|
||||
return new MlirFunctionContext(fn_name);
|
||||
}
|
||||
}
|
||||
|
@ -64,6 +64,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
@ -141,6 +142,12 @@ bool IsResourceOutputShapesAttribute(const AttrValue& attr_value,
|
||||
return false;
|
||||
}
|
||||
|
||||
void LoadImporterDialects(mlir::MLIRContext& context) {
|
||||
// Load dialects involved in the conversion
|
||||
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
||||
context.getDialectRegistry().loadAll(&context);
|
||||
}
|
||||
|
||||
// This class is used to generate new MLIR function name strings that are both
|
||||
// unique in the TF function library `flib_` and unique among the name strings
|
||||
// generated by the class object during its lifetime.
|
||||
@ -2136,11 +2143,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
mlir::MLIRContext* context, const Graph& graph,
|
||||
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
|
||||
const GraphImportConfig& specs, llvm::StringRef func_name) {
|
||||
// Load dialects involved in the conversion
|
||||
context->loadDialect<mlir::StandardOpsDialect>();
|
||||
context->loadDialect<mlir::TF::TensorFlowDialect>();
|
||||
context->loadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
|
||||
LoadImporterDialects(*context);
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
@ -3197,6 +3200,7 @@ Status CreateSavedModelIR(
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
||||
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context, bool add_default_attributes) {
|
||||
LoadImporterDialects(*context);
|
||||
GraphDebugInfo dummy_debug_info;
|
||||
const GraphDebugInfo& debug_info =
|
||||
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
|
||||
@ -3276,6 +3280,7 @@ class SavedModelSignatureDefImporter {
|
||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||
const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context, bool upgrade_legacy) {
|
||||
LoadImporterDialects(*context);
|
||||
SavedModelSignatureDefImporter importer(bundle, exported_names, context);
|
||||
TF_RETURN_IF_ERROR(importer.InitializeGraph(upgrade_legacy));
|
||||
return importer.ConvertSignatures();
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
@ -86,6 +87,9 @@ static LogicalResult MlirToGraphdefTranslateFunction(
|
||||
}
|
||||
|
||||
static TranslateFromMLIRRegistration mlir_to_graphdef_translate(
|
||||
"mlir-to-graphdef", MlirToGraphdefTranslateFunction);
|
||||
"mlir-to-graphdef", MlirToGraphdefTranslateFunction,
|
||||
[](DialectRegistry& registry) {
|
||||
mlir::RegisterAllTensorFlowDialects(registry);
|
||||
});
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -67,6 +68,7 @@ static LogicalResult MlirToTfNodeDef(ModuleOp module,
|
||||
// Test only translation to convert a simple MLIR module with a single TF
|
||||
// dialect op to NodeDef.
|
||||
static TranslateFromMLIRRegistration translate_from_mlir_registration(
|
||||
"test-only-mlir-to-tf-nodedef", MlirToTfNodeDef);
|
||||
"test-only-mlir-to-tf-nodedef", MlirToTfNodeDef,
|
||||
mlir::RegisterAllTensorFlowDialects);
|
||||
|
||||
} // namespace mlir
|
||||
|
Loading…
x
Reference in New Issue
Block a user