fix the GitHub Issue #43789 that file_io.delete_recursively_v2 not compatible while calling on files on cloud storage.

PiperOrigin-RevId: 336951308
Change-Id: Ieb43a96b1f6c6fc481785cd4c60b1a5c31cb5c1c
This commit is contained in:
A. Unique TensorFlower 2020-10-13 14:05:57 -07:00 committed by TensorFlower Gardener
parent 0f53c3fd79
commit 72028307fd
2 changed files with 15 additions and 12 deletions

View File

@ -205,7 +205,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
raise
multi_process_runner.get_barrier().wait()
backup_filepath = os.path.join(bar_dir, 'chief', 'checkpoint')
backup_filepath = os.path.join(bar_dir, 'checkpoint')
test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))

View File

@ -73,17 +73,15 @@ class WorkerTrainingState(object):
# workers need to perform `save()`.
# But all workers should restore from the same checkpoint_dir as passed in
# read_checkpoint_manager.
self.read_checkpoint_manager = checkpoint_management.CheckpointManager(
checkpoint,
directory=os.path.join(checkpoint_dir, 'chief'),
max_to_keep=1)
write_checkpoint_dir = distributed_file_utils.write_dirpath(
self.write_checkpoint_dir = distributed_file_utils.write_dirpath(
checkpoint_dir, self._model.distribute_strategy)
if write_checkpoint_dir == checkpoint_dir:
self.write_checkpoint_manager = self.read_checkpoint_manager
self.write_checkpoint_manager = checkpoint_management.CheckpointManager(
checkpoint, directory=self.write_checkpoint_dir, max_to_keep=1)
if self.write_checkpoint_dir == checkpoint_dir:
self.read_checkpoint_manager = self.write_checkpoint_manager
else:
self.write_checkpoint_manager = checkpoint_management.CheckpointManager(
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
self.read_checkpoint_manager = checkpoint_management.CheckpointManager(
checkpoint, directory=checkpoint_dir, max_to_keep=1)
def back_up(self, epoch):
"""Back up the current state of training into a checkpoint file.
@ -113,8 +111,13 @@ class WorkerTrainingState(object):
Delete the backup directories which should not exist after `fit()`
successfully finishes.
"""
if self.write_checkpoint_manager is self.read_checkpoint_manager:
file_io.delete_recursively_v2(self.write_checkpoint_manager.directory)
# pylint: disable=protected-access
for pathname in file_io.get_matching_files_v2(
self.write_checkpoint_manager._prefix + '*'):
file_io.delete_recursively_v2(pathname)
for pathname in file_io.get_matching_files_v2(
os.path.join(self.write_checkpoint_manager.directory, 'checkpoint')):
file_io.delete_recursively_v2(pathname)
def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
"""Maybe load initial epoch from ckpt considering possible worker recovery.