Filter heartbeat list to tpu devices, remove ping check.

PiperOrigin-RevId: 237324723
This commit is contained in:
Michael Banfield 2019-03-07 14:25:21 -08:00 committed by TensorFlower Gardener
parent 57467ada28
commit 87c56c8fea

View File

@ -50,29 +50,6 @@ def _clone_session(session, graph=None):
graph=graph if graph else session.graph) graph=graph if graph else session.graph)
def _make_heartbeat_op(session, device, request_ph):
"""Return a heartbeat op or None if heartbeats are not supported by device."""
try:
# Test if we can connect in a isolated graph + session
with ops.Graph().as_default():
with _clone_session(session) as temp_session:
with ops.device(device):
heartbeat_op = tpu_ops.worker_heartbeat('')
options = config_pb2.RunOptions(timeout_in_ms=5000)
temp_session.run(heartbeat_op, options=options)
except errors.InvalidArgumentError as _:
logging.warning('Error running heartbeat on %s', device)
return None
except errors.DeadlineExceededError as _:
logging.warning('Timeout connecting to %s when testing heartbeat', device)
return None
# If we successfully connected and pinged the worker, go ahead and construct
# the operation.
with ops.device(device):
return tpu_ops.worker_heartbeat(request_ph)
class WorkerHeartbeatManager(object): class WorkerHeartbeatManager(object):
"""Manages the status/heartbeat monitor for a set of workers.""" """Manages the status/heartbeat monitor for a set of workers."""
@ -104,16 +81,11 @@ class WorkerHeartbeatManager(object):
name='worker_heartbeat_request', dtype=dtypes.string) name='worker_heartbeat_request', dtype=dtypes.string)
heartbeat_ops = [] heartbeat_ops = []
kept_devices = []
for device in devices: for device in devices:
heartbeat_op = _make_heartbeat_op(session, device, request_placeholder) with ops.device(device):
if heartbeat_op is not None: heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
kept_devices.append(device)
heartbeat_ops.append(heartbeat_op)
else:
logging.warning('Heartbeat support not available for %s', device)
return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops, return WorkerHeartbeatManager(session, devices, heartbeat_ops,
request_placeholder) request_placeholder)
def num_workers(self): def num_workers(self):
@ -185,11 +157,16 @@ class WorkerHeartbeatManager(object):
def all_worker_devices(session): def all_worker_devices(session):
"""Return a list of devices for each worker in the system.""" """Return a list of devices for each worker in the system."""
devices = session.list_devices() devices = session.list_devices()
return [
device.name devices_that_support_heartbeats = []
for device in devices
if ':CPU:' in device.name and 'coordinator' not in device.name for device in devices:
] name = device.name
# Pick devices that have a TPU but target the attached CPU
if ':TPU:0' in name and 'coordinator' not in name:
devices_that_support_heartbeats.append(name.replace('TPU', 'CPU'))
return devices_that_support_heartbeats
class WatchdogManager(threading.Thread): class WatchdogManager(threading.Thread):
@ -353,9 +330,15 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
self._session, all_worker_devices(self._session)) self._session, all_worker_devices(self._session))
self._heartbeat_supported = self._workers.num_workers() > 0 self._heartbeat_supported = self._workers.num_workers() > 0
if self._heartbeat_supported: if self._heartbeat_supported:
self._workers.configure( try:
event_pb2.WorkerHeartbeatRequest( self._workers.configure(
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) event_pb2.WorkerHeartbeatRequest(
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
except errors.InvalidArgumentError:
logging.warn(
'TPU device does not support heartbeats. Failure '
'handling will be disabled.')
self._heartbeat_supported = False
else: else:
logging.warn( logging.warn(
'No workers support hearbeats. Failure handling will be disabled.') 'No workers support hearbeats. Failure handling will be disabled.')