PSv2/cfit: tf.distribute changes to accompany compile-fit support.

1) Single instance of ClusterCoordinator given a Strategy object
2) Circular references of ClusterCoordinator and ParameterServerStrategy
3) Attribute of a Strategy indicating if it is supposed to be used with a ClusterCoordinator

PiperOrigin-RevId: 356868615
Change-Id: If19600c0101f40a9e840fe71abb848f386e32735
This commit is contained in:
Rick Chao 2021-02-10 17:44:37 -08:00 committed by TensorFlower Gardener
parent a5ff8a7043
commit 4835f6ec4f
4 changed files with 19 additions and 5 deletions

View File

@ -895,7 +895,6 @@ class Cluster(object):
self.workers = [
Worker(i, w, self) for i, w in enumerate(worker_device_strings)
]
self._strategy = strategy
def stop(self):
"""Stop worker, worker preemption threads, and the closure queue."""
@ -933,13 +932,11 @@ class Cluster(object):
Returns:
A `RemoteValue` object.
"""
self._strategy.extended._being_scheduled = True # pylint: disable=protected-access
closure = Closure(
function,
self._closure_queue._cancellation_mgr, # pylint: disable=protected-access
args=args,
kwargs=kwargs)
self._strategy.extended._being_scheduled = False # pylint: disable=protected-access
self._closure_queue.put(closure)
return closure.output_remote_value
@ -1008,6 +1005,14 @@ class ClusterCoordinator(object):
are subject to changes.
"""
def __new__(cls, strategy):
# `ClusterCoordinator` is kept as a single instance to a given `Strategy`.
# TODO(rchao): Needs a lock for thread-safety
if strategy._cluster_coordinator is None:
strategy._cluster_coordinator = super(
ClusterCoordinator, cls).__new__(cls)
return strategy._cluster_coordinator
def __init__(self, strategy):
"""Initialization of a `ClusterCoordinator` instance.
@ -1026,7 +1031,7 @@ class ClusterCoordinator(object):
"`tf.distribute.experimental.coordinator.ClusterCoordinator` "
"currently.")
self._strategy = strategy
self._strategy.extended._used_with_coordinator = True
self.strategy.extended._used_with_coordinator = True
self._cluster = Cluster(strategy)
def __del__(self):
@ -1103,7 +1108,10 @@ class ClusterCoordinator(object):
# Slot variables are usually created during function tracing time; thus
# `schedule` needs to be called within the `strategy.scope()`.
with self.strategy.scope():
return self._cluster.schedule(fn, args=args, kwargs=kwargs)
self.strategy.extended._being_scheduled = True # pylint: disable=protected-access
remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
self.strategy.extended._being_scheduled = False # pylint: disable=protected-access
return remote_value
def join(self):
"""Blocks until all the scheduled functions have finished execution.

View File

@ -183,6 +183,7 @@ class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring
# Wait for threads to close.
self.cluster_coord = None
self.strategy = None
gc.collect()
time.sleep(1)

View File

@ -798,6 +798,9 @@ class StrategyBase(object):
self._mean_reduce_helper_fns = {}
self._reduce_sum_fns = {}
# Whether this strategy is designed to work with `ClusterCoordinator`.
self._should_use_with_coordinator = False
@property
def extended(self):
"""`tf.distribute.StrategyExtended` with additional methods."""

View File

@ -435,6 +435,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
variable_partitioner)
self._verify_args_and_config(cluster_resolver)
self._cluster_coordinator = None
logging.info(
"`tf.distribute.experimental.ParameterServerStrategy` is initialized "
"with cluster_spec: %s", cluster_resolver.cluster_spec())
@ -444,6 +445,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
super(ParameterServerStrategyV2, self).__init__(self._extended)
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
"ParameterServerStrategy")
self._should_use_with_coordinator = True
def _connect_to_cluster(self, coordinator_name):
if coordinator_name in ["worker", "ps"]: