Add a helper function, GetSessionInitializerExportedName(), to return the

exported name of a session initializer function.

PiperOrigin-RevId: 318107585
Change-Id: I9d2b8f85e9c261dd9069106ce5b8592a3db4e160
This commit is contained in:
Kuangyuan Chen 2020-06-24 11:44:21 -07:00 committed by TensorFlower Gardener
parent c204c0f893
commit ce0660bb29
3 changed files with 46 additions and 0 deletions

View File

@ -91,6 +91,16 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) {
return session_initializer.emitOpError()
<< "the initializer function should have no output";
auto exported_names = GetExportedNames(init_func_op);
if (exported_names.empty())
return session_initializer.emitOpError()
<< "the initializer function should be exported";
if (exported_names.size() != 1)
return session_initializer.emitOpError()
<< "the initializer function should have only one exported names";
return success();
}
@ -429,5 +439,16 @@ void SessionInitializerOp::getCanonicalizationPatterns(
results.insert<OptimizeSessionInitializerPattern>(context);
}
llvm::Optional<StringRef> GetSessionInitializerExportedName(ModuleOp op) {
auto session_initializer_op = GetSessionInitializerOp(op);
if (!session_initializer_op) return llvm::None;
SymbolTable symbol_table(op);
auto init_func_op =
symbol_table.lookup<mlir::FuncOp>(session_initializer_op.initializer());
auto exported_names = GetExportedNames(init_func_op);
return exported_names[0];
}
} // namespace tf_saved_model
} // namespace mlir

View File

@ -65,6 +65,9 @@ Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor);
// otherwise.
SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op);
// Returns the exported name for the session initializer function.
llvm::Optional<StringRef> GetSessionInitializerExportedName(mlir::ModuleOp op);
} // namespace tf_saved_model
} // namespace mlir

View File

@ -352,3 +352,25 @@ module attributes {tf_saved_model.semantics} {
return %0 : tensor<1xf32>
}
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function should be exported}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
func @init() attributes {sym_visibility = "private"} {
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function should have only one exported name}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
func @init() attributes { tf_saved_model.exported_names = ["a", "b"] } {
return
}
}