Filter heartbeat list to tpu devices, remove ping check.
PiperOrigin-RevId: 237324723
This commit is contained in:
parent
57467ada28
commit
87c56c8fea
@ -50,29 +50,6 @@ def _clone_session(session, graph=None):
|
|||||||
graph=graph if graph else session.graph)
|
graph=graph if graph else session.graph)
|
||||||
|
|
||||||
|
|
||||||
def _make_heartbeat_op(session, device, request_ph):
|
|
||||||
"""Return a heartbeat op or None if heartbeats are not supported by device."""
|
|
||||||
try:
|
|
||||||
# Test if we can connect in a isolated graph + session
|
|
||||||
with ops.Graph().as_default():
|
|
||||||
with _clone_session(session) as temp_session:
|
|
||||||
with ops.device(device):
|
|
||||||
heartbeat_op = tpu_ops.worker_heartbeat('')
|
|
||||||
options = config_pb2.RunOptions(timeout_in_ms=5000)
|
|
||||||
temp_session.run(heartbeat_op, options=options)
|
|
||||||
except errors.InvalidArgumentError as _:
|
|
||||||
logging.warning('Error running heartbeat on %s', device)
|
|
||||||
return None
|
|
||||||
except errors.DeadlineExceededError as _:
|
|
||||||
logging.warning('Timeout connecting to %s when testing heartbeat', device)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# If we successfully connected and pinged the worker, go ahead and construct
|
|
||||||
# the operation.
|
|
||||||
with ops.device(device):
|
|
||||||
return tpu_ops.worker_heartbeat(request_ph)
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerHeartbeatManager(object):
|
class WorkerHeartbeatManager(object):
|
||||||
"""Manages the status/heartbeat monitor for a set of workers."""
|
"""Manages the status/heartbeat monitor for a set of workers."""
|
||||||
|
|
||||||
@ -104,16 +81,11 @@ class WorkerHeartbeatManager(object):
|
|||||||
name='worker_heartbeat_request', dtype=dtypes.string)
|
name='worker_heartbeat_request', dtype=dtypes.string)
|
||||||
|
|
||||||
heartbeat_ops = []
|
heartbeat_ops = []
|
||||||
kept_devices = []
|
|
||||||
for device in devices:
|
for device in devices:
|
||||||
heartbeat_op = _make_heartbeat_op(session, device, request_placeholder)
|
with ops.device(device):
|
||||||
if heartbeat_op is not None:
|
heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
|
||||||
kept_devices.append(device)
|
|
||||||
heartbeat_ops.append(heartbeat_op)
|
|
||||||
else:
|
|
||||||
logging.warning('Heartbeat support not available for %s', device)
|
|
||||||
|
|
||||||
return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops,
|
return WorkerHeartbeatManager(session, devices, heartbeat_ops,
|
||||||
request_placeholder)
|
request_placeholder)
|
||||||
|
|
||||||
def num_workers(self):
|
def num_workers(self):
|
||||||
@ -185,11 +157,16 @@ class WorkerHeartbeatManager(object):
|
|||||||
def all_worker_devices(session):
|
def all_worker_devices(session):
|
||||||
"""Return a list of devices for each worker in the system."""
|
"""Return a list of devices for each worker in the system."""
|
||||||
devices = session.list_devices()
|
devices = session.list_devices()
|
||||||
return [
|
|
||||||
device.name
|
devices_that_support_heartbeats = []
|
||||||
for device in devices
|
|
||||||
if ':CPU:' in device.name and 'coordinator' not in device.name
|
for device in devices:
|
||||||
]
|
name = device.name
|
||||||
|
# Pick devices that have a TPU but target the attached CPU
|
||||||
|
if ':TPU:0' in name and 'coordinator' not in name:
|
||||||
|
devices_that_support_heartbeats.append(name.replace('TPU', 'CPU'))
|
||||||
|
|
||||||
|
return devices_that_support_heartbeats
|
||||||
|
|
||||||
|
|
||||||
class WatchdogManager(threading.Thread):
|
class WatchdogManager(threading.Thread):
|
||||||
@ -353,9 +330,15 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
|
|||||||
self._session, all_worker_devices(self._session))
|
self._session, all_worker_devices(self._session))
|
||||||
self._heartbeat_supported = self._workers.num_workers() > 0
|
self._heartbeat_supported = self._workers.num_workers() > 0
|
||||||
if self._heartbeat_supported:
|
if self._heartbeat_supported:
|
||||||
self._workers.configure(
|
try:
|
||||||
event_pb2.WorkerHeartbeatRequest(
|
self._workers.configure(
|
||||||
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
|
event_pb2.WorkerHeartbeatRequest(
|
||||||
|
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
|
||||||
|
except errors.InvalidArgumentError:
|
||||||
|
logging.warn(
|
||||||
|
'TPU device does not support heartbeats. Failure '
|
||||||
|
'handling will be disabled.')
|
||||||
|
self._heartbeat_supported = False
|
||||||
else:
|
else:
|
||||||
logging.warn(
|
logging.warn(
|
||||||
'No workers support hearbeats. Failure handling will be disabled.')
|
'No workers support hearbeats. Failure handling will be disabled.')
|
||||||
|
Loading…
Reference in New Issue
Block a user