diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 1dca7adda78..948468b5cb2 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -27,7 +27,6 @@ import io import json import os import re -import tempfile import time import numpy as np @@ -1303,36 +1302,24 @@ class ModelCheckpoint(Callback): def _get_file_path(self, epoch, logs): """Returns the file path for checkpoint.""" # pylint: disable=protected-access - if not self.model._in_multi_worker_mode( - ) or multi_worker_util.should_save_checkpoint(): - try: - # `filepath` may contain placeholders such as `{epoch:02d}` and - # `{mape:.2f}`. A mismatch between logged metrics and the path's - # placeholders can cause formatting to fail. - return self.filepath.format(epoch=epoch + 1, **logs) - except KeyError as e: - raise KeyError('Failed to format this callback filepath: "{}". ' - 'Reason: {}'.format(self.filepath, e)) - else: - # If this is multi-worker training, and this worker should not - # save checkpoint, we use a temp filepath to store a dummy checkpoint, so - # it writes to a file that will be removed at the end of `_save_model()` - # call. This is because the SyncOnReadVariable needs to be synced across - # all the workers in order to be read, and all workers need to initiate - # that. - self._temp_file_dir = tempfile.mkdtemp() - extension = os.path.splitext(self.filepath)[1] - return os.path.join(self._temp_file_dir, 'temp' + extension) + try: + # `filepath` may contain placeholders such as `{epoch:02d}` and + # `{mape:.2f}`. A mismatch between logged metrics and the path's + # placeholders can cause formatting to fail. + file_path = self.filepath.format(epoch=epoch + 1, **logs) + except KeyError as e: + raise KeyError('Failed to format this callback filepath: "{}". ' + 'Reason: {}'.format(self.filepath, e)) + self._write_filepath = distributed_file_utils.write_filepath( + file_path, self.model.distribute_strategy) + return self._write_filepath def _maybe_remove_file(self): # Remove the checkpoint directory in multi-worker training where this worker # should not checkpoint. It is a dummy directory previously saved for sync # distributed training. - - if (self.model._in_multi_worker_mode() and # pylint: disable=protected-access - not multi_worker_util.should_save_checkpoint()): - file_io.delete_recursively(self._temp_file_dir) - del self._temp_file_dir + distributed_file_utils.remove_temp_dir_with_filepath( + self._write_filepath, self.model.distribute_strategy) def _get_most_recently_modified_file_matching_pattern(self, pattern): """Returns the most recently modified filepath matching pattern. diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py index c99b6db8f4d..7ea385e0b04 100644 --- a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py @@ -23,6 +23,7 @@ from absl.testing import parameterized from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import distributed_file_utils from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base as test_base from tensorflow.python.keras import callbacks @@ -106,6 +107,16 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): training_state.checkpoint_exists(saving_filepath), test_base.is_chief()) + # If it's chief, the model should be saved (`write_filepath` should + # simply return `saving_filepath`); if not, i.e. for non-chief workers, + # the temporary path generated by `write_filepath` should no longer + # contain the checkpoint that has been deleted. + test_obj.assertEqual( + training_state.checkpoint_exists( + distributed_file_utils.write_filepath( + saving_filepath, model._distribution_strategy)), + test_base.is_chief()) + multi_process_runner.run( proc_model_checkpoint_saves_on_chief_but_not_otherwise, cluster_spec=test_base.create_cluster_spec(num_workers=2),