From 87c56c8fea2dcc36099cd0ffad8ea178877c8e25 Mon Sep 17 00:00:00 2001 From: Michael Banfield Date: Thu, 7 Mar 2019 14:25:21 -0800 Subject: [PATCH] Filter heartbeat list to tpu devices, remove ping check. PiperOrigin-RevId: 237324723 --- tensorflow/python/tpu/session_support.py | 61 +++++++++--------------- 1 file changed, 22 insertions(+), 39 deletions(-) diff --git a/tensorflow/python/tpu/session_support.py b/tensorflow/python/tpu/session_support.py index 3df7185be08..0cca8aeb55b 100644 --- a/tensorflow/python/tpu/session_support.py +++ b/tensorflow/python/tpu/session_support.py @@ -50,29 +50,6 @@ def _clone_session(session, graph=None): graph=graph if graph else session.graph) -def _make_heartbeat_op(session, device, request_ph): - """Return a heartbeat op or None if heartbeats are not supported by device.""" - try: - # Test if we can connect in a isolated graph + session - with ops.Graph().as_default(): - with _clone_session(session) as temp_session: - with ops.device(device): - heartbeat_op = tpu_ops.worker_heartbeat('') - options = config_pb2.RunOptions(timeout_in_ms=5000) - temp_session.run(heartbeat_op, options=options) - except errors.InvalidArgumentError as _: - logging.warning('Error running heartbeat on %s', device) - return None - except errors.DeadlineExceededError as _: - logging.warning('Timeout connecting to %s when testing heartbeat', device) - return None - - # If we successfully connected and pinged the worker, go ahead and construct - # the operation. - with ops.device(device): - return tpu_ops.worker_heartbeat(request_ph) - - class WorkerHeartbeatManager(object): """Manages the status/heartbeat monitor for a set of workers.""" @@ -104,16 +81,11 @@ class WorkerHeartbeatManager(object): name='worker_heartbeat_request', dtype=dtypes.string) heartbeat_ops = [] - kept_devices = [] for device in devices: - heartbeat_op = _make_heartbeat_op(session, device, request_placeholder) - if heartbeat_op is not None: - kept_devices.append(device) - heartbeat_ops.append(heartbeat_op) - else: - logging.warning('Heartbeat support not available for %s', device) + with ops.device(device): + heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder)) - return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops, + return WorkerHeartbeatManager(session, devices, heartbeat_ops, request_placeholder) def num_workers(self): @@ -185,11 +157,16 @@ class WorkerHeartbeatManager(object): def all_worker_devices(session): """Return a list of devices for each worker in the system.""" devices = session.list_devices() - return [ - device.name - for device in devices - if ':CPU:' in device.name and 'coordinator' not in device.name - ] + + devices_that_support_heartbeats = [] + + for device in devices: + name = device.name + # Pick devices that have a TPU but target the attached CPU + if ':TPU:0' in name and 'coordinator' not in name: + devices_that_support_heartbeats.append(name.replace('TPU', 'CPU')) + + return devices_that_support_heartbeats class WatchdogManager(threading.Thread): @@ -353,9 +330,15 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): self._session, all_worker_devices(self._session)) self._heartbeat_supported = self._workers.num_workers() > 0 if self._heartbeat_supported: - self._workers.configure( - event_pb2.WorkerHeartbeatRequest( - shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) + try: + self._workers.configure( + event_pb2.WorkerHeartbeatRequest( + shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) + except errors.InvalidArgumentError: + logging.warn( + 'TPU device does not support heartbeats. Failure ' + 'handling will be disabled.') + self._heartbeat_supported = False else: logging.warn( 'No workers support hearbeats. Failure handling will be disabled.')