PSv2: Privatize ClusterCoordinator.cluster attribute since Cluster is not meant for public use.

PiperOrigin-RevId: 348168836
Change-Id: I9abbd50de59861f02139d7472b397b937e1a4e85
This commit is contained in:
Rick Chao 2020-12-18 02:50:23 -08:00 committed by TensorFlower Gardener
parent b5c8d770f5
commit e8f1a46dc1
4 changed files with 16 additions and 13 deletions

View File

@ -994,7 +994,7 @@ class ClusterCoordinator(object):
"currently.")
self._strategy = strategy
self._strategy.extended._used_with_coordinator = True
self.cluster = Cluster(strategy)
self._cluster = Cluster(strategy)
@property
def strategy(self):
@ -1067,7 +1067,7 @@ class ClusterCoordinator(object):
# Slot variables are usually created during function tracing time; thus
# `schedule` needs to be called within the `strategy.scope()`.
with self.strategy.scope():
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
return self._cluster.schedule(fn, args=args, kwargs=kwargs)
def join(self):
"""Blocks until all the scheduled functions have finished execution.
@ -1088,7 +1088,7 @@ class ClusterCoordinator(object):
previously scheduled function since the last time an error was thrown or
since the beginning of the program.
"""
self.cluster.join()
self._cluster.join()
def done(self):
"""Returns whether all the scheduled functions have finished execution.
@ -1106,7 +1106,7 @@ class ClusterCoordinator(object):
previously scheduled function since the last time an error was thrown or
since the beginning of the program.
"""
return self.cluster.done()
return self._cluster.done()
def create_per_worker_dataset(self, dataset_fn):
"""Create dataset on workers by calling `dataset_fn` on worker devices.
@ -1168,7 +1168,7 @@ class ClusterCoordinator(object):
iterators (that are on the workers).
"""
input_workers = input_lib.InputWorkers([
(w.device_name, [w.device_name]) for w in self.cluster.workers
(w.device_name, [w.device_name]) for w in self._cluster.workers
])
return _PerWorkerDistributedDataset(dataset_fn, input_workers, self)
@ -1191,7 +1191,7 @@ class ClusterCoordinator(object):
objects.
"""
results = []
for w in self.cluster.workers:
for w in self._cluster.workers:
results.append(w._create_resource(fn, args=args, kwargs=kwargs)) # pylint: disable=protected-access
return PerWorkerValues(tuple(results))

View File

@ -162,7 +162,7 @@ class ClusterCoordinatorMprTest(test.TestCase):
if test_join:
ps_coordinator.join()
if test_schedule:
while ps_coordinator.cluster._closure_queue._error is None:
while ps_coordinator._cluster._closure_queue._error is None:
time.sleep(1)
ps_coordinator.schedule(worker_fn)
except errors.UnavailableError:

View File

@ -575,7 +575,7 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
def testDatasetsShuffledDifferently(self):
# This test requires at least two workers in the cluster.
self.assertGreaterEqual(len(self.coordinator.cluster.workers), 2)
self.assertGreaterEqual(len(self.coordinator._cluster.workers), 2)
random_seed.set_random_seed(None)
@ -587,12 +587,12 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
# Get elements from the first two iterators.
iterator_1 = distributed_iterator._values[0]
iterator_1._rebuild_on(self.coordinator.cluster.workers[0])
iterator_1._rebuild_on(self.coordinator._cluster.workers[0])
iterator_1 = iterator_1.fetch()
elements_in_iterator_1 = [e.numpy() for e in iterator_1]
iterator_2 = distributed_iterator._values[1]
iterator_2._rebuild_on(self.coordinator.cluster.workers[1])
iterator_2._rebuild_on(self.coordinator._cluster.workers[1])
iterator_2 = iterator_2.fetch()
elements_in_iterator_2 = [e.numpy() for e in iterator_2]

View File

@ -175,7 +175,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
model.schedule_training_functions(4)
# Model does infinite training step, so at this moment, we expect to have 2
# infinite closures inflight, and 2 closures in the queue.
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< 2):
time.sleep(0.1)
self.assertFalse(self.cluster_coord.done())
self._restart(downtime_secs=2, job="worker")
@ -356,7 +357,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
# Model does infinite training step, so at this moment, we expect to have 2
# infinite closures inflight, and 8 closures in the queue.
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< 2):
time.sleep(0.1)
self.assertFalse(self.cluster_coord.done())
self._cluster.kill_task("worker", 0)
@ -380,7 +382,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
# Model does infinite training step, so at this moment, we expect to have 2
# infinite closures inflight, and 8 closures in the queue.
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< 2):
time.sleep(0.1)
self.assertFalse(self.cluster_coord.done())
self._cluster.kill_task("worker", 0)