diff --git a/tensorflow/python/keras/distribute/distributed_training_utils.py b/tensorflow/python/keras/distribute/distributed_training_utils.py index cbd802736e3..6f43f8a5176 100644 --- a/tensorflow/python/keras/distribute/distributed_training_utils.py +++ b/tensorflow/python/keras/distribute/distributed_training_utils.py @@ -175,19 +175,6 @@ def validate_callbacks(input_callbacks, optimizer): """ if input_callbacks: for callback in input_callbacks: - if not isinstance(callback, - (callbacks.TensorBoard, callbacks.ReduceLROnPlateau, - callbacks.LearningRateScheduler, callbacks.CSVLogger, - callbacks.EarlyStopping, callbacks.ModelCheckpoint, - callbacks.TerminateOnNaN, callbacks.ProgbarLogger, - callbacks.History, callbacks.RemoteMonitor)): - logging.warning('Your input callback is not one of the predefined ' - 'Callbacks that supports DistributionStrategy. You ' - 'might encounter an error if you access one of the ' - 'model\'s attributes as part of the callback since ' - 'these attributes are not set. You can access each of ' - 'the individual distributed models using the ' - '`_grouped_model` attribute of your original model.') if isinstance(callback, (callbacks.LearningRateScheduler, callbacks.ReduceLROnPlateau)): diff --git a/tensorflow/python/keras/distribute/distributed_training_utils_test.py b/tensorflow/python/keras/distribute/distributed_training_utils_test.py index 0ea777fd1bb..4adc8b5f451 100644 --- a/tensorflow/python/keras/distribute/distributed_training_utils_test.py +++ b/tensorflow/python/keras/distribute/distributed_training_utils_test.py @@ -57,30 +57,6 @@ class DistributedTrainingUtilsTest(test.TestCase): self.assertEqual(0, mock_warning.call_count) - @test.mock.patch.object(logging, 'warning', autospec=True) - def test_validate_callbacks_custom_callback(self, mock_warning): - - class CustomCallback(callbacks.Callback): - pass - - distributed_training_utils.validate_callbacks([CustomCallback()], - adam.Adam()) - - self.assertEqual(1, mock_warning.call_count) - - call_args, call_kwargs = mock_warning.call_args - - self.assertEqual(('Your input callback is not one of the predefined ' - 'Callbacks that supports DistributionStrategy. You ' - 'might encounter an error if you access one of the ' - 'model\'s attributes as part of the callback since ' - 'these attributes are not set. You can access each of ' - 'the individual distributed models using the ' - '`_grouped_model` attribute of your original model.',), - call_args) - - self.assertEqual(0, len(call_kwargs)) - if __name__ == '__main__': test.main()