PSv2: Privatize ClusterCoordinator.cluster attribute since Cluster is not meant for public use.

PiperOrigin-RevId: 348168836
Change-Id: I9abbd50de59861f02139d7472b397b937e1a4e85
This commit is contained in:
Rick Chao 2020-12-18 02:50:23 -08:00 committed by TensorFlower Gardener
parent b5c8d770f5
commit e8f1a46dc1
4 changed files with 16 additions and 13 deletions

View File

@ -994,7 +994,7 @@ class ClusterCoordinator(object):
"currently.") "currently.")
self._strategy = strategy self._strategy = strategy
self._strategy.extended._used_with_coordinator = True self._strategy.extended._used_with_coordinator = True
self.cluster = Cluster(strategy) self._cluster = Cluster(strategy)
@property @property
def strategy(self): def strategy(self):
@ -1067,7 +1067,7 @@ class ClusterCoordinator(object):
# Slot variables are usually created during function tracing time; thus # Slot variables are usually created during function tracing time; thus
# `schedule` needs to be called within the `strategy.scope()`. # `schedule` needs to be called within the `strategy.scope()`.
with self.strategy.scope(): with self.strategy.scope():
return self.cluster.schedule(fn, args=args, kwargs=kwargs) return self._cluster.schedule(fn, args=args, kwargs=kwargs)
def join(self): def join(self):
"""Blocks until all the scheduled functions have finished execution. """Blocks until all the scheduled functions have finished execution.
@ -1088,7 +1088,7 @@ class ClusterCoordinator(object):
previously scheduled function since the last time an error was thrown or previously scheduled function since the last time an error was thrown or
since the beginning of the program. since the beginning of the program.
""" """
self.cluster.join() self._cluster.join()
def done(self): def done(self):
"""Returns whether all the scheduled functions have finished execution. """Returns whether all the scheduled functions have finished execution.
@ -1106,7 +1106,7 @@ class ClusterCoordinator(object):
previously scheduled function since the last time an error was thrown or previously scheduled function since the last time an error was thrown or
since the beginning of the program. since the beginning of the program.
""" """
return self.cluster.done() return self._cluster.done()
def create_per_worker_dataset(self, dataset_fn): def create_per_worker_dataset(self, dataset_fn):
"""Create dataset on workers by calling `dataset_fn` on worker devices. """Create dataset on workers by calling `dataset_fn` on worker devices.
@ -1168,7 +1168,7 @@ class ClusterCoordinator(object):
iterators (that are on the workers). iterators (that are on the workers).
""" """
input_workers = input_lib.InputWorkers([ input_workers = input_lib.InputWorkers([
(w.device_name, [w.device_name]) for w in self.cluster.workers (w.device_name, [w.device_name]) for w in self._cluster.workers
]) ])
return _PerWorkerDistributedDataset(dataset_fn, input_workers, self) return _PerWorkerDistributedDataset(dataset_fn, input_workers, self)
@ -1191,7 +1191,7 @@ class ClusterCoordinator(object):
objects. objects.
""" """
results = [] results = []
for w in self.cluster.workers: for w in self._cluster.workers:
results.append(w._create_resource(fn, args=args, kwargs=kwargs)) # pylint: disable=protected-access results.append(w._create_resource(fn, args=args, kwargs=kwargs)) # pylint: disable=protected-access
return PerWorkerValues(tuple(results)) return PerWorkerValues(tuple(results))

View File

@ -162,7 +162,7 @@ class ClusterCoordinatorMprTest(test.TestCase):
if test_join: if test_join:
ps_coordinator.join() ps_coordinator.join()
if test_schedule: if test_schedule:
while ps_coordinator.cluster._closure_queue._error is None: while ps_coordinator._cluster._closure_queue._error is None:
time.sleep(1) time.sleep(1)
ps_coordinator.schedule(worker_fn) ps_coordinator.schedule(worker_fn)
except errors.UnavailableError: except errors.UnavailableError:

View File

@ -575,7 +575,7 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
def testDatasetsShuffledDifferently(self): def testDatasetsShuffledDifferently(self):
# This test requires at least two workers in the cluster. # This test requires at least two workers in the cluster.
self.assertGreaterEqual(len(self.coordinator.cluster.workers), 2) self.assertGreaterEqual(len(self.coordinator._cluster.workers), 2)
random_seed.set_random_seed(None) random_seed.set_random_seed(None)
@ -587,12 +587,12 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
# Get elements from the first two iterators. # Get elements from the first two iterators.
iterator_1 = distributed_iterator._values[0] iterator_1 = distributed_iterator._values[0]
iterator_1._rebuild_on(self.coordinator.cluster.workers[0]) iterator_1._rebuild_on(self.coordinator._cluster.workers[0])
iterator_1 = iterator_1.fetch() iterator_1 = iterator_1.fetch()
elements_in_iterator_1 = [e.numpy() for e in iterator_1] elements_in_iterator_1 = [e.numpy() for e in iterator_1]
iterator_2 = distributed_iterator._values[1] iterator_2 = distributed_iterator._values[1]
iterator_2._rebuild_on(self.coordinator.cluster.workers[1]) iterator_2._rebuild_on(self.coordinator._cluster.workers[1])
iterator_2 = iterator_2.fetch() iterator_2 = iterator_2.fetch()
elements_in_iterator_2 = [e.numpy() for e in iterator_2] elements_in_iterator_2 = [e.numpy() for e in iterator_2]

View File

@ -175,7 +175,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
model.schedule_training_functions(4) model.schedule_training_functions(4)
# Model does infinite training step, so at this moment, we expect to have 2 # Model does infinite training step, so at this moment, we expect to have 2
# infinite closures inflight, and 2 closures in the queue. # infinite closures inflight, and 2 closures in the queue.
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2: while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< 2):
time.sleep(0.1) time.sleep(0.1)
self.assertFalse(self.cluster_coord.done()) self.assertFalse(self.cluster_coord.done())
self._restart(downtime_secs=2, job="worker") self._restart(downtime_secs=2, job="worker")
@ -356,7 +357,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
# Model does infinite training step, so at this moment, we expect to have 2 # Model does infinite training step, so at this moment, we expect to have 2
# infinite closures inflight, and 8 closures in the queue. # infinite closures inflight, and 8 closures in the queue.
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2: while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< 2):
time.sleep(0.1) time.sleep(0.1)
self.assertFalse(self.cluster_coord.done()) self.assertFalse(self.cluster_coord.done())
self._cluster.kill_task("worker", 0) self._cluster.kill_task("worker", 0)
@ -380,7 +382,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
# Model does infinite training step, so at this moment, we expect to have 2 # Model does infinite training step, so at this moment, we expect to have 2
# infinite closures inflight, and 8 closures in the queue. # infinite closures inflight, and 8 closures in the queue.
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2: while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< 2):
time.sleep(0.1) time.sleep(0.1)
self.assertFalse(self.cluster_coord.done()) self.assertFalse(self.cluster_coord.done())
self._cluster.kill_task("worker", 0) self._cluster.kill_task("worker", 0)