Distribute Coordinator currently assumes TF_CONFIG to be the only way to configure a strategy. We now allow cluster resolvers to be passed as arguments to instantiate the strategy instead of TF_CONFIG which should be used instead if set by the user.

PiperOrigin-RevId: 301207236
Change-Id: Ibe72c91876d81d588a0dc95041b7543571f74efb
This commit is contained in:
Anjali Sridhar 2020-03-16 11:57:16 -07:00 committed by TensorFlower Gardener
parent a749d9dfa8
commit 32aeb9957e
2 changed files with 13 additions and 6 deletions

View File

@ -166,13 +166,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
container_strategy, container_strategy,
communication, communication,
cluster_resolver): cluster_resolver):
cluster_resolver = cluster_resolver or TFConfigClusterResolver() self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
distribute_lib.StrategyExtendedV1.__init__(self, container_strategy) distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
assert isinstance( assert isinstance(
communication, communication,
cross_device_ops_lib.CollectiveCommunication) cross_device_ops_lib.CollectiveCommunication)
self._communication = communication self._communication = communication
self._initialize_strategy(cluster_resolver) self._initialize_strategy(self._cluster_resolver)
assert isinstance(self._get_cross_device_ops(), assert isinstance(self._get_cross_device_ops(),
cross_device_ops_lib.CollectiveAllReduce) cross_device_ops_lib.CollectiveAllReduce)

View File

@ -750,6 +750,9 @@ def run_distribute_coordinator(worker_fn,
otherwise. otherwise.
""" """
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
environment = tf_config.get("environment", None)
if not cluster_spec: if not cluster_spec:
cluster_spec = tf_config.get("cluster", {}) cluster_spec = tf_config.get("cluster", {})
task_env = tf_config.get("task", {}) task_env = tf_config.get("task", {})
@ -758,11 +761,15 @@ def run_distribute_coordinator(worker_fn,
task_id = int(task_env.get("index", task_id)) task_id = int(task_env.get("index", task_id))
if cluster_spec: if cluster_spec:
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
# TODO(yuefengz): validate cluster_spec. # TODO(yuefengz): validate cluster_spec.
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
rpc_layer = tf_config.get("rpc_layer", rpc_layer) elif hasattr(strategy.extended, "_cluster_resolver"):
environment = tf_config.get("environment", None) cluster_resolver = strategy.extended._cluster_resolver # pylint: disable=protected-access
task_type = cluster_resolver.task_type
task_id = cluster_resolver.task_id
rpc_layer = cluster_resolver.rpc_layer or rpc_layer
environment = cluster_resolver.environment
cluster_spec = cluster_resolver.cluster_spec()
# Setting the session config is necessary for some strategies such as # Setting the session config is necessary for some strategies such as
# CollectiveAllReduceStrategy. # CollectiveAllReduceStrategy.