diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 363e48f74d1..d7489cf420c 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -871,10 +871,10 @@ py_library( ":one_device_strategy", ":test_util", ":tpu_strategy", - "//tensorflow/python:config", "//tensorflow/python:platform", "//tensorflow/python:tf2", - "//tensorflow/python:util", + "//tensorflow/python:tf_export", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", "//tensorflow/python/eager:context", "//tensorflow/python/eager:remote", diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py index 861a2f0490e..a2e3610b7c5 100644 --- a/tensorflow/python/distribute/combinations.py +++ b/tensorflow/python/distribute/combinations.py @@ -257,7 +257,7 @@ class NamedDistribution(object): use_cloud_tpu=False, has_chief=False, num_workers=1, - use_pool_runner=False, + pool_runner_fn=None, no_xla=False): """Initialize NamedDistribution. @@ -269,8 +269,8 @@ class NamedDistribution(object): use_cloud_tpu: Whether the strategy requires cloud TPU. has_chief: Whether the strategy requires a chief worker. num_workers: The number of workers that the strategy requires. - use_pool_runner: Whether to use a pool runner so that workers are re-used - each time. + pool_runner_fn: An optional callable that returns a MultiProcessPoolRunner + to run the test. no_xla: Whether to skip in XLA tests. """ object.__init__(self) @@ -281,25 +281,14 @@ class NamedDistribution(object): self.use_cloud_tpu = use_cloud_tpu self.has_chief = has_chief self.num_workers = num_workers - self.use_pool_runner = use_pool_runner + self._pool_runner_fn = pool_runner_fn self.no_xla = no_xla - self._runner = None @property def runner(self): - if not self._runner: - if (_num_total_workers(self.has_chief, self.num_workers) > 1 and - self.use_pool_runner): - # Need to create the strategy in the initializer so that collectives are - # configured before eager context initialization. - cluster_spec = multi_worker_test_base.create_cluster_spec( - has_chief=self.has_chief, - num_workers=self.num_workers, - num_ps=0, - has_eval=False) - self._runner = multi_process_runner.MultiProcessPoolRunner( - cluster_spec, initializer=self._distribution_fn) - return self._runner + if self._pool_runner_fn is not None: + return self._pool_runner_fn() + return None @property def strategy(self): diff --git a/tensorflow/python/distribute/combinations_test.py b/tensorflow/python/distribute/combinations_test.py index e9897a45805..02ddcbef632 100644 --- a/tensorflow/python/distribute/combinations_test.py +++ b/tensorflow/python/distribute/combinations_test.py @@ -96,13 +96,6 @@ class ClusterCombinationTest(test.TestCase, parameterized.TestCase): # set to the main process. self.assertNotEqual(os.getenv("TF_CONFIG"), "") - def test_runner_creation(self): - cmb = combinations.NamedDistribution( - "Strategy1", lambda: None, has_chief=True, num_workers=2, - use_pool_runner=True) - self.assertIsNone(cmb._runner) - self.assertIsNotNone(cmb.runner) - # unittest.expectedFailure doesn't work with parameterized test methods, so we # have to decorate the class instead. diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py index 2a7afabf166..a9a66bc491c 100644 --- a/tensorflow/python/distribute/strategy_combinations.py +++ b/tensorflow/python/distribute/strategy_combinations.py @@ -26,6 +26,7 @@ from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy as mirrored_lib from tensorflow.python.distribute import multi_process_runner +from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import one_device_strategy as one_device_lib from tensorflow.python.distribute import test_util from tensorflow.python.distribute import tpu_strategy as tpu_lib @@ -158,6 +159,50 @@ def _get_multi_worker_mirrored_creator(required_gpus): return _create_multi_worker_mirrored +def _deferred_pool_runner(has_chief, num_workers, initializer=None): + """Returns a callable that returns the pool runner. + + It creates the pool runner only upon first invocation. This avoids creating it + when this file is imported. + + Args: + has_chief: whether there should be a chief. + num_workers: the number of workers excluding the chief. + initializer: initializer of each process. + + Returns: + A callable that returns the runner. + """ + + container = [] + + def get_or_create(): + if not container: + cluster_spec = multi_worker_test_base.create_cluster_spec( + has_chief=has_chief, + num_workers=num_workers, + num_ps=0, + has_eval=False) + runner = multi_process_runner.MultiProcessPoolRunner( + cluster_spec, initializer=initializer) + container.append(runner) + return container[0] + + return get_or_create + + +# We need to create the strategy in the initializer to start the server before +# any test runs. +_two_worker_pool = _deferred_pool_runner( + has_chief=True, + num_workers=1, + initializer=_get_multi_worker_mirrored_creator(required_gpus=0)) +_four_worker_pool = _deferred_pool_runner( + has_chief=True, + num_workers=3, + initializer=_get_multi_worker_mirrored_creator(required_gpus=0)) + + # pylint: disable=g-long-lambda default_strategy = combinations.NamedDistribution( "Default", @@ -230,7 +275,7 @@ multi_worker_mirrored_2x1_cpu = combinations.NamedDistribution( _get_multi_worker_mirrored_creator(required_gpus=0), has_chief=True, num_workers=1, - use_pool_runner=True, + pool_runner_fn=_two_worker_pool, no_xla=True, ) # chief + 1 worker, with 1 GPU each. @@ -240,7 +285,7 @@ multi_worker_mirrored_2x1_gpu = combinations.NamedDistribution( has_chief=True, num_workers=1, required_gpus=1, - use_pool_runner=True, + pool_runner_fn=_two_worker_pool, no_xla=True, ) # chief + 1 worker, with 2 GPU each. @@ -250,7 +295,7 @@ multi_worker_mirrored_2x2_gpu = combinations.NamedDistribution( has_chief=True, num_workers=1, required_gpus=2, - use_pool_runner=True, + pool_runner_fn=_two_worker_pool, no_xla=True, ) # chief + 3 workers, with CPU. @@ -259,7 +304,7 @@ multi_worker_mirrored_4x1_cpu = combinations.NamedDistribution( _get_multi_worker_mirrored_creator(required_gpus=0), has_chief=True, num_workers=3, - use_pool_runner=True, + pool_runner_fn=_four_worker_pool, no_xla=True, )