Distribute Coordinator currently assumes TF_CONFIG to be the only way to configure a strategy. We now allow cluster resolvers to be passed as arguments to instantiate the strategy instead of TF_CONFIG which should be used instead if set by the user.
PiperOrigin-RevId: 301207236 Change-Id: Ibe72c91876d81d588a0dc95041b7543571f74efb
This commit is contained in:
parent
a749d9dfa8
commit
32aeb9957e
tensorflow/python/distribute
@ -166,13 +166,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
container_strategy,
|
||||
communication,
|
||||
cluster_resolver):
|
||||
cluster_resolver = cluster_resolver or TFConfigClusterResolver()
|
||||
self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
|
||||
distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
|
||||
assert isinstance(
|
||||
communication,
|
||||
cross_device_ops_lib.CollectiveCommunication)
|
||||
self._communication = communication
|
||||
self._initialize_strategy(cluster_resolver)
|
||||
self._initialize_strategy(self._cluster_resolver)
|
||||
assert isinstance(self._get_cross_device_ops(),
|
||||
cross_device_ops_lib.CollectiveAllReduce)
|
||||
|
||||
|
@ -750,6 +750,9 @@ def run_distribute_coordinator(worker_fn,
|
||||
otherwise.
|
||||
"""
|
||||
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
||||
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
|
||||
environment = tf_config.get("environment", None)
|
||||
|
||||
if not cluster_spec:
|
||||
cluster_spec = tf_config.get("cluster", {})
|
||||
task_env = tf_config.get("task", {})
|
||||
@ -758,11 +761,15 @@ def run_distribute_coordinator(worker_fn,
|
||||
task_id = int(task_env.get("index", task_id))
|
||||
|
||||
if cluster_spec:
|
||||
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
|
||||
# TODO(yuefengz): validate cluster_spec.
|
||||
|
||||
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
|
||||
environment = tf_config.get("environment", None)
|
||||
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
|
||||
elif hasattr(strategy.extended, "_cluster_resolver"):
|
||||
cluster_resolver = strategy.extended._cluster_resolver # pylint: disable=protected-access
|
||||
task_type = cluster_resolver.task_type
|
||||
task_id = cluster_resolver.task_id
|
||||
rpc_layer = cluster_resolver.rpc_layer or rpc_layer
|
||||
environment = cluster_resolver.environment
|
||||
cluster_spec = cluster_resolver.cluster_spec()
|
||||
|
||||
# Setting the session config is necessary for some strategies such as
|
||||
# CollectiveAllReduceStrategy.
|
||||
|
Loading…
Reference in New Issue
Block a user