Added check for case batch_outputs is not assigned in predict. Raise ValueError.

PiperOrigin-RevId: 318091655
Change-Id: I78832b5e1c41e763a24e1d77ea3a18bec1994c9a
This commit is contained in:
Haifeng Jin 2020-06-24 10:30:12 -07:00 committed by TensorFlower Gardener
parent 833ae39436
commit 1b049d2a22
2 changed files with 16 additions and 0 deletions

View File

@ -1592,6 +1592,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
predict_function = self.make_predict_function()
self._predict_counter.assign(0)
callbacks.on_predict_begin()
batch_outputs = None
for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
with data_handler.catch_stop_iteration():
for step in data_handler.steps():
@ -1610,6 +1611,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
outputs, batch_outputs)
end_step = step + data_handler.step_increment
callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
if batch_outputs is None:
raise ValueError('Expect x to be a non-empty array or dataset.')
callbacks.on_predict_end()
all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
return tf_utils.to_numpy_or_python_type(all_outputs)

View File

@ -1647,6 +1647,19 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
):
training_module.Model([inputs], output)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_predict_error_with_empty_x(self):
inputs = layers_module.Input(shape=(2,))
outputs = layers_module.Dense(4)(inputs)
model = training_module.Model(inputs=inputs, outputs=outputs)
model.compile(loss='mse')
with self.assertRaisesRegexp(
ValueError,
'Expect x to be a non-empty array or dataset.'
):
model.predict(np.array([]))
class LossWeightingTest(keras_parameterized.TestCase):