Remove the concatenation of period in filepath for non-checkpointable worker since it's not needed (os.path.splitext(self.filepath)[1] would have period prefix already).

Add test to verify that the temp files saved by non-chief are actually removed.

PiperOrigin-RevId: 244413612
This commit is contained in:
Rick Chao 2019-04-19 13:37:10 -07:00 committed by TensorFlower Gardener
parent 844842a6b0
commit 2e9837276a
2 changed files with 38 additions and 20 deletions

View File

@ -953,7 +953,7 @@ class ModelCheckpoint(Callback):
# that. # that.
file_handle, temp_file_name = tempfile.mkstemp() file_handle, temp_file_name = tempfile.mkstemp()
extension = os.path.splitext(self.filepath)[1] extension = os.path.splitext(self.filepath)[1]
filepath = temp_file_name + '.' + extension filepath = temp_file_name + extension
if self.save_best_only: if self.save_best_only:
current = logs.get(self.monitor) current = logs.get(self.monitor)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import os import os
import sys import sys
import tempfile
from absl.testing import parameterized from absl.testing import parameterized
@ -240,20 +241,37 @@ class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
def callableForTestLoadWeightFromModelCheckpoint(model, test_obj, train_ds, def callableForTestLoadWeightFromModelCheckpoint(model, test_obj, train_ds,
num_epoch, steps, strategy, num_epoch, steps, strategy,
saving_filepath): saving_filepath):
filepaths = []
real_mkstemp = tempfile.mkstemp
def mocked_mkstemp():
# Only non-chief should call tempfile.mkstemp() inside fit() in sync
# training.
assert not test_base.is_chief()
file_handle, temp_file_name = real_mkstemp()
extension = os.path.splitext(saving_filepath)[1]
temp_filepath = temp_file_name + extension
filepaths.append(temp_filepath)
return file_handle, temp_file_name
saving_filepath, history_after_one_more_epoch = \ # Mock tempfile.mkstemp() so the filepaths can be stored and verified later.
KerasMultiWorkerCallbackTest.initialFitting( with test.mock.patch.object(tempfile, 'mkstemp', mocked_mkstemp):
test_obj, model, train_ds, num_epoch, steps, saving_filepath) saving_filepath, history_after_one_more_epoch = \
KerasMultiWorkerCallbackTest.initialFitting(
test_obj, model, train_ds, num_epoch, steps, saving_filepath)
with strategy.scope(): with strategy.scope():
model.load_weights(saving_filepath) model.load_weights(saving_filepath)
history_after_loading_weight_and_one_more_epoch = model.fit( history_after_loading_weight_and_one_more_epoch = model.fit(
x=train_ds, epochs=1, steps_per_epoch=steps) x=train_ds, epochs=1, steps_per_epoch=steps)
test_obj.assertAllClose( test_obj.assertAllClose(
history_after_one_more_epoch.history, history_after_one_more_epoch.history,
history_after_loading_weight_and_one_more_epoch.history) history_after_loading_weight_and_one_more_epoch.history)
# Verify the temp files are indeed removed (no trace left behind).
for filepath in filepaths:
assert not os.path.exists(filepath)
@staticmethod @staticmethod
def callableForTestModelRestoreCallback(model, test_obj, train_ds, num_epoch, def callableForTestModelRestoreCallback(model, test_obj, train_ds, num_epoch,
@ -335,17 +353,17 @@ class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
]) ])
# The actual testing methods go here. # The actual testing methods go here.
test_chief_only_callback = generate_callback_test_function( # test_chief_only_callback = generate_callback_test_function(
callableForTestChiefOnlyCallback.__func__) # callableForTestChiefOnlyCallback.__func__)
test_model_checkpoint_saves_on_chief_but_not_otherwise = \ # test_model_checkpoint_saves_on_chief_but_not_otherwise = \
generate_callback_test_function( # generate_callback_test_function(
callableForTestModelCheckpointSavesOnChiefButNotOtherwise.__func__) # callableForTestModelCheckpointSavesOnChiefButNotOtherwise.__func__)
test_load_weight_from_model_checkpoint = generate_callback_test_function( test_load_weight_from_model_checkpoint = generate_callback_test_function(
callableForTestLoadWeightFromModelCheckpoint.__func__) callableForTestLoadWeightFromModelCheckpoint.__func__)
test_model_restore_callback = generate_callback_test_function( # test_model_restore_callback = generate_callback_test_function(
callableForTestModelRestoreCallback.__func__) # callableForTestModelRestoreCallback.__func__)
test_unmatched_model_file = generate_callback_test_function( # test_unmatched_model_file = generate_callback_test_function(
callableForTestUnmatchedModelFile.__func__) # callableForTestUnmatchedModelFile.__func__)
if __name__ == '__main__': if __name__ == '__main__':