# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the SessionRunHook for preemptible Cloud TPUs."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging as _logging
import os
import threading
import time

from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import session_run_hook


class CloudTPUPreemptedHook(session_run_hook.SessionRunHook):
  """The SessionRunHook for preemptible Cloud TPUs.

  This is an implementation of SessionRunHook for the pre-emptible Google Cloud
  TPU service. It attempts to close the session if the TPU is preempted, and
  exits the coordinator process if the session cannot be closed.
  """

  def __init__(self, cluster):
    self._cluster = cluster

  def after_create_session(self, session, coord):
    if tpu_cluster_resolver.is_running_in_gce():
      self._tpu_poller = _TPUPollingThread(self._cluster, session)
      self._tpu_poller.start()

  def end(self, session):
    self._tpu_poller.stop()


class _TPUPollingThread(threading.Thread):
  """A thread that polls the state of a TPU node.

  When the node transitions into a TERMINAL state (PREEMPTED, TERMINATED)
  that's considered as not recoverable by the underlying infrastructure,
  it attempts to close the session, and exits the entire process if the
  session.close() stucks.
  """

  def __init__(self, cluster, session):
    super(_TPUPollingThread, self).__init__()

    self.daemon = True
    self._running = True
    self._session_closed = False
    self._cluster = cluster
    self._session = session
    self._interval = 30

    # Some of the Google API libraries are quite chatty, so disable them.
    for name in ['googleapiclient.discovery', 'oauth2client.client']:
      _logging.getLogger(name).setLevel(_logging.WARNING)

  def stop(self):
    self._running = False
    self._session_closed = True
    self.join()

  def run(self):
    if not tpu_cluster_resolver.is_running_in_gce():
      logging.warning(
          'TPUPollingThread is running in a non-GCE environment, exiting...')
      self._running = False
      return

    while self._running:
      recoverable = self._cluster._cloud_tpu_client.recoverable()  # pylint: disable=protected-access
      if not recoverable:
        logging.warning(
            'TPUPollingThread found TPU %s in state %s',
            self._cluster._tpu, self._cluster._cloud_tpu_client.state())  # pylint: disable=protected-access
        os._exit(1)  # pylint: disable=protected-access
      time.sleep(self._interval)