From 72028307fdd8b00559ed631a409c9237ff0c24b8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 13 Oct 2020 14:05:57 -0700 Subject: [PATCH] fix the GitHub Issue #43789 that file_io.delete_recursively_v2 not compatible while calling on files on cloud storage. PiperOrigin-RevId: 336951308 Change-Id: Ieb43a96b1f6c6fc481785cd4c60b1a5c31cb5c1c --- .../multi_worker_callback_tf2_test.py | 2 +- .../keras/distribute/worker_training_state.py | 25 +++++++++++-------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py index d051eb5932a..ea4b349b1cf 100644 --- a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py @@ -205,7 +205,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): raise multi_process_runner.get_barrier().wait() - backup_filepath = os.path.join(bar_dir, 'chief', 'checkpoint') + backup_filepath = os.path.join(bar_dir, 'checkpoint') test_obj.assertTrue(file_io.file_exists_v2(backup_filepath)) test_obj.assertTrue(file_io.file_exists_v2(saving_filepath)) diff --git a/tensorflow/python/keras/distribute/worker_training_state.py b/tensorflow/python/keras/distribute/worker_training_state.py index a1e76283f36..6385594e0c0 100644 --- a/tensorflow/python/keras/distribute/worker_training_state.py +++ b/tensorflow/python/keras/distribute/worker_training_state.py @@ -73,17 +73,15 @@ class WorkerTrainingState(object): # workers need to perform `save()`. # But all workers should restore from the same checkpoint_dir as passed in # read_checkpoint_manager. - self.read_checkpoint_manager = checkpoint_management.CheckpointManager( - checkpoint, - directory=os.path.join(checkpoint_dir, 'chief'), - max_to_keep=1) - write_checkpoint_dir = distributed_file_utils.write_dirpath( + self.write_checkpoint_dir = distributed_file_utils.write_dirpath( checkpoint_dir, self._model.distribute_strategy) - if write_checkpoint_dir == checkpoint_dir: - self.write_checkpoint_manager = self.read_checkpoint_manager + self.write_checkpoint_manager = checkpoint_management.CheckpointManager( + checkpoint, directory=self.write_checkpoint_dir, max_to_keep=1) + if self.write_checkpoint_dir == checkpoint_dir: + self.read_checkpoint_manager = self.write_checkpoint_manager else: - self.write_checkpoint_manager = checkpoint_management.CheckpointManager( - checkpoint, directory=write_checkpoint_dir, max_to_keep=1) + self.read_checkpoint_manager = checkpoint_management.CheckpointManager( + checkpoint, directory=checkpoint_dir, max_to_keep=1) def back_up(self, epoch): """Back up the current state of training into a checkpoint file. @@ -113,8 +111,13 @@ class WorkerTrainingState(object): Delete the backup directories which should not exist after `fit()` successfully finishes. """ - if self.write_checkpoint_manager is self.read_checkpoint_manager: - file_io.delete_recursively_v2(self.write_checkpoint_manager.directory) + # pylint: disable=protected-access + for pathname in file_io.get_matching_files_v2( + self.write_checkpoint_manager._prefix + '*'): + file_io.delete_recursively_v2(pathname) + for pathname in file_io.get_matching_files_v2( + os.path.join(self.write_checkpoint_manager.directory, 'checkpoint')): + file_io.delete_recursively_v2(pathname) def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode): """Maybe load initial epoch from ckpt considering possible worker recovery.