Use distributed_file_utils in ModelCheckpoint callback so the temporary file is saved in a subdirectory of user provided filepath.
PiperOrigin-RevId: 307471655 Change-Id: I107cc4f1278aa60718cd476f6cd813246bf35c05
This commit is contained in:
parent
7ebbab819e
commit
63f2383118
@ -27,7 +27,6 @@ import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
@ -1303,36 +1302,24 @@ class ModelCheckpoint(Callback):
|
||||
def _get_file_path(self, epoch, logs):
|
||||
"""Returns the file path for checkpoint."""
|
||||
# pylint: disable=protected-access
|
||||
if not self.model._in_multi_worker_mode(
|
||||
) or multi_worker_util.should_save_checkpoint():
|
||||
try:
|
||||
# `filepath` may contain placeholders such as `{epoch:02d}` and
|
||||
# `{mape:.2f}`. A mismatch between logged metrics and the path's
|
||||
# placeholders can cause formatting to fail.
|
||||
return self.filepath.format(epoch=epoch + 1, **logs)
|
||||
except KeyError as e:
|
||||
raise KeyError('Failed to format this callback filepath: "{}". '
|
||||
'Reason: {}'.format(self.filepath, e))
|
||||
else:
|
||||
# If this is multi-worker training, and this worker should not
|
||||
# save checkpoint, we use a temp filepath to store a dummy checkpoint, so
|
||||
# it writes to a file that will be removed at the end of `_save_model()`
|
||||
# call. This is because the SyncOnReadVariable needs to be synced across
|
||||
# all the workers in order to be read, and all workers need to initiate
|
||||
# that.
|
||||
self._temp_file_dir = tempfile.mkdtemp()
|
||||
extension = os.path.splitext(self.filepath)[1]
|
||||
return os.path.join(self._temp_file_dir, 'temp' + extension)
|
||||
try:
|
||||
# `filepath` may contain placeholders such as `{epoch:02d}` and
|
||||
# `{mape:.2f}`. A mismatch between logged metrics and the path's
|
||||
# placeholders can cause formatting to fail.
|
||||
file_path = self.filepath.format(epoch=epoch + 1, **logs)
|
||||
except KeyError as e:
|
||||
raise KeyError('Failed to format this callback filepath: "{}". '
|
||||
'Reason: {}'.format(self.filepath, e))
|
||||
self._write_filepath = distributed_file_utils.write_filepath(
|
||||
file_path, self.model.distribute_strategy)
|
||||
return self._write_filepath
|
||||
|
||||
def _maybe_remove_file(self):
|
||||
# Remove the checkpoint directory in multi-worker training where this worker
|
||||
# should not checkpoint. It is a dummy directory previously saved for sync
|
||||
# distributed training.
|
||||
|
||||
if (self.model._in_multi_worker_mode() and # pylint: disable=protected-access
|
||||
not multi_worker_util.should_save_checkpoint()):
|
||||
file_io.delete_recursively(self._temp_file_dir)
|
||||
del self._temp_file_dir
|
||||
distributed_file_utils.remove_temp_dir_with_filepath(
|
||||
self._write_filepath, self.model.distribute_strategy)
|
||||
|
||||
def _get_most_recently_modified_file_matching_pattern(self, pattern):
|
||||
"""Returns the most recently modified filepath matching pattern.
|
||||
|
@ -23,6 +23,7 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distributed_file_utils
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.keras import callbacks
|
||||
@ -106,6 +107,16 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
training_state.checkpoint_exists(saving_filepath),
|
||||
test_base.is_chief())
|
||||
|
||||
# If it's chief, the model should be saved (`write_filepath` should
|
||||
# simply return `saving_filepath`); if not, i.e. for non-chief workers,
|
||||
# the temporary path generated by `write_filepath` should no longer
|
||||
# contain the checkpoint that has been deleted.
|
||||
test_obj.assertEqual(
|
||||
training_state.checkpoint_exists(
|
||||
distributed_file_utils.write_filepath(
|
||||
saving_filepath, model._distribution_strategy)),
|
||||
test_base.is_chief())
|
||||
|
||||
multi_process_runner.run(
|
||||
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||
|
Loading…
Reference in New Issue
Block a user