Use distributed_file_utils in ModelCheckpoint callback so the temporary file is saved in a subdirectory of user provided filepath.

PiperOrigin-RevId: 307471655
Change-Id: I107cc4f1278aa60718cd476f6cd813246bf35c05
This commit is contained in:
Rick Chao 2020-04-20 14:03:06 -07:00 committed by TensorFlower Gardener
parent 7ebbab819e
commit 63f2383118
2 changed files with 24 additions and 26 deletions

View File

@ -27,7 +27,6 @@ import io
import json import json
import os import os
import re import re
import tempfile
import time import time
import numpy as np import numpy as np
@ -1303,36 +1302,24 @@ class ModelCheckpoint(Callback):
def _get_file_path(self, epoch, logs): def _get_file_path(self, epoch, logs):
"""Returns the file path for checkpoint.""" """Returns the file path for checkpoint."""
# pylint: disable=protected-access # pylint: disable=protected-access
if not self.model._in_multi_worker_mode(
) or multi_worker_util.should_save_checkpoint():
try: try:
# `filepath` may contain placeholders such as `{epoch:02d}` and # `filepath` may contain placeholders such as `{epoch:02d}` and
# `{mape:.2f}`. A mismatch between logged metrics and the path's # `{mape:.2f}`. A mismatch between logged metrics and the path's
# placeholders can cause formatting to fail. # placeholders can cause formatting to fail.
return self.filepath.format(epoch=epoch + 1, **logs) file_path = self.filepath.format(epoch=epoch + 1, **logs)
except KeyError as e: except KeyError as e:
raise KeyError('Failed to format this callback filepath: "{}". ' raise KeyError('Failed to format this callback filepath: "{}". '
'Reason: {}'.format(self.filepath, e)) 'Reason: {}'.format(self.filepath, e))
else: self._write_filepath = distributed_file_utils.write_filepath(
# If this is multi-worker training, and this worker should not file_path, self.model.distribute_strategy)
# save checkpoint, we use a temp filepath to store a dummy checkpoint, so return self._write_filepath
# it writes to a file that will be removed at the end of `_save_model()`
# call. This is because the SyncOnReadVariable needs to be synced across
# all the workers in order to be read, and all workers need to initiate
# that.
self._temp_file_dir = tempfile.mkdtemp()
extension = os.path.splitext(self.filepath)[1]
return os.path.join(self._temp_file_dir, 'temp' + extension)
def _maybe_remove_file(self): def _maybe_remove_file(self):
# Remove the checkpoint directory in multi-worker training where this worker # Remove the checkpoint directory in multi-worker training where this worker
# should not checkpoint. It is a dummy directory previously saved for sync # should not checkpoint. It is a dummy directory previously saved for sync
# distributed training. # distributed training.
distributed_file_utils.remove_temp_dir_with_filepath(
if (self.model._in_multi_worker_mode() and # pylint: disable=protected-access self._write_filepath, self.model.distribute_strategy)
not multi_worker_util.should_save_checkpoint()):
file_io.delete_recursively(self._temp_file_dir)
del self._temp_file_dir
def _get_most_recently_modified_file_matching_pattern(self, pattern): def _get_most_recently_modified_file_matching_pattern(self, pattern):
"""Returns the most recently modified filepath matching pattern. """Returns the most recently modified filepath matching pattern.

View File

@ -23,6 +23,7 @@ from absl.testing import parameterized
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distributed_file_utils
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base as test_base from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.keras import callbacks from tensorflow.python.keras import callbacks
@ -106,6 +107,16 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
training_state.checkpoint_exists(saving_filepath), training_state.checkpoint_exists(saving_filepath),
test_base.is_chief()) test_base.is_chief())
# If it's chief, the model should be saved (`write_filepath` should
# simply return `saving_filepath`); if not, i.e. for non-chief workers,
# the temporary path generated by `write_filepath` should no longer
# contain the checkpoint that has been deleted.
test_obj.assertEqual(
training_state.checkpoint_exists(
distributed_file_utils.write_filepath(
saving_filepath, model._distribution_strategy)),
test_base.is_chief())
multi_process_runner.run( multi_process_runner.run(
proc_model_checkpoint_saves_on_chief_but_not_otherwise, proc_model_checkpoint_saves_on_chief_but_not_otherwise,
cluster_spec=test_base.create_cluster_spec(num_workers=2), cluster_spec=test_base.create_cluster_spec(num_workers=2),