Allows for reading program-only saved models in C++.

This enables reading of models produced with the
`experimental_program_only` save flag when using the experimental C++
infrastructure.

PiperOrigin-RevId: 356289681
Change-Id: I5ff34e00371564d5a0204965d1135104fea9845e
This commit is contained in:
Cesar Crusius 2021-02-08 09:58:20 -08:00 committed by TensorFlower Gardener
parent 4afbaca02c
commit e2b5e921b8

View File

@ -114,18 +114,27 @@ Status SavedModelV2Bundle::Load(const std::string& export_dir,
TF_RETURN_IF_ERROR(
ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
// Load the variables checkpoint reader.
const std::string variables_prefix = io::JoinPath(
export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename);
bundle->variable_reader_.reset(
new BundleReader(Env::Default(), variables_prefix));
TF_RETURN_WITH_CONTEXT_IF_ERROR(
bundle->variable_reader_->status(),
"Unable to load SavedModel variables checkpoint from ", variables_prefix);
const std::string variables_dir =
io::JoinPath(export_dir, kSavedModelVariablesDirectory);
if (!Env::Default()->FileExists(variables_dir).ok()) {
LOG(INFO)
<< "No checkpoint found, assuming this is a program-only SavedModel";
} else {
// Load the variables checkpoint reader.
const std::string variables_prefix =
io::JoinPath(variables_dir, kSavedModelVariablesFilename);
bundle->variable_reader_.reset(
new BundleReader(Env::Default(), variables_prefix));
TF_RETURN_WITH_CONTEXT_IF_ERROR(
bundle->variable_reader_->status(),
"Unable to load SavedModel variables checkpoint from ",
variables_prefix);
// Deserialize the object graph proto from the tensor bundle.
TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
}
// Deserialize the object graph proto from the tensor bundle.
TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
return Status::OK();
}