Make callback string formatting failure error message more user-friendly.

PiperOrigin-RevId: 284224536
Change-Id: I57b8c2ac506deaaa45365ddb80f564b787abb23f
This commit is contained in:
A. Unique TensorFlower 2019-12-06 11:21:35 -08:00 committed by TensorFlower Gardener
parent 29021b9e66
commit a5ac4c72dd
2 changed files with 20 additions and 1 deletions

View File

@ -1052,7 +1052,14 @@ class ModelCheckpoint(Callback):
# pylint: disable=protected-access
if not self.model._in_multi_worker_mode(
) or multi_worker_util.should_save_checkpoint():
return self.filepath.format(epoch=epoch + 1, **logs)
try:
# `filepath` may contain placeholders such as `{epoch:02d}` and
# `{mape:.2f}`. A mismatch between logged metrics and the path's
# placeholders can cause formatting to fail.
return self.filepath.format(epoch=epoch + 1, **logs)
except KeyError as e:
raise KeyError('Failed to format this callback filepath: "{}". '
'Reason: {}'.format(self.filepath, e))
else:
# If this is multi-worker training, and this worker should not
# save checkpoint, we use a temp filepath to store a dummy checkpoint, so

View File

@ -828,6 +828,18 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
'filepath for ModelCheckpoint.'):
model.fit(train_ds, epochs=1, callbacks=[callback])
def test_ModelCheckpoint_with_bad_path_placeholders(self):
(model, train_ds, callback,
filepath) = self._get_dummy_resource_for_model_checkpoint_testing()
temp_dir = self.get_temp_dir()
filepath = os.path.join(temp_dir, 'chkpt_{epoch:02d}_{mape:.2f}.h5')
callback = keras.callbacks.ModelCheckpoint(filepath=filepath)
with self.assertRaisesRegexp(KeyError, 'Failed to format this callback '
'filepath.*'):
model.fit(train_ds, epochs=1, callbacks=[callback])
def test_EarlyStopping(self):
with self.cached_session():
np.random.seed(123)