From a2f2c7b035fa8023d66146564afcb298fa8a379d Mon Sep 17 00:00:00 2001 From: Jing Li Date: Thu, 1 Nov 2018 21:57:59 -0700 Subject: [PATCH] Replace tf.keras.callbacks.TensorBoard callback with a new callback TensorBoardWithValidation inherited from the former callback, which makes evaluation at the end of specified epochs and export the results to tensorboard. PiperOrigin-RevId: 219749605 --- tensorflow/contrib/tpu/python/tpu/keras_support.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 9f76435bebc..8bf6e4367c1 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -1536,10 +1536,17 @@ class KerasTPUModel(models.Model): verbose=1, sample_weight=None, steps=None): - assert not self._numpy_to_infeed_manager_list # Ensure empty. + original_numpy_to_infeed_manager_list = [] + if self._numpy_to_infeed_manager_list: + # evaluate call may be executed as callbacks during the training. In this + # case, _numpy_to_infeed_manager_list is not empty, so save it for + # recovery at the end of evaluate call. + original_numpy_to_infeed_manager_list = self._numpy_to_infeed_manager_list + self._numpy_to_infeed_manager_list = [] with _tpu_session_context(): - infeed_managers = [] # Managers to clean up at the end of the fit call. + # Managers to clean up at the end of the evaluate call. + infeed_managers = [] if isinstance(x, dataset_ops.Dataset): # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( @@ -1569,7 +1576,8 @@ class KerasTPUModel(models.Model): return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose, sample_weight, steps) finally: - self._numpy_to_infeed_manager_list = [] + self._numpy_to_infeed_manager_list = ( + original_numpy_to_infeed_manager_list) def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight,