Make MultiProccessRunner rpc_layer defaults to "grpc" if not specified.

PiperOrigin-RevId: 318578551
Change-Id: I59ee41ded51ed8117ec63ef4ed31c8ad4c4fab3d
This commit is contained in:
Chenkai Kuang 2020-06-26 18:21:04 -07:00 committed by TensorFlower Gardener
parent c9d13afacd
commit a9a82a1f0d
2 changed files with 7 additions and 3 deletions

View File

@ -138,7 +138,7 @@ class MultiProcessRunner(object):
"worker2.example.com:2222"],
"ps": ["ps0.example.com:2222",
"ps1.example.com:2222"]}
rpc_layer: RPC layer to use. Default value is 'grpc+loas'.
rpc_layer: RPC layer to use. Default value is 'grpc'.
max_run_time: If set, child processes is forced to exit at approximately
this many seconds after `start` is called. We achieve this through
`signal.alarm()` api. Note that this is best effort at Python level
@ -184,7 +184,7 @@ class MultiProcessRunner(object):
self._proc_func = proc_func
self._cluster_spec = cluster_spec
self._rpc_layer = rpc_layer
self._rpc_layer = rpc_layer or 'grpc'
self._max_run_time = max_run_time
self._grpc_fail_fast = grpc_fail_fast
self._stream_stdout = stream_stdout

View File

@ -109,11 +109,15 @@ def _get_multi_worker_mirrored_creator(required_gpus):
def _create_multi_worker_mirrored():
tf_config = cluster_resolver.TFConfigClusterResolver()
master = tf_config.master()
if tf_config.rpc_layer:
# Strip off the rpc_layer suffix.
master = master[len("%s://" % tf_config.rpc_layer):]
resolver = cluster_resolver.SimpleClusterResolver(
cluster_spec=tf_config.cluster_spec(),
task_type=tf_config.task_type,
task_id=tf_config.task_id,
master=tf_config.master(),
master=master,
environment=tf_config.environment,
num_accelerators={"GPU": required_gpus},
rpc_layer=tf_config.rpc_layer or "grpc",