Remove irrelevant callbacks warning with distribute strategies.
PiperOrigin-RevId: 244263923
This commit is contained in:
parent
50e00d1acf
commit
5aa52f33d3
@ -175,19 +175,6 @@ def validate_callbacks(input_callbacks, optimizer):
|
||||
"""
|
||||
if input_callbacks:
|
||||
for callback in input_callbacks:
|
||||
if not isinstance(callback,
|
||||
(callbacks.TensorBoard, callbacks.ReduceLROnPlateau,
|
||||
callbacks.LearningRateScheduler, callbacks.CSVLogger,
|
||||
callbacks.EarlyStopping, callbacks.ModelCheckpoint,
|
||||
callbacks.TerminateOnNaN, callbacks.ProgbarLogger,
|
||||
callbacks.History, callbacks.RemoteMonitor)):
|
||||
logging.warning('Your input callback is not one of the predefined '
|
||||
'Callbacks that supports DistributionStrategy. You '
|
||||
'might encounter an error if you access one of the '
|
||||
'model\'s attributes as part of the callback since '
|
||||
'these attributes are not set. You can access each of '
|
||||
'the individual distributed models using the '
|
||||
'`_grouped_model` attribute of your original model.')
|
||||
if isinstance(callback, (callbacks.LearningRateScheduler,
|
||||
callbacks.ReduceLROnPlateau)):
|
||||
|
||||
|
@ -57,30 +57,6 @@ class DistributedTrainingUtilsTest(test.TestCase):
|
||||
|
||||
self.assertEqual(0, mock_warning.call_count)
|
||||
|
||||
@test.mock.patch.object(logging, 'warning', autospec=True)
|
||||
def test_validate_callbacks_custom_callback(self, mock_warning):
|
||||
|
||||
class CustomCallback(callbacks.Callback):
|
||||
pass
|
||||
|
||||
distributed_training_utils.validate_callbacks([CustomCallback()],
|
||||
adam.Adam())
|
||||
|
||||
self.assertEqual(1, mock_warning.call_count)
|
||||
|
||||
call_args, call_kwargs = mock_warning.call_args
|
||||
|
||||
self.assertEqual(('Your input callback is not one of the predefined '
|
||||
'Callbacks that supports DistributionStrategy. You '
|
||||
'might encounter an error if you access one of the '
|
||||
'model\'s attributes as part of the callback since '
|
||||
'these attributes are not set. You can access each of '
|
||||
'the individual distributed models using the '
|
||||
'`_grouped_model` attribute of your original model.',),
|
||||
call_args)
|
||||
|
||||
self.assertEqual(0, len(call_kwargs))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user