Make ProgbarLogger non-blocking when verbose=2
PiperOrigin-RevId: 305106724 Change-Id: I192447000c60d9667d7580eef75576e59050b79f
This commit is contained in:
parent
76ac3a41aa
commit
543935b7d0
@ -863,6 +863,7 @@ class ProgbarLogger(Callback):
|
||||
|
||||
def __init__(self, count_mode='samples', stateful_metrics=None):
|
||||
super(ProgbarLogger, self).__init__()
|
||||
self._supports_tf_logs = True
|
||||
if count_mode == 'samples':
|
||||
self.use_steps = False
|
||||
elif count_mode == 'steps':
|
||||
@ -931,8 +932,7 @@ class ProgbarLogger(Callback):
|
||||
self.seen = 0
|
||||
self.progbar = None
|
||||
|
||||
def _batch_update_progbar(self, batch, logs=None):
|
||||
"""Updates the progbar."""
|
||||
def _maybe_init_progbar(self):
|
||||
if self.stateful_metrics is None:
|
||||
if self.model:
|
||||
self.stateful_metrics = (set(m.name for m in self.model.metrics))
|
||||
@ -946,22 +946,33 @@ class ProgbarLogger(Callback):
|
||||
stateful_metrics=self.stateful_metrics,
|
||||
unit_name='step' if self.use_steps else 'sample')
|
||||
|
||||
logs = copy.copy(logs) if logs else {}
|
||||
batch_size = logs.pop('size', 0)
|
||||
num_steps = logs.pop('num_steps', 1) # DistStrat can run >1 steps.
|
||||
logs.pop('batch', None)
|
||||
def _batch_update_progbar(self, batch, logs=None):
|
||||
"""Updates the progbar."""
|
||||
logs = logs or {}
|
||||
self._maybe_init_progbar()
|
||||
if self.use_steps:
|
||||
self.seen = batch + 1 # One-indexed.
|
||||
else:
|
||||
# v1 path only.
|
||||
logs = copy.copy(logs)
|
||||
batch_size = logs.pop('size', 0)
|
||||
num_steps = logs.pop('num_steps', 1)
|
||||
logs.pop('batch', None)
|
||||
add_seen = num_steps * batch_size
|
||||
self.seen += add_seen
|
||||
self.progbar.update(self.seen, list(logs.items()), finalize=False)
|
||||
|
||||
if self.verbose == 1:
|
||||
# Only block async when verbose = 1.
|
||||
logs = tf_utils.to_numpy_or_python_type(logs)
|
||||
self.progbar.update(self.seen, list(logs.items()), finalize=False)
|
||||
|
||||
def _finalize_progbar(self, logs):
|
||||
logs = logs or {}
|
||||
self._maybe_init_progbar()
|
||||
if self.target is None:
|
||||
self.target = self.seen
|
||||
self.progbar.target = self.seen
|
||||
logs = logs or {}
|
||||
logs = tf_utils.to_numpy_or_python_type(logs)
|
||||
self.progbar.update(self.seen, list(logs.items()), finalize=True)
|
||||
|
||||
|
||||
@ -2121,9 +2132,6 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
'keras_embedding.ckpt-{}'.format(epoch))
|
||||
self.model.save_weights(embeddings_ckpt)
|
||||
|
||||
def _implements_train_batch_hooks(self):
|
||||
return not (self._start_batch == 0 and self._stop_batch == 0)
|
||||
|
||||
|
||||
@keras_export('keras.callbacks.ReduceLROnPlateau')
|
||||
class ReduceLROnPlateau(Callback):
|
||||
|
@ -873,6 +873,44 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
cb_list.on_predict_batch_end(logs)
|
||||
cb_list.on_predict_end(logs)
|
||||
|
||||
def test_ProgbarLogger_verbose_2_nonblocking(self):
|
||||
# Should only cause a sync block on epoch end methods.
|
||||
callback = keras.callbacks.ProgbarLogger(count_mode='steps')
|
||||
self.assertTrue(callback._supports_tf_logs)
|
||||
|
||||
model = keras.Sequential([keras.layers.Dense(1)])
|
||||
cb_list = keras.callbacks.CallbackList([callback],
|
||||
model=model,
|
||||
epochs=1,
|
||||
steps=10,
|
||||
verbose=2)
|
||||
|
||||
with context.eager_mode():
|
||||
tensor = ops.convert_to_tensor(1.)
|
||||
|
||||
def mock_numpy():
|
||||
raise RuntimeError(
|
||||
'If this error is seen, ModelCheckpoint is causing a blocking '
|
||||
'NumPy conversion even when not checkpointing.')
|
||||
|
||||
with test.mock.patch.object(tensor, 'numpy', mock_numpy):
|
||||
logs = {'metric': tensor}
|
||||
|
||||
cb_list.on_train_begin(logs)
|
||||
cb_list.on_epoch_begin(0, logs)
|
||||
cb_list.on_train_batch_begin(0, logs)
|
||||
cb_list.on_train_batch_end(0, logs)
|
||||
|
||||
cb_list.on_test_begin(logs)
|
||||
cb_list.on_test_batch_begin(0, logs)
|
||||
cb_list.on_test_batch_end(0, logs)
|
||||
cb_list.on_test_end(logs)
|
||||
|
||||
with self.assertRaisesRegexp(RuntimeError, 'NumPy conversion'):
|
||||
# on_epoch_end should still block.
|
||||
cb_list.on_epoch_end(0, logs)
|
||||
cb_list.on_train_end(logs)
|
||||
|
||||
def test_EarlyStopping(self):
|
||||
with self.cached_session():
|
||||
np.random.seed(123)
|
||||
|
Loading…
x
Reference in New Issue
Block a user