From e5b12c6ce34335ff386101548323cb4801f04296 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Mon, 22 Jul 2019 13:48:23 -0700 Subject: [PATCH] Fix invalid `steps` argument usage test for single execution path. In multiple execution path code, in eager mode we raised an error but otherwise we just raised a warning message. Updated the test case to check for a warning message for all use cases in single execution path. PiperOrigin-RevId: 259398447 --- tensorflow/python/keras/engine/training_test.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 9f020221322..aeec0264b92 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1412,7 +1412,7 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase): run_distributed=testing_utils.should_run_distributed()) err_msg = 'When passing input data as arrays, do not specify' - if testing_utils.should_run_eagerly(): + if testing_utils.should_run_eagerly() and not model._run_distributed: with self.assertRaisesRegex(ValueError, err_msg): model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4) @@ -1423,15 +1423,12 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase): model.predict(np.zeros((100, 1)), steps=4) else: with test.mock.patch.object(logging, 'warning') as mock_log: - model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4) - self.assertRegexpMatches(str(mock_log.call_args), err_msg) - - with test.mock.patch.object(logging, 'warning') as mock_log: - model.evaluate(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps=4) - self.assertRegexpMatches(str(mock_log.call_args), err_msg) - - with test.mock.patch.object(logging, 'warning') as mock_log: - model.predict(np.zeros((100, 1)), steps=4) + model._standardize_user_data( + np.zeros((100, 1)), + np.ones((100, 1)), + batch_size=25, + check_steps=True, + steps=4) self.assertRegexpMatches(str(mock_log.call_args), err_msg)