Filter heartbeat list to tpu devices, remove ping check.
PiperOrigin-RevId: 237324723
This commit is contained in:
parent
57467ada28
commit
87c56c8fea
@ -50,29 +50,6 @@ def _clone_session(session, graph=None):
|
||||
graph=graph if graph else session.graph)
|
||||
|
||||
|
||||
def _make_heartbeat_op(session, device, request_ph):
|
||||
"""Return a heartbeat op or None if heartbeats are not supported by device."""
|
||||
try:
|
||||
# Test if we can connect in a isolated graph + session
|
||||
with ops.Graph().as_default():
|
||||
with _clone_session(session) as temp_session:
|
||||
with ops.device(device):
|
||||
heartbeat_op = tpu_ops.worker_heartbeat('')
|
||||
options = config_pb2.RunOptions(timeout_in_ms=5000)
|
||||
temp_session.run(heartbeat_op, options=options)
|
||||
except errors.InvalidArgumentError as _:
|
||||
logging.warning('Error running heartbeat on %s', device)
|
||||
return None
|
||||
except errors.DeadlineExceededError as _:
|
||||
logging.warning('Timeout connecting to %s when testing heartbeat', device)
|
||||
return None
|
||||
|
||||
# If we successfully connected and pinged the worker, go ahead and construct
|
||||
# the operation.
|
||||
with ops.device(device):
|
||||
return tpu_ops.worker_heartbeat(request_ph)
|
||||
|
||||
|
||||
class WorkerHeartbeatManager(object):
|
||||
"""Manages the status/heartbeat monitor for a set of workers."""
|
||||
|
||||
@ -104,16 +81,11 @@ class WorkerHeartbeatManager(object):
|
||||
name='worker_heartbeat_request', dtype=dtypes.string)
|
||||
|
||||
heartbeat_ops = []
|
||||
kept_devices = []
|
||||
for device in devices:
|
||||
heartbeat_op = _make_heartbeat_op(session, device, request_placeholder)
|
||||
if heartbeat_op is not None:
|
||||
kept_devices.append(device)
|
||||
heartbeat_ops.append(heartbeat_op)
|
||||
else:
|
||||
logging.warning('Heartbeat support not available for %s', device)
|
||||
with ops.device(device):
|
||||
heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
|
||||
|
||||
return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops,
|
||||
return WorkerHeartbeatManager(session, devices, heartbeat_ops,
|
||||
request_placeholder)
|
||||
|
||||
def num_workers(self):
|
||||
@ -185,11 +157,16 @@ class WorkerHeartbeatManager(object):
|
||||
def all_worker_devices(session):
|
||||
"""Return a list of devices for each worker in the system."""
|
||||
devices = session.list_devices()
|
||||
return [
|
||||
device.name
|
||||
for device in devices
|
||||
if ':CPU:' in device.name and 'coordinator' not in device.name
|
||||
]
|
||||
|
||||
devices_that_support_heartbeats = []
|
||||
|
||||
for device in devices:
|
||||
name = device.name
|
||||
# Pick devices that have a TPU but target the attached CPU
|
||||
if ':TPU:0' in name and 'coordinator' not in name:
|
||||
devices_that_support_heartbeats.append(name.replace('TPU', 'CPU'))
|
||||
|
||||
return devices_that_support_heartbeats
|
||||
|
||||
|
||||
class WatchdogManager(threading.Thread):
|
||||
@ -353,9 +330,15 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
|
||||
self._session, all_worker_devices(self._session))
|
||||
self._heartbeat_supported = self._workers.num_workers() > 0
|
||||
if self._heartbeat_supported:
|
||||
self._workers.configure(
|
||||
event_pb2.WorkerHeartbeatRequest(
|
||||
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
|
||||
try:
|
||||
self._workers.configure(
|
||||
event_pb2.WorkerHeartbeatRequest(
|
||||
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
|
||||
except errors.InvalidArgumentError:
|
||||
logging.warn(
|
||||
'TPU device does not support heartbeats. Failure '
|
||||
'handling will be disabled.')
|
||||
self._heartbeat_supported = False
|
||||
else:
|
||||
logging.warn(
|
||||
'No workers support hearbeats. Failure handling will be disabled.')
|
||||
|
Loading…
Reference in New Issue
Block a user