From 4ecd2a70dd750b20a61033fe08301745685bf288 Mon Sep 17 00:00:00 2001 From: Sherry Moore Date: Wed, 24 Feb 2016 18:10:47 -0800 Subject: [PATCH] Added unit test for max_to_keep being None. Change: 115516426 --- tensorflow/python/training/saver_test.py | 88 +++++++++++++----------- 1 file changed, 47 insertions(+), 41 deletions(-) diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index ff9ab43ac63..5694a656301 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -37,6 +37,14 @@ from tensorflow.python.framework import function from tensorflow.python.platform import gfile +def _TestDir(test_name): + test_dir = os.path.join(tf.test.get_temp_dir(), test_name) + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + gfile.MakeDirs(test_dir) + return test_dir + + class SaverTest(tf.test.TestCase): def testBasics(self): @@ -349,12 +357,7 @@ class SaveRestoreShardedTest(tf.test.TestCase): class MaxToKeepTest(tf.test.TestCase): def testNonSharded(self): - save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_non_sharded") - try: - gfile.DeleteRecursively(save_dir) - except OSError: - pass # Ignore - gfile.MakeDirs(save_dir) + save_dir = _TestDir("max_to_keep_non_sharded") with self.test_session() as sess: v = tf.Variable(10.0, name="v") @@ -456,12 +459,7 @@ class MaxToKeepTest(tf.test.TestCase): self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1))) def testSharded(self): - save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_sharded") - try: - gfile.DeleteRecursively(save_dir) - except OSError: - pass # Ignore - gfile.MakeDirs(save_dir) + save_dir = _TestDir("max_to_keep_sharded") with tf.Session( target="", @@ -495,17 +493,39 @@ class MaxToKeepTest(tf.test.TestCase): self.assertEqual(2, len(gfile.Glob(s3))) self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3))) + def testNoMaxToKeep(self): + save_dir = _TestDir("no_max_to_keep") + save_dir2 = _TestDir("max_to_keep_0") + + with self.test_session() as sess: + v = tf.Variable(10.0, name="v") + tf.initialize_all_variables().run() + + # Test max_to_keep being None. + save = tf.train.Saver({"v": v}, max_to_keep=None) + self.assertEqual([], save.last_checkpoints) + s1 = save.save(sess, os.path.join(save_dir, "s1")) + self.assertEqual([], save.last_checkpoints) + self.assertTrue(gfile.Exists(s1)) + s2 = save.save(sess, os.path.join(save_dir, "s2")) + self.assertEqual([], save.last_checkpoints) + self.assertTrue(gfile.Exists(s2)) + + # Test max_to_keep being 0. + save2 = tf.train.Saver({"v": v}, max_to_keep=0) + self.assertEqual([], save2.last_checkpoints) + s1 = save2.save(sess, os.path.join(save_dir2, "s1")) + self.assertEqual([], save2.last_checkpoints) + self.assertTrue(gfile.Exists(s1)) + s2 = save2.save(sess, os.path.join(save_dir2, "s2")) + self.assertEqual([], save2.last_checkpoints) + self.assertTrue(gfile.Exists(s2)) + class KeepCheckpointEveryNHoursTest(tf.test.TestCase): def testNonSharded(self): - save_dir = os.path.join(self.get_temp_dir(), - "keep_checkpoint_every_n_hours") - try: - gfile.DeleteRecursively(save_dir) - except OSError: - pass # Ignore - gfile.MakeDirs(save_dir) + save_dir = _TestDir("keep_checkpoint_every_n_hours") with self.test_session() as sess: v = tf.Variable([10.0], name="v") @@ -685,15 +705,8 @@ class LatestCheckpointWithRelativePaths(tf.test.TestCase): class CheckpointStateTest(tf.test.TestCase): - def _TestDir(self, test_name): - test_dir = os.path.join(self.get_temp_dir(), test_name) - if os.path.exists(test_dir): - shutil.rmtree(test_dir) - gfile.MakeDirs(test_dir) - return test_dir - def testAbsPath(self): - save_dir = self._TestDir("abs_paths") + save_dir = _TestDir("abs_paths") abs_path = os.path.join(save_dir, "model-0") ckpt = tf.train.generate_checkpoint_state_proto(save_dir, abs_path) self.assertEqual(ckpt.model_checkpoint_path, abs_path) @@ -712,7 +725,7 @@ class CheckpointStateTest(tf.test.TestCase): self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path) def testAllModelCheckpointPaths(self): - save_dir = self._TestDir("all_models_test") + save_dir = _TestDir("all_models_test") abs_path = os.path.join(save_dir, "model-0") for paths in [None, [], ["model-2"]]: ckpt = tf.train.generate_checkpoint_state_proto( @@ -726,7 +739,7 @@ class CheckpointStateTest(tf.test.TestCase): self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) def testUpdateCheckpointState(self): - save_dir = self._TestDir("update_checkpoint_state") + save_dir = _TestDir("update_checkpoint_state") os.chdir(save_dir) # Make a temporary train directory. train_dir = "train" @@ -746,15 +759,8 @@ class CheckpointStateTest(tf.test.TestCase): class MetaGraphTest(tf.test.TestCase): - def _TestDir(self, test_name): - test_dir = os.path.join(self.get_temp_dir(), test_name) - if os.path.exists(test_dir): - shutil.rmtree(test_dir) - gfile.MakeDirs(test_dir) - return test_dir - def testAddCollectionDef(self): - test_dir = self._TestDir("good_collection") + test_dir = _TestDir("good_collection") filename = os.path.join(test_dir, "metafile") with self.test_session(): # Creates a graph. @@ -819,7 +825,7 @@ class MetaGraphTest(tf.test.TestCase): self.assertEqual(len(meta_graph_def.collection_def), 0) def _testMultiSaverCollectionSave(self): - test_dir = self._TestDir("saver_collection") + test_dir = _TestDir("saver_collection") filename = os.path.join(test_dir, "metafile") saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") @@ -894,7 +900,7 @@ class MetaGraphTest(tf.test.TestCase): self._testMultiSaverCollectionRestore() def testBinaryAndTextFormat(self): - test_dir = self._TestDir("binary_and_text") + test_dir = _TestDir("binary_and_text") filename = os.path.join(test_dir, "metafile") with self.test_session(graph=tf.Graph()): # Creates a graph. @@ -924,7 +930,7 @@ class MetaGraphTest(tf.test.TestCase): tf.train.import_meta_graph(filename) def testSliceVariable(self): - test_dir = self._TestDir("slice_saver") + test_dir = _TestDir("slice_saver") filename = os.path.join(test_dir, "metafile") with self.test_session(): v1 = tf.Variable([20.0], name="v1") @@ -946,7 +952,7 @@ class MetaGraphTest(tf.test.TestCase): self.assertProtoEquals(meta_graph_def, new_meta_graph_def) def _testGraphExtensionSave(self): - test_dir = self._TestDir("graph_extension") + test_dir = _TestDir("graph_extension") filename = os.path.join(test_dir, "metafile") saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") with self.test_session(graph=tf.Graph()) as sess: