Infer local ip in connect_to_cluster.
PiperOrigin-RevId: 339573404 Change-Id: I9d9ad47cddff454115931b70a752a2cfb93d5762
This commit is contained in:
parent
9eab49ca92
commit
38bda5fea7
@ -19,7 +19,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import socket
|
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
|
|
||||||
@ -166,12 +165,9 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
|||||||
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
|
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
|
||||||
job_def = cluster_def.job.add()
|
job_def = cluster_def.job.add()
|
||||||
job_def.name = job_name
|
job_def.name = job_name
|
||||||
|
# TODO(fishx): Update this to make sure remote worker has valid ip address
|
||||||
ipstr = _get_local_ip_address(local_port)
|
# to connect with local.
|
||||||
if ipstr:
|
job_def.tasks[0] = "localhost:{}".format(local_port)
|
||||||
job_def.tasks[0] = "{}:{}".format(ipstr, local_port)
|
|
||||||
else:
|
|
||||||
job_def.tasks[0] = "localhost:{}".format(local_port)
|
|
||||||
|
|
||||||
server_def = ServerDef(
|
server_def = ServerDef(
|
||||||
cluster=cluster_def,
|
cluster=cluster_def,
|
||||||
@ -225,29 +221,3 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
|||||||
|
|
||||||
def _strip_prefix(s, prefix):
|
def _strip_prefix(s, prefix):
|
||||||
return s[len(prefix):] if s.startswith(prefix) else s
|
return s[len(prefix):] if s.startswith(prefix) else s
|
||||||
|
|
||||||
|
|
||||||
def _get_local_ip_address(port):
|
|
||||||
"""Returns the first local ip address.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
port: the port used to lookup ip addresses using the socket library.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
a string representing the ip address. If it is an IPv6 address, it will be
|
|
||||||
wrapped by a pair of brackets. Or None if a local ip address cannot be
|
|
||||||
found.
|
|
||||||
"""
|
|
||||||
hostname = socket.gethostname()
|
|
||||||
addrinfo = socket.getaddrinfo(hostname, port)
|
|
||||||
# Use the first ip address.
|
|
||||||
# See the documentation of socket.getaddrinfo here:
|
|
||||||
# https://docs.python.org/3/library/socket.html#socket.getaddrinfo.
|
|
||||||
if not addrinfo or not addrinfo[0][4]:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
ipstr = addrinfo[0][4][0]
|
|
||||||
if addrinfo[0][0] == socket.AddressFamily.AF_INET6:
|
|
||||||
return "[%s]" % ipstr
|
|
||||||
else:
|
|
||||||
return ipstr
|
|
||||||
|
Loading…
Reference in New Issue
Block a user