Make MultiProccessRunner rpc_layer defaults to "grpc" if not specified.
PiperOrigin-RevId: 318578551 Change-Id: I59ee41ded51ed8117ec63ef4ed31c8ad4c4fab3d
This commit is contained in:
parent
c9d13afacd
commit
a9a82a1f0d
@ -138,7 +138,7 @@ class MultiProcessRunner(object):
|
|||||||
"worker2.example.com:2222"],
|
"worker2.example.com:2222"],
|
||||||
"ps": ["ps0.example.com:2222",
|
"ps": ["ps0.example.com:2222",
|
||||||
"ps1.example.com:2222"]}
|
"ps1.example.com:2222"]}
|
||||||
rpc_layer: RPC layer to use. Default value is 'grpc+loas'.
|
rpc_layer: RPC layer to use. Default value is 'grpc'.
|
||||||
max_run_time: If set, child processes is forced to exit at approximately
|
max_run_time: If set, child processes is forced to exit at approximately
|
||||||
this many seconds after `start` is called. We achieve this through
|
this many seconds after `start` is called. We achieve this through
|
||||||
`signal.alarm()` api. Note that this is best effort at Python level
|
`signal.alarm()` api. Note that this is best effort at Python level
|
||||||
@ -184,7 +184,7 @@ class MultiProcessRunner(object):
|
|||||||
|
|
||||||
self._proc_func = proc_func
|
self._proc_func = proc_func
|
||||||
self._cluster_spec = cluster_spec
|
self._cluster_spec = cluster_spec
|
||||||
self._rpc_layer = rpc_layer
|
self._rpc_layer = rpc_layer or 'grpc'
|
||||||
self._max_run_time = max_run_time
|
self._max_run_time = max_run_time
|
||||||
self._grpc_fail_fast = grpc_fail_fast
|
self._grpc_fail_fast = grpc_fail_fast
|
||||||
self._stream_stdout = stream_stdout
|
self._stream_stdout = stream_stdout
|
||||||
|
@ -109,11 +109,15 @@ def _get_multi_worker_mirrored_creator(required_gpus):
|
|||||||
|
|
||||||
def _create_multi_worker_mirrored():
|
def _create_multi_worker_mirrored():
|
||||||
tf_config = cluster_resolver.TFConfigClusterResolver()
|
tf_config = cluster_resolver.TFConfigClusterResolver()
|
||||||
|
master = tf_config.master()
|
||||||
|
if tf_config.rpc_layer:
|
||||||
|
# Strip off the rpc_layer suffix.
|
||||||
|
master = master[len("%s://" % tf_config.rpc_layer):]
|
||||||
resolver = cluster_resolver.SimpleClusterResolver(
|
resolver = cluster_resolver.SimpleClusterResolver(
|
||||||
cluster_spec=tf_config.cluster_spec(),
|
cluster_spec=tf_config.cluster_spec(),
|
||||||
task_type=tf_config.task_type,
|
task_type=tf_config.task_type,
|
||||||
task_id=tf_config.task_id,
|
task_id=tf_config.task_id,
|
||||||
master=tf_config.master(),
|
master=master,
|
||||||
environment=tf_config.environment,
|
environment=tf_config.environment,
|
||||||
num_accelerators={"GPU": required_gpus},
|
num_accelerators={"GPU": required_gpus},
|
||||||
rpc_layer=tf_config.rpc_layer or "grpc",
|
rpc_layer=tf_config.rpc_layer or "grpc",
|
||||||
|
Loading…
Reference in New Issue
Block a user