diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 9f020221322..aeec0264b92 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1412,7 +1412,7 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase): run_distributed=testing_utils.should_run_distributed()) err_msg = 'When passing input data as arrays, do not specify' - if testing_utils.should_run_eagerly(): + if testing_utils.should_run_eagerly() and not model._run_distributed: with self.assertRaisesRegex(ValueError, err_msg): model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4) @@ -1423,15 +1423,12 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase): model.predict(np.zeros((100, 1)), steps=4) else: with test.mock.patch.object(logging, 'warning') as mock_log: - model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4) - self.assertRegexpMatches(str(mock_log.call_args), err_msg) - - with test.mock.patch.object(logging, 'warning') as mock_log: - model.evaluate(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps=4) - self.assertRegexpMatches(str(mock_log.call_args), err_msg) - - with test.mock.patch.object(logging, 'warning') as mock_log: - model.predict(np.zeros((100, 1)), steps=4) + model._standardize_user_data( + np.zeros((100, 1)), + np.ones((100, 1)), + batch_size=25, + check_steps=True, + steps=4) self.assertRegexpMatches(str(mock_log.call_args), err_msg)