Update NamedDistribution for lazy creation of the runner in tf.distribute.
Only create runner when the first time it is retrieved. The creation of the runner/cluster spec will trigger the portpicker code to reserve an unused port, and this is not ideal in __init__ since the creation of the strategy combination happens at tf load time, eg "import tensorflow as tf". In the normal case, the port should only be reserved when the strategy_combination is used. PiperOrigin-RevId: 337337130 Change-Id: I66835ada8ce7b65470c11fba92948c9caec9d970
This commit is contained in:
parent
1c9934fd7e
commit
c63b59ce64
@ -281,23 +281,24 @@ class NamedDistribution(object):
|
||||
self.use_cloud_tpu = use_cloud_tpu
|
||||
self.has_chief = has_chief
|
||||
self.num_workers = num_workers
|
||||
self.use_pool_runner = use_pool_runner
|
||||
self.no_xla = no_xla
|
||||
self._runner = None
|
||||
|
||||
if _num_total_workers(self.has_chief, self.num_workers) > 1:
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(
|
||||
has_chief=has_chief,
|
||||
num_workers=num_workers,
|
||||
num_ps=0,
|
||||
has_eval=False)
|
||||
if use_pool_runner:
|
||||
# Need to create the strategy in the initializer so that collectives are
|
||||
# configured before eager context initialization.
|
||||
self._runner = multi_process_runner.MultiProcessPoolRunner(
|
||||
cluster_spec, initializer=self._distribution_fn)
|
||||
|
||||
@property
|
||||
def runner(self):
|
||||
if not self._runner:
|
||||
if (_num_total_workers(self.has_chief, self.num_workers) > 1 and
|
||||
self.use_pool_runner):
|
||||
# Need to create the strategy in the initializer so that collectives are
|
||||
# configured before eager context initialization.
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(
|
||||
has_chief=self.has_chief,
|
||||
num_workers=self.num_workers,
|
||||
num_ps=0,
|
||||
has_eval=False)
|
||||
self._runner = multi_process_runner.MultiProcessPoolRunner(
|
||||
cluster_spec, initializer=self._distribution_fn)
|
||||
return self._runner
|
||||
|
||||
@property
|
||||
|
@ -96,6 +96,13 @@ class ClusterCombinationTest(test.TestCase, parameterized.TestCase):
|
||||
# set to the main process.
|
||||
self.assertNotEqual(os.getenv("TF_CONFIG"), "")
|
||||
|
||||
def test_runner_creation(self):
|
||||
cmb = combinations.NamedDistribution(
|
||||
"Strategy1", lambda: None, has_chief=True, num_workers=2,
|
||||
use_pool_runner=True)
|
||||
self.assertIsNone(cmb._runner)
|
||||
self.assertIsNotNone(cmb.runner)
|
||||
|
||||
|
||||
# unittest.expectedFailure doesn't work with parameterized test methods, so we
|
||||
# have to decorate the class instead.
|
||||
|
Loading…
Reference in New Issue
Block a user