diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 809c527fb84..7e3dadedcf9 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -863,6 +863,7 @@ class ProgbarLogger(Callback): def __init__(self, count_mode='samples', stateful_metrics=None): super(ProgbarLogger, self).__init__() + self._supports_tf_logs = True if count_mode == 'samples': self.use_steps = False elif count_mode == 'steps': @@ -931,8 +932,7 @@ class ProgbarLogger(Callback): self.seen = 0 self.progbar = None - def _batch_update_progbar(self, batch, logs=None): - """Updates the progbar.""" + def _maybe_init_progbar(self): if self.stateful_metrics is None: if self.model: self.stateful_metrics = (set(m.name for m in self.model.metrics)) @@ -946,22 +946,33 @@ class ProgbarLogger(Callback): stateful_metrics=self.stateful_metrics, unit_name='step' if self.use_steps else 'sample') - logs = copy.copy(logs) if logs else {} - batch_size = logs.pop('size', 0) - num_steps = logs.pop('num_steps', 1) # DistStrat can run >1 steps. - logs.pop('batch', None) + def _batch_update_progbar(self, batch, logs=None): + """Updates the progbar.""" + logs = logs or {} + self._maybe_init_progbar() if self.use_steps: self.seen = batch + 1 # One-indexed. else: + # v1 path only. + logs = copy.copy(logs) + batch_size = logs.pop('size', 0) + num_steps = logs.pop('num_steps', 1) + logs.pop('batch', None) add_seen = num_steps * batch_size self.seen += add_seen - self.progbar.update(self.seen, list(logs.items()), finalize=False) + + if self.verbose == 1: + # Only block async when verbose = 1. + logs = tf_utils.to_numpy_or_python_type(logs) + self.progbar.update(self.seen, list(logs.items()), finalize=False) def _finalize_progbar(self, logs): + logs = logs or {} + self._maybe_init_progbar() if self.target is None: self.target = self.seen self.progbar.target = self.seen - logs = logs or {} + logs = tf_utils.to_numpy_or_python_type(logs) self.progbar.update(self.seen, list(logs.items()), finalize=True) @@ -2121,9 +2132,6 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): 'keras_embedding.ckpt-{}'.format(epoch)) self.model.save_weights(embeddings_ckpt) - def _implements_train_batch_hooks(self): - return not (self._start_batch == 0 and self._stop_batch == 0) - @keras_export('keras.callbacks.ReduceLROnPlateau') class ReduceLROnPlateau(Callback): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index b3d16907ada..a2709bcb5de 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -873,6 +873,44 @@ class KerasCallbacksTest(keras_parameterized.TestCase): cb_list.on_predict_batch_end(logs) cb_list.on_predict_end(logs) + def test_ProgbarLogger_verbose_2_nonblocking(self): + # Should only cause a sync block on epoch end methods. + callback = keras.callbacks.ProgbarLogger(count_mode='steps') + self.assertTrue(callback._supports_tf_logs) + + model = keras.Sequential([keras.layers.Dense(1)]) + cb_list = keras.callbacks.CallbackList([callback], + model=model, + epochs=1, + steps=10, + verbose=2) + + with context.eager_mode(): + tensor = ops.convert_to_tensor(1.) + + def mock_numpy(): + raise RuntimeError( + 'If this error is seen, ModelCheckpoint is causing a blocking ' + 'NumPy conversion even when not checkpointing.') + + with test.mock.patch.object(tensor, 'numpy', mock_numpy): + logs = {'metric': tensor} + + cb_list.on_train_begin(logs) + cb_list.on_epoch_begin(0, logs) + cb_list.on_train_batch_begin(0, logs) + cb_list.on_train_batch_end(0, logs) + + cb_list.on_test_begin(logs) + cb_list.on_test_batch_begin(0, logs) + cb_list.on_test_batch_end(0, logs) + cb_list.on_test_end(logs) + + with self.assertRaisesRegexp(RuntimeError, 'NumPy conversion'): + # on_epoch_end should still block. + cb_list.on_epoch_end(0, logs) + cb_list.on_train_end(logs) + def test_EarlyStopping(self): with self.cached_session(): np.random.seed(123)