Use distributed_file_utils in ModelCheckpoint callback so the temporary file is saved in a subdirectory of user provided filepath.
PiperOrigin-RevId: 307471655 Change-Id: I107cc4f1278aa60718cd476f6cd813246bf35c05
This commit is contained in:
parent
7ebbab819e
commit
63f2383118
|
@ -27,7 +27,6 @@ import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -1303,36 +1302,24 @@ class ModelCheckpoint(Callback):
|
||||||
def _get_file_path(self, epoch, logs):
|
def _get_file_path(self, epoch, logs):
|
||||||
"""Returns the file path for checkpoint."""
|
"""Returns the file path for checkpoint."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if not self.model._in_multi_worker_mode(
|
|
||||||
) or multi_worker_util.should_save_checkpoint():
|
|
||||||
try:
|
try:
|
||||||
# `filepath` may contain placeholders such as `{epoch:02d}` and
|
# `filepath` may contain placeholders such as `{epoch:02d}` and
|
||||||
# `{mape:.2f}`. A mismatch between logged metrics and the path's
|
# `{mape:.2f}`. A mismatch between logged metrics and the path's
|
||||||
# placeholders can cause formatting to fail.
|
# placeholders can cause formatting to fail.
|
||||||
return self.filepath.format(epoch=epoch + 1, **logs)
|
file_path = self.filepath.format(epoch=epoch + 1, **logs)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise KeyError('Failed to format this callback filepath: "{}". '
|
raise KeyError('Failed to format this callback filepath: "{}". '
|
||||||
'Reason: {}'.format(self.filepath, e))
|
'Reason: {}'.format(self.filepath, e))
|
||||||
else:
|
self._write_filepath = distributed_file_utils.write_filepath(
|
||||||
# If this is multi-worker training, and this worker should not
|
file_path, self.model.distribute_strategy)
|
||||||
# save checkpoint, we use a temp filepath to store a dummy checkpoint, so
|
return self._write_filepath
|
||||||
# it writes to a file that will be removed at the end of `_save_model()`
|
|
||||||
# call. This is because the SyncOnReadVariable needs to be synced across
|
|
||||||
# all the workers in order to be read, and all workers need to initiate
|
|
||||||
# that.
|
|
||||||
self._temp_file_dir = tempfile.mkdtemp()
|
|
||||||
extension = os.path.splitext(self.filepath)[1]
|
|
||||||
return os.path.join(self._temp_file_dir, 'temp' + extension)
|
|
||||||
|
|
||||||
def _maybe_remove_file(self):
|
def _maybe_remove_file(self):
|
||||||
# Remove the checkpoint directory in multi-worker training where this worker
|
# Remove the checkpoint directory in multi-worker training where this worker
|
||||||
# should not checkpoint. It is a dummy directory previously saved for sync
|
# should not checkpoint. It is a dummy directory previously saved for sync
|
||||||
# distributed training.
|
# distributed training.
|
||||||
|
distributed_file_utils.remove_temp_dir_with_filepath(
|
||||||
if (self.model._in_multi_worker_mode() and # pylint: disable=protected-access
|
self._write_filepath, self.model.distribute_strategy)
|
||||||
not multi_worker_util.should_save_checkpoint()):
|
|
||||||
file_io.delete_recursively(self._temp_file_dir)
|
|
||||||
del self._temp_file_dir
|
|
||||||
|
|
||||||
def _get_most_recently_modified_file_matching_pattern(self, pattern):
|
def _get_most_recently_modified_file_matching_pattern(self, pattern):
|
||||||
"""Returns the most recently modified filepath matching pattern.
|
"""Returns the most recently modified filepath matching pattern.
|
||||||
|
|
|
@ -23,6 +23,7 @@ from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
|
from tensorflow.python.distribute import distributed_file_utils
|
||||||
from tensorflow.python.distribute import multi_process_runner
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||||
from tensorflow.python.keras import callbacks
|
from tensorflow.python.keras import callbacks
|
||||||
|
@ -106,6 +107,16 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||||
training_state.checkpoint_exists(saving_filepath),
|
training_state.checkpoint_exists(saving_filepath),
|
||||||
test_base.is_chief())
|
test_base.is_chief())
|
||||||
|
|
||||||
|
# If it's chief, the model should be saved (`write_filepath` should
|
||||||
|
# simply return `saving_filepath`); if not, i.e. for non-chief workers,
|
||||||
|
# the temporary path generated by `write_filepath` should no longer
|
||||||
|
# contain the checkpoint that has been deleted.
|
||||||
|
test_obj.assertEqual(
|
||||||
|
training_state.checkpoint_exists(
|
||||||
|
distributed_file_utils.write_filepath(
|
||||||
|
saving_filepath, model._distribution_strategy)),
|
||||||
|
test_base.is_chief())
|
||||||
|
|
||||||
multi_process_runner.run(
|
multi_process_runner.run(
|
||||||
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
|
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
|
||||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||||
|
|
Loading…
Reference in New Issue