PSv2: Privatize ClusterCoordinator.cluster attribute since Cluster is not meant for public use.
PiperOrigin-RevId: 348168836 Change-Id: I9abbd50de59861f02139d7472b397b937e1a4e85
This commit is contained in:
parent
b5c8d770f5
commit
e8f1a46dc1
tensorflow/python/distribute/coordinator
@ -994,7 +994,7 @@ class ClusterCoordinator(object):
|
||||
"currently.")
|
||||
self._strategy = strategy
|
||||
self._strategy.extended._used_with_coordinator = True
|
||||
self.cluster = Cluster(strategy)
|
||||
self._cluster = Cluster(strategy)
|
||||
|
||||
@property
|
||||
def strategy(self):
|
||||
@ -1067,7 +1067,7 @@ class ClusterCoordinator(object):
|
||||
# Slot variables are usually created during function tracing time; thus
|
||||
# `schedule` needs to be called within the `strategy.scope()`.
|
||||
with self.strategy.scope():
|
||||
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
|
||||
return self._cluster.schedule(fn, args=args, kwargs=kwargs)
|
||||
|
||||
def join(self):
|
||||
"""Blocks until all the scheduled functions have finished execution.
|
||||
@ -1088,7 +1088,7 @@ class ClusterCoordinator(object):
|
||||
previously scheduled function since the last time an error was thrown or
|
||||
since the beginning of the program.
|
||||
"""
|
||||
self.cluster.join()
|
||||
self._cluster.join()
|
||||
|
||||
def done(self):
|
||||
"""Returns whether all the scheduled functions have finished execution.
|
||||
@ -1106,7 +1106,7 @@ class ClusterCoordinator(object):
|
||||
previously scheduled function since the last time an error was thrown or
|
||||
since the beginning of the program.
|
||||
"""
|
||||
return self.cluster.done()
|
||||
return self._cluster.done()
|
||||
|
||||
def create_per_worker_dataset(self, dataset_fn):
|
||||
"""Create dataset on workers by calling `dataset_fn` on worker devices.
|
||||
@ -1168,7 +1168,7 @@ class ClusterCoordinator(object):
|
||||
iterators (that are on the workers).
|
||||
"""
|
||||
input_workers = input_lib.InputWorkers([
|
||||
(w.device_name, [w.device_name]) for w in self.cluster.workers
|
||||
(w.device_name, [w.device_name]) for w in self._cluster.workers
|
||||
])
|
||||
|
||||
return _PerWorkerDistributedDataset(dataset_fn, input_workers, self)
|
||||
@ -1191,7 +1191,7 @@ class ClusterCoordinator(object):
|
||||
objects.
|
||||
"""
|
||||
results = []
|
||||
for w in self.cluster.workers:
|
||||
for w in self._cluster.workers:
|
||||
results.append(w._create_resource(fn, args=args, kwargs=kwargs)) # pylint: disable=protected-access
|
||||
return PerWorkerValues(tuple(results))
|
||||
|
||||
|
@ -162,7 +162,7 @@ class ClusterCoordinatorMprTest(test.TestCase):
|
||||
if test_join:
|
||||
ps_coordinator.join()
|
||||
if test_schedule:
|
||||
while ps_coordinator.cluster._closure_queue._error is None:
|
||||
while ps_coordinator._cluster._closure_queue._error is None:
|
||||
time.sleep(1)
|
||||
ps_coordinator.schedule(worker_fn)
|
||||
except errors.UnavailableError:
|
||||
|
@ -575,7 +575,7 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
|
||||
|
||||
def testDatasetsShuffledDifferently(self):
|
||||
# This test requires at least two workers in the cluster.
|
||||
self.assertGreaterEqual(len(self.coordinator.cluster.workers), 2)
|
||||
self.assertGreaterEqual(len(self.coordinator._cluster.workers), 2)
|
||||
|
||||
random_seed.set_random_seed(None)
|
||||
|
||||
@ -587,12 +587,12 @@ class ClusterCoordinatorTest(TestCaseWithErrorReportingThread):
|
||||
|
||||
# Get elements from the first two iterators.
|
||||
iterator_1 = distributed_iterator._values[0]
|
||||
iterator_1._rebuild_on(self.coordinator.cluster.workers[0])
|
||||
iterator_1._rebuild_on(self.coordinator._cluster.workers[0])
|
||||
iterator_1 = iterator_1.fetch()
|
||||
elements_in_iterator_1 = [e.numpy() for e in iterator_1]
|
||||
|
||||
iterator_2 = distributed_iterator._values[1]
|
||||
iterator_2._rebuild_on(self.coordinator.cluster.workers[1])
|
||||
iterator_2._rebuild_on(self.coordinator._cluster.workers[1])
|
||||
iterator_2 = iterator_2.fetch()
|
||||
elements_in_iterator_2 = [e.numpy() for e in iterator_2]
|
||||
|
||||
|
@ -175,7 +175,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
model.schedule_training_functions(4)
|
||||
# Model does infinite training step, so at this moment, we expect to have 2
|
||||
# infinite closures inflight, and 2 closures in the queue.
|
||||
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
|
||||
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
||||
< 2):
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(self.cluster_coord.done())
|
||||
self._restart(downtime_secs=2, job="worker")
|
||||
@ -356,7 +357,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
|
||||
# Model does infinite training step, so at this moment, we expect to have 2
|
||||
# infinite closures inflight, and 8 closures in the queue.
|
||||
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
|
||||
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
||||
< 2):
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(self.cluster_coord.done())
|
||||
self._cluster.kill_task("worker", 0)
|
||||
@ -380,7 +382,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
|
||||
# Model does infinite training step, so at this moment, we expect to have 2
|
||||
# infinite closures inflight, and 8 closures in the queue.
|
||||
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
|
||||
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
||||
< 2):
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(self.cluster_coord.done())
|
||||
self._cluster.kill_task("worker", 0)
|
||||
|
Loading…
Reference in New Issue
Block a user