Distribute Coordinator currently assumes TF_CONFIG to be the only way to configure a strategy. We now allow cluster resolvers to be passed as arguments to instantiate the strategy instead of TF_CONFIG which should be used instead if set by the user.
PiperOrigin-RevId: 301207236 Change-Id: Ibe72c91876d81d588a0dc95041b7543571f74efb
This commit is contained in:
parent
a749d9dfa8
commit
32aeb9957e
@ -166,13 +166,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
container_strategy,
|
container_strategy,
|
||||||
communication,
|
communication,
|
||||||
cluster_resolver):
|
cluster_resolver):
|
||||||
cluster_resolver = cluster_resolver or TFConfigClusterResolver()
|
self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
|
||||||
distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
|
distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
communication,
|
communication,
|
||||||
cross_device_ops_lib.CollectiveCommunication)
|
cross_device_ops_lib.CollectiveCommunication)
|
||||||
self._communication = communication
|
self._communication = communication
|
||||||
self._initialize_strategy(cluster_resolver)
|
self._initialize_strategy(self._cluster_resolver)
|
||||||
assert isinstance(self._get_cross_device_ops(),
|
assert isinstance(self._get_cross_device_ops(),
|
||||||
cross_device_ops_lib.CollectiveAllReduce)
|
cross_device_ops_lib.CollectiveAllReduce)
|
||||||
|
|
||||||
|
@ -750,6 +750,9 @@ def run_distribute_coordinator(worker_fn,
|
|||||||
otherwise.
|
otherwise.
|
||||||
"""
|
"""
|
||||||
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
||||||
|
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
|
||||||
|
environment = tf_config.get("environment", None)
|
||||||
|
|
||||||
if not cluster_spec:
|
if not cluster_spec:
|
||||||
cluster_spec = tf_config.get("cluster", {})
|
cluster_spec = tf_config.get("cluster", {})
|
||||||
task_env = tf_config.get("task", {})
|
task_env = tf_config.get("task", {})
|
||||||
@ -758,11 +761,15 @@ def run_distribute_coordinator(worker_fn,
|
|||||||
task_id = int(task_env.get("index", task_id))
|
task_id = int(task_env.get("index", task_id))
|
||||||
|
|
||||||
if cluster_spec:
|
if cluster_spec:
|
||||||
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
|
|
||||||
# TODO(yuefengz): validate cluster_spec.
|
# TODO(yuefengz): validate cluster_spec.
|
||||||
|
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
|
||||||
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
|
elif hasattr(strategy.extended, "_cluster_resolver"):
|
||||||
environment = tf_config.get("environment", None)
|
cluster_resolver = strategy.extended._cluster_resolver # pylint: disable=protected-access
|
||||||
|
task_type = cluster_resolver.task_type
|
||||||
|
task_id = cluster_resolver.task_id
|
||||||
|
rpc_layer = cluster_resolver.rpc_layer or rpc_layer
|
||||||
|
environment = cluster_resolver.environment
|
||||||
|
cluster_spec = cluster_resolver.cluster_spec()
|
||||||
|
|
||||||
# Setting the session config is necessary for some strategies such as
|
# Setting the session config is necessary for some strategies such as
|
||||||
# CollectiveAllReduceStrategy.
|
# CollectiveAllReduceStrategy.
|
||||||
|
Loading…
Reference in New Issue
Block a user