diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py index 2091eca0b9c..29da67a30a5 100644 --- a/tensorflow/python/training/server_lib.py +++ b/tensorflow/python/training/server_lib.py @@ -307,6 +307,12 @@ class ClusterSpec(object): def __ne__(self, other): return self._cluster_spec != other + def __str__(self): + key_values = self.as_dict() + string_items = [ + repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)] + return "ClusterSpec({" + ", ".join(string_items) + "})" + def as_dict(self): """Returns a dictionary from job names to their tasks. diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py index 26aac787ed4..063044f0d05 100644 --- a/tensorflow/python/training/server_lib_test.py +++ b/tensorflow/python/training/server_lib_test.py @@ -421,6 +421,17 @@ class ServerDefTest(test.TestCase): class ClusterSpecTest(test.TestCase): + def testStringConversion(self): + cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:1111"], + "worker": ["worker0:3333", "worker1:4444"] + }) + + expected_str = ( + "ClusterSpec({'ps': ['ps0:1111'], 'worker': ['worker0:3333', " + "'worker1:4444']})") + self.assertEqual(expected_str, str(cluster_spec)) + def testProtoDictDefEquivalences(self): cluster_spec = server_lib.ClusterSpec({ "ps": ["ps0:2222", "ps1:2222"],