Switch to Saver v2 checkpoint in SavedModel.

Change: 136486667
This commit is contained in:
Sukriti Ramesh 2016-10-18 09:11:23 -08:00 committed by TensorFlower Gardener
parent 11cc2f54be
commit 4e118daa4a
23 changed files with 19 additions and 21 deletions

View File

@ -31,6 +31,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",
"//tensorflow/core/util/tensor_bundle:naming",
], ],
) )

View File

@ -36,10 +36,6 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
// SavedModel variables filename. // SavedModel variables filename.
constexpr char kSavedModelVariablesFilename[] = "variables"; constexpr char kSavedModelVariablesFilename[] = "variables";
// SavedModel sharded variables filename.
constexpr char kSavedModelVariablesShardedFilename[] =
"variables-\?\?\?\?\?-of-\?\?\?\?\?";
// Commonly used tags. // Commonly used tags.
constexpr char kSavedModelTagServe[] = "serve"; constexpr char kSavedModelTagServe[] = "serve";
constexpr char kSavedModelTagTrain[] = "train"; constexpr char kSavedModelTagTrain[] = "train";

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/tensor_bundle/naming.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -87,17 +88,20 @@ Status Restore(const RunOptions& run_options, const string& export_dir,
const StringPiece variable_filename_const_op_name, const StringPiece variable_filename_const_op_name,
Session* session) { Session* session) {
// Find path to variables to be restored in export directory. // Find path to variables to be restored in export directory.
string variables_path = const string variables_directory =
io::JoinPath(export_dir, kSavedModelVariablesDirectory); io::JoinPath(export_dir, kSavedModelVariablesDirectory);
const string unsharded_variables_path = // Check for saver checkpoints in v2 format. Models exported in the checkpoint
io::JoinPath(variables_path, kSavedModelVariablesFilename); // v2 format will have a variables.index file. The corresponding
if (Env::Default()->FileExists(unsharded_variables_path)) { // variables are stored in the variables.data-?????-of-????? files.
variables_path = unsharded_variables_path; const string variables_index_path = io::JoinPath(
} else { variables_directory, MetaFilename(kSavedModelVariablesFilename));
const string sharded_variables_path = if (!Env::Default()->FileExists(variables_index_path)) {
io::JoinPath(variables_path, kSavedModelVariablesShardedFilename); return errors::NotFound(
variables_path = sharded_variables_path; "Checkpoint index file not found in SavedModel directory.");
} }
const string variables_path =
io::JoinPath(variables_directory, kSavedModelVariablesFilename);
// Add variables to the graph. // Add variables to the graph.
Tensor variables_path_tensor(DT_STRING, TensorShape({})); Tensor variables_path_tensor(DT_STRING, TensorShape({}));
variables_path_tensor.scalar<string>()() = variables_path; variables_path_tensor.scalar<string>()() = variables_path;

View File

@ -1,2 +0,0 @@
model_checkpoint_path: "/tmp/saved_model/half_plus_two/variables/variables-?????-of-00001"
all_model_checkpoint_paths: "/tmp/saved_model/half_plus_two/variables/variables-?????-of-00001"

View File

@ -256,7 +256,7 @@ class SavedModelBuilder(object):
saver = tf_saver.Saver( saver = tf_saver.Saver(
variables.all_variables(), variables.all_variables(),
sharded=True, sharded=True,
write_version=saver_pb2.SaverDef.V1) write_version=saver_pb2.SaverDef.V2)
meta_graph_def = saver.export_meta_graph() meta_graph_def = saver.export_meta_graph()
# Tag the meta graph def and add it to the SavedModel. # Tag the meta graph def and add it to the SavedModel.
@ -305,7 +305,7 @@ class SavedModelBuilder(object):
saver = tf_saver.Saver( saver = tf_saver.Saver(
variables.all_variables(), variables.all_variables(),
sharded=True, sharded=True,
write_version=saver_pb2.SaverDef.V1) write_version=saver_pb2.SaverDef.V2)
saver.save(sess, variables_path, write_meta_graph=False) saver.save(sess, variables_path, write_meta_graph=False)
meta_graph_def = saver.export_meta_graph() meta_graph_def = saver.export_meta_graph()

View File

@ -31,4 +31,3 @@ TAG_TRAINING = "train"
VARIABLES_DIRECTORY = "variables" VARIABLES_DIRECTORY = "variables"
VARIABLES_FILENAME = "variables" VARIABLES_FILENAME = "variables"
VARIABLES_FILENAME_SHARDED = VARIABLES_FILENAME + "-?????-of-?????"

View File

@ -124,11 +124,11 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False):
def main(_): def main(_):
export_dir_pb = "/tmp/saved_model/half_plus_two" export_dir_pb = "/tmp/saved_model/v2_half_plus_two_unsharded"
_generate_saved_model_for_half_plus_two(export_dir_pb) _generate_saved_model_for_half_plus_two(export_dir_pb)
print("SavedModel generated at: %s" % export_dir_pb) print("SavedModel generated at: %s" % export_dir_pb)
export_dir_pbtxt = "/tmp/saved_model/half_plus_two_pbtxt" export_dir_pbtxt = "/tmp/saved_model/v2_half_plus_two_pbtxt"
_generate_saved_model_for_half_plus_two(export_dir_pbtxt, as_text=True) _generate_saved_model_for_half_plus_two(export_dir_pbtxt, as_text=True)
print("SavedModel generated at: %s" % export_dir_pbtxt) print("SavedModel generated at: %s" % export_dir_pbtxt)

View File

@ -188,7 +188,7 @@ def load(sess, tags, export_dir):
variables_path = os.path.join( variables_path = os.path.join(
compat.as_bytes(export_dir), compat.as_bytes(export_dir),
compat.as_bytes(constants.VARIABLES_DIRECTORY), compat.as_bytes(constants.VARIABLES_DIRECTORY),
compat.as_bytes(constants.VARIABLES_FILENAME_SHARDED)) compat.as_bytes(constants.VARIABLES_FILENAME))
# Restore the variables using the built saver in the provided session. # Restore the variables using the built saver in the provided session.
saver.restore(sess, variables_path) saver.restore(sess, variables_path)