Added unit test for max_to_keep being None.
Change: 115516426
This commit is contained in:
parent
77da168dbc
commit
4ecd2a70dd
@ -37,6 +37,14 @@ from tensorflow.python.framework import function
|
|||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
|
|
||||||
|
|
||||||
|
def _TestDir(test_name):
|
||||||
|
test_dir = os.path.join(tf.test.get_temp_dir(), test_name)
|
||||||
|
if os.path.exists(test_dir):
|
||||||
|
shutil.rmtree(test_dir)
|
||||||
|
gfile.MakeDirs(test_dir)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
|
|
||||||
class SaverTest(tf.test.TestCase):
|
class SaverTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testBasics(self):
|
def testBasics(self):
|
||||||
@ -349,12 +357,7 @@ class SaveRestoreShardedTest(tf.test.TestCase):
|
|||||||
class MaxToKeepTest(tf.test.TestCase):
|
class MaxToKeepTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testNonSharded(self):
|
def testNonSharded(self):
|
||||||
save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_non_sharded")
|
save_dir = _TestDir("max_to_keep_non_sharded")
|
||||||
try:
|
|
||||||
gfile.DeleteRecursively(save_dir)
|
|
||||||
except OSError:
|
|
||||||
pass # Ignore
|
|
||||||
gfile.MakeDirs(save_dir)
|
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
v = tf.Variable(10.0, name="v")
|
v = tf.Variable(10.0, name="v")
|
||||||
@ -456,12 +459,7 @@ class MaxToKeepTest(tf.test.TestCase):
|
|||||||
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
|
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
|
||||||
|
|
||||||
def testSharded(self):
|
def testSharded(self):
|
||||||
save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_sharded")
|
save_dir = _TestDir("max_to_keep_sharded")
|
||||||
try:
|
|
||||||
gfile.DeleteRecursively(save_dir)
|
|
||||||
except OSError:
|
|
||||||
pass # Ignore
|
|
||||||
gfile.MakeDirs(save_dir)
|
|
||||||
|
|
||||||
with tf.Session(
|
with tf.Session(
|
||||||
target="",
|
target="",
|
||||||
@ -495,17 +493,39 @@ class MaxToKeepTest(tf.test.TestCase):
|
|||||||
self.assertEqual(2, len(gfile.Glob(s3)))
|
self.assertEqual(2, len(gfile.Glob(s3)))
|
||||||
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
|
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
|
||||||
|
|
||||||
|
def testNoMaxToKeep(self):
|
||||||
|
save_dir = _TestDir("no_max_to_keep")
|
||||||
|
save_dir2 = _TestDir("max_to_keep_0")
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
v = tf.Variable(10.0, name="v")
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
|
# Test max_to_keep being None.
|
||||||
|
save = tf.train.Saver({"v": v}, max_to_keep=None)
|
||||||
|
self.assertEqual([], save.last_checkpoints)
|
||||||
|
s1 = save.save(sess, os.path.join(save_dir, "s1"))
|
||||||
|
self.assertEqual([], save.last_checkpoints)
|
||||||
|
self.assertTrue(gfile.Exists(s1))
|
||||||
|
s2 = save.save(sess, os.path.join(save_dir, "s2"))
|
||||||
|
self.assertEqual([], save.last_checkpoints)
|
||||||
|
self.assertTrue(gfile.Exists(s2))
|
||||||
|
|
||||||
|
# Test max_to_keep being 0.
|
||||||
|
save2 = tf.train.Saver({"v": v}, max_to_keep=0)
|
||||||
|
self.assertEqual([], save2.last_checkpoints)
|
||||||
|
s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
|
||||||
|
self.assertEqual([], save2.last_checkpoints)
|
||||||
|
self.assertTrue(gfile.Exists(s1))
|
||||||
|
s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
|
||||||
|
self.assertEqual([], save2.last_checkpoints)
|
||||||
|
self.assertTrue(gfile.Exists(s2))
|
||||||
|
|
||||||
|
|
||||||
class KeepCheckpointEveryNHoursTest(tf.test.TestCase):
|
class KeepCheckpointEveryNHoursTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testNonSharded(self):
|
def testNonSharded(self):
|
||||||
save_dir = os.path.join(self.get_temp_dir(),
|
save_dir = _TestDir("keep_checkpoint_every_n_hours")
|
||||||
"keep_checkpoint_every_n_hours")
|
|
||||||
try:
|
|
||||||
gfile.DeleteRecursively(save_dir)
|
|
||||||
except OSError:
|
|
||||||
pass # Ignore
|
|
||||||
gfile.MakeDirs(save_dir)
|
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
v = tf.Variable([10.0], name="v")
|
v = tf.Variable([10.0], name="v")
|
||||||
@ -685,15 +705,8 @@ class LatestCheckpointWithRelativePaths(tf.test.TestCase):
|
|||||||
|
|
||||||
class CheckpointStateTest(tf.test.TestCase):
|
class CheckpointStateTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _TestDir(self, test_name):
|
|
||||||
test_dir = os.path.join(self.get_temp_dir(), test_name)
|
|
||||||
if os.path.exists(test_dir):
|
|
||||||
shutil.rmtree(test_dir)
|
|
||||||
gfile.MakeDirs(test_dir)
|
|
||||||
return test_dir
|
|
||||||
|
|
||||||
def testAbsPath(self):
|
def testAbsPath(self):
|
||||||
save_dir = self._TestDir("abs_paths")
|
save_dir = _TestDir("abs_paths")
|
||||||
abs_path = os.path.join(save_dir, "model-0")
|
abs_path = os.path.join(save_dir, "model-0")
|
||||||
ckpt = tf.train.generate_checkpoint_state_proto(save_dir, abs_path)
|
ckpt = tf.train.generate_checkpoint_state_proto(save_dir, abs_path)
|
||||||
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
|
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
|
||||||
@ -712,7 +725,7 @@ class CheckpointStateTest(tf.test.TestCase):
|
|||||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
|
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
|
||||||
|
|
||||||
def testAllModelCheckpointPaths(self):
|
def testAllModelCheckpointPaths(self):
|
||||||
save_dir = self._TestDir("all_models_test")
|
save_dir = _TestDir("all_models_test")
|
||||||
abs_path = os.path.join(save_dir, "model-0")
|
abs_path = os.path.join(save_dir, "model-0")
|
||||||
for paths in [None, [], ["model-2"]]:
|
for paths in [None, [], ["model-2"]]:
|
||||||
ckpt = tf.train.generate_checkpoint_state_proto(
|
ckpt = tf.train.generate_checkpoint_state_proto(
|
||||||
@ -726,7 +739,7 @@ class CheckpointStateTest(tf.test.TestCase):
|
|||||||
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
|
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
|
||||||
|
|
||||||
def testUpdateCheckpointState(self):
|
def testUpdateCheckpointState(self):
|
||||||
save_dir = self._TestDir("update_checkpoint_state")
|
save_dir = _TestDir("update_checkpoint_state")
|
||||||
os.chdir(save_dir)
|
os.chdir(save_dir)
|
||||||
# Make a temporary train directory.
|
# Make a temporary train directory.
|
||||||
train_dir = "train"
|
train_dir = "train"
|
||||||
@ -746,15 +759,8 @@ class CheckpointStateTest(tf.test.TestCase):
|
|||||||
|
|
||||||
class MetaGraphTest(tf.test.TestCase):
|
class MetaGraphTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _TestDir(self, test_name):
|
|
||||||
test_dir = os.path.join(self.get_temp_dir(), test_name)
|
|
||||||
if os.path.exists(test_dir):
|
|
||||||
shutil.rmtree(test_dir)
|
|
||||||
gfile.MakeDirs(test_dir)
|
|
||||||
return test_dir
|
|
||||||
|
|
||||||
def testAddCollectionDef(self):
|
def testAddCollectionDef(self):
|
||||||
test_dir = self._TestDir("good_collection")
|
test_dir = _TestDir("good_collection")
|
||||||
filename = os.path.join(test_dir, "metafile")
|
filename = os.path.join(test_dir, "metafile")
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
# Creates a graph.
|
# Creates a graph.
|
||||||
@ -819,7 +825,7 @@ class MetaGraphTest(tf.test.TestCase):
|
|||||||
self.assertEqual(len(meta_graph_def.collection_def), 0)
|
self.assertEqual(len(meta_graph_def.collection_def), 0)
|
||||||
|
|
||||||
def _testMultiSaverCollectionSave(self):
|
def _testMultiSaverCollectionSave(self):
|
||||||
test_dir = self._TestDir("saver_collection")
|
test_dir = _TestDir("saver_collection")
|
||||||
filename = os.path.join(test_dir, "metafile")
|
filename = os.path.join(test_dir, "metafile")
|
||||||
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
|
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
|
||||||
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
|
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
|
||||||
@ -894,7 +900,7 @@ class MetaGraphTest(tf.test.TestCase):
|
|||||||
self._testMultiSaverCollectionRestore()
|
self._testMultiSaverCollectionRestore()
|
||||||
|
|
||||||
def testBinaryAndTextFormat(self):
|
def testBinaryAndTextFormat(self):
|
||||||
test_dir = self._TestDir("binary_and_text")
|
test_dir = _TestDir("binary_and_text")
|
||||||
filename = os.path.join(test_dir, "metafile")
|
filename = os.path.join(test_dir, "metafile")
|
||||||
with self.test_session(graph=tf.Graph()):
|
with self.test_session(graph=tf.Graph()):
|
||||||
# Creates a graph.
|
# Creates a graph.
|
||||||
@ -924,7 +930,7 @@ class MetaGraphTest(tf.test.TestCase):
|
|||||||
tf.train.import_meta_graph(filename)
|
tf.train.import_meta_graph(filename)
|
||||||
|
|
||||||
def testSliceVariable(self):
|
def testSliceVariable(self):
|
||||||
test_dir = self._TestDir("slice_saver")
|
test_dir = _TestDir("slice_saver")
|
||||||
filename = os.path.join(test_dir, "metafile")
|
filename = os.path.join(test_dir, "metafile")
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
v1 = tf.Variable([20.0], name="v1")
|
v1 = tf.Variable([20.0], name="v1")
|
||||||
@ -946,7 +952,7 @@ class MetaGraphTest(tf.test.TestCase):
|
|||||||
self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
|
self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
|
||||||
|
|
||||||
def _testGraphExtensionSave(self):
|
def _testGraphExtensionSave(self):
|
||||||
test_dir = self._TestDir("graph_extension")
|
test_dir = _TestDir("graph_extension")
|
||||||
filename = os.path.join(test_dir, "metafile")
|
filename = os.path.join(test_dir, "metafile")
|
||||||
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
|
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
|
||||||
with self.test_session(graph=tf.Graph()) as sess:
|
with self.test_session(graph=tf.Graph()) as sess:
|
||||||
|
Loading…
Reference in New Issue
Block a user