Add test for AsyncCheckpointSaverHook without listeners.

PiperOrigin-RevId: 286242467
Change-Id: I9b023f6e979a18978d3cc10cca2706134bb8611e
This commit is contained in:
A. Unique TensorFlower 2019-12-18 12:24:46 -08:00 committed by TensorFlower Gardener
parent 112e2b3fc7
commit 324f1adc0e

View File

@ -145,6 +145,51 @@ class AsyncCheckpointingTest(test.TestCase):
mock_listener.before_save.assert_called()
mock_listener.after_save.assert_called()
def testAsyncCheckpointHookWithoutListeners(self):
resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
checkpoint_interval = 5
keep_checkpoint_max = 10
config = tpu_config.RunConfig(
master=resolver.master(),
model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
save_checkpoints_steps=1000,
keep_checkpoint_max=keep_checkpoint_max+1, # off by one
tpu_config=tpu_config.TPUConfig(
iterations_per_loop=checkpoint_interval,))
estimator = tpu_estimator.TPUEstimator(
use_tpu=True,
model_fn=model_fn,
config=config,
train_batch_size=32,
eval_batch_size=32,
predict_batch_size=1,
params={},
)
max_steps = 100
estimator.train(
input_fn=input_fn,
max_steps=max_steps,
hooks=[
async_checkpoint.AsyncCheckpointSaverHook(
FLAGS.model_dir,
save_steps=checkpoint_interval)
])
current_step = estimator_lib._load_global_step_from_checkpoint_dir(
FLAGS.model_dir) # pylint: disable=protected-access
# TODO(power) -- identify a better way to count the number of checkpoints.
checkpoints = file_io.get_matching_files(
FLAGS.model_dir + '/model.ckpt*.meta')
checkpoint_count = len(checkpoints)
logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
self.assertLessEqual(checkpoint_count, keep_checkpoint_max)
self.assertEqual(current_step, max_steps)
if __name__ == '__main__':
v2_compat.disable_v2_behavior()