Added check for case batch_outputs is not assigned in predict. Raise ValueError.
PiperOrigin-RevId: 318091655 Change-Id: I78832b5e1c41e763a24e1d77ea3a18bec1994c9a
This commit is contained in:
parent
833ae39436
commit
1b049d2a22
@ -1592,6 +1592,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
predict_function = self.make_predict_function()
|
||||
self._predict_counter.assign(0)
|
||||
callbacks.on_predict_begin()
|
||||
batch_outputs = None
|
||||
for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
|
||||
with data_handler.catch_stop_iteration():
|
||||
for step in data_handler.steps():
|
||||
@ -1610,6 +1611,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
outputs, batch_outputs)
|
||||
end_step = step + data_handler.step_increment
|
||||
callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
|
||||
if batch_outputs is None:
|
||||
raise ValueError('Expect x to be a non-empty array or dataset.')
|
||||
callbacks.on_predict_end()
|
||||
all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
|
||||
return tf_utils.to_numpy_or_python_type(all_outputs)
|
||||
|
@ -1647,6 +1647,19 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||
):
|
||||
training_module.Model([inputs], output)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_predict_error_with_empty_x(self):
|
||||
inputs = layers_module.Input(shape=(2,))
|
||||
outputs = layers_module.Dense(4)(inputs)
|
||||
model = training_module.Model(inputs=inputs, outputs=outputs)
|
||||
model.compile(loss='mse')
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'Expect x to be a non-empty array or dataset.'
|
||||
):
|
||||
model.predict(np.array([]))
|
||||
|
||||
|
||||
class LossWeightingTest(keras_parameterized.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user