Adds a TFCONFIGClusterResolver for future interoperability with Distribute Coordinator
PiperOrigin-RevId: 218713679
This commit is contained in:
parent
ea8361998d
commit
ac2a686c0e
@ -32,6 +32,7 @@ py_library(
|
||||
":gce_cluster_resolver_py",
|
||||
":kubernetes_cluster_resolver_py",
|
||||
":slurm_cluster_resolver_py",
|
||||
":tfconfig_cluster_resolver_py",
|
||||
":tpu_cluster_resolver_py",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
@ -56,6 +57,16 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tfconfig_cluster_resolver_py",
|
||||
srcs = ["python/training/tfconfig_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_cluster_resolver_py",
|
||||
srcs = ["python/training/tpu_cluster_resolver.py"],
|
||||
@ -116,6 +127,22 @@ tf_py_test(
|
||||
main = "python/training/gce_cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tfconfig_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["python/training/tfconfig_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":tfconfig_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "python/training/tfconfig_cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
|
@ -24,4 +24,5 @@ from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import
|
||||
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.kubernetes_cluster_resolver import KubernetesClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
|
||||
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables."""
|
||||
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
_TF_CONFIG_ENV = 'TF_CONFIG'
|
||||
_SESSION_MASTER_KEY = 'session_master'
|
||||
|
||||
|
||||
class TFConfigClusterResolver(ClusterResolver):
|
||||
"""Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar."""
|
||||
|
||||
def _load_tf_config(self):
|
||||
return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec based on the TF_CONFIG environment variable.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec with information from the TF_CONFIG environment variable.
|
||||
"""
|
||||
tf_config = self._load_tf_config()
|
||||
if 'cluster' not in tf_config:
|
||||
return ClusterSpec({})
|
||||
return ClusterSpec(tf_config['cluster'])
|
||||
|
||||
def master(self, task_type=None, task_index=0):
|
||||
"""Returns the master address to use when creating a TensorFlow session.
|
||||
|
||||
Args:
|
||||
task_type: (String, optional) Overrides and sets the task_type of the
|
||||
master.
|
||||
task_index: (Integer, optional) Overrides and sets the task id of the
|
||||
master.
|
||||
|
||||
Returns:
|
||||
The address of the master.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the task_type or task_id is not specified and the
|
||||
`TF_CONFIG` environment variable does not contain a task section.
|
||||
"""
|
||||
|
||||
# If `session_master` is set, just use that.
|
||||
tf_config = self._load_tf_config()
|
||||
if _SESSION_MASTER_KEY in tf_config:
|
||||
return tf_config[_SESSION_MASTER_KEY]
|
||||
|
||||
if 'rpc_layer' in tf_config:
|
||||
rpclayer = '%s://' % tf_config['rpc_layer']
|
||||
else:
|
||||
rpclayer = ''
|
||||
|
||||
# Return an empty string if we are the only job in the ClusterSpec.
|
||||
cluster_spec = self.cluster_spec()
|
||||
if (not cluster_spec.jobs or
|
||||
(len(cluster_spec.jobs) == 1 and
|
||||
len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)):
|
||||
return ''
|
||||
|
||||
# We try to auto-detect the task type and id, but uses the user-supplied one
|
||||
# where available
|
||||
if not task_type:
|
||||
if 'task' not in tf_config:
|
||||
raise RuntimeError('You must either specify a `task_type`, or your '
|
||||
'TF_CONFIG must contain a `task` section.')
|
||||
task_type = tf_config['task']['type']
|
||||
task_index = tf_config['task']['index']
|
||||
|
||||
return rpclayer + cluster_spec.task_address(task_type, task_index)
|
@ -0,0 +1,158 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TFCONFIGClusterResolver."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
class TFConfigClusterResolverTest(test.TestCase):
|
||||
|
||||
def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
|
||||
self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
|
||||
self.assertProtoEquals(
|
||||
expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
|
||||
self.assertProtoEquals(
|
||||
expected_proto,
|
||||
server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
|
||||
self.assertProtoEquals(
|
||||
expected_proto,
|
||||
server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
|
||||
|
||||
def testNormalClusterSpecRead(self):
|
||||
os.environ['TF_CONFIG'] = """
|
||||
{
|
||||
"cluster": {
|
||||
"ps": ["ps0:2222", "ps1:2222"],
|
||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
|
||||
},
|
||||
"task": {
|
||||
"type": "ps",
|
||||
"index": 0
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
cluster_resolver = TFConfigClusterResolver()
|
||||
expected_proto = """
|
||||
job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
|
||||
tasks { key: 1 value: 'ps1:2222' } }
|
||||
job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
|
||||
tasks { key: 1 value: 'worker1:2222' }
|
||||
tasks { key: 2 value: 'worker2:2222' } }
|
||||
"""
|
||||
actual_cluster_spec = cluster_resolver.cluster_spec()
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||
|
||||
def testAutomaticMasterRead(self):
|
||||
os.environ['TF_CONFIG'] = """
|
||||
{
|
||||
"cluster": {
|
||||
"ps": ["ps0:2222", "ps1:2222"],
|
||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
|
||||
},
|
||||
"task": {
|
||||
"type": "ps",
|
||||
"index": 0
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
cluster_resolver = TFConfigClusterResolver()
|
||||
self.assertEqual('ps0:2222', cluster_resolver.master())
|
||||
|
||||
def testSpecifiedTaskTypeAndIndexMasterRead(self):
|
||||
os.environ['TF_CONFIG'] = """
|
||||
{
|
||||
"cluster": {
|
||||
"ps": ["ps0:2222", "ps1:2222"],
|
||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
|
||||
},
|
||||
"task": {
|
||||
"type": "ps",
|
||||
"index": 0
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
cluster_resolver = TFConfigClusterResolver()
|
||||
self.assertEqual('worker1:2222', cluster_resolver.master('worker', 1))
|
||||
|
||||
def testSessionMasterRead(self):
|
||||
os.environ['TF_CONFIG'] = """
|
||||
{
|
||||
"cluster": {
|
||||
"ps": ["ps0:2222", "ps1:2222"],
|
||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
|
||||
},
|
||||
"session_master": "sessionmaster:2222",
|
||||
"task": {
|
||||
"type": "ps",
|
||||
"index": 0
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
cluster_resolver = TFConfigClusterResolver()
|
||||
self.assertEqual('sessionmaster:2222', cluster_resolver.master())
|
||||
|
||||
def testRpcLayerRead(self):
|
||||
os.environ['TF_CONFIG'] = """
|
||||
{
|
||||
"cluster": {
|
||||
"ps": ["ps0:2222", "ps1:2222"],
|
||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
|
||||
},
|
||||
"rpc_layer": "grpc",
|
||||
"task": {
|
||||
"type": "ps",
|
||||
"index": 0
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
cluster_resolver = TFConfigClusterResolver()
|
||||
self.assertEqual('grpc://ps0:2222', cluster_resolver.master())
|
||||
|
||||
def testZeroItemsInClusterSpecMasterRead(self):
|
||||
os.environ['TF_CONFIG'] = """
|
||||
{}
|
||||
"""
|
||||
|
||||
cluster_resolver = TFConfigClusterResolver()
|
||||
self.assertEqual('', cluster_resolver.master())
|
||||
|
||||
def testOneItemInClusterSpecMasterRead(self):
|
||||
os.environ['TF_CONFIG'] = """
|
||||
{
|
||||
"cluster": {
|
||||
"worker": ["worker0:2222"]
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
cluster_resolver = TFConfigClusterResolver()
|
||||
self.assertEqual('', cluster_resolver.master())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user