diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 9f76435bebc..8bf6e4367c1 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -1536,10 +1536,17 @@ class KerasTPUModel(models.Model): verbose=1, sample_weight=None, steps=None): - assert not self._numpy_to_infeed_manager_list # Ensure empty. + original_numpy_to_infeed_manager_list = [] + if self._numpy_to_infeed_manager_list: + # evaluate call may be executed as callbacks during the training. In this + # case, _numpy_to_infeed_manager_list is not empty, so save it for + # recovery at the end of evaluate call. + original_numpy_to_infeed_manager_list = self._numpy_to_infeed_manager_list + self._numpy_to_infeed_manager_list = [] with _tpu_session_context(): - infeed_managers = [] # Managers to clean up at the end of the fit call. + # Managers to clean up at the end of the evaluate call. + infeed_managers = [] if isinstance(x, dataset_ops.Dataset): # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( @@ -1569,7 +1576,8 @@ class KerasTPUModel(models.Model): return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose, sample_weight, steps) finally: - self._numpy_to_infeed_manager_list = [] + self._numpy_to_infeed_manager_list = ( + original_numpy_to_infeed_manager_list) def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight,