Split out most of k8s_tensorflow into a library and add a way to pass any
environment variables. Add benchmark_util library that would use environemnt variable to decide on a storage location. Change: 147890534
This commit is contained in:
parent
a6421c4dda
commit
b06281ba47
23
tensorflow/tools/dist_test/python/BUILD
Normal file
23
tensorflow/tools/dist_test/python/BUILD
Normal file
@ -0,0 +1,23 @@
|
||||
# Python tools for running distributed benchmarks.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
py_library(
|
||||
name = "benchmark_util_lib",
|
||||
srcs = ["benchmark_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "benchmark_util_test",
|
||||
srcs = ["benchmark_util_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":benchmark_util_lib",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
77
tensorflow/tools/dist_test/python/benchmark_util.py
Normal file
77
tensorflow/tools/dist_test/python/benchmark_util.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Provides helper functions for distributed benchmarks running on Jenkins."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import calendar
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
from google.protobuf import json_format
|
||||
|
||||
from tensorflow.core.util import test_log_pb2
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
|
||||
_OUTPUT_FILE_ENV_VAR = 'TF_DIST_BENCHMARK_RESULTS_FILE'
|
||||
_TEST_NAME_ENV_VAR = 'TF_DIST_BENCHMARK_NAME'
|
||||
|
||||
|
||||
# Represents a single timing entry where
|
||||
# - name is a string
|
||||
# - timing is the latency to track (for e.g. mean time per iter)
|
||||
# - iters is the number of iterations
|
||||
TimingEntry = namedtuple(
|
||||
'TimingEntry', ['name', 'timing', 'iters'])
|
||||
|
||||
|
||||
def store_data_in_json(timing_entries, start_time, output_file=None):
|
||||
"""Stores benchmark results in JSON format.
|
||||
|
||||
Args:
|
||||
timing_entries: list of TimingEntry objects.
|
||||
start_time: (datetime) start time of the test run.
|
||||
output_file: if specified, writes benchmark results to output_file.
|
||||
If not specified, writes results to the file specified by
|
||||
BENCHMARK_RESULTS_FILE environment variable.
|
||||
|
||||
Raises:
|
||||
ValueError: when neither output_file is passed in nor
|
||||
BENCHMARK_RESULTS_FILE is set.
|
||||
"""
|
||||
test_result = test_log_pb2.TestResults(
|
||||
start_time=calendar.timegm(start_time.timetuple()))
|
||||
if not output_file:
|
||||
if _OUTPUT_FILE_ENV_VAR not in os.environ:
|
||||
raise ValueError('Could not determine location to store results at.')
|
||||
output_file = os.environ[_OUTPUT_FILE_ENV_VAR]
|
||||
|
||||
with gfile.Open(output_file, 'wb') as jsonfile:
|
||||
if _TEST_NAME_ENV_VAR in os.environ:
|
||||
test_result.name = os.environ['POD_NAME_PREFIX']
|
||||
else:
|
||||
test_result.name = 'TestBenchmark'
|
||||
|
||||
for timing_entry in timing_entries:
|
||||
test_result.entries.entry.add(
|
||||
name=timing_entry.name,
|
||||
iters=timing_entry.iters,
|
||||
wall_time=timing_entry.timing
|
||||
)
|
||||
json_test_results = json_format.MessageToJson(test_result)
|
||||
jsonfile.write(json_test_results)
|
58
tensorflow/tools/dist_test/python/benchmark_util_test.py
Normal file
58
tensorflow/tools/dist_test/python/benchmark_util_test.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.tools.dist_test.python.benchmark_util."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.tools.dist_test.python import benchmark_util
|
||||
|
||||
|
||||
class BenchmarkUtilTest(googletest.TestCase):
|
||||
|
||||
def testStoreDataWithNoEntries(self):
|
||||
output_file = os.path.join(test.get_temp_dir(), 'test_output1.json')
|
||||
timing_entries = []
|
||||
benchmark_util.store_data_in_json(
|
||||
timing_entries, datetime.date(2017, 1, 1), output_file)
|
||||
json_output = json.loads(open(output_file, 'r').read())
|
||||
self.assertEquals('TestBenchmark', json_output['name'])
|
||||
self.assertEquals(u'1483228800', json_output['startTime'])
|
||||
|
||||
def testStoreDataWithEntries(self):
|
||||
output_file = os.path.join(test.get_temp_dir(), 'test_output2.json')
|
||||
timing_entries = [
|
||||
benchmark_util.TimingEntry('test', 0.1, 1)]
|
||||
benchmark_util.store_data_in_json(
|
||||
timing_entries, datetime.date(2017, 1, 1), output_file)
|
||||
json_output = json.loads(open(output_file, 'r').read())
|
||||
|
||||
self.assertEquals(1, len(json_output['entries']['entry']))
|
||||
self.assertEquals('test', json_output['entries']['entry'][0]['name'])
|
||||
self.assertEquals(0.1, json_output['entries']['entry'][0]['wallTime'])
|
||||
self.assertEquals(u'1', json_output['entries']['entry'][0]['iters'])
|
||||
self.assertEquals(u'1483228800', json_output['startTime'])
|
||||
self.assertEquals('TestBenchmark', json_output['name'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
21
tensorflow/tools/dist_test/scripts/BUILD
Normal file
21
tensorflow/tools/dist_test/scripts/BUILD
Normal file
@ -0,0 +1,21 @@
|
||||
# Tools for running distributed benchmarks.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["k8s_tensorflow.py"])
|
||||
|
||||
py_library(
|
||||
name = "k8s_tensorflow_lib",
|
||||
srcs = ["k8s_tensorflow_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "k8s_tensorflow_test",
|
||||
srcs = ["k8s_tensorflow_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":k8s_tensorflow_lib",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
@ -25,6 +25,8 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import k8s_tensorflow_lib
|
||||
|
||||
# Note: It is intentional that we do not import tensorflow in this script. The
|
||||
# machine that launches a TensorFlow k8s cluster does not have to have the
|
||||
# Python package of TensorFlow installed on it.
|
||||
@ -33,125 +35,6 @@ import sys
|
||||
DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server'
|
||||
DEFAULT_PORT = 2222
|
||||
|
||||
# TODO(cais): Consider adding resource requests/limits to the pods.
|
||||
|
||||
# Worker pods will mount host volume /shared, as a convenient way to create
|
||||
# shared storage among workers during local tests.
|
||||
WORKER_RC = (
|
||||
"""apiVersion: v1
|
||||
kind: ReplicationController
|
||||
metadata:
|
||||
name: {name_prefix}-worker{worker_id}
|
||||
spec:
|
||||
replicas: 1
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
tf-worker: "{worker_id}"
|
||||
name-prefix: "{name_prefix}"
|
||||
job: "worker"
|
||||
spec:
|
||||
containers:
|
||||
- name: tf-worker{worker_id}
|
||||
image: {docker_image}
|
||||
args:
|
||||
- --cluster_spec={cluster_spec}
|
||||
- --job_name=worker
|
||||
- --task_id={worker_id}
|
||||
ports:
|
||||
- containerPort: {port}
|
||||
env:
|
||||
- name: POD_NAME_PREFIX
|
||||
value: {name_prefix}
|
||||
volumeMounts: [{volume_mounts}]
|
||||
volumes: [{volumes}]
|
||||
""")
|
||||
WORKER_SVC = (
|
||||
"""apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {name_prefix}-worker{worker_id}
|
||||
labels:
|
||||
tf-worker: "{worker_id}"
|
||||
spec:
|
||||
ports:
|
||||
- port: {port}
|
||||
targetPort: {port}
|
||||
selector:
|
||||
tf-worker: "{worker_id}"
|
||||
""")
|
||||
WORKER_LB_SVC = (
|
||||
"""apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {name_prefix}-worker{worker_id}
|
||||
labels:
|
||||
tf-worker: "{worker_id}"
|
||||
spec:
|
||||
type: LoadBalancer
|
||||
ports:
|
||||
- port: {port}
|
||||
selector:
|
||||
tf-worker: "{worker_id}"
|
||||
""")
|
||||
PARAM_SERVER_RC = (
|
||||
"""apiVersion: v1
|
||||
kind: ReplicationController
|
||||
metadata:
|
||||
name: {name_prefix}-ps{param_server_id}
|
||||
spec:
|
||||
replicas: 1
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
tf-ps: "{param_server_id}"
|
||||
name-prefix: "{name_prefix}"
|
||||
job: "ps"
|
||||
spec:
|
||||
containers:
|
||||
- name: tf-ps{param_server_id}
|
||||
image: {docker_image}
|
||||
args:
|
||||
- --cluster_spec={cluster_spec}
|
||||
- --job_name=ps
|
||||
- --task_id={param_server_id}
|
||||
ports:
|
||||
- containerPort: {port}
|
||||
env:
|
||||
- name: POD_NAME_PREFIX
|
||||
value: {name_prefix}
|
||||
volumeMounts: [{volume_mounts}]
|
||||
volumes: [{volumes}]
|
||||
""")
|
||||
PARAM_SERVER_SVC = (
|
||||
"""apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {name_prefix}-ps{param_server_id}
|
||||
labels:
|
||||
tf-ps: "{param_server_id}"
|
||||
spec:
|
||||
ports:
|
||||
- port: {port}
|
||||
selector:
|
||||
tf-ps: "{param_server_id}"
|
||||
""")
|
||||
PARAM_LB_SVC = ("""apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {name_prefix}-ps{param_server_id}
|
||||
labels:
|
||||
tf-ps: "{param_server_id}"
|
||||
spec:
|
||||
type: LoadBalancer
|
||||
ports:
|
||||
- port: {port}
|
||||
selector:
|
||||
tf-ps: "{param_server_id}"
|
||||
""")
|
||||
VOLUME_MOUNTS = '{name: shared, mountPath: /shared}'
|
||||
VOLUMES = '{name: shared, hostPath: {path: /shared}}'
|
||||
|
||||
|
||||
def main():
|
||||
"""Do arg parsing."""
|
||||
@ -204,108 +87,17 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
# Generate contents of yaml config
|
||||
yaml_config = GenerateConfig(args.num_workers,
|
||||
args.num_parameter_servers,
|
||||
args.grpc_port,
|
||||
args.request_load_balancer,
|
||||
args.docker_image,
|
||||
args.name_prefix,
|
||||
args.use_shared_volume)
|
||||
yaml_config = k8s_tensorflow_lib.GenerateConfig(
|
||||
args.num_workers,
|
||||
args.num_parameter_servers,
|
||||
args.grpc_port,
|
||||
args.request_load_balancer,
|
||||
args.docker_image,
|
||||
args.name_prefix,
|
||||
env_vars=None,
|
||||
use_shared_volume=args.use_shared_volume)
|
||||
print(yaml_config) # pylint: disable=superfluous-parens
|
||||
|
||||
|
||||
def GenerateConfig(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
request_load_balancer,
|
||||
docker_image,
|
||||
name_prefix,
|
||||
use_shared_volume):
|
||||
"""Generate configuration strings."""
|
||||
config = ''
|
||||
for worker in range(num_workers):
|
||||
config += WORKER_RC.format(
|
||||
port=port,
|
||||
worker_id=worker,
|
||||
docker_image=docker_image,
|
||||
name_prefix=name_prefix,
|
||||
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
|
||||
volumes=VOLUMES if use_shared_volume else '',
|
||||
cluster_spec=WorkerClusterSpecString(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
name_prefix))
|
||||
config += '---\n'
|
||||
if request_load_balancer:
|
||||
config += WORKER_LB_SVC.format(port=port,
|
||||
worker_id=worker,
|
||||
name_prefix=name_prefix)
|
||||
else:
|
||||
config += WORKER_SVC.format(port=port,
|
||||
worker_id=worker,
|
||||
name_prefix=name_prefix)
|
||||
config += '---\n'
|
||||
|
||||
for param_server in range(num_param_servers):
|
||||
config += PARAM_SERVER_RC.format(
|
||||
port=port,
|
||||
param_server_id=param_server,
|
||||
docker_image=docker_image,
|
||||
name_prefix=name_prefix,
|
||||
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
|
||||
volumes=VOLUMES if use_shared_volume else '',
|
||||
cluster_spec=ParamServerClusterSpecString(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
name_prefix))
|
||||
config += '---\n'
|
||||
if request_load_balancer:
|
||||
config += PARAM_LB_SVC.format(
|
||||
port=port, param_server_id=param_server, name_prefix=name_prefix)
|
||||
else:
|
||||
config += PARAM_SERVER_SVC.format(
|
||||
port=port, param_server_id=param_server, name_prefix=name_prefix)
|
||||
config += '---\n'
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def WorkerClusterSpecString(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
name_prefix):
|
||||
"""Generates worker cluster spec."""
|
||||
return ClusterSpecString(num_workers, num_param_servers, port, name_prefix)
|
||||
|
||||
|
||||
def ParamServerClusterSpecString(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
name_prefix):
|
||||
"""Generates parameter server spec."""
|
||||
return ClusterSpecString(num_workers, num_param_servers, port,
|
||||
name_prefix)
|
||||
|
||||
|
||||
def ClusterSpecString(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
name_prefix):
|
||||
"""Generates general cluster spec."""
|
||||
spec = 'worker|'
|
||||
for worker in range(num_workers):
|
||||
spec += '%s-worker%d:%d' % (name_prefix, worker, port)
|
||||
if worker != num_workers-1:
|
||||
spec += ';'
|
||||
|
||||
spec += ',ps|'
|
||||
for param_server in range(num_param_servers):
|
||||
spec += '%s-ps%d:%d' % (name_prefix, param_server, port)
|
||||
if param_server != num_param_servers-1:
|
||||
spec += ';'
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
309
tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py
Normal file
309
tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py
Normal file
@ -0,0 +1,309 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Generates YAML configuration files for distributed TensorFlow workers.
|
||||
|
||||
The workers will be run in a Kubernetes (k8s) container cluster.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Note: It is intentional that we do not import tensorflow in this script. The
|
||||
# machine that launches a TensorFlow k8s cluster does not have to have the
|
||||
# Python package of TensorFlow installed on it.
|
||||
|
||||
# TODO(cais): Consider adding resource requests/limits to the pods.
|
||||
|
||||
# Worker pods will mount host volume /shared, as a convenient way to create
|
||||
# shared storage among workers during local tests.
|
||||
WORKER_RC = (
|
||||
"""apiVersion: v1
|
||||
kind: ReplicationController
|
||||
metadata:
|
||||
name: {name_prefix}-worker{worker_id}
|
||||
spec:
|
||||
replicas: 1
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
tf-worker: "{worker_id}"
|
||||
name-prefix: "{name_prefix}"
|
||||
job: "worker"
|
||||
spec:
|
||||
containers:
|
||||
- name: tf-worker{worker_id}
|
||||
image: {docker_image}
|
||||
args: [{args}]
|
||||
ports:
|
||||
- containerPort: {port}
|
||||
env: [{env_vars}]
|
||||
volumeMounts: [{volume_mounts}]
|
||||
volumes: [{volumes}]
|
||||
""")
|
||||
WORKER_SVC = (
|
||||
"""apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {name_prefix}-worker{worker_id}
|
||||
labels:
|
||||
tf-worker: "{worker_id}"
|
||||
spec:
|
||||
ports:
|
||||
- port: {port}
|
||||
targetPort: {port}
|
||||
selector:
|
||||
tf-worker: "{worker_id}"
|
||||
""")
|
||||
WORKER_LB_SVC = (
|
||||
"""apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {name_prefix}-worker{worker_id}
|
||||
labels:
|
||||
tf-worker: "{worker_id}"
|
||||
spec:
|
||||
type: LoadBalancer
|
||||
ports:
|
||||
- port: {port}
|
||||
selector:
|
||||
tf-worker: "{worker_id}"
|
||||
""")
|
||||
PARAM_SERVER_RC = (
|
||||
"""apiVersion: v1
|
||||
kind: ReplicationController
|
||||
metadata:
|
||||
name: {name_prefix}-ps{param_server_id}
|
||||
spec:
|
||||
replicas: 1
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
tf-ps: "{param_server_id}"
|
||||
name-prefix: "{name_prefix}"
|
||||
job: "ps"
|
||||
spec:
|
||||
containers:
|
||||
- name: tf-ps{param_server_id}
|
||||
image: {docker_image}
|
||||
args: [{args}]
|
||||
ports:
|
||||
- containerPort: {port}
|
||||
env: [{env_vars}]
|
||||
volumeMounts: [{volume_mounts}]
|
||||
volumes: [{volumes}]
|
||||
""")
|
||||
PARAM_SERVER_SVC = (
|
||||
"""apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {name_prefix}-ps{param_server_id}
|
||||
labels:
|
||||
tf-ps: "{param_server_id}"
|
||||
spec:
|
||||
ports:
|
||||
- port: {port}
|
||||
selector:
|
||||
tf-ps: "{param_server_id}"
|
||||
""")
|
||||
PARAM_LB_SVC = ("""apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {name_prefix}-ps{param_server_id}
|
||||
labels:
|
||||
tf-ps: "{param_server_id}"
|
||||
spec:
|
||||
type: LoadBalancer
|
||||
ports:
|
||||
- port: {port}
|
||||
selector:
|
||||
tf-ps: "{param_server_id}"
|
||||
""")
|
||||
VOLUME_MOUNTS = '{name: shared, mountPath: /shared}'
|
||||
VOLUMES = '{name: shared, hostPath: {path: /shared}}'
|
||||
_ENV_VAR_TEMPLATE = '{name: "%s", value: "%s"}'
|
||||
_ARG_TEMPLATE = '"--%s=%s"'
|
||||
|
||||
|
||||
def GenerateConfig(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
request_load_balancer,
|
||||
docker_image,
|
||||
name_prefix,
|
||||
env_vars=None,
|
||||
use_shared_volume=True,
|
||||
use_cluster_spec=True):
|
||||
"""Generate configuration strings.
|
||||
|
||||
Args:
|
||||
num_workers: number of worker jobs.
|
||||
num_param_servers: number of ps server jobs.
|
||||
port: GRPC server port.
|
||||
request_load_balancer: request worker0 to be exposed on a public IP
|
||||
address via an external load balancer.
|
||||
docker_image: docker image to use.
|
||||
name_prefix: name to prepend to pod job names.
|
||||
env_vars: dictionary of environment variables to set.
|
||||
use_shared_volume: whether to add hostPath to /shared directory
|
||||
to the kubernetes config.
|
||||
use_cluster_spec: if true, pass --cluster_spec to worker and ps jobs.
|
||||
If false, pass --worker_hosts and --ps_hosts to worker and ps jobs.
|
||||
|
||||
Returns:
|
||||
Kubernetes yaml config.
|
||||
"""
|
||||
if env_vars is None:
|
||||
env_vars = {}
|
||||
env_str = ', '.join([_ENV_VAR_TEMPLATE % (name, value)
|
||||
for name, value in env_vars.items()])
|
||||
config = ''
|
||||
common_args = GetCommonArgs(
|
||||
num_workers, num_param_servers, port, name_prefix, use_cluster_spec)
|
||||
for worker in range(num_workers):
|
||||
worker_args = {
|
||||
'job_name': 'worker',
|
||||
'task_id': worker
|
||||
}
|
||||
worker_args.update(common_args)
|
||||
arg_str = ', '.join([_ARG_TEMPLATE % (name, value)
|
||||
for name, value in worker_args.items()])
|
||||
config += WORKER_RC.format(
|
||||
port=port,
|
||||
worker_id=worker,
|
||||
docker_image=docker_image,
|
||||
name_prefix=name_prefix,
|
||||
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
|
||||
volumes=VOLUMES if use_shared_volume else '',
|
||||
args=arg_str,
|
||||
env_vars=env_str)
|
||||
config += '---\n'
|
||||
if request_load_balancer:
|
||||
config += WORKER_LB_SVC.format(port=port,
|
||||
worker_id=worker,
|
||||
name_prefix=name_prefix)
|
||||
else:
|
||||
config += WORKER_SVC.format(port=port,
|
||||
worker_id=worker,
|
||||
name_prefix=name_prefix)
|
||||
config += '---\n'
|
||||
|
||||
for param_server in range(num_param_servers):
|
||||
ps_args = {
|
||||
'job_name': 'ps',
|
||||
'task_id': param_server
|
||||
}
|
||||
ps_args.update(common_args)
|
||||
arg_str = ', '.join([_ARG_TEMPLATE % (name, value)
|
||||
for name, value in ps_args.items()])
|
||||
config += PARAM_SERVER_RC.format(
|
||||
port=port,
|
||||
param_server_id=param_server,
|
||||
docker_image=docker_image,
|
||||
name_prefix=name_prefix,
|
||||
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
|
||||
volumes=VOLUMES if use_shared_volume else '',
|
||||
args=arg_str,
|
||||
env_vars=env_str)
|
||||
config += '---\n'
|
||||
if request_load_balancer:
|
||||
config += PARAM_LB_SVC.format(
|
||||
port=port, param_server_id=param_server, name_prefix=name_prefix)
|
||||
else:
|
||||
config += PARAM_SERVER_SVC.format(
|
||||
port=port, param_server_id=param_server, name_prefix=name_prefix)
|
||||
config += '---\n'
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def WorkerClusterSpecString(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
name_prefix):
|
||||
"""Generates worker cluster spec."""
|
||||
return ClusterSpecString(num_workers, num_param_servers, port, name_prefix)
|
||||
|
||||
|
||||
def ParamServerClusterSpecString(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
name_prefix):
|
||||
"""Generates parameter server spec."""
|
||||
return ClusterSpecString(num_workers, num_param_servers, port,
|
||||
name_prefix)
|
||||
|
||||
|
||||
def ClusterSpecString(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
name_prefix):
|
||||
"""Generates general cluster spec."""
|
||||
spec = 'worker|'
|
||||
for worker in range(num_workers):
|
||||
spec += '%s-worker%d:%d' % (name_prefix, worker, port)
|
||||
if worker != num_workers-1:
|
||||
spec += ';'
|
||||
|
||||
spec += ',ps|'
|
||||
for param_server in range(num_param_servers):
|
||||
spec += '%s-ps%d:%d' % (name_prefix, param_server, port)
|
||||
if param_server != num_param_servers-1:
|
||||
spec += ';'
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def GetCommonArgs(num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
name_prefix,
|
||||
use_cluster_spec):
|
||||
"""Get arguments common to both worker and ps jobs.
|
||||
|
||||
Args:
|
||||
num_workers: number of workers.
|
||||
num_param_servers: number of ps servers.
|
||||
port: worker and ps port number.
|
||||
name_prefix: prefix to prepend to job names.
|
||||
use_cluster_spec: if true, pass --cluster_spec argument.
|
||||
If false, parse --worker_hosts and --ps_hosts arguments.
|
||||
|
||||
Returns:
|
||||
A dictionary of argument names mapping to argument values.
|
||||
"""
|
||||
common_args = {}
|
||||
if use_cluster_spec:
|
||||
common_args['cluster_spec'] = WorkerClusterSpecString(
|
||||
num_workers,
|
||||
num_param_servers,
|
||||
port,
|
||||
name_prefix)
|
||||
else:
|
||||
common_args['worker_hosts'] = WorkerHosts(num_workers, port, name_prefix)
|
||||
common_args['ps_hosts'] = PsHosts(num_param_servers, port, name_prefix)
|
||||
return common_args
|
||||
|
||||
|
||||
def WorkerHosts(num_workers, port, name_prefix):
|
||||
worker_hosts = ['%s-worker%d:%d' % (name_prefix, i, port)
|
||||
for i in range(num_workers)]
|
||||
return ','.join(worker_hosts)
|
||||
|
||||
|
||||
def PsHosts(num_ps, port, name_prefix):
|
||||
ps_hosts = ['%s-ps%d:%d' % (name_prefix, i, port)
|
||||
for i in range(num_ps)]
|
||||
return ','.join(ps_hosts)
|
132
tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py
Normal file
132
tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py
Normal file
@ -0,0 +1,132 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.tools.dist_test.scripts.k8s_tensorflow_lib."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.tools.dist_test.scripts import k8s_tensorflow_lib
|
||||
|
||||
|
||||
class K8sTensorflowTest(googletest.TestCase):
|
||||
|
||||
def testGenerateConfig_LoadBalancer(self):
|
||||
# Use loadbalancer
|
||||
config = k8s_tensorflow_lib.GenerateConfig(
|
||||
num_workers=1,
|
||||
num_param_servers=1,
|
||||
port=5000,
|
||||
request_load_balancer=True,
|
||||
docker_image='test_image',
|
||||
name_prefix='abc',
|
||||
use_shared_volume=False)
|
||||
self.assertTrue('LoadBalancer' in config)
|
||||
|
||||
# Don't use loadbalancer
|
||||
config = k8s_tensorflow_lib.GenerateConfig(
|
||||
num_workers=1,
|
||||
num_param_servers=1,
|
||||
port=5000,
|
||||
request_load_balancer=False,
|
||||
docker_image='test_image',
|
||||
name_prefix='abc',
|
||||
use_shared_volume=False)
|
||||
self.assertFalse('LoadBalancer' in config)
|
||||
|
||||
def testGenerateConfig_SharedVolume(self):
|
||||
# Use shared directory
|
||||
config = k8s_tensorflow_lib.GenerateConfig(
|
||||
num_workers=1,
|
||||
num_param_servers=1,
|
||||
port=5000,
|
||||
request_load_balancer=False,
|
||||
docker_image='test_image',
|
||||
name_prefix='abc',
|
||||
use_shared_volume=True)
|
||||
self.assertTrue('/shared' in config)
|
||||
|
||||
# Don't use shared directory
|
||||
config = k8s_tensorflow_lib.GenerateConfig(
|
||||
num_workers=1,
|
||||
num_param_servers=1,
|
||||
port=5000,
|
||||
request_load_balancer=False,
|
||||
docker_image='test_image',
|
||||
name_prefix='abc',
|
||||
use_shared_volume=False)
|
||||
self.assertFalse('/shared' in config)
|
||||
|
||||
def testEnvVar(self):
|
||||
# Use loadbalancer
|
||||
config = k8s_tensorflow_lib.GenerateConfig(
|
||||
num_workers=1,
|
||||
num_param_servers=1,
|
||||
port=5000,
|
||||
request_load_balancer=True,
|
||||
docker_image='test_image',
|
||||
name_prefix='abc',
|
||||
use_shared_volume=False,
|
||||
env_vars={'test1': 'test1_value', 'test2': 'test2_value'})
|
||||
self.assertTrue('{name: "test1", value: "test1_value"}' in config)
|
||||
self.assertTrue('{name: "test2", value: "test2_value"}' in config)
|
||||
|
||||
def testClusterSpec(self):
|
||||
# Use cluster_spec
|
||||
config = k8s_tensorflow_lib.GenerateConfig(
|
||||
num_workers=1,
|
||||
num_param_servers=1,
|
||||
port=5000,
|
||||
request_load_balancer=True,
|
||||
docker_image='test_image',
|
||||
name_prefix='abc',
|
||||
use_shared_volume=False,
|
||||
use_cluster_spec=True)
|
||||
self.assertFalse('worker_hosts' in config)
|
||||
self.assertFalse('ps_hosts' in config)
|
||||
self.assertTrue(
|
||||
'"--cluster_spec=worker|abc-worker0:5000,ps|abc-ps0:5000"' in config)
|
||||
|
||||
# Don't use cluster_spec
|
||||
config = k8s_tensorflow_lib.GenerateConfig(
|
||||
num_workers=1,
|
||||
num_param_servers=1,
|
||||
port=5000,
|
||||
request_load_balancer=True,
|
||||
docker_image='test_image',
|
||||
name_prefix='abc',
|
||||
use_shared_volume=False,
|
||||
use_cluster_spec=False)
|
||||
self.assertFalse('cluster_spec' in config)
|
||||
self.assertTrue('"--worker_hosts=abc-worker0:5000"' in config)
|
||||
self.assertTrue('"--ps_hosts=abc-ps0:5000"' in config)
|
||||
|
||||
def testWorkerHosts(self):
|
||||
self.assertEquals(
|
||||
'test_prefix-worker0:1234',
|
||||
k8s_tensorflow_lib.WorkerHosts(1, 1234, 'test_prefix'))
|
||||
self.assertEquals(
|
||||
'test_prefix-worker0:1234,test_prefix-worker1:1234',
|
||||
k8s_tensorflow_lib.WorkerHosts(2, 1234, 'test_prefix'))
|
||||
|
||||
def testPsHosts(self):
|
||||
self.assertEquals(
|
||||
'test_prefix-ps0:1234,test_prefix-ps1:1234',
|
||||
k8s_tensorflow_lib.PsHosts(2, 1234, 'test_prefix'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
Loading…
x
Reference in New Issue
Block a user