Fix invalid steps
argument usage test for single execution path.
In multiple execution path code, in eager mode we raised an error but otherwise we just raised a warning message. Updated the test case to check for a warning message for all use cases in single execution path. PiperOrigin-RevId: 259398447
This commit is contained in:
parent
7cc180f107
commit
e5b12c6ce3
@ -1412,7 +1412,7 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
|||||||
run_distributed=testing_utils.should_run_distributed())
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
err_msg = 'When passing input data as arrays, do not specify'
|
err_msg = 'When passing input data as arrays, do not specify'
|
||||||
|
|
||||||
if testing_utils.should_run_eagerly():
|
if testing_utils.should_run_eagerly() and not model._run_distributed:
|
||||||
with self.assertRaisesRegex(ValueError, err_msg):
|
with self.assertRaisesRegex(ValueError, err_msg):
|
||||||
model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4)
|
model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4)
|
||||||
|
|
||||||
@ -1423,15 +1423,12 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
|||||||
model.predict(np.zeros((100, 1)), steps=4)
|
model.predict(np.zeros((100, 1)), steps=4)
|
||||||
else:
|
else:
|
||||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
with test.mock.patch.object(logging, 'warning') as mock_log:
|
||||||
model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4)
|
model._standardize_user_data(
|
||||||
self.assertRegexpMatches(str(mock_log.call_args), err_msg)
|
np.zeros((100, 1)),
|
||||||
|
np.ones((100, 1)),
|
||||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
batch_size=25,
|
||||||
model.evaluate(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps=4)
|
check_steps=True,
|
||||||
self.assertRegexpMatches(str(mock_log.call_args), err_msg)
|
steps=4)
|
||||||
|
|
||||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
|
||||||
model.predict(np.zeros((100, 1)), steps=4)
|
|
||||||
self.assertRegexpMatches(str(mock_log.call_args), err_msg)
|
self.assertRegexpMatches(str(mock_log.call_args), err_msg)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user