diff --git a/tensorflow/python/tpu/async_checkpoint_test.py b/tensorflow/python/tpu/async_checkpoint_test.py index 3506e099e66..ed558be19b1 100644 --- a/tensorflow/python/tpu/async_checkpoint_test.py +++ b/tensorflow/python/tpu/async_checkpoint_test.py @@ -145,6 +145,51 @@ class AsyncCheckpointingTest(test.TestCase): mock_listener.before_save.assert_called() mock_listener.after_save.assert_called() + def testAsyncCheckpointHookWithoutListeners(self): + resolver = tpu_cluster_resolver.TPUClusterResolver( + tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) + + checkpoint_interval = 5 + keep_checkpoint_max = 10 + config = tpu_config.RunConfig( + master=resolver.master(), + model_dir=os.path.join(FLAGS.model_dir, 'runconfig'), + save_checkpoints_steps=1000, + keep_checkpoint_max=keep_checkpoint_max+1, # off by one + tpu_config=tpu_config.TPUConfig( + iterations_per_loop=checkpoint_interval,)) + + estimator = tpu_estimator.TPUEstimator( + use_tpu=True, + model_fn=model_fn, + config=config, + train_batch_size=32, + eval_batch_size=32, + predict_batch_size=1, + params={}, + ) + + max_steps = 100 + estimator.train( + input_fn=input_fn, + max_steps=max_steps, + hooks=[ + async_checkpoint.AsyncCheckpointSaverHook( + FLAGS.model_dir, + save_steps=checkpoint_interval) + ]) + + current_step = estimator_lib._load_global_step_from_checkpoint_dir( + FLAGS.model_dir) # pylint: disable=protected-access + + # TODO(power) -- identify a better way to count the number of checkpoints. + checkpoints = file_io.get_matching_files( + FLAGS.model_dir + '/model.ckpt*.meta') + checkpoint_count = len(checkpoints) + logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints) + self.assertLessEqual(checkpoint_count, keep_checkpoint_max) + self.assertEqual(current_step, max_steps) + if __name__ == '__main__': v2_compat.disable_v2_behavior()