Make callback string formatting failure error message more user-friendly.
PiperOrigin-RevId: 284224536 Change-Id: I57b8c2ac506deaaa45365ddb80f564b787abb23f
This commit is contained in:
parent
29021b9e66
commit
a5ac4c72dd
@ -1052,7 +1052,14 @@ class ModelCheckpoint(Callback):
|
||||
# pylint: disable=protected-access
|
||||
if not self.model._in_multi_worker_mode(
|
||||
) or multi_worker_util.should_save_checkpoint():
|
||||
return self.filepath.format(epoch=epoch + 1, **logs)
|
||||
try:
|
||||
# `filepath` may contain placeholders such as `{epoch:02d}` and
|
||||
# `{mape:.2f}`. A mismatch between logged metrics and the path's
|
||||
# placeholders can cause formatting to fail.
|
||||
return self.filepath.format(epoch=epoch + 1, **logs)
|
||||
except KeyError as e:
|
||||
raise KeyError('Failed to format this callback filepath: "{}". '
|
||||
'Reason: {}'.format(self.filepath, e))
|
||||
else:
|
||||
# If this is multi-worker training, and this worker should not
|
||||
# save checkpoint, we use a temp filepath to store a dummy checkpoint, so
|
||||
|
@ -828,6 +828,18 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
'filepath for ModelCheckpoint.'):
|
||||
model.fit(train_ds, epochs=1, callbacks=[callback])
|
||||
|
||||
def test_ModelCheckpoint_with_bad_path_placeholders(self):
|
||||
(model, train_ds, callback,
|
||||
filepath) = self._get_dummy_resource_for_model_checkpoint_testing()
|
||||
|
||||
temp_dir = self.get_temp_dir()
|
||||
filepath = os.path.join(temp_dir, 'chkpt_{epoch:02d}_{mape:.2f}.h5')
|
||||
callback = keras.callbacks.ModelCheckpoint(filepath=filepath)
|
||||
|
||||
with self.assertRaisesRegexp(KeyError, 'Failed to format this callback '
|
||||
'filepath.*'):
|
||||
model.fit(train_ds, epochs=1, callbacks=[callback])
|
||||
|
||||
def test_EarlyStopping(self):
|
||||
with self.cached_session():
|
||||
np.random.seed(123)
|
||||
|
Loading…
Reference in New Issue
Block a user