From a9a82a1f0d575418628282086f03ab66107a9c4c Mon Sep 17 00:00:00 2001 From: Chenkai Kuang Date: Fri, 26 Jun 2020 18:21:04 -0700 Subject: [PATCH] Make MultiProccessRunner rpc_layer defaults to "grpc" if not specified. PiperOrigin-RevId: 318578551 Change-Id: I59ee41ded51ed8117ec63ef4ed31c8ad4c4fab3d --- tensorflow/python/distribute/multi_process_runner.py | 4 ++-- tensorflow/python/distribute/strategy_combinations.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 89162b50f4b..33451f2f255 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -138,7 +138,7 @@ class MultiProcessRunner(object): "worker2.example.com:2222"], "ps": ["ps0.example.com:2222", "ps1.example.com:2222"]} - rpc_layer: RPC layer to use. Default value is 'grpc+loas'. + rpc_layer: RPC layer to use. Default value is 'grpc'. max_run_time: If set, child processes is forced to exit at approximately this many seconds after `start` is called. We achieve this through `signal.alarm()` api. Note that this is best effort at Python level @@ -184,7 +184,7 @@ class MultiProcessRunner(object): self._proc_func = proc_func self._cluster_spec = cluster_spec - self._rpc_layer = rpc_layer + self._rpc_layer = rpc_layer or 'grpc' self._max_run_time = max_run_time self._grpc_fail_fast = grpc_fail_fast self._stream_stdout = stream_stdout diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py index d66c7acba77..33c6fd17fc5 100644 --- a/tensorflow/python/distribute/strategy_combinations.py +++ b/tensorflow/python/distribute/strategy_combinations.py @@ -109,11 +109,15 @@ def _get_multi_worker_mirrored_creator(required_gpus): def _create_multi_worker_mirrored(): tf_config = cluster_resolver.TFConfigClusterResolver() + master = tf_config.master() + if tf_config.rpc_layer: + # Strip off the rpc_layer suffix. + master = master[len("%s://" % tf_config.rpc_layer):] resolver = cluster_resolver.SimpleClusterResolver( cluster_spec=tf_config.cluster_spec(), task_type=tf_config.task_type, task_id=tf_config.task_id, - master=tf_config.master(), + master=master, environment=tf_config.environment, num_accelerators={"GPU": required_gpus}, rpc_layer=tf_config.rpc_layer or "grpc",