fix the GitHub Issue #43789 that file_io.delete_recursively_v2 not compatible while calling on files on cloud storage.
PiperOrigin-RevId: 336951308 Change-Id: Ieb43a96b1f6c6fc481785cd4c60b1a5c31cb5c1c
This commit is contained in:
parent
0f53c3fd79
commit
72028307fd
@ -205,7 +205,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
raise
|
||||
|
||||
multi_process_runner.get_barrier().wait()
|
||||
backup_filepath = os.path.join(bar_dir, 'chief', 'checkpoint')
|
||||
backup_filepath = os.path.join(bar_dir, 'checkpoint')
|
||||
test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
|
||||
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
|
||||
|
||||
|
||||
@ -73,17 +73,15 @@ class WorkerTrainingState(object):
|
||||
# workers need to perform `save()`.
|
||||
# But all workers should restore from the same checkpoint_dir as passed in
|
||||
# read_checkpoint_manager.
|
||||
self.read_checkpoint_manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint,
|
||||
directory=os.path.join(checkpoint_dir, 'chief'),
|
||||
max_to_keep=1)
|
||||
write_checkpoint_dir = distributed_file_utils.write_dirpath(
|
||||
self.write_checkpoint_dir = distributed_file_utils.write_dirpath(
|
||||
checkpoint_dir, self._model.distribute_strategy)
|
||||
if write_checkpoint_dir == checkpoint_dir:
|
||||
self.write_checkpoint_manager = self.read_checkpoint_manager
|
||||
self.write_checkpoint_manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory=self.write_checkpoint_dir, max_to_keep=1)
|
||||
if self.write_checkpoint_dir == checkpoint_dir:
|
||||
self.read_checkpoint_manager = self.write_checkpoint_manager
|
||||
else:
|
||||
self.write_checkpoint_manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
|
||||
self.read_checkpoint_manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory=checkpoint_dir, max_to_keep=1)
|
||||
|
||||
def back_up(self, epoch):
|
||||
"""Back up the current state of training into a checkpoint file.
|
||||
@ -113,8 +111,13 @@ class WorkerTrainingState(object):
|
||||
Delete the backup directories which should not exist after `fit()`
|
||||
successfully finishes.
|
||||
"""
|
||||
if self.write_checkpoint_manager is self.read_checkpoint_manager:
|
||||
file_io.delete_recursively_v2(self.write_checkpoint_manager.directory)
|
||||
# pylint: disable=protected-access
|
||||
for pathname in file_io.get_matching_files_v2(
|
||||
self.write_checkpoint_manager._prefix + '*'):
|
||||
file_io.delete_recursively_v2(pathname)
|
||||
for pathname in file_io.get_matching_files_v2(
|
||||
os.path.join(self.write_checkpoint_manager.directory, 'checkpoint')):
|
||||
file_io.delete_recursively_v2(pathname)
|
||||
|
||||
def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
|
||||
"""Maybe load initial epoch from ckpt considering possible worker recovery.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user