Split out most of k8s_tensorflow into a library and add a way to pass any

environment variables. Add benchmark_util library that would use environemnt
variable to decide on a storage location.
Change: 147890534
This commit is contained in:
A. Unique TensorFlower 2017-02-17 15:58:01 -08:00 committed by TensorFlower Gardener
parent a6421c4dda
commit b06281ba47
7 changed files with 631 additions and 219 deletions

View File

@ -0,0 +1,23 @@
# Python tools for running distributed benchmarks.
licenses(["notice"]) # Apache 2.0
py_library(
name = "benchmark_util_lib",
srcs = ["benchmark_util.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:platform",
],
)
py_test(
name = "benchmark_util_test",
srcs = ["benchmark_util_test.py"],
srcs_version = "PY2AND3",
deps = [
":benchmark_util_lib",
"//tensorflow/python:client_testlib",
],
)

View File

@ -0,0 +1,77 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides helper functions for distributed benchmarks running on Jenkins."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import calendar
from collections import namedtuple
import os
from google.protobuf import json_format
from tensorflow.core.util import test_log_pb2
from tensorflow.python.platform import gfile
_OUTPUT_FILE_ENV_VAR = 'TF_DIST_BENCHMARK_RESULTS_FILE'
_TEST_NAME_ENV_VAR = 'TF_DIST_BENCHMARK_NAME'
# Represents a single timing entry where
# - name is a string
# - timing is the latency to track (for e.g. mean time per iter)
# - iters is the number of iterations
TimingEntry = namedtuple(
'TimingEntry', ['name', 'timing', 'iters'])
def store_data_in_json(timing_entries, start_time, output_file=None):
"""Stores benchmark results in JSON format.
Args:
timing_entries: list of TimingEntry objects.
start_time: (datetime) start time of the test run.
output_file: if specified, writes benchmark results to output_file.
If not specified, writes results to the file specified by
BENCHMARK_RESULTS_FILE environment variable.
Raises:
ValueError: when neither output_file is passed in nor
BENCHMARK_RESULTS_FILE is set.
"""
test_result = test_log_pb2.TestResults(
start_time=calendar.timegm(start_time.timetuple()))
if not output_file:
if _OUTPUT_FILE_ENV_VAR not in os.environ:
raise ValueError('Could not determine location to store results at.')
output_file = os.environ[_OUTPUT_FILE_ENV_VAR]
with gfile.Open(output_file, 'wb') as jsonfile:
if _TEST_NAME_ENV_VAR in os.environ:
test_result.name = os.environ['POD_NAME_PREFIX']
else:
test_result.name = 'TestBenchmark'
for timing_entry in timing_entries:
test_result.entries.entry.add(
name=timing_entry.name,
iters=timing_entry.iters,
wall_time=timing_entry.timing
)
json_test_results = json_format.MessageToJson(test_result)
jsonfile.write(json_test_results)

View File

@ -0,0 +1,58 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.tools.dist_test.python.benchmark_util."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
import json
import os
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
from tensorflow.tools.dist_test.python import benchmark_util
class BenchmarkUtilTest(googletest.TestCase):
def testStoreDataWithNoEntries(self):
output_file = os.path.join(test.get_temp_dir(), 'test_output1.json')
timing_entries = []
benchmark_util.store_data_in_json(
timing_entries, datetime.date(2017, 1, 1), output_file)
json_output = json.loads(open(output_file, 'r').read())
self.assertEquals('TestBenchmark', json_output['name'])
self.assertEquals(u'1483228800', json_output['startTime'])
def testStoreDataWithEntries(self):
output_file = os.path.join(test.get_temp_dir(), 'test_output2.json')
timing_entries = [
benchmark_util.TimingEntry('test', 0.1, 1)]
benchmark_util.store_data_in_json(
timing_entries, datetime.date(2017, 1, 1), output_file)
json_output = json.loads(open(output_file, 'r').read())
self.assertEquals(1, len(json_output['entries']['entry']))
self.assertEquals('test', json_output['entries']['entry'][0]['name'])
self.assertEquals(0.1, json_output['entries']['entry'][0]['wallTime'])
self.assertEquals(u'1', json_output['entries']['entry'][0]['iters'])
self.assertEquals(u'1483228800', json_output['startTime'])
self.assertEquals('TestBenchmark', json_output['name'])
if __name__ == '__main__':
googletest.main()

View File

