Add a helper function, GetSessionInitializerExportedName(), to return the
exported name of a session initializer function. PiperOrigin-RevId: 318107585 Change-Id: I9d2b8f85e9c261dd9069106ce5b8592a3db4e160
This commit is contained in:
parent
c204c0f893
commit
ce0660bb29
@ -91,6 +91,16 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) {
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function should have no output";
|
||||
|
||||
auto exported_names = GetExportedNames(init_func_op);
|
||||
|
||||
if (exported_names.empty())
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function should be exported";
|
||||
|
||||
if (exported_names.size() != 1)
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function should have only one exported names";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -429,5 +439,16 @@ void SessionInitializerOp::getCanonicalizationPatterns(
|
||||
results.insert<OptimizeSessionInitializerPattern>(context);
|
||||
}
|
||||
|
||||
llvm::Optional<StringRef> GetSessionInitializerExportedName(ModuleOp op) {
|
||||
auto session_initializer_op = GetSessionInitializerOp(op);
|
||||
if (!session_initializer_op) return llvm::None;
|
||||
|
||||
SymbolTable symbol_table(op);
|
||||
auto init_func_op =
|
||||
symbol_table.lookup<mlir::FuncOp>(session_initializer_op.initializer());
|
||||
auto exported_names = GetExportedNames(init_func_op);
|
||||
return exported_names[0];
|
||||
}
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
@ -65,6 +65,9 @@ Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor);
|
||||
// otherwise.
|
||||
SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op);
|
||||
|
||||
// Returns the exported name for the session initializer function.
|
||||
llvm::Optional<StringRef> GetSessionInitializerExportedName(mlir::ModuleOp op);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -352,3 +352,25 @@ module attributes {tf_saved_model.semantics} {
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function should be exported}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
func @init() attributes {sym_visibility = "private"} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function should have only one exported name}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
func @init() attributes { tf_saved_model.exported_names = ["a", "b"] } {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user