diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index a0ebec4f95e..5355920ced5 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -1592,6 +1592,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): predict_function = self.make_predict_function() self._predict_counter.assign(0) callbacks.on_predict_begin() + batch_outputs = None for _, iterator in data_handler.enumerate_epochs(): # Single epoch. with data_handler.catch_stop_iteration(): for step in data_handler.steps(): @@ -1610,6 +1611,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): outputs, batch_outputs) end_step = step + data_handler.step_increment callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs}) + if batch_outputs is None: + raise ValueError('Expect x to be a non-empty array or dataset.') callbacks.on_predict_end() all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs) return tf_utils.to_numpy_or_python_type(all_outputs) diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 5cf15926bfb..ad904ce9aa7 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1647,6 +1647,19 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase): ): training_module.Model([inputs], output) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_predict_error_with_empty_x(self): + inputs = layers_module.Input(shape=(2,)) + outputs = layers_module.Dense(4)(inputs) + model = training_module.Model(inputs=inputs, outputs=outputs) + model.compile(loss='mse') + + with self.assertRaisesRegexp( + ValueError, + 'Expect x to be a non-empty array or dataset.' + ): + model.predict(np.array([])) + class LossWeightingTest(keras_parameterized.TestCase):