Adds a TFCONFIGClusterResolver for future interoperability with Distribute Coordinator

PiperOrigin-RevId: 218713679
This commit is contained in:
Frank Chen 2018-10-25 10:51:14 -07:00 committed by TensorFlower Gardener
parent ea8361998d
commit ac2a686c0e
4 changed files with 278 additions and 0 deletions

View File

@ -32,6 +32,7 @@ py_library(
":gce_cluster_resolver_py",
":kubernetes_cluster_resolver_py",
":slurm_cluster_resolver_py",
":tfconfig_cluster_resolver_py",
":tpu_cluster_resolver_py",
"//tensorflow/python:util",
],
@ -56,6 +57,16 @@ py_library(
],
)
py_library(
name = "tfconfig_cluster_resolver_py",
srcs = ["python/training/tfconfig_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training",
],
)
py_library(
name = "tpu_cluster_resolver_py",
srcs = ["python/training/tpu_cluster_resolver.py"],
@ -116,6 +127,22 @@ tf_py_test(
main = "python/training/gce_cluster_resolver_test.py",
)
tf_py_test(
name = "tfconfig_cluster_resolver_py_test",
size = "small",
srcs = ["python/training/tfconfig_cluster_resolver_test.py"],
additional_deps = [
":tfconfig_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
grpc_enabled = True,
main = "python/training/tfconfig_cluster_resolver_test.py",
)
tf_py_test(
name = "tpu_cluster_resolver_py_test",
size = "small",

View File

@ -24,4 +24,5 @@ from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.kubernetes_cluster_resolver import KubernetesClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver

View File

@ -0,0 +1,92 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
_TF_CONFIG_ENV = 'TF_CONFIG'
_SESSION_MASTER_KEY = 'session_master'
class TFConfigClusterResolver(ClusterResolver):
"""Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar."""
def _load_tf_config(self):
return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
def cluster_spec(self):
"""Returns a ClusterSpec based on the TF_CONFIG environment variable.
Returns:
A ClusterSpec with information from the TF_CONFIG environment variable.
"""
tf_config = self._load_tf_config()
if 'cluster' not in tf_config:
return ClusterSpec({})
return ClusterSpec(tf_config['cluster'])
def master(self, task_type=None, task_index=0):
"""Returns the master address to use when creating a TensorFlow session.
Args:
task_type: (String, optional) Overrides and sets the task_type of the
master.
task_index: (Integer, optional) Overrides and sets the task id of the
master.
Returns:
The address of the master.
Raises:
RuntimeError: If the task_type or task_id is not specified and the
`TF_CONFIG` environment variable does not contain a task section.
"""
# If `session_master` is set, just use that.
tf_config = self._load_tf_config()
if _SESSION_MASTER_KEY in tf_config:
return tf_config[_SESSION_MASTER_KEY]
if 'rpc_layer' in tf_config:
rpclayer = '%s://' % tf_config['rpc_layer']
else:
rpclayer = ''
# Return an empty string if we are the only job in the ClusterSpec.
cluster_spec = self.cluster_spec()
if (not cluster_spec.jobs or
(len(cluster_spec.jobs) == 1 and
len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)):
return ''
# We try to auto-detect the task type and id, but uses the user-supplied one
# where available
if not task_type:
if 'task' not in tf_config:
raise RuntimeError('You must either specify a `task_type`, or your '
'TF_CONFIG must contain a `task` section.')
task_type = tf_config['task']['type']
task_index = tf_config['task']['index']
return rpclayer + cluster_spec.task_address(task_type, task_index)

View File

@ -0,0 +1,158 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TFCONFIGClusterResolver."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
class TFConfigClusterResolverTest(test.TestCase):
def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
self.assertProtoEquals(
expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
def testNormalClusterSpecRead(self):
os.environ['TF_CONFIG'] = """
{
"cluster": {
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
},
"task": {
"type": "ps",
"index": 0
}
}
"""
cluster_resolver = TFConfigClusterResolver()
expected_proto = """
job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
tasks { key: 1 value: 'ps1:2222' } }
job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
tasks { key: 1 value: 'worker1:2222' }
tasks { key: 2 value: 'worker2:2222' } }
"""
actual_cluster_spec = cluster_resolver.cluster_spec()
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
def testAutomaticMasterRead(self):
os.environ['TF_CONFIG'] = """
{
"cluster": {
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
},
"task": {
"type": "ps",
"index": 0
}
}
"""
cluster_resolver = TFConfigClusterResolver()
self.assertEqual('ps0:2222', cluster_resolver.master())
def testSpecifiedTaskTypeAndIndexMasterRead(self):
os.environ['TF_CONFIG'] = """
{
"cluster": {
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
},
"task": {
"type": "ps",
"index": 0
}
}
"""
cluster_resolver = TFConfigClusterResolver()
self.assertEqual('worker1:2222', cluster_resolver.master('worker', 1))
def testSessionMasterRead(self):
os.environ['TF_CONFIG'] = """
{
"cluster": {
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
},
"session_master": "sessionmaster:2222",
"task": {
"type": "ps",
"index": 0
}
}
"""
cluster_resolver = TFConfigClusterResolver()
self.assertEqual('sessionmaster:2222', cluster_resolver.master())
def testRpcLayerRead(self):
os.environ['TF_CONFIG'] = """
{
"cluster": {
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
},
"rpc_layer": "grpc",
"task": {
"type": "ps",
"index": 0
}
}
"""
cluster_resolver = TFConfigClusterResolver()
self.assertEqual('grpc://ps0:2222', cluster_resolver.master())
def testZeroItemsInClusterSpecMasterRead(self):
os.environ['TF_CONFIG'] = """
{}
"""
cluster_resolver = TFConfigClusterResolver()
self.assertEqual('', cluster_resolver.master())
def testOneItemInClusterSpecMasterRead(self):
os.environ['TF_CONFIG'] = """
{
"cluster": {
"worker": ["worker0:2222"]
}
}
"""
cluster_resolver = TFConfigClusterResolver()
self.assertEqual('', cluster_resolver.master())
if __name__ == '__main__':
test.main()