Remove irrelevant callbacks warning with distribute strategies.
PiperOrigin-RevId: 244263923
This commit is contained in:
parent
50e00d1acf
commit
5aa52f33d3
@ -175,19 +175,6 @@ def validate_callbacks(input_callbacks, optimizer):
|
|||||||
"""
|
"""
|
||||||
if input_callbacks:
|
if input_callbacks:
|
||||||
for callback in input_callbacks:
|
for callback in input_callbacks:
|
||||||
if not isinstance(callback,
|
|
||||||
(callbacks.TensorBoard, callbacks.ReduceLROnPlateau,
|
|
||||||
callbacks.LearningRateScheduler, callbacks.CSVLogger,
|
|
||||||
callbacks.EarlyStopping, callbacks.ModelCheckpoint,
|
|
||||||
callbacks.TerminateOnNaN, callbacks.ProgbarLogger,
|
|
||||||
callbacks.History, callbacks.RemoteMonitor)):
|
|
||||||
logging.warning('Your input callback is not one of the predefined '
|
|
||||||
'Callbacks that supports DistributionStrategy. You '
|
|
||||||
'might encounter an error if you access one of the '
|
|
||||||
'model\'s attributes as part of the callback since '
|
|
||||||
'these attributes are not set. You can access each of '
|
|
||||||
'the individual distributed models using the '
|
|
||||||
'`_grouped_model` attribute of your original model.')
|
|
||||||
if isinstance(callback, (callbacks.LearningRateScheduler,
|
if isinstance(callback, (callbacks.LearningRateScheduler,
|
||||||
callbacks.ReduceLROnPlateau)):
|
callbacks.ReduceLROnPlateau)):
|
||||||
|
|
||||||
|
@ -57,30 +57,6 @@ class DistributedTrainingUtilsTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(0, mock_warning.call_count)
|
self.assertEqual(0, mock_warning.call_count)
|
||||||
|
|
||||||
@test.mock.patch.object(logging, 'warning', autospec=True)
|
|
||||||
def test_validate_callbacks_custom_callback(self, mock_warning):
|
|
||||||
|
|
||||||
class CustomCallback(callbacks.Callback):
|
|
||||||
pass
|
|
||||||
|
|
||||||
distributed_training_utils.validate_callbacks([CustomCallback()],
|
|
||||||
adam.Adam())
|
|
||||||
|
|
||||||
self.assertEqual(1, mock_warning.call_count)
|
|
||||||
|
|
||||||
call_args, call_kwargs = mock_warning.call_args
|
|
||||||
|
|
||||||
self.assertEqual(('Your input callback is not one of the predefined '
|
|
||||||
'Callbacks that supports DistributionStrategy. You '
|
|
||||||
'might encounter an error if you access one of the '
|
|
||||||
'model\'s attributes as part of the callback since '
|
|
||||||
'these attributes are not set. You can access each of '
|
|
||||||
'the individual distributed models using the '
|
|
||||||
'`_grouped_model` attribute of your original model.',),
|
|
||||||
call_args)
|
|
||||||
|
|
||||||
self.assertEqual(0, len(call_kwargs))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user