Fix invalid steps argument usage test for single execution path.

In multiple execution path code, in eager mode we raised an error but otherwise we just raised a warning message. Updated the test case to check for a warning message for all use cases in single execution path.

PiperOrigin-RevId: 259398447
This commit is contained in:
Pavithra Vijay 2019-07-22 13:48:23 -07:00 committed by TensorFlower Gardener
parent 7cc180f107
commit e5b12c6ce3

View File

@ -1412,7 +1412,7 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
run_distributed=testing_utils.should_run_distributed()) run_distributed=testing_utils.should_run_distributed())
err_msg = 'When passing input data as arrays, do not specify' err_msg = 'When passing input data as arrays, do not specify'
if testing_utils.should_run_eagerly(): if testing_utils.should_run_eagerly() and not model._run_distributed:
with self.assertRaisesRegex(ValueError, err_msg): with self.assertRaisesRegex(ValueError, err_msg):
model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4) model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4)
@ -1423,15 +1423,12 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
model.predict(np.zeros((100, 1)), steps=4) model.predict(np.zeros((100, 1)), steps=4)
else: else:
with test.mock.patch.object(logging, 'warning') as mock_log: with test.mock.patch.object(logging, 'warning') as mock_log:
model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4) model._standardize_user_data(
self.assertRegexpMatches(str(mock_log.call_args), err_msg) np.zeros((100, 1)),
np.ones((100, 1)),
with test.mock.patch.object(logging, 'warning') as mock_log: batch_size=25,
model.evaluate(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps=4) check_steps=True,
self.assertRegexpMatches(str(mock_log.call_args), err_msg) steps=4)
with test.mock.patch.object(logging, 'warning') as mock_log:
model.predict(np.zeros((100, 1)), steps=4)
self.assertRegexpMatches(str(mock_log.call_args), err_msg) self.assertRegexpMatches(str(mock_log.call_args), err_msg)