@ -0,0 +1,21 @@
# Tools for running distributed benchmarks.
licenses(["notice"]) # Apache 2.0
exports_files(["k8s_tensorflow.py"])
py_library(
name = "k8s_tensorflow_lib",
srcs = ["k8s_tensorflow_lib.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "k8s_tensorflow_test",
srcs = ["k8s_tensorflow_test.py"],
srcs_version = "PY2AND3",
deps = [
":k8s_tensorflow_lib",
"//tensorflow/python:client_testlib",
],
)

View File

@ -25,6 +25,8 @@ from __future__ import print_function
import argparse
import sys
import k8s_tensorflow_lib
# Note: It is intentional that we do not import tensorflow in this script. The
# machine that launches a TensorFlow k8s cluster does not have to have the
# Python package of TensorFlow installed on it.
@ -33,125 +35,6 @@ import sys
DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server'
DEFAULT_PORT = 2222
# TODO(cais): Consider adding resource requests/limits to the pods.
# Worker pods will mount host volume /shared, as a convenient way to create
# shared storage among workers during local tests.
WORKER_RC = (
"""apiVersion: v1
kind: ReplicationController
metadata:
name: {name_prefix}-worker{worker_id}
spec:
replicas: 1
template:
metadata:
labels:
tf-worker: "{worker_id}"
name-prefix: "{name_prefix}"
job: "worker"
spec:
containers:
- name: tf-worker{worker_id}
image: {docker_image}
args:
- --cluster_spec={cluster_spec}
- --job_name=worker
- --task_id={worker_id}
ports:
- containerPort: {port}
env:
- name: POD_NAME_PREFIX
value: {name_prefix}
volumeMounts: [{volume_mounts}]
volumes: [{volumes}]
""")
WORKER_SVC = (
"""apiVersion: v1
kind: Service
metadata:
name: {name_prefix}-worker{worker_id}
labels:
tf-worker: "{worker_id}"
spec:
ports:
- port: {port}
targetPort: {port}
selector:
tf-worker: "{worker_id}"
""")
WORKER_LB_SVC = (
"""apiVersion: v1
kind: Service
metadata:
name: {name_prefix}-worker{worker_id}
labels:
tf-worker: "{worker_id}"
spec:
type: LoadBalancer
ports:
- port: {port}
selector:
tf-worker: "{worker_id}"
""")
PARAM_SERVER_RC = (
"""apiVersion: v1
kind: ReplicationController
metadata:
name: {name_prefix}-ps{param_server_id}
spec:
replicas: 1
template:
metadata:
labels:
tf-ps: "{param_server_id}"
name-prefix: "{name_prefix}"
job: "ps"
spec:
containers:
- name: tf-ps{param_server_id}
image: {docker_image}
args:
- --cluster_spec={cluster_spec}
- --job_name=ps
- --task_id={param_server_id}
ports:
- containerPort: {port}
env:
- name: POD_NAME_PREFIX
value: {name_prefix}
volumeMounts: [{volume_mounts}]
volumes: [{volumes}]
""")
PARAM_SERVER_SVC = (
"""apiVersion: v1
kind: Service
metadata:
name: {name_prefix}-ps{param_server_id}
labels:
tf-ps: "{param_server_id}"
spec:
ports:
- port: {port}
selector:
tf-ps: "{param_server_id}"
""")
PARAM_LB_SVC = ("""apiVersion: v1
kind: Service
metadata:
name: {name_prefix}-ps{param_server_id}
labels:
tf-ps: "{param_server_id}"
spec:
type: LoadBalancer
ports:
- port: {port}
selector:
tf-ps: "{param_server_id}"
""")
VOLUME_MOUNTS = '{name: shared, mountPath: /shared}'
VOLUMES = '{name: shared, hostPath: {path: /shared}}'
def main():
"""Do arg parsing."""
@ -204,108 +87,17 @@ def main():
sys.exit(1)
# Generate contents of yaml config
yaml_config = GenerateConfig(args.num_workers,
args.num_parameter_servers,
args.grpc_port,
args.request_load_balancer,
args.docker_image,
args.name_prefix,
args.use_shared_volume)
yaml_config = k8s_tensorflow_lib.GenerateConfig(
args.num_workers,
args.num_parameter_servers,
args.grpc_port,
args.request_load_balancer,
args.docker_image,
args.name_prefix,
env_vars=None,
use_shared_volume=args.use_shared_volume)
print(yaml_config) # pylint: disable=superfluous-parens
def GenerateConfig(num_workers,
num_param_servers,
port,
request_load_balancer,
docker_image,
name_prefix,
use_shared_volume):
"""Generate configuration strings."""
config = ''
for worker in range(num_workers):
config += WORKER_RC.format(
port=port,
worker_id=worker,
docker_image=docker_image,
name_prefix=name_prefix,
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
volumes=VOLUMES if use_shared_volume else '',
cluster_spec=WorkerClusterSpecString(num_workers,
num_param_servers,
port,
name_prefix))
config += '---\n'
if request_load_balancer:
config += WORKER_LB_SVC.format(port=port,
worker_id=worker,
name_prefix=name_prefix)
else:
config += WORKER_SVC.format(port=port,
worker_id=worker,
name_prefix=name_prefix)
config += '---\n'
for param_server in range(num_param_servers):
config += PARAM_SERVER_RC.format(
port=port,
param_server_id=param_server,
docker_image=docker_image,
name_prefix=name_prefix,
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
volumes=VOLUMES if use_shared_volume else '',
cluster_spec=ParamServerClusterSpecString(num_workers,
num_param_servers,
port,
name_prefix))
config += '---\n'
if request_load_balancer:
config += PARAM_LB_SVC.format(
port=port, param_server_id=param_server, name_prefix=name_prefix)
else:
config += PARAM_SERVER_SVC.format(
port=port, param_server_id=param_server, name_prefix=name_prefix)
config += '---\n'
return config
def WorkerClusterSpecString(num_workers,
num_param_servers,
port,
name_prefix):
"""Generates worker cluster spec."""
return ClusterSpecString(num_workers, num_param_servers, port, name_prefix)
def ParamServerClusterSpecString(num_workers,
num_param_servers,
port,
name_prefix):
"""Generates parameter server spec."""
return ClusterSpecString(num_workers, num_param_servers, port,
name_prefix)
def ClusterSpecString(num_workers,
num_param_servers,
port,
name_prefix):
"""Generates general cluster spec."""
spec = 'worker|'
for worker in range(num_workers):
spec += '%s-worker%d:%d' % (name_prefix, worker, port)
if worker != num_workers-1:
spec += ';'
spec += ',ps|'
for param_server in range(num_param_servers):
spec += '%s-ps%d:%d' % (name_prefix, param_server, port)
if param_server != num_param_servers-1:
spec += ';'
return spec
if __name__ == '__main__':
main()

View File

@ -0,0 +1,309 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generates YAML configuration files for distributed TensorFlow workers.
The workers will be run in a Kubernetes (k8s) container cluster.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Note: It is intentional that we do not import tensorflow in this script. The
# machine that launches a TensorFlow k8s cluster does not have to have the
# Python package of TensorFlow installed on it.
# TODO(cais): Consider adding resource requests/limits to the pods.
# Worker pods will mount host volume /shared, as a convenient way to create
# shared storage among workers during local tests.
WORKER_RC = (
"""apiVersion: v1
kind: ReplicationController
metadata:
name: {name_prefix}-worker{worker_id}
spec:
replicas: 1
template:
metadata:
labels:
tf-worker: "{worker_id}"
name-prefix: "{name_prefix}"
job: "worker"
spec:
containers:
- name: tf-worker{worker_id}
image: {docker_image}
args: [{args}]
ports:
- containerPort: {port}
env: [{env_vars}]
volumeMounts: [{volume_mounts}]
volumes: [{volumes}]
""")
WORKER_SVC = (
"""apiVersion: v1
kind: Service
metadata:
name: {name_prefix}-worker{worker_id}
labels:
tf-worker: "{worker_id}"
spec:
ports:
- port: {port}
targetPort: {port}
selector:
tf-worker: "{worker_id}"
""")
WORKER_LB_SVC = (
"""apiVersion: v1
kind: Service
metadata:
name: {name_prefix}-worker{worker_id}
labels:
tf-worker: "{worker_id}"
spec:
type: LoadBalancer
ports:
- port: {port}
selector:
tf-worker: "{worker_id}"
""")
PARAM_SERVER_RC = (
"""apiVersion: v1
kind: ReplicationController
metadata:
name: {name_prefix}-ps{param_server_id}
spec:
replicas: 1
template:
metadata:
labels:
tf-ps: "{param_server_id}"
name-prefix: "{name_prefix}"
job: "ps"
spec:
containers:
- name: tf-ps{param_server_id}
image: {docker_image}
args: [{args}]
ports:
- containerPort: {port}
env: [{env_vars}]
volumeMounts: [{volume_mounts}]
volumes: [{volumes}]
""")
PARAM_SERVER_SVC = (
"""apiVersion: v1
kind: Service
metadata:
name: {name_prefix}-ps{param_server_id}
labels:
tf-ps: "{param_server_id}"
spec:
ports:
- port: {port}
selector:
tf-ps: "{param_server_id}"
""")
PARAM_LB_SVC = ("""apiVersion: v1
kind: Service
metadata:
name: {name_prefix}-ps{param_server_id}
labels:
tf-ps: "{param_server_id}"
spec:
type: LoadBalancer
ports:
- port: {port}
selector:
tf-ps: "{param_server_id}"
""")
VOLUME_MOUNTS = '{name: shared, mountPath: /shared}'
VOLUMES = '{name: shared, hostPath: {path: /shared}}'
_ENV_VAR_TEMPLATE = '{name: "%s", value: "%s"}'
_ARG_TEMPLATE = '"--%s=%s"'
def GenerateConfig(num_workers,
num_param_servers,
port,
request_load_balancer,
docker_image,
name_prefix,
env_vars=None,
use_shared_volume=True,
use_cluster_spec=True):
"""Generate configuration strings.
Args:
num_workers: number of worker jobs.
num_param_servers: number of ps server jobs.
port: GRPC server port.
request_load_balancer: request worker0 to be exposed on a public IP
address via an external load balancer.
docker_image: docker image to use.
name_prefix: name to prepend to pod job names.
env_vars: dictionary of environment variables to set.
use_shared_volume: whether to add hostPath to /shared directory
to the kubernetes config.
use_cluster_spec: if true, pass --cluster_spec to worker and ps jobs.
If false, pass --worker_hosts and --ps_hosts to worker and ps jobs.
Returns:
Kubernetes yaml config.
"""
if env_vars is None:
env_vars = {}
env_str = ', '.join([_ENV_VAR_TEMPLATE % (name, value)
for name, value in env_vars.items()])
config = ''
common_args = GetCommonArgs(
num_workers, num_param_servers, port, name_prefix, use_cluster_spec)
for worker in range(num_workers):
worker_args = {
'job_name': 'worker',
'task_id': worker
}
worker_args.update(common_args)
arg_str = ', '.join([_ARG_TEMPLATE % (name, value)
for name, value in worker_args.items()])
config += WORKER_RC.format(
port=port,
worker_id=worker,
docker_image=docker_image,
name_prefix=name_prefix,
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
volumes=VOLUMES if use_shared_volume else '',
args=arg_str,
env_vars=env_str)
config += '---\n'
if request_load_balancer:
config += WORKER_LB_SVC.format(port=port,
worker_id=worker,
name_prefix=name_prefix)
else:
config += WORKER_SVC.format(port=port,
worker_id=worker,
name_prefix=name_prefix)
config += '---\n'
for param_server in range(num_param_servers):
ps_args = {
'job_name': 'ps',
'task_id': param_server
}
ps_args.update(common_args)
arg_str = ', '.join([_ARG_TEMPLATE % (name, value)
for name, value in ps_args.items()])
config += PARAM_SERVER_RC.format(
port=port,
param_server_id=param_server,
docker_image=docker_image,
name_prefix=name_prefix,
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
volumes=VOLUMES if use_shared_volume else '',
args=arg_str,
env_vars=env_str)
config += '---\n'
if request_load_balancer:
config += PARAM_LB_SVC.format(
port=port, param_server_id=param_server, name_prefix=name_prefix)
else:
config += PARAM_SERVER_SVC.format(
port=port, param_server_id=param_server, name_prefix=name_prefix)
config += '---\n'
return config
def WorkerClusterSpecString(num_workers,
num_param_servers,
port,
name_prefix):
"""Generates worker cluster spec."""
return ClusterSpecString(num_workers, num_param_servers, port, name_prefix)
def ParamServerClusterSpecString(num_workers,
num_param_servers,
port,
name_prefix):
"""Generates parameter server spec."""
return ClusterSpecString(num_workers, num_param_servers, port,
name_prefix)
def ClusterSpecString(num_workers,
num_param_servers,
port,
name_prefix):
"""Generates general cluster spec."""
spec = 'worker|'
for worker in range(num_workers):
spec += '%s-worker%d:%d' % (name_prefix, worker, port)
if worker != num_workers-1:
spec += ';'
spec += ',ps|'
for param_server in range(num_param_servers):
spec += '%s-ps%d:%d' % (name_prefix, param_server, port)
if param_server != num_param_servers-1:
spec += ';'
return spec
def GetCommonArgs(num_workers,
num_param_servers,
port,
name_prefix,
use_cluster_spec):
"""Get arguments common to both worker and ps jobs.
Args:
num_workers: number of workers.
num_param_servers: number of ps servers.
port: worker and ps port number.
name_prefix: prefix to prepend to job names.
use_cluster_spec: if true, pass --cluster_spec argument.
If false, parse --worker_hosts and --ps_hosts arguments.
Returns:
A dictionary of argument names mapping to argument values.
"""
common_args = {}
if use_cluster_spec:
common_args['cluster_spec'] = WorkerClusterSpecString(
num_workers,
num_param_servers,
port,
name_prefix)
else:
common_args['worker_hosts'] = WorkerHosts(num_workers, port, name_prefix)
common_args['ps_hosts'] = PsHosts(num_param_servers, port, name_prefix)
return common_args
def WorkerHosts(num_workers, port, name_prefix):
worker_hosts = ['%s-worker%d:%d' % (name_prefix, i, port)
for i in range(num_workers)]
return ','.join(worker_hosts)
def PsHosts(num_ps, port, name_prefix):
ps_hosts = ['%s-ps%d:%d' % (name_prefix, i, port)
for i in range(num_ps)]
return ','.join(ps_hosts)

View File

@ -0,0 +1,132 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.tools.dist_test.scripts.k8s_tensorflow_lib."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.platform import googletest
from tensorflow.tools.dist_test.scripts import k8s_tensorflow_lib
class K8sTensorflowTest(googletest.TestCase):
def testGenerateConfig_LoadBalancer(self):
# Use loadbalancer
config = k8s_tensorflow_lib.GenerateConfig(
num_workers=1,
num_param_servers=1,
port=5000,
request_load_balancer=True,
docker_image='test_image',
name_prefix='abc',
use_shared_volume=False)
self.assertTrue('LoadBalancer' in config)
# Don't use loadbalancer
config = k8s_tensorflow_lib.GenerateConfig(
num_workers=1,
num_param_servers=1,
port=5000,
request_load_balancer=False,
docker_image='test_image',
name_prefix='abc',
use_shared_volume=False)
self.assertFalse('LoadBalancer' in config)
def testGenerateConfig_SharedVolume(self):
# Use shared directory
config = k8s_tensorflow_lib.GenerateConfig(
num_workers=1,
num_param_servers=1,
port=5000,
request_load_balancer=False,
docker_image='test_image',
name_prefix='abc',
use_shared_volume=True)
self.assertTrue('/shared' in config)
# Don't use shared directory
config = k8s_tensorflow_lib.GenerateConfig(
num_workers=1,
num_param_servers=1,
port=5000,
request_load_balancer=False,
docker_image='test_image',
name_prefix='abc',
use_shared_volume=False)
self.assertFalse('/shared' in config)
def testEnvVar(self):
# Use loadbalancer
config = k8s_tensorflow_lib.GenerateConfig(
num_workers=1,
num_param_servers=1,
port=5000,
request_load_balancer=True,
docker_image='test_image',
name_prefix='abc',
use_shared_volume=False,
env_vars={'test1': 'test1_value', 'test2': 'test2_value'})
self.assertTrue('{name: "test1", value: "test1_value"}' in config)
self.assertTrue('{name: "test2", value: "test2_value"}' in config)
def testClusterSpec(self):
# Use cluster_spec
config = k8s_tensorflow_lib.GenerateConfig(
num_workers=1,
num_param_servers=1,
port=5000,
request_load_balancer=True,
docker_image='test_image',
name_prefix='abc',
use_shared_volume=False,
use_cluster_spec=True)
self.assertFalse('worker_hosts' in config)
self.assertFalse('ps_hosts' in config)
self.assertTrue(
'"--cluster_spec=worker|abc-worker0:5000,ps|abc-ps0:5000"' in config)
# Don't use cluster_spec
config = k8s_tensorflow_lib.GenerateConfig(
num_workers=1,
num_param_servers=1,
port=5000,
request_load_balancer=True,
docker_image='test_image',
name_prefix='abc',
use_shared_volume=False,
use_cluster_spec=False)
self.assertFalse('cluster_spec' in config)
self.assertTrue('"--worker_hosts=abc-worker0:5000"' in config)
self.assertTrue('"--ps_hosts=abc-ps0:5000"' in config)
def testWorkerHosts(self):
self.assertEquals(
'test_prefix-worker0:1234',
k8s_tensorflow_lib.WorkerHosts(1, 1234, 'test_prefix'))
self.assertEquals(
'test_prefix-worker0:1234,test_prefix-worker1:1234',
k8s_tensorflow_lib.WorkerHosts(2, 1234, 'test_prefix'))
def testPsHosts(self):
self.assertEquals(
'test_prefix-ps0:1234,test_prefix-ps1:1234',
k8s_tensorflow_lib.PsHosts(2, 1234, 'test_prefix'))
if __name__ == '__main__':
googletest.main()