Return a more intuitive error when trying to fit a model with an empty dataset.
PiperOrigin-RevId: 318721653 Change-Id: I8566cd07909bf6b08ac48061f040cd2295a58b3b
This commit is contained in:
parent
bd006c354f
commit
1fb8f4988d
@ -1023,7 +1023,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
2. If `model.fit` is wrapped in `tf.function`.
|
||||
|
||||
ValueError: In case of mismatch between the provided input data
|
||||
and what the model expects.
|
||||
and what the model expects or when the input data is empty.
|
||||
"""
|
||||
_keras_api_gauge.get_cell('fit').set(True)
|
||||
# Legacy graph support is contained in `training_v1.Model`.
|
||||
@ -1083,6 +1083,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
# happen after `callbacks.on_train_begin`.
|
||||
data_handler._initial_epoch = ( # pylint: disable=protected-access
|
||||
self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
|
||||
logs = None
|
||||
for epoch, iterator in data_handler.enumerate_epochs():
|
||||
self.reset_metrics()
|
||||
callbacks.on_epoch_begin(epoch)
|
||||
@ -1101,6 +1102,9 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
logs = tmp_logs # No error, now safe to assign to logs.
|
||||
end_step = step + data_handler.step_increment
|
||||
callbacks.on_train_batch_end(end_step, logs)
|
||||
|
||||
if logs is None:
|
||||
raise ValueError('Expect x to be a non-empty array or dataset.')
|
||||
epoch_logs = copy.copy(logs)
|
||||
|
||||
# Run validation.
|
||||
|
@ -90,6 +90,14 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
hist = model.fit(x=np.array([0.]), y=np.array([0.]))
|
||||
self.assertAllClose(hist.history['loss'][0], 10000)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_fit_on_empty(self):
|
||||
model = sequential.Sequential([layers_module.Dense(1)])
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Expect x to be a non-empty array or dataset.'):
|
||||
model.fit(x=np.array([]), y=np.array([]))
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_run_eagerly_setting(self):
|
||||
model = sequential.Sequential([layers_module.Dense(1)])
|
||||
|
Loading…
Reference in New Issue
Block a user