Return a more intuitive error when trying to fit a model with an empty dataset.

PiperOrigin-RevId: 318721653
Change-Id: I8566cd07909bf6b08ac48061f040cd2295a58b3b
This commit is contained in:
Chris Gorgolewski 2020-06-28 13:16:52 -07:00 committed by TensorFlower Gardener
parent bd006c354f
commit 1fb8f4988d
2 changed files with 13 additions and 1 deletions

View File

@ -1023,7 +1023,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
2. If `model.fit` is wrapped in `tf.function`.
ValueError: In case of mismatch between the provided input data
and what the model expects.
and what the model expects or when the input data is empty.
"""
_keras_api_gauge.get_cell('fit').set(True)
# Legacy graph support is contained in `training_v1.Model`.
@ -1083,6 +1083,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
# happen after `callbacks.on_train_begin`.
data_handler._initial_epoch = ( # pylint: disable=protected-access
self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
logs = None
for epoch, iterator in data_handler.enumerate_epochs():
self.reset_metrics()
callbacks.on_epoch_begin(epoch)
@ -1101,6 +1102,9 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
logs = tmp_logs # No error, now safe to assign to logs.
end_step = step + data_handler.step_increment
callbacks.on_train_batch_end(end_step, logs)
if logs is None:
raise ValueError('Expect x to be a non-empty array or dataset.')
epoch_logs = copy.copy(logs)
# Run validation.

View File

@ -90,6 +90,14 @@ class TrainingTest(keras_parameterized.TestCase):
hist = model.fit(x=np.array([0.]), y=np.array([0.]))
self.assertAllClose(hist.history['loss'][0], 10000)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_fit_on_empty(self):
model = sequential.Sequential([layers_module.Dense(1)])
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
with self.assertRaisesRegexp(
ValueError, 'Expect x to be a non-empty array or dataset.'):
model.fit(x=np.array([]), y=np.array([]))
@keras_parameterized.run_all_keras_modes
def test_run_eagerly_setting(self):
model = sequential.Sequential([layers_module.Dense(1)])