Allows for reading program-only saved models in C++.
This enables reading of models produced with the `experimental_program_only` save flag when using the experimental C++ infrastructure. PiperOrigin-RevId: 356289681 Change-Id: I5ff34e00371564d5a0204965d1135104fea9845e
This commit is contained in:
parent
4afbaca02c
commit
e2b5e921b8
@ -114,18 +114,27 @@ Status SavedModelV2Bundle::Load(const std::string& export_dir,
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
|
||||
|
||||
// Load the variables checkpoint reader.
|
||||
const std::string variables_prefix = io::JoinPath(
|
||||
export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename);
|
||||
bundle->variable_reader_.reset(
|
||||
new BundleReader(Env::Default(), variables_prefix));
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
bundle->variable_reader_->status(),
|
||||
"Unable to load SavedModel variables checkpoint from ", variables_prefix);
|
||||
const std::string variables_dir =
|
||||
io::JoinPath(export_dir, kSavedModelVariablesDirectory);
|
||||
if (!Env::Default()->FileExists(variables_dir).ok()) {
|
||||
LOG(INFO)
|
||||
<< "No checkpoint found, assuming this is a program-only SavedModel";
|
||||
} else {
|
||||
// Load the variables checkpoint reader.
|
||||
const std::string variables_prefix =
|
||||
io::JoinPath(variables_dir, kSavedModelVariablesFilename);
|
||||
bundle->variable_reader_.reset(
|
||||
new BundleReader(Env::Default(), variables_prefix));
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
bundle->variable_reader_->status(),
|
||||
"Unable to load SavedModel variables checkpoint from ",
|
||||
variables_prefix);
|
||||
|
||||
// Deserialize the object graph proto from the tensor bundle.
|
||||
TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
|
||||
bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
|
||||
}
|
||||
|
||||
// Deserialize the object graph proto from the tensor bundle.
|
||||
TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
|
||||
bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user