Make ProgbarLogger non-blocking when verbose=2

PiperOrigin-RevId: 305106724
Change-Id: I192447000c60d9667d7580eef75576e59050b79f
This commit is contained in:
Thomas O'Malley 2020-04-06 13:38:48 -07:00 committed by TensorFlower Gardener
parent 76ac3a41aa
commit 543935b7d0
2 changed files with 57 additions and 11 deletions

View File

@ -863,6 +863,7 @@ class ProgbarLogger(Callback):
def __init__(self, count_mode='samples', stateful_metrics=None):
super(ProgbarLogger, self).__init__()
self._supports_tf_logs = True
if count_mode == 'samples':
self.use_steps = False
elif count_mode == 'steps':
@ -931,8 +932,7 @@ class ProgbarLogger(Callback):
self.seen = 0
self.progbar = None
def _batch_update_progbar(self, batch, logs=None):
"""Updates the progbar."""
def _maybe_init_progbar(self):
if self.stateful_metrics is None:
if self.model:
self.stateful_metrics = (set(m.name for m in self.model.metrics))
@ -946,22 +946,33 @@ class ProgbarLogger(Callback):
stateful_metrics=self.stateful_metrics,
unit_name='step' if self.use_steps else 'sample')
logs = copy.copy(logs) if logs else {}
batch_size = logs.pop('size', 0)
num_steps = logs.pop('num_steps', 1) # DistStrat can run >1 steps.
logs.pop('batch', None)
def _batch_update_progbar(self, batch, logs=None):
"""Updates the progbar."""
logs = logs or {}
self._maybe_init_progbar()
if self.use_steps:
self.seen = batch + 1 # One-indexed.
else:
# v1 path only.
logs = copy.copy(logs)
batch_size = logs.pop('size', 0)
num_steps = logs.pop('num_steps', 1)
logs.pop('batch', None)
add_seen = num_steps * batch_size
self.seen += add_seen
self.progbar.update(self.seen, list(logs.items()), finalize=False)
if self.verbose == 1:
# Only block async when verbose = 1.
logs = tf_utils.to_numpy_or_python_type(logs)
self.progbar.update(self.seen, list(logs.items()), finalize=False)
def _finalize_progbar(self, logs):
logs = logs or {}
self._maybe_init_progbar()
if self.target is None:
self.target = self.seen
self.progbar.target = self.seen
logs = logs or {}
logs = tf_utils.to_numpy_or_python_type(logs)
self.progbar.update(self.seen, list(logs.items()), finalize=True)
@ -2121,9 +2132,6 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
'keras_embedding.ckpt-{}'.format(epoch))
self.model.save_weights(embeddings_ckpt)
def _implements_train_batch_hooks(self):
return not (self._start_batch == 0 and self._stop_batch == 0)
@keras_export('keras.callbacks.ReduceLROnPlateau')
class ReduceLROnPlateau(Callback):

View File

@ -873,6 +873,44 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
cb_list.on_predict_batch_end(logs)
cb_list.on_predict_end(logs)
def test_ProgbarLogger_verbose_2_nonblocking(self):
# Should only cause a sync block on epoch end methods.
callback = keras.callbacks.ProgbarLogger(count_mode='steps')
self.assertTrue(callback._supports_tf_logs)
model = keras.Sequential([keras.layers.Dense(1)])
cb_list = keras.callbacks.CallbackList([callback],
model=model,
epochs=1,
steps=10,
verbose=2)
with context.eager_mode():
tensor = ops.convert_to_tensor(1.)
def mock_numpy():
raise RuntimeError(
'If this error is seen, ModelCheckpoint is causing a blocking '
'NumPy conversion even when not checkpointing.')
with test.mock.patch.object(tensor, 'numpy', mock_numpy):
logs = {'metric': tensor}
cb_list.on_train_begin(logs)
cb_list.on_epoch_begin(0, logs)
cb_list.on_train_batch_begin(0, logs)
cb_list.on_train_batch_end(0, logs)
cb_list.on_test_begin(logs)
cb_list.on_test_batch_begin(0, logs)
cb_list.on_test_batch_end(0, logs)
cb_list.on_test_end(logs)
with self.assertRaisesRegexp(RuntimeError, 'NumPy conversion'):
# on_epoch_end should still block.
cb_list.on_epoch_end(0, logs)
cb_list.on_train_end(logs)
def test_EarlyStopping(self):
with self.cached_session():
np.random.seed(123)