Add test for AsyncCheckpointSaverHook without listeners.
PiperOrigin-RevId: 286242467 Change-Id: I9b023f6e979a18978d3cc10cca2706134bb8611e
This commit is contained in:
parent
112e2b3fc7
commit
324f1adc0e
@ -145,6 +145,51 @@ class AsyncCheckpointingTest(test.TestCase):
|
||||
mock_listener.before_save.assert_called()
|
||||
mock_listener.after_save.assert_called()
|
||||
|
||||
def testAsyncCheckpointHookWithoutListeners(self):
|
||||
resolver = tpu_cluster_resolver.TPUClusterResolver(
|
||||
tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
|
||||
|
||||
checkpoint_interval = 5
|
||||
keep_checkpoint_max = 10
|
||||
config = tpu_config.RunConfig(
|
||||
master=resolver.master(),
|
||||
model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
|
||||
save_checkpoints_steps=1000,
|
||||
keep_checkpoint_max=keep_checkpoint_max+1, # off by one
|
||||
tpu_config=tpu_config.TPUConfig(
|
||||
iterations_per_loop=checkpoint_interval,))
|
||||
|
||||
estimator = tpu_estimator.TPUEstimator(
|
||||
use_tpu=True,
|
||||
model_fn=model_fn,
|
||||
config=config,
|
||||
train_batch_size=32,
|
||||
eval_batch_size=32,
|
||||
predict_batch_size=1,
|
||||
params={},
|
||||
)
|
||||
|
||||
max_steps = 100
|
||||
estimator.train(
|
||||
input_fn=input_fn,
|
||||
max_steps=max_steps,
|
||||
hooks=[
|
||||
async_checkpoint.AsyncCheckpointSaverHook(
|
||||
FLAGS.model_dir,
|
||||
save_steps=checkpoint_interval)
|
||||
])
|
||||
|
||||
current_step = estimator_lib._load_global_step_from_checkpoint_dir(
|
||||
FLAGS.model_dir) # pylint: disable=protected-access
|
||||
|
||||
# TODO(power) -- identify a better way to count the number of checkpoints.
|
||||
checkpoints = file_io.get_matching_files(
|
||||
FLAGS.model_dir + '/model.ckpt*.meta')
|
||||
checkpoint_count = len(checkpoints)
|
||||
logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
|
||||
self.assertLessEqual(checkpoint_count, keep_checkpoint_max)
|
||||
self.assertEqual(current_step, max_steps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.disable_v2_behavior()
|
||||
|
Loading…
Reference in New Issue
Block a user