Update NamedDistribution for lazy creation of the runner in tf.distribute.

Only create runner when the first time it is retrieved. The creation of the runner/cluster spec will trigger the portpicker code to reserve an unused port, and this is not ideal in __init__ since the creation of the strategy combination happens at tf load time, eg "import tensorflow as tf".

In the normal case, the port should only be reserved when the strategy_combination is used.

PiperOrigin-RevId: 337337130
Change-Id: I66835ada8ce7b65470c11fba92948c9caec9d970
This commit is contained in:
Scott Zhu 2020-10-15 10:32:20 -07:00 committed by TensorFlower Gardener
parent 1c9934fd7e
commit c63b59ce64
2 changed files with 20 additions and 12 deletions

View File

@ -281,23 +281,24 @@ class NamedDistribution(object):
self.use_cloud_tpu = use_cloud_tpu
self.has_chief = has_chief
self.num_workers = num_workers
self.use_pool_runner = use_pool_runner
self.no_xla = no_xla
self._runner = None
if _num_total_workers(self.has_chief, self.num_workers) > 1:
cluster_spec = multi_worker_test_base.create_cluster_spec(
has_chief=has_chief,
num_workers=num_workers,
num_ps=0,
has_eval=False)
if use_pool_runner:
# Need to create the strategy in the initializer so that collectives are
# configured before eager context initialization.
self._runner = multi_process_runner.MultiProcessPoolRunner(
cluster_spec, initializer=self._distribution_fn)
@property
def runner(self):
if not self._runner:
if (_num_total_workers(self.has_chief, self.num_workers) > 1 and
self.use_pool_runner):
# Need to create the strategy in the initializer so that collectives are
# configured before eager context initialization.
cluster_spec = multi_worker_test_base.create_cluster_spec(
has_chief=self.has_chief,
num_workers=self.num_workers,
num_ps=0,
has_eval=False)
self._runner = multi_process_runner.MultiProcessPoolRunner(
cluster_spec, initializer=self._distribution_fn)
return self._runner
@property

View File

@ -96,6 +96,13 @@ class ClusterCombinationTest(test.TestCase, parameterized.TestCase):
# set to the main process.
self.assertNotEqual(os.getenv("TF_CONFIG"), "")
def test_runner_creation(self):
cmb = combinations.NamedDistribution(
"Strategy1", lambda: None, has_chief=True, num_workers=2,
use_pool_runner=True)
self.assertIsNone(cmb._runner)
self.assertIsNotNone(cmb.runner)
# unittest.expectedFailure doesn't work with parameterized test methods, so we
# have to decorate the class instead.