Expose some methods for restoring SavedModels

PiperOrigin-RevId: 327083815
Change-Id: I6daad78e9705d691b969012bed7026a649359776
This commit is contained in:
A. Unique TensorFlower 2020-08-17 13:13:32 -07:00 committed by TensorFlower Gardener
parent cfdb6a0ca4
commit 6ae37dc80b
7 changed files with 133 additions and 60 deletions

View File

@ -47,6 +47,7 @@ cc_library(
# TODO(b/111634734): :lib and :protos_all contain dependencies that
# cannot be built on mobile platforms. Instead, include the appropriate
# tf_lib depending on the build platform.
"@com_google_absl//absl/memory:memory",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
]),
@ -171,6 +172,7 @@ tf_cc_test(
deps = [
":constants",
":loader",
":reader",
":signature_constants",
":tag_constants",
"//tensorflow/core:lib",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
@ -95,16 +96,6 @@ static Status ValidateSavedTensors(const GraphDef& graph_def) {
return Status::OK();
}
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
const SessionOptions& session_options,
std::unique_ptr<Session>* session) {
Session* session_p = nullptr;
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
session->reset(session_p);
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def()));
return (*session)->Create(meta_graph_def.graph_def());
}
Tensor CreateStringTensor(const string& value) {
Tensor tensor(DT_STRING, TensorShape({}));
tensor.scalar<tstring>()() = value;
@ -228,22 +219,18 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
nullptr /* outputs */, &run_metadata, session);
}
Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
LOG(INFO) << "Reading SavedModel debug info (if present) from: "
<< export_dir;
} // namespace
const string debug_info_pb_path =
io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
GraphDebugInfo debug_info;
TF_RETURN_IF_ERROR(
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
*debug_info_proto =
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
}
return Status::OK();
SavedModelBundleInterface::~SavedModelBundleInterface() {}
Status LoadMetagraphIntoSession(const SessionOptions& session_options,
const MetaGraphDef& meta_graph,
std::unique_ptr<Session>* session) {
Session* session_p = nullptr;
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
session->reset(session_p);
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph.graph_def()));
return (*session)->Create(meta_graph.graph_def());
}
Status LoadSavedModelInternal(const SessionOptions& session_options,
@ -251,46 +238,17 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundle* const bundle) {
const uint64 read_start_microseconds = Env::Default()->NowMicros();
TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
&bundle->meta_graph_def));
TF_RETURN_IF_ERROR(
ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info));
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
bundle->meta_graph_def, session_options, &bundle->session));
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
TF_RETURN_IF_ERROR(
RunRestore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(),
asset_file_defs, bundle->session.get()));
// Record walltime spent in restoring graph from disk, but postpone metric
// increments until graph init finishes.
const uint64 restore_graph_walltime =
GetLatencyMicroseconds(read_start_microseconds);
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
string init_op_name;
TF_RETURN_IF_ERROR(
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
asset_file_defs, bundle->session.get(),
init_op_name));
load_latency_by_stage->GetCell(export_dir, "restore_graph")
->Add(restore_graph_walltime);
// Record wall time spent in init op.
load_latency_by_stage->GetCell(export_dir, "init_graph")
->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
TF_RETURN_IF_ERROR(LoadMetagraphIntoSession(
session_options, bundle->meta_graph_def, &bundle->session));
TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def,
export_dir, &bundle->session));
return Status::OK();
}
} // namespace
SavedModelBundleInterface::~SavedModelBundleInterface() {}
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
@ -424,6 +382,35 @@ class LiteSessionWrapper : public Session {
};
} // namespace
Status RestoreSession(const RunOptions& run_options,
const MetaGraphDef& meta_graph, const string& export_dir,
std::unique_ptr<Session>* session) {
const uint64 read_start_microseconds = Env::Default()->NowMicros();
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
meta_graph.saver_def().restore_op_name(),
meta_graph.saver_def().filename_tensor_name(),
asset_file_defs, session->get()));
// Record walltime spent in restoring graph from disk, but postpone metric
// increments until graph init finishes.
const uint64 restore_graph_walltime =
GetLatencyMicroseconds(read_start_microseconds);
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
string init_op_name;
TF_RETURN_IF_ERROR(
internal::GetInitOp(export_dir, meta_graph, &init_op_name));
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, meta_graph,
asset_file_defs, session->get(), init_op_name));
load_latency_by_stage->GetCell(export_dir, "restore_graph")
->Add(restore_graph_walltime);
// Record wall time spent in init op.
load_latency_by_stage->GetCell(export_dir, "init_graph")
->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
return Status::OK();
}
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,

