Switch to Saver v2 checkpoint in SavedModel.
Change: 136486667
This commit is contained in:
parent
11cc2f54be
commit
4e118daa4a
@ -31,6 +31,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core/util/tensor_bundle:naming",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -36,10 +36,6 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
|
||||
// SavedModel variables filename.
|
||||
constexpr char kSavedModelVariablesFilename[] = "variables";
|
||||
|
||||
// SavedModel sharded variables filename.
|
||||
constexpr char kSavedModelVariablesShardedFilename[] =
|
||||
"variables-\?\?\?\?\?-of-\?\?\?\?\?";
|
||||
|
||||
// Commonly used tags.
|
||||
constexpr char kSavedModelTagServe[] = "serve";
|
||||
constexpr char kSavedModelTagTrain[] = "train";
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/tensor_bundle/naming.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -87,17 +88,20 @@ Status Restore(const RunOptions& run_options, const string& export_dir,
|
||||
const StringPiece variable_filename_const_op_name,
|
||||
Session* session) {
|
||||
// Find path to variables to be restored in export directory.
|
||||
string variables_path =
|
||||
const string variables_directory =
|
||||
io::JoinPath(export_dir, kSavedModelVariablesDirectory);
|
||||
const string unsharded_variables_path =
|
||||
io::JoinPath(variables_path, kSavedModelVariablesFilename);
|
||||
if (Env::Default()->FileExists(unsharded_variables_path)) {
|
||||
variables_path = unsharded_variables_path;
|
||||
} else {
|
||||
const string sharded_variables_path =
|
||||
io::JoinPath(variables_path, kSavedModelVariablesShardedFilename);
|
||||
variables_path = sharded_variables_path;
|
||||
// Check for saver checkpoints in v2 format. Models exported in the checkpoint
|
||||
// v2 format will have a variables.index file. The corresponding
|
||||
// variables are stored in the variables.data-?????-of-????? files.
|
||||
const string variables_index_path = io::JoinPath(
|
||||
variables_directory, MetaFilename(kSavedModelVariablesFilename));
|
||||
if (!Env::Default()->FileExists(variables_index_path)) {
|
||||
return errors::NotFound(
|
||||
"Checkpoint index file not found in SavedModel directory.");
|
||||
}
|
||||
const string variables_path =
|
||||
io::JoinPath(variables_directory, kSavedModelVariablesFilename);
|
||||
|
||||
// Add variables to the graph.
|
||||
Tensor variables_path_tensor(DT_STRING, TensorShape({}));
|
||||
variables_path_tensor.scalar<string>()() = variables_path;
|
||||
|
Binary file not shown.
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two/variables/variables.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two/variables/variables.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two/variables/variables.index
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.index
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.index
vendored
Normal file
Binary file not shown.
Binary file not shown.
@ -1,2 +0,0 @@
|
||||
model_checkpoint_path: "/tmp/saved_model/half_plus_two/variables/variables-?????-of-00001"
|
||||
all_model_checkpoint_paths: "/tmp/saved_model/half_plus_two/variables/variables-?????-of-00001"
|
Binary file not shown.
Binary file not shown.
BIN
tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.index
vendored
Normal file
BIN
tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.index
vendored
Normal file
Binary file not shown.
@ -256,7 +256,7 @@ class SavedModelBuilder(object):
|
||||
saver = tf_saver.Saver(
|
||||
variables.all_variables(),
|
||||
sharded=True,
|
||||
write_version=saver_pb2.SaverDef.V1)
|
||||
write_version=saver_pb2.SaverDef.V2)
|
||||
meta_graph_def = saver.export_meta_graph()
|
||||
|
||||
# Tag the meta graph def and add it to the SavedModel.
|
||||
@ -305,7 +305,7 @@ class SavedModelBuilder(object):
|
||||
saver = tf_saver.Saver(
|
||||
variables.all_variables(),
|
||||
sharded=True,
|
||||
write_version=saver_pb2.SaverDef.V1)
|
||||
write_version=saver_pb2.SaverDef.V2)
|
||||
saver.save(sess, variables_path, write_meta_graph=False)
|
||||
meta_graph_def = saver.export_meta_graph()
|
||||
|
||||
|
@ -31,4 +31,3 @@ TAG_TRAINING = "train"
|
||||
|
||||
VARIABLES_DIRECTORY = "variables"
|
||||
VARIABLES_FILENAME = "variables"
|
||||
VARIABLES_FILENAME_SHARDED = VARIABLES_FILENAME + "-?????-of-?????"
|
||||
|
@ -124,11 +124,11 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False):
|
||||
|
||||
|
||||
def main(_):
|
||||
export_dir_pb = "/tmp/saved_model/half_plus_two"
|
||||
export_dir_pb = "/tmp/saved_model/v2_half_plus_two_unsharded"
|
||||
_generate_saved_model_for_half_plus_two(export_dir_pb)
|
||||
print("SavedModel generated at: %s" % export_dir_pb)
|
||||
|
||||
export_dir_pbtxt = "/tmp/saved_model/half_plus_two_pbtxt"
|
||||
export_dir_pbtxt = "/tmp/saved_model/v2_half_plus_two_pbtxt"
|
||||
_generate_saved_model_for_half_plus_two(export_dir_pbtxt, as_text=True)
|
||||
print("SavedModel generated at: %s" % export_dir_pbtxt)
|
||||
|
||||
|
@ -188,7 +188,7 @@ def load(sess, tags, export_dir):
|
||||
variables_path = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.VARIABLES_DIRECTORY),
|
||||
compat.as_bytes(constants.VARIABLES_FILENAME_SHARDED))
|
||||
compat.as_bytes(constants.VARIABLES_FILENAME))
|
||||
|
||||
# Restore the variables using the built saver in the provided session.
|
||||
saver.restore(sess, variables_path)
|
||||
|
Loading…
Reference in New Issue
Block a user