Remove irrelevant callbacks warning with distribute strategies.

PiperOrigin-RevId: 244263923
This commit is contained in:
Priya Gupta 2019-04-18 14:53:50 -07:00 committed by TensorFlower Gardener
parent 50e00d1acf
commit 5aa52f33d3
2 changed files with 0 additions and 37 deletions

View File

@ -175,19 +175,6 @@ def validate_callbacks(input_callbacks, optimizer):
"""
if input_callbacks:
for callback in input_callbacks:
if not isinstance(callback,
(callbacks.TensorBoard, callbacks.ReduceLROnPlateau,
callbacks.LearningRateScheduler, callbacks.CSVLogger,
callbacks.EarlyStopping, callbacks.ModelCheckpoint,
callbacks.TerminateOnNaN, callbacks.ProgbarLogger,
callbacks.History, callbacks.RemoteMonitor)):
logging.warning('Your input callback is not one of the predefined '
'Callbacks that supports DistributionStrategy. You '
'might encounter an error if you access one of the '
'model\'s attributes as part of the callback since '
'these attributes are not set. You can access each of '
'the individual distributed models using the '
'`_grouped_model` attribute of your original model.')
if isinstance(callback, (callbacks.LearningRateScheduler,
callbacks.ReduceLROnPlateau)):

View File

@ -57,30 +57,6 @@ class DistributedTrainingUtilsTest(test.TestCase):
self.assertEqual(0, mock_warning.call_count)
@test.mock.patch.object(logging, 'warning', autospec=True)
def test_validate_callbacks_custom_callback(self, mock_warning):
class CustomCallback(callbacks.Callback):
pass
distributed_training_utils.validate_callbacks([CustomCallback()],
adam.Adam())
self.assertEqual(1, mock_warning.call_count)
call_args, call_kwargs = mock_warning.call_args
self.assertEqual(('Your input callback is not one of the predefined '
'Callbacks that supports DistributionStrategy. You '
'might encounter an error if you access one of the '
'model\'s attributes as part of the callback since '
'these attributes are not set. You can access each of '
'the individual distributed models using the '
'`_grouped_model` attribute of your original model.',),
call_args)
self.assertEqual(0, len(call_kwargs))
if __name__ == '__main__':
test.main()