View File

@ -96,6 +96,21 @@ class SavedModelBundleLite : public SavedModelBundleInterface {
protobuf::Map<string, SignatureDef> signatures_;
};
// Restore variable and resources in the SavedModel export dir for the
// indicated metagraph.
// The recommended way to load a saved model is to call LoadSavedModel,
// which provides an already initialized Metagraph, Session, and DebugInfo.
Status RestoreSession(const RunOptions& run_options,
const MetaGraphDef& meta_graph, const string& export_dir,
std::unique_ptr<Session>* session);
// Initialize a session which wraps this metagraph.
// The recommended way to load a saved model is to call LoadSavedModel,
// which provides an already initialized Metagraph, Session, and DebugInfo.
Status LoadMetagraphIntoSession(const SessionOptions& session_options,
const MetaGraphDef& meta_graph,
std::unique_ptr<Session>* session);
/// Loads a SavedModel from the specified export directory. The MetaGraphDef
/// to be loaded is identified by the supplied tags, corresponding exactly to
/// the set of tags used at SavedModel build time. Stores a SavedModel bundle in

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <unordered_set>
#include "absl/memory/memory.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -86,4 +87,22 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
return Status::OK();
}
Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
LOG(INFO) << "Reading SavedModel debug info (if present) from: "
<< export_dir;
const string debug_info_pb_path =
io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
GraphDebugInfo debug_info;
TF_RETURN_IF_ERROR(
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
*debug_info_proto =
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
@ -34,6 +35,11 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
const std::unordered_set<string>& tags,
MetaGraphDef* const meta_graph_def);
// Store debug info from the SavedModel export dir.
Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto);
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_

View File

@ -106,5 +106,11 @@ TEST_F(ReaderTest, InvalidExportPath) {
EXPECT_FALSE(st.ok());
}
TEST_F(ReaderTest, ReadSavedModelDebugInfoIfPresent) {
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
std::unique_ptr<GraphDebugInfo> debug_info_proto;
TF_ASSERT_OK(ReadSavedModelDebugInfoIfPresent(export_dir, &debug_info_proto));
}
} // namespace
} // namespace tensorflow

View File

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/example/example.pb.h"
@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
namespace {
@ -131,6 +132,43 @@ TEST_F(LoaderTest, TagMatch) {
CheckSavedModelBundle(export_dir, bundle);
}
TEST_F(LoaderTest, ReadMetaGraphFromSavedModel) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
MetaGraphDef actual_metagraph;
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&actual_metagraph));
EXPECT_EQ(actual_metagraph.DebugString(),
bundle.meta_graph_def.DebugString());
}
TEST_F(LoaderTest, RestoreSession) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
SavedModelBundle actual_bundle;
const std::unordered_set<std::string> tags = {kSavedModelTagServe};
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, tags,
&actual_bundle.meta_graph_def));
TF_ASSERT_OK(LoadMetagraphIntoSession(
session_options, actual_bundle.meta_graph_def, &actual_bundle.session));
TF_ASSERT_OK(RestoreSession(run_options, actual_bundle.meta_graph_def,
export_dir, &actual_bundle.session));
CheckSavedModelBundle(export_dir, actual_bundle);
}
TEST_F(LoaderTest, NoTagMatch) {
SavedModelBundle bundle;
RunOptions run_options;