From 4835f6ec4fe31ddf84656231ff9ce31a22578eab Mon Sep 17 00:00:00 2001 From: Rick Chao Date: Wed, 10 Feb 2021 17:44:37 -0800 Subject: [PATCH] PSv2/cfit: tf.distribute changes to accompany compile-fit support. 1) Single instance of ClusterCoordinator given a Strategy object 2) Circular references of ClusterCoordinator and ParameterServerStrategy 3) Attribute of a Strategy indicating if it is supposed to be used with a ClusterCoordinator PiperOrigin-RevId: 356868615 Change-Id: If19600c0101f40a9e840fe71abb848f386e32735 --- .../coordinator/cluster_coordinator.py | 18 +++++++++++++----- .../coordinator/fault_tolerance_test.py | 1 + tensorflow/python/distribute/distribute_lib.py | 3 +++ .../distribute/parameter_server_strategy_v2.py | 2 ++ 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py index 33641add117..07adc633816 100644 --- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py @@ -895,7 +895,6 @@ class Cluster(object): self.workers = [ Worker(i, w, self) for i, w in enumerate(worker_device_strings) ] - self._strategy = strategy def stop(self): """Stop worker, worker preemption threads, and the closure queue.""" @@ -933,13 +932,11 @@ class Cluster(object): Returns: A `RemoteValue` object. """ - self._strategy.extended._being_scheduled = True # pylint: disable=protected-access closure = Closure( function, self._closure_queue._cancellation_mgr, # pylint: disable=protected-access args=args, kwargs=kwargs) - self._strategy.extended._being_scheduled = False # pylint: disable=protected-access self._closure_queue.put(closure) return closure.output_remote_value @@ -1008,6 +1005,14 @@ class ClusterCoordinator(object): are subject to changes. """ + def __new__(cls, strategy): + # `ClusterCoordinator` is kept as a single instance to a given `Strategy`. + # TODO(rchao): Needs a lock for thread-safety + if strategy._cluster_coordinator is None: + strategy._cluster_coordinator = super( + ClusterCoordinator, cls).__new__(cls) + return strategy._cluster_coordinator + def __init__(self, strategy): """Initialization of a `ClusterCoordinator` instance. @@ -1026,7 +1031,7 @@ class ClusterCoordinator(object): "`tf.distribute.experimental.coordinator.ClusterCoordinator` " "currently.") self._strategy = strategy - self._strategy.extended._used_with_coordinator = True + self.strategy.extended._used_with_coordinator = True self._cluster = Cluster(strategy) def __del__(self): @@ -1103,7 +1108,10 @@ class ClusterCoordinator(object): # Slot variables are usually created during function tracing time; thus # `schedule` needs to be called within the `strategy.scope()`. with self.strategy.scope(): - return self._cluster.schedule(fn, args=args, kwargs=kwargs) + self.strategy.extended._being_scheduled = True # pylint: disable=protected-access + remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs) + self.strategy.extended._being_scheduled = False # pylint: disable=protected-access + return remote_value def join(self): """Blocks until all the scheduled functions have finished execution. diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py index e96970418ed..b4d0f09c281 100644 --- a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py +++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py @@ -183,6 +183,7 @@ class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring # Wait for threads to close. self.cluster_coord = None + self.strategy = None gc.collect() time.sleep(1) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 68b9ee1fc79..3be9900d22a 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -798,6 +798,9 @@ class StrategyBase(object): self._mean_reduce_helper_fns = {} self._reduce_sum_fns = {} + # Whether this strategy is designed to work with `ClusterCoordinator`. + self._should_use_with_coordinator = False + @property def extended(self): """`tf.distribute.StrategyExtended` with additional methods.""" diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index c3e1d3ff8b1..3ca48fe0933 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -435,6 +435,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy): self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver, variable_partitioner) self._verify_args_and_config(cluster_resolver) + self._cluster_coordinator = None logging.info( "`tf.distribute.experimental.ParameterServerStrategy` is initialized " "with cluster_spec: %s", cluster_resolver.cluster_spec()) @@ -444,6 +445,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy): super(ParameterServerStrategyV2, self).__init__(self._extended) distribute_lib.distribution_strategy_gauge.get_cell("V2").set( "ParameterServerStrategy") + self._should_use_with_coordinator = True def _connect_to_cluster(self, coordinator_name): if coordinator_name in ["worker", "ps"]: