PSv2/cfit: tf.distribute changes to accompany compile-fit support.
1) Single instance of ClusterCoordinator given a Strategy object 2) Circular references of ClusterCoordinator and ParameterServerStrategy 3) Attribute of a Strategy indicating if it is supposed to be used with a ClusterCoordinator PiperOrigin-RevId: 356868615 Change-Id: If19600c0101f40a9e840fe71abb848f386e32735
This commit is contained in:
parent
a5ff8a7043
commit
4835f6ec4f
@ -895,7 +895,6 @@ class Cluster(object):
|
||||
self.workers = [
|
||||
Worker(i, w, self) for i, w in enumerate(worker_device_strings)
|
||||
]
|
||||
self._strategy = strategy
|
||||
|
||||
def stop(self):
|
||||
"""Stop worker, worker preemption threads, and the closure queue."""
|
||||
@ -933,13 +932,11 @@ class Cluster(object):
|
||||
Returns:
|
||||
A `RemoteValue` object.
|
||||
"""
|
||||
self._strategy.extended._being_scheduled = True # pylint: disable=protected-access
|
||||
closure = Closure(
|
||||
function,
|
||||
self._closure_queue._cancellation_mgr, # pylint: disable=protected-access
|
||||
args=args,
|
||||
kwargs=kwargs)
|
||||
self._strategy.extended._being_scheduled = False # pylint: disable=protected-access
|
||||
self._closure_queue.put(closure)
|
||||
return closure.output_remote_value
|
||||
|
||||
@ -1008,6 +1005,14 @@ class ClusterCoordinator(object):
|
||||
are subject to changes.
|
||||
"""
|
||||
|
||||
def __new__(cls, strategy):
|
||||
# `ClusterCoordinator` is kept as a single instance to a given `Strategy`.
|
||||
# TODO(rchao): Needs a lock for thread-safety
|
||||
if strategy._cluster_coordinator is None:
|
||||
strategy._cluster_coordinator = super(
|
||||
ClusterCoordinator, cls).__new__(cls)
|
||||
return strategy._cluster_coordinator
|
||||
|
||||
def __init__(self, strategy):
|
||||
"""Initialization of a `ClusterCoordinator` instance.
|
||||
|
||||
@ -1026,7 +1031,7 @@ class ClusterCoordinator(object):
|
||||
"`tf.distribute.experimental.coordinator.ClusterCoordinator` "
|
||||
"currently.")
|
||||
self._strategy = strategy
|
||||
self._strategy.extended._used_with_coordinator = True
|
||||
self.strategy.extended._used_with_coordinator = True
|
||||
self._cluster = Cluster(strategy)
|
||||
|
||||
def __del__(self):
|
||||
@ -1103,7 +1108,10 @@ class ClusterCoordinator(object):
|
||||
# Slot variables are usually created during function tracing time; thus
|
||||
# `schedule` needs to be called within the `strategy.scope()`.
|
||||
with self.strategy.scope():
|
||||
return self._cluster.schedule(fn, args=args, kwargs=kwargs)
|
||||
self.strategy.extended._being_scheduled = True # pylint: disable=protected-access
|
||||
remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
|
||||
self.strategy.extended._being_scheduled = False # pylint: disable=protected-access
|
||||
return remote_value
|
||||
|
||||
def join(self):
|
||||
"""Blocks until all the scheduled functions have finished execution.
|
||||
|
@ -183,6 +183,7 @@ class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring
|
||||
|
||||
# Wait for threads to close.
|
||||
self.cluster_coord = None
|
||||
self.strategy = None
|
||||
gc.collect()
|
||||
time.sleep(1)
|
||||
|
||||
|
@ -798,6 +798,9 @@ class StrategyBase(object):
|
||||
self._mean_reduce_helper_fns = {}
|
||||
self._reduce_sum_fns = {}
|
||||
|
||||
# Whether this strategy is designed to work with `ClusterCoordinator`.
|
||||
self._should_use_with_coordinator = False
|
||||
|
||||
@property
|
||||
def extended(self):
|
||||
"""`tf.distribute.StrategyExtended` with additional methods."""
|
||||
|
@ -435,6 +435,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
|
||||
self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
|
||||
variable_partitioner)
|
||||
self._verify_args_and_config(cluster_resolver)
|
||||
self._cluster_coordinator = None
|
||||
logging.info(
|
||||
"`tf.distribute.experimental.ParameterServerStrategy` is initialized "
|
||||
"with cluster_spec: %s", cluster_resolver.cluster_spec())
|
||||
@ -444,6 +445,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
|
||||
super(ParameterServerStrategyV2, self).__init__(self._extended)
|
||||
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
|
||||
"ParameterServerStrategy")
|
||||
self._should_use_with_coordinator = True
|
||||
|
||||
def _connect_to_cluster(self, coordinator_name):
|
||||
if coordinator_name in ["worker", "ps"]:
|
||||
|
Loading…
Reference in New Issue
Block a user