Expose some methods for restoring SavedModels
PiperOrigin-RevId: 327083815 Change-Id: I6daad78e9705d691b969012bed7026a649359776
This commit is contained in:
parent
cfdb6a0ca4
commit
6ae37dc80b
tensorflow/cc/saved_model
@ -47,6 +47,7 @@ cc_library(
|
||||
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
||||
# cannot be built on mobile platforms. Instead, include the appropriate
|
||||
# tf_lib depending on the build platform.
|
||||
"@com_google_absl//absl/memory:memory",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]),
|
||||
@ -171,6 +172,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":constants",
|
||||
":loader",
|
||||
":reader",
|
||||
":signature_constants",
|
||||
":tag_constants",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/saver.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
@ -95,16 +96,6 @@ static Status ValidateSavedTensors(const GraphDef& graph_def) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||
const SessionOptions& session_options,
|
||||
std::unique_ptr<Session>* session) {
|
||||
Session* session_p = nullptr;
|
||||
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
|
||||
session->reset(session_p);
|
||||
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def()));
|
||||
return (*session)->Create(meta_graph_def.graph_def());
|
||||
}
|
||||
|
||||
Tensor CreateStringTensor(const string& value) {
|
||||
Tensor tensor(DT_STRING, TensorShape({}));
|
||||
tensor.scalar<tstring>()() = value;
|
||||
@ -228,22 +219,18 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
nullptr /* outputs */, &run_metadata, session);
|
||||
}
|
||||
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
|
||||
LOG(INFO) << "Reading SavedModel debug info (if present) from: "
|
||||
<< export_dir;
|
||||
} // namespace
|
||||
|
||||
const string debug_info_pb_path =
|
||||
io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
|
||||
if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
|
||||
GraphDebugInfo debug_info;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
|
||||
*debug_info_proto =
|
||||
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
|
||||
}
|
||||
return Status::OK();
|
||||
SavedModelBundleInterface::~SavedModelBundleInterface() {}
|
||||
|
||||
Status LoadMetagraphIntoSession(const SessionOptions& session_options,
|
||||
const MetaGraphDef& meta_graph,
|
||||
std::unique_ptr<Session>* session) {
|
||||
Session* session_p = nullptr;
|
||||
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
|
||||
session->reset(session_p);
|
||||
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph.graph_def()));
|
||||
return (*session)->Create(meta_graph.graph_def());
|
||||
}
|
||||
|
||||
Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
@ -251,46 +238,17 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
SavedModelBundle* const bundle) {
|
||||
const uint64 read_start_microseconds = Env::Default()->NowMicros();
|
||||
TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
|
||||
&bundle->meta_graph_def));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info));
|
||||
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
|
||||
bundle->meta_graph_def, session_options, &bundle->session));
|
||||
|
||||
std::vector<AssetFileDef> asset_file_defs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunRestore(run_options, export_dir,
|
||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
||||
asset_file_defs, bundle->session.get()));
|
||||
// Record walltime spent in restoring graph from disk, but postpone metric
|
||||
// increments until graph init finishes.
|
||||
const uint64 restore_graph_walltime =
|
||||
GetLatencyMicroseconds(read_start_microseconds);
|
||||
|
||||
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
|
||||
string init_op_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
|
||||
asset_file_defs, bundle->session.get(),
|
||||
init_op_name));
|
||||
load_latency_by_stage->GetCell(export_dir, "restore_graph")
|
||||
->Add(restore_graph_walltime);
|
||||
// Record wall time spent in init op.
|
||||
load_latency_by_stage->GetCell(export_dir, "init_graph")
|
||||
->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
|
||||
TF_RETURN_IF_ERROR(LoadMetagraphIntoSession(
|
||||
session_options, bundle->meta_graph_def, &bundle->session));
|
||||
TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def,
|
||||
export_dir, &bundle->session));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SavedModelBundleInterface::~SavedModelBundleInterface() {}
|
||||
|
||||
Status LoadSavedModel(const SessionOptions& session_options,
|
||||
const RunOptions& run_options, const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
@ -424,6 +382,35 @@ class LiteSessionWrapper : public Session {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Status RestoreSession(const RunOptions& run_options,
|
||||
const MetaGraphDef& meta_graph, const string& export_dir,
|
||||
std::unique_ptr<Session>* session) {
|
||||
const uint64 read_start_microseconds = Env::Default()->NowMicros();
|
||||
std::vector<AssetFileDef> asset_file_defs;
|
||||
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
|
||||
TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
|
||||
meta_graph.saver_def().restore_op_name(),
|
||||
meta_graph.saver_def().filename_tensor_name(),
|
||||
asset_file_defs, session->get()));
|
||||
// Record walltime spent in restoring graph from disk, but postpone metric
|
||||
// increments until graph init finishes.
|
||||
const uint64 restore_graph_walltime =
|
||||
GetLatencyMicroseconds(read_start_microseconds);
|
||||
|
||||
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
|
||||
string init_op_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
internal::GetInitOp(export_dir, meta_graph, &init_op_name));
|
||||
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, meta_graph,
|
||||
asset_file_defs, session->get(), init_op_name));
|
||||
load_latency_by_stage->GetCell(export_dir, "restore_graph")
|
||||
->Add(restore_graph_walltime);
|
||||
// Record wall time spent in init op.
|
||||
load_latency_by_stage->GetCell(export_dir, "init_graph")
|
||||
->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LoadSavedModel(const SessionOptions& session_options,
|
||||
const RunOptions& run_options, const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
|
@ -96,6 +96,21 @@ class SavedModelBundleLite : public SavedModelBundleInterface {
|
||||
protobuf::Map<string, SignatureDef> signatures_;
|
||||
};
|
||||
|
||||
// Restore variable and resources in the SavedModel export dir for the
|
||||
// indicated metagraph.
|
||||
// The recommended way to load a saved model is to call LoadSavedModel,
|
||||
// which provides an already initialized Metagraph, Session, and DebugInfo.
|
||||
Status RestoreSession(const RunOptions& run_options,
|
||||
const MetaGraphDef& meta_graph, const string& export_dir,
|
||||
std::unique_ptr<Session>* session);
|
||||
|
||||
// Initialize a session which wraps this metagraph.
|
||||
// The recommended way to load a saved model is to call LoadSavedModel,
|
||||
// which provides an already initialized Metagraph, Session, and DebugInfo.
|
||||
Status LoadMetagraphIntoSession(const SessionOptions& session_options,
|
||||
const MetaGraphDef& meta_graph,
|
||||
std::unique_ptr<Session>* session);
|
||||
|
||||
/// Loads a SavedModel from the specified export directory. The MetaGraphDef
|
||||
/// to be loaded is identified by the supplied tags, corresponding exactly to
|
||||
/// the set of tags used at SavedModel build time. Stores a SavedModel bundle in
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -86,4 +87,22 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
|
||||
LOG(INFO) << "Reading SavedModel debug info (if present) from: "
|
||||
<< export_dir;
|
||||
|
||||
const string debug_info_pb_path =
|
||||
io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
|
||||
if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
|
||||
GraphDebugInfo debug_info;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
|
||||
*debug_info_proto =
|
||||
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -34,6 +35,11 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
MetaGraphDef* const meta_graph_def);
|
||||
|
||||
// Store debug info from the SavedModel export dir.
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_
|
||||
|
@ -106,5 +106,11 @@ TEST_F(ReaderTest, InvalidExportPath) {
|
||||
EXPECT_FALSE(st.ok());
|
||||
}
|
||||
|
||||
TEST_F(ReaderTest, ReadSavedModelDebugInfoIfPresent) {
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
std::unique_ptr<GraphDebugInfo> debug_info_proto;
|
||||
TF_ASSERT_OK(ReadSavedModelDebugInfoIfPresent(export_dir, &debug_info_proto));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/cc/saved_model/signature_constants.h"
|
||||
#include "tensorflow/cc/saved_model/tag_constants.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -131,6 +132,43 @@ TEST_F(LoaderTest, TagMatch) {
|
||||
CheckSavedModelBundle(export_dir, bundle);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, ReadMetaGraphFromSavedModel) {
|
||||
SavedModelBundle bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
MetaGraphDef actual_metagraph;
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&actual_metagraph));
|
||||
EXPECT_EQ(actual_metagraph.DebugString(),
|
||||
bundle.meta_graph_def.DebugString());
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, RestoreSession) {
|
||||
SavedModelBundle bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
|
||||
SavedModelBundle actual_bundle;
|
||||
const std::unordered_set<std::string> tags = {kSavedModelTagServe};
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, tags,
|
||||
&actual_bundle.meta_graph_def));
|
||||
TF_ASSERT_OK(LoadMetagraphIntoSession(
|
||||
session_options, actual_bundle.meta_graph_def, &actual_bundle.session));
|
||||
TF_ASSERT_OK(RestoreSession(run_options, actual_bundle.meta_graph_def,
|
||||
export_dir, &actual_bundle.session));
|
||||
CheckSavedModelBundle(export_dir, actual_bundle);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, NoTagMatch) {
|
||||
SavedModelBundle bundle;
|
||||
RunOptions run_options;
|
||||
|
Loading…
Reference in New Issue
Block a user