Switch to Saver v2 checkpoint in SavedModel.

Change: 136486667
This commit is contained in:
Sukriti Ramesh 2016-10-18 09:11:23 -08:00 committed by TensorFlower Gardener
parent 11cc2f54be
commit 4e118daa4a
23 changed files with 19 additions and 21 deletions

View File

@ -31,6 +31,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core/util/tensor_bundle:naming",
],
)

View File

@ -36,10 +36,6 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
// SavedModel variables filename.
constexpr char kSavedModelVariablesFilename[] = "variables";
// SavedModel sharded variables filename.
constexpr char kSavedModelVariablesShardedFilename[] =
"variables-\?\?\?\?\?-of-\?\?\?\?\?";
// Commonly used tags.
constexpr char kSavedModelTagServe[] = "serve";
constexpr char kSavedModelTagTrain[] = "train";

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/tensor_bundle/naming.h"
namespace tensorflow {
namespace {
@ -87,17 +88,20 @@ Status Restore(const RunOptions& run_options, const string& export_dir,
const StringPiece variable_filename_const_op_name,
Session* session) {
// Find path to variables to be restored in export directory.
string variables_path =
const string variables_directory =
io::JoinPath(export_dir, kSavedModelVariablesDirectory);
const string unsharded_variables_path =
io::JoinPath(variables_path, kSavedModelVariablesFilename);
if (Env::Default()->FileExists(unsharded_variables_path)) {
variables_path = unsharded_variables_path;
} else {
const string sharded_variables_path =
io::JoinPath(variables_path, kSavedModelVariablesShardedFilename);
variables_path = sharded_variables_path;
// Check for saver checkpoints in v2 format. Models exported in the checkpoint
// v2 format will have a variables.index file. The corresponding
// variables are stored in the variables.data-?????-of-????? files.
const string variables_index_path = io::JoinPath(
variables_directory, MetaFilename(kSavedModelVariablesFilename));
if (!Env::Default()->FileExists(variables_index_path)) {
return errors::NotFound(
"Checkpoint index file not found in SavedModel directory.");
}
const string variables_path =
io::JoinPath(variables_directory, kSavedModelVariablesFilename);
// Add variables to the graph.
Tensor variables_path_tensor(DT_STRING, TensorShape({}));
variables_path_tensor.scalar<string>()() = variables_path;

View File

@ -1,2 +0,0 @@
model_checkpoint_path: "/tmp/saved_model/half_plus_two/variables/variables-?????-of-00001"
all_model_checkpoint_paths: "/tmp/saved_model/half_plus_two/variables/variables-?????-of-00001"

View File

@ -256,7 +256,7 @@ class SavedModelBuilder(object):
saver = tf_saver.Saver(
variables.all_variables(),
sharded=True,
write_version=saver_pb2.SaverDef.V1)
write_version=saver_pb2.SaverDef.V2)
meta_graph_def = saver.export_meta_graph()
# Tag the meta graph def and add it to the SavedModel.
@ -305,7 +305,7 @@ class SavedModelBuilder(object):
saver = tf_saver.Saver(
variables.all_variables(),
sharded=True,
write_version=saver_pb2.SaverDef.V1)
write_version=saver_pb2.SaverDef.V2)
saver.save(sess, variables_path, write_meta_graph=False)
meta_graph_def = saver.export_meta_graph()

View File

@ -31,4 +31,3 @@ TAG_TRAINING = "train"
VARIABLES_DIRECTORY = "variables"
VARIABLES_FILENAME = "variables"
VARIABLES_FILENAME_SHARDED = VARIABLES_FILENAME + "-?????-of-?????"

View File

@ -124,11 +124,11 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False):
def main(_):
export_dir_pb = "/tmp/saved_model/half_plus_two"
export_dir_pb = "/tmp/saved_model/v2_half_plus_two_unsharded"
_generate_saved_model_for_half_plus_two(export_dir_pb)
print("SavedModel generated at: %s" % export_dir_pb)
export_dir_pbtxt = "/tmp/saved_model/half_plus_two_pbtxt"
export_dir_pbtxt = "/tmp/saved_model/v2_half_plus_two_pbtxt"
_generate_saved_model_for_half_plus_two(export_dir_pbtxt, as_text=True)
print("SavedModel generated at: %s" % export_dir_pbtxt)

View File

@ -188,7 +188,7 @@ def load(sess, tags, export_dir):
variables_path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.VARIABLES_DIRECTORY),
compat.as_bytes(constants.VARIABLES_FILENAME_SHARDED))
compat.as_bytes(constants.VARIABLES_FILENAME))
# Restore the variables using the built saver in the provided session.
saver.restore(sess, variables_path)