PSv2: Privatize ClusterCoordinator.cluster attribute since Cluster is not meant for public use.
PiperOrigin-RevId: 348168836 Change-Id: I9abbd50de59861f02139d7472b397b937e1a4e85
This commit is contained in:
parent
b5c8d770f5
commit
e8f1a46dc1
@ -994,7 +994,7 @@ class ClusterCoordinator(object):
|
|||||||
"currently.")
|
"currently.")
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
self._strategy.extended._used_with_coordinator = True
|
self._strategy.extended._used_with_coordinator = True
|
||||||
self.cluster = Cluster(strategy)
|
self._cluster = Cluster(strategy)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strategy(self):
|
def strategy(self):
|
||||||
@ -1067,7 +1067,7 @@ class ClusterCoordinator(object):
|
|||||||
# Slot variables are usually created during function tracing time; thus
|
# Slot variables are usually created during function tracing time; thus
|
||||||
# `schedule` needs to be called within the `strategy.scope()`.
|
# `schedule` needs to be called within the `strategy.scope()`.
|
||||||
with self.strategy.scope():
|
with self.strategy.scope():
|
||||||
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
|
return self._cluster.schedule(fn, args=args, kwargs=kwargs)
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
"""Blocks until all the scheduled functions have finished execution.
|
"""Blocks until all the scheduled functions have finished execution.
|
||||||
@ -1088,7 +1088,7 @@ class ClusterCoordinator(object):
|
|||||||
previously scheduled function since the last time an error was thrown or
|
previously scheduled function since the last time an error was thrown or
|
||||||
since the beginning of the program.
|
since the beginning of the program.
|
||||||
"""
|
"""
|
||||||
self.cluster.join()
|
self._cluster.join()
|
||||||
|
|
||||||
def done(self):
|
def done(self):
|
||||||
"""Returns whether all the scheduled functions have finished execution.
|
"""Returns whether all the scheduled functions have finished execution.
|
||||||
@ -1106,7 +1106,7 @@ class ClusterCoordinator(object):
|
|||||||
previously scheduled function since the last time an error was thrown or
|
previously scheduled function since the last time an error was thrown or
|
||||||
since the beginning of the program.
|
since the beginning of the program.
|
||||||
"""
|
"""
|
||||||
return self.cluster.done()
|
return self._cluster.done()
|
||||||
|
|
||||||
def create_per_worker_dataset(self, dataset_fn):
|
def create_per_worker_dataset(self, dataset_fn):
|
||||||
"""Create dataset on workers by calling `dataset_fn` on worker devices.
|
"""Create dataset on workers by calling `dataset_fn` on worker devices.
|
||||||
@ -1168,7 +1168,7 @@ class ClusterCoordinator(object):
|
|||||||
iterators (that are on the workers).
|
iterators (that are on the workers).
|
||||||
"""
|
"""
|
||||||
input_workers = input_lib.InputWorkers([
|
input_workers = input_lib.InputWorkers([
|
||||||
(w.device_name, [w.device_name]) for w in self.cluster.workers
|
(w.device_name, [w.device_name]) for w in self._cluster.workers
|
||||||
])
|
])
|
||||||
|
|
||||||
return _PerWorkerDistributedDataset(dataset_fn, input_workers, self)
|
return _PerWorkerDistributedDataset(dataset_fn, input_workers, self)
|
||||||
@ -1191,7 +1191,7 @@ class ClusterCoordinator(object):
|
|||||||
objects.
|
objects.
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for w in self.cluster.workers:
|
for w in self._cluster.workers:
|
||||||
results.append(w._create_resource(fn, args=args, kwargs=kwargs)) # pylint: disable=protected-access
|
results.append(w._create_resource(fn, args=args, kwargs=kwargs)) # pylint: disable=protected-access
|
||||||
return PerWorkerValues(tuple(results))
|
return PerWorkerValues(tuple(results))
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ class ClusterCoordinatorMprTest(test.TestCase):
|
|||||||
if test_join:
|
if test_join:
|
||||||
ps_coordinator.join()
|
ps_coordinator.join()
|
||||||
if test_schedule:
|
if test_schedule:
|
||||||
while ps_coordinator.cluster._closure_queue._error is None:
|
while ps_coordinator._cluster._closure_queue._error is None:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
ps_coordinator.schedule(worker_fn)
|
ps_coordinator.schedule(worker_fn)
|
||||||
except errors.UnavailableError:
|
except errors.UnavailableError:
|
||||||
|
@ -575,7 +575,7 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
|
|||||||
|
|
||||||
def testDatasetsShuffledDifferently(self):
|
def testDatasetsShuffledDifferently(self):
|
||||||
# This test requires at least two workers in the cluster.
|
# This test requires at least two workers in the cluster.
|
||||||
self.assertGreaterEqual(len(self.coordinator.cluster.workers), 2)
|
self.assertGreaterEqual(len(self.coordinator._cluster.workers), 2)
|
||||||
|
|
||||||
random_seed.set_random_seed(None)
|
random_seed.set_random_seed(None)
|
||||||
|
|
||||||
@ -587,12 +587,12 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
|
|||||||
|
|
||||||
# Get elements from the first two iterators.
|
# Get elements from the first two iterators.
|
||||||
iterator_1 = distributed_iterator._values[0]
|
iterator_1 = distributed_iterator._values[0]
|
||||||
iterator_1._rebuild_on(self.coordinator.cluster.workers[0])
|
iterator_1._rebuild_on(self.coordinator._cluster.workers[0])
|
||||||
iterator_1 = iterator_1.fetch()
|
iterator_1 = iterator_1.fetch()
|
||||||
elements_in_iterator_1 = [e.numpy() for e in iterator_1]
|
elements_in_iterator_1 = [e.numpy() for e in iterator_1]
|
||||||
|
|
||||||
iterator_2 = distributed_iterator._values[1]
|
iterator_2 = distributed_iterator._values[1]
|
||||||
iterator_2._rebuild_on(self.coordinator.cluster.workers[1])
|
iterator_2._rebuild_on(self.coordinator._cluster.workers[1])
|
||||||
iterator_2 = iterator_2.fetch()
|
iterator_2 = iterator_2.fetch()
|
||||||
elements_in_iterator_2 = [e.numpy() for e in iterator_2]
|
elements_in_iterator_2 = [e.numpy() for e in iterator_2]
|
||||||
|
|
||||||
|
@ -175,7 +175,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
model.schedule_training_functions(4)
|
model.schedule_training_functions(4)
|
||||||
# Model does infinite training step, so at this moment, we expect to have 2
|
# Model does infinite training step, so at this moment, we expect to have 2
|
||||||
# infinite closures inflight, and 2 closures in the queue.
|
# infinite closures inflight, and 2 closures in the queue.
|
||||||
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
|
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
||||||
|
< 2):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self.assertFalse(self.cluster_coord.done())
|
self.assertFalse(self.cluster_coord.done())
|
||||||
self._restart(downtime_secs=2, job="worker")
|
self._restart(downtime_secs=2, job="worker")
|
||||||
@ -356,7 +357,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
|
|
||||||
# Model does infinite training step, so at this moment, we expect to have 2
|
# Model does infinite training step, so at this moment, we expect to have 2
|
||||||
# infinite closures inflight, and 8 closures in the queue.
|
# infinite closures inflight, and 8 closures in the queue.
|
||||||
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
|
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
||||||
|
< 2):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self.assertFalse(self.cluster_coord.done())
|
self.assertFalse(self.cluster_coord.done())
|
||||||
self._cluster.kill_task("worker", 0)
|
self._cluster.kill_task("worker", 0)
|
||||||
@ -380,7 +382,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
|
|
||||||
# Model does infinite training step, so at this moment, we expect to have 2
|
# Model does infinite training step, so at this moment, we expect to have 2
|
||||||
# infinite closures inflight, and 8 closures in the queue.
|
# infinite closures inflight, and 8 closures in the queue.
|
||||||
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
|
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
||||||
|
< 2):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self.assertFalse(self.cluster_coord.done())
|
self.assertFalse(self.cluster_coord.done())
|
||||||
self._cluster.kill_task("worker", 0)
|
self._cluster.kill_task("worker", 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user