From ce0660bb292b57e2e59c0868550fc4d7208d30e4 Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Wed, 24 Jun 2020 11:44:21 -0700 Subject: [PATCH] Add a helper function, GetSessionInitializerExportedName(), to return the exported name of a session initializer function. PiperOrigin-RevId: 318107585 Change-Id: I9d2b8f85e9c261dd9069106ce5b8592a3db4e160 --- .../mlir/tensorflow/ir/tf_saved_model.cc | 21 ++++++++++++++++++ .../mlir/tensorflow/ir/tf_saved_model.h | 3 +++ .../tests/tf_saved_model_ops_invalid.mlir | 22 +++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 5a7d81d4c0c..38c0390acca 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -91,6 +91,16 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) { return session_initializer.emitOpError() << "the initializer function should have no output"; + auto exported_names = GetExportedNames(init_func_op); + + if (exported_names.empty()) + return session_initializer.emitOpError() + << "the initializer function should be exported"; + + if (exported_names.size() != 1) + return session_initializer.emitOpError() + << "the initializer function should have only one exported names"; + return success(); } @@ -429,5 +439,16 @@ void SessionInitializerOp::getCanonicalizationPatterns( results.insert(context); } +llvm::Optional GetSessionInitializerExportedName(ModuleOp op) { + auto session_initializer_op = GetSessionInitializerOp(op); + if (!session_initializer_op) return llvm::None; + + SymbolTable symbol_table(op); + auto init_func_op = + symbol_table.lookup(session_initializer_op.initializer()); + auto exported_names = GetExportedNames(init_func_op); + return exported_names[0]; +} + } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h index b6f8753cc51..056df4d6a43 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h @@ -65,6 +65,9 @@ Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor); // otherwise. SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op); +// Returns the exported name for the session initializer function. +llvm::Optional GetSessionInitializerExportedName(mlir::ModuleOp op); + } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir index 260174b184f..46eea9e508d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir @@ -352,3 +352,25 @@ module attributes {tf_saved_model.semantics} { return %0 : tensor<1xf32> } } + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{the initializer function should be exported}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() attributes {sym_visibility = "private"} { + return + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{the initializer function should have only one exported name}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() attributes { tf_saved_model.exported_names = ["a", "b"] } { + return + } +}