Split out most of k8s_tensorflow into a library and add a way to pass any
environment variables. Add benchmark_util library that would use environemnt variable to decide on a storage location. Change: 147890534
This commit is contained in:
parent
a6421c4dda
commit
b06281ba47
23
tensorflow/tools/dist_test/python/BUILD
Normal file
23
tensorflow/tools/dist_test/python/BUILD
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# Python tools for running distributed benchmarks.
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "benchmark_util_lib",
|
||||||
|
srcs = ["benchmark_util.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "benchmark_util_test",
|
||||||
|
srcs = ["benchmark_util_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":benchmark_util_lib",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
77
tensorflow/tools/dist_test/python/benchmark_util.py
Normal file
77
tensorflow/tools/dist_test/python/benchmark_util.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Provides helper functions for distributed benchmarks running on Jenkins."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import calendar
|
||||||
|
from collections import namedtuple
|
||||||
|
import os
|
||||||
|
|
||||||
|
from google.protobuf import json_format
|
||||||
|
|
||||||
|
from tensorflow.core.util import test_log_pb2
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
|
|
||||||
|
|
||||||
|
_OUTPUT_FILE_ENV_VAR = 'TF_DIST_BENCHMARK_RESULTS_FILE'
|
||||||
|
_TEST_NAME_ENV_VAR = 'TF_DIST_BENCHMARK_NAME'
|
||||||
|
|
||||||
|
|
||||||
|
# Represents a single timing entry where
|
||||||
|
# - name is a string
|
||||||
|
# - timing is the latency to track (for e.g. mean time per iter)
|
||||||
|
# - iters is the number of iterations
|
||||||
|
TimingEntry = namedtuple(
|
||||||
|
'TimingEntry', ['name', 'timing', 'iters'])
|
||||||
|
|
||||||
|
|
||||||
|
def store_data_in_json(timing_entries, start_time, output_file=None):
|
||||||
|
"""Stores benchmark results in JSON format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timing_entries: list of TimingEntry objects.
|
||||||
|
start_time: (datetime) start time of the test run.
|
||||||
|
output_file: if specified, writes benchmark results to output_file.
|
||||||
|
If not specified, writes results to the file specified by
|
||||||
|
BENCHMARK_RESULTS_FILE environment variable.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: when neither output_file is passed in nor
|
||||||
|
BENCHMARK_RESULTS_FILE is set.
|
||||||
|
"""
|
||||||
|
test_result = test_log_pb2.TestResults(
|
||||||
|
start_time=calendar.timegm(start_time.timetuple()))
|
||||||
|
if not output_file:
|
||||||
|
if _OUTPUT_FILE_ENV_VAR not in os.environ:
|
||||||
|
raise ValueError('Could not determine location to store results at.')
|
||||||
|
output_file = os.environ[_OUTPUT_FILE_ENV_VAR]
|
||||||
|
|
||||||
|
with gfile.Open(output_file, 'wb') as jsonfile:
|
||||||
|
if _TEST_NAME_ENV_VAR in os.environ:
|
||||||
|
test_result.name = os.environ['POD_NAME_PREFIX']
|
||||||
|
else:
|
||||||
|
test_result.name = 'TestBenchmark'
|
||||||
|
|
||||||
|
for timing_entry in timing_entries:
|
||||||
|
test_result.entries.entry.add(
|
||||||
|
name=timing_entry.name,
|
||||||
|
iters=timing_entry.iters,
|
||||||
|
wall_time=timing_entry.timing
|
||||||
|
)
|
||||||
|
json_test_results = json_format.MessageToJson(test_result)
|
||||||
|
jsonfile.write(json_test_results)
|
58
tensorflow/tools/dist_test/python/benchmark_util_test.py
Normal file
58
tensorflow/tools/dist_test/python/benchmark_util_test.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.tools.dist_test.python.benchmark_util."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tensorflow.python.platform import googletest
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.tools.dist_test.python import benchmark_util
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkUtilTest(googletest.TestCase):
|
||||||
|
|
||||||
|
def testStoreDataWithNoEntries(self):
|
||||||
|
output_file = os.path.join(test.get_temp_dir(), 'test_output1.json')
|
||||||
|
timing_entries = []
|
||||||
|
benchmark_util.store_data_in_json(
|
||||||
|
timing_entries, datetime.date(2017, 1, 1), output_file)
|
||||||
|
json_output = json.loads(open(output_file, 'r').read())
|
||||||
|
self.assertEquals('TestBenchmark', json_output['name'])
|
||||||
|
self.assertEquals(u'1483228800', json_output['startTime'])
|
||||||
|
|
||||||
|
def testStoreDataWithEntries(self):
|
||||||
|
output_file = os.path.join(test.get_temp_dir(), 'test_output2.json')
|
||||||
|
timing_entries = [
|
||||||
|
benchmark_util.TimingEntry('test', 0.1, 1)]
|
||||||
|
benchmark_util.store_data_in_json(
|
||||||
|
timing_entries, datetime.date(2017, 1, 1), output_file)
|
||||||
|
json_output = json.loads(open(output_file, 'r').read())
|
||||||
|
|
||||||
|
self.assertEquals(1, len(json_output['entries']['entry']))
|
||||||
|
self.assertEquals('test', json_output['entries']['entry'][0]['name'])
|
||||||
|
self.assertEquals(0.1, json_output['entries']['entry'][0]['wallTime'])
|
||||||
|
self.assertEquals(u'1', json_output['entries']['entry'][0]['iters'])
|
||||||
|
self.assertEquals(u'1483228800', json_output['startTime'])
|
||||||
|
self.assertEquals('TestBenchmark', json_output['name'])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
googletest.main()
|
21
tensorflow/tools/dist_test/scripts/BUILD
Normal file
21
tensorflow/tools/dist_test/scripts/BUILD
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Tools for running distributed benchmarks.
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["k8s_tensorflow.py"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "k8s_tensorflow_lib",
|
||||||
|
srcs = ["k8s_tensorflow_lib.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "k8s_tensorflow_test",
|
||||||
|
srcs = ["k8s_tensorflow_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":k8s_tensorflow_lib",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
@ -25,6 +25,8 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import k8s_tensorflow_lib
|
||||||
|
|
||||||
# Note: It is intentional that we do not import tensorflow in this script. The
|
# Note: It is intentional that we do not import tensorflow in this script. The
|
||||||
# machine that launches a TensorFlow k8s cluster does not have to have the
|
# machine that launches a TensorFlow k8s cluster does not have to have the
|
||||||
# Python package of TensorFlow installed on it.
|
# Python package of TensorFlow installed on it.
|
||||||
@ -33,125 +35,6 @@ import sys
|
|||||||
DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server'
|
DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server'
|
||||||
DEFAULT_PORT = 2222
|
DEFAULT_PORT = 2222
|
||||||
|
|
||||||
# TODO(cais): Consider adding resource requests/limits to the pods.
|
|
||||||
|
|
||||||
# Worker pods will mount host volume /shared, as a convenient way to create
|
|
||||||
# shared storage among workers during local tests.
|
|
||||||
WORKER_RC = (
|
|
||||||
"""apiVersion: v1
|
|
||||||
kind: ReplicationController
|
|
||||||
metadata:
|
|
||||||
name: {name_prefix}-worker{worker_id}
|
|
||||||
spec:
|
|
||||||
replicas: 1
|
|
||||||
template:
|
|
||||||
metadata:
|
|
||||||
labels:
|
|
||||||
tf-worker: "{worker_id}"
|
|
||||||
name-prefix: "{name_prefix}"
|
|
||||||
job: "worker"
|
|
||||||
spec:
|
|
||||||
containers:
|
|
||||||
- name: tf-worker{worker_id}
|
|
||||||
image: {docker_image}
|
|
||||||
args:
|
|
||||||
- --cluster_spec={cluster_spec}
|
|
||||||
- --job_name=worker
|
|
||||||
- --task_id={worker_id}
|
|
||||||
ports:
|
|
||||||
- containerPort: {port}
|
|
||||||
env:
|
|
||||||
- name: POD_NAME_PREFIX
|
|
||||||
value: {name_prefix}
|
|
||||||
volumeMounts: [{volume_mounts}]
|
|
||||||
volumes: [{volumes}]
|
|
||||||
""")
|
|
||||||
WORKER_SVC = (
|
|
||||||
"""apiVersion: v1
|
|
||||||
kind: Service
|
|
||||||
metadata:
|
|
||||||
name: {name_prefix}-worker{worker_id}
|
|
||||||
labels:
|
|
||||||
tf-worker: "{worker_id}"
|
|
||||||
spec:
|
|
||||||
ports:
|
|
||||||
- port: {port}
|
|
||||||
targetPort: {port}
|
|
||||||
selector:
|
|
||||||
tf-worker: "{worker_id}"
|
|
||||||
""")
|
|
||||||
WORKER_LB_SVC = (
|
|
||||||
"""apiVersion: v1
|
|
||||||
kind: Service
|
|
||||||
metadata:
|
|
||||||
name: {name_prefix}-worker{worker_id}
|
|
||||||
labels:
|
|
||||||
tf-worker: "{worker_id}"
|
|
||||||
spec:
|
|
||||||
type: LoadBalancer
|
|
||||||
ports:
|
|
||||||
- port: {port}
|
|
||||||
selector:
|
|
||||||
tf-worker: "{worker_id}"
|
|
||||||
""")
|
|
||||||
PARAM_SERVER_RC = (
|
|
||||||
"""apiVersion: v1
|
|
||||||
kind: ReplicationController
|
|
||||||
metadata:
|
|
||||||
name: {name_prefix}-ps{param_server_id}
|
|
||||||
spec:
|
|
||||||
replicas: 1
|
|
||||||
template:
|
|
||||||
metadata:
|
|
||||||
labels:
|
|
||||||
tf-ps: "{param_server_id}"
|
|
||||||
name-prefix: "{name_prefix}"
|
|
||||||
job: "ps"
|
|
||||||
spec:
|
|
||||||
containers:
|
|
||||||
- name: tf-ps{param_server_id}
|
|
||||||
image: {docker_image}
|
|
||||||
args:
|
|
||||||
- --cluster_spec={cluster_spec}
|
|
||||||
- --job_name=ps
|
|
||||||
- --task_id={param_server_id}
|
|
||||||
ports:
|
|
||||||
- containerPort: {port}
|
|
||||||
env:
|
|
||||||
- name: POD_NAME_PREFIX
|
|
||||||
value: {name_prefix}
|
|
||||||
volumeMounts: [{volume_mounts}]
|
|
||||||
volumes: [{volumes}]
|
|
||||||
""")
|
|
||||||
PARAM_SERVER_SVC = (
|
|
||||||
"""apiVersion: v1
|
|
||||||
kind: Service
|
|
||||||
metadata:
|
|
||||||
name: {name_prefix}-ps{param_server_id}
|
|
||||||
labels:
|
|
||||||
tf-ps: "{param_server_id}"
|
|
||||||
spec:
|
|
||||||
ports:
|
|
||||||
- port: {port}
|
|
||||||
selector:
|
|
||||||
tf-ps: "{param_server_id}"
|
|
||||||
""")
|
|
||||||
PARAM_LB_SVC = ("""apiVersion: v1
|
|
||||||
kind: Service
|
|
||||||
metadata:
|
|
||||||
name: {name_prefix}-ps{param_server_id}
|
|
||||||
labels:
|
|
||||||
tf-ps: "{param_server_id}"
|
|
||||||
spec:
|
|
||||||
type: LoadBalancer
|
|
||||||
ports:
|
|
||||||
- port: {port}
|
|
||||||
selector:
|
|
||||||
tf-ps: "{param_server_id}"
|
|
||||||
""")
|
|
||||||
VOLUME_MOUNTS = '{name: shared, mountPath: /shared}'
|
|
||||||
VOLUMES = '{name: shared, hostPath: {path: /shared}}'
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Do arg parsing."""
|
"""Do arg parsing."""
|
||||||
@ -204,108 +87,17 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Generate contents of yaml config
|
# Generate contents of yaml config
|
||||||
yaml_config = GenerateConfig(args.num_workers,
|
yaml_config = k8s_tensorflow_lib.GenerateConfig(
|
||||||
args.num_parameter_servers,
|
args.num_workers,
|
||||||
args.grpc_port,
|
args.num_parameter_servers,
|
||||||
args.request_load_balancer,
|
args.grpc_port,
|
||||||
args.docker_image,
|
args.request_load_balancer,
|
||||||
args.name_prefix,
|
args.docker_image,
|
||||||
args.use_shared_volume)
|
args.name_prefix,
|
||||||
|
env_vars=None,
|
||||||
|
use_shared_volume=args.use_shared_volume)
|
||||||
print(yaml_config) # pylint: disable=superfluous-parens
|
print(yaml_config) # pylint: disable=superfluous-parens
|
||||||
|
|
||||||
|
|
||||||
def GenerateConfig(num_workers,
|
|
||||||
num_param_servers,
|
|
||||||
port,
|
|
||||||
request_load_balancer,
|
|
||||||
docker_image,
|
|
||||||
name_prefix,
|
|
||||||
use_shared_volume):
|
|
||||||
"""Generate configuration strings."""
|
|
||||||
config = ''
|
|
||||||
for worker in range(num_workers):
|
|
||||||
config += WORKER_RC.format(
|
|
||||||
port=port,
|
|
||||||
worker_id=worker,
|
|
||||||
docker_image=docker_image,
|
|
||||||
name_prefix=name_prefix,
|
|
||||||
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
|
|
||||||
volumes=VOLUMES if use_shared_volume else '',
|
|
||||||
cluster_spec=WorkerClusterSpecString(num_workers,
|
|
||||||
num_param_servers,
|
|
||||||
port,
|
|
||||||
name_prefix))
|
|
||||||
config += '---\n'
|
|
||||||
if request_load_balancer:
|
|
||||||
config += WORKER_LB_SVC.format(port=port,
|
|
||||||
worker_id=worker,
|
|
||||||
name_prefix=name_prefix)
|
|
||||||
else:
|
|
||||||
config += WORKER_SVC.format(port=port,
|
|
||||||
worker_id=worker,
|
|
||||||
name_prefix=name_prefix)
|
|
||||||
config += '---\n'
|
|
||||||
|
|
||||||
for param_server in range(num_param_servers):
|
|
||||||
config += PARAM_SERVER_RC.format(
|
|
||||||
port=port,
|
|
||||||
param_server_id=param_server,
|
|
||||||
docker_image=docker_image,
|
|
||||||
name_prefix=name_prefix,
|
|
||||||
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
|
|
||||||
volumes=VOLUMES if use_shared_volume else '',
|
|
||||||
cluster_spec=ParamServerClusterSpecString(num_workers,
|
|
||||||
num_param_servers,
|
|
||||||
port,
|
|
||||||
name_prefix))
|
|
||||||
config += '---\n'
|
|
||||||
if request_load_balancer:
|
|
||||||
config += PARAM_LB_SVC.format(
|
|
||||||
port=port, param_server_id=param_server, name_prefix=name_prefix)
|
|
||||||
else:
|
|
||||||
config += PARAM_SERVER_SVC.format(
|
|
||||||
port=port, param_server_id=param_server, name_prefix=name_prefix)
|
|
||||||
config += '---\n'
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def WorkerClusterSpecString(num_workers,
|
|
||||||
num_param_servers,
|
|
||||||
port,
|
|
||||||
name_prefix):
|
|
||||||
"""Generates worker cluster spec."""
|
|
||||||
return ClusterSpecString(num_workers, num_param_servers, port, name_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
def ParamServerClusterSpecString(num_workers,
|
|
||||||
num_param_servers,
|
|
||||||
port,
|
|
||||||
name_prefix):
|
|
||||||
"""Generates parameter server spec."""
|
|
||||||
return ClusterSpecString(num_workers, num_param_servers, port,
|
|
||||||
name_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
def ClusterSpecString(num_workers,
|
|
||||||
num_param_servers,
|
|
||||||
port,
|
|
||||||
name_prefix):
|
|
||||||
"""Generates general cluster spec."""
|
|
||||||
spec = 'worker|'
|
|
||||||
for worker in range(num_workers):
|
|
||||||
spec += '%s-worker%d:%d' % (name_prefix, worker, port)
|
|
||||||
if worker != num_workers-1:
|
|
||||||
spec += ';'
|
|
||||||
|
|
||||||
spec += ',ps|'
|
|
||||||
for param_server in range(num_param_servers):
|
|
||||||
spec += '%s-ps%d:%d' % (name_prefix, param_server, port)
|
|
||||||
if param_server != num_param_servers-1:
|
|
||||||
spec += ';'
|
|
||||||
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
309
tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py
Normal file
309
tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Generates YAML configuration files for distributed TensorFlow workers.
|
||||||
|
|
||||||
|
The workers will be run in a Kubernetes (k8s) container cluster.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Note: It is intentional that we do not import tensorflow in this script. The
|
||||||
|
# machine that launches a TensorFlow k8s cluster does not have to have the
|
||||||
|
# Python package of TensorFlow installed on it.
|
||||||
|
|
||||||
|
# TODO(cais): Consider adding resource requests/limits to the pods.
|
||||||
|
|
||||||
|
# Worker pods will mount host volume /shared, as a convenient way to create
|
||||||
|
# shared storage among workers during local tests.
|
||||||
|
WORKER_RC = (
|
||||||
|
"""apiVersion: v1
|
||||||
|
kind: ReplicationController
|
||||||
|
metadata:
|
||||||
|
name: {name_prefix}-worker{worker_id}
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
tf-worker: "{worker_id}"
|
||||||
|
name-prefix: "{name_prefix}"
|
||||||
|
job: "worker"
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: tf-worker{worker_id}
|
||||||
|
image: {docker_image}
|
||||||
|
args: [{args}]
|
||||||
|
ports:
|
||||||
|
- containerPort: {port}
|
||||||
|
env: [{env_vars}]
|
||||||
|
volumeMounts: [{volume_mounts}]
|
||||||
|
volumes: [{volumes}]
|
||||||
|
""")
|
||||||
|
WORKER_SVC = (
|
||||||
|
"""apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: {name_prefix}-worker{worker_id}
|
||||||
|
labels:
|
||||||
|
tf-worker: "{worker_id}"
|
||||||
|
spec:
|
||||||
|
ports:
|
||||||
|
- port: {port}
|
||||||
|
targetPort: {port}
|
||||||
|
selector:
|
||||||
|
tf-worker: "{worker_id}"
|
||||||
|
""")
|
||||||
|
WORKER_LB_SVC = (
|
||||||
|
"""apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: {name_prefix}-worker{worker_id}
|
||||||
|
labels:
|
||||||
|
tf-worker: "{worker_id}"
|
||||||
|
spec:
|
||||||
|
type: LoadBalancer
|
||||||
|
ports:
|
||||||
|
- port: {port}
|
||||||
|
selector:
|
||||||
|
tf-worker: "{worker_id}"
|
||||||
|
""")
|
||||||
|
PARAM_SERVER_RC = (
|
||||||
|
"""apiVersion: v1
|
||||||
|
kind: ReplicationController
|
||||||
|
metadata:
|
||||||
|
name: {name_prefix}-ps{param_server_id}
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
tf-ps: "{param_server_id}"
|
||||||
|
name-prefix: "{name_prefix}"
|
||||||
|
job: "ps"
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: tf-ps{param_server_id}
|
||||||
|
image: {docker_image}
|
||||||
|
args: [{args}]
|
||||||
|
ports:
|
||||||
|
- containerPort: {port}
|
||||||
|
env: [{env_vars}]
|
||||||
|
volumeMounts: [{volume_mounts}]
|
||||||
|
volumes: [{volumes}]
|
||||||
|
""")
|
||||||
|
PARAM_SERVER_SVC = (
|
||||||
|
"""apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: {name_prefix}-ps{param_server_id}
|
||||||
|
labels:
|
||||||
|
tf-ps: "{param_server_id}"
|
||||||
|
spec:
|
||||||
|
ports:
|
||||||
|
- port: {port}
|
||||||
|
selector:
|
||||||
|
tf-ps: "{param_server_id}"
|
||||||
|
""")
|
||||||
|
PARAM_LB_SVC = ("""apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: {name_prefix}-ps{param_server_id}
|
||||||
|
labels:
|
||||||
|
tf-ps: "{param_server_id}"
|
||||||
|
spec:
|
||||||
|
type: LoadBalancer
|
||||||
|
ports:
|
||||||
|
- port: {port}
|
||||||
|
selector:
|
||||||
|
tf-ps: "{param_server_id}"
|
||||||
|
""")
|
||||||
|
VOLUME_MOUNTS = '{name: shared, mountPath: /shared}'
|
||||||
|
VOLUMES = '{name: shared, hostPath: {path: /shared}}'
|
||||||
|
_ENV_VAR_TEMPLATE = '{name: "%s", value: "%s"}'
|
||||||
|
_ARG_TEMPLATE = '"--%s=%s"'
|
||||||
|
|
||||||
|
|
||||||
|
def GenerateConfig(num_workers,
|
||||||
|
num_param_servers,
|
||||||
|
port,
|
||||||
|
request_load_balancer,
|
||||||
|
docker_image,
|
||||||
|
name_prefix,
|
||||||
|
env_vars=None,
|
||||||
|
use_shared_volume=True,
|
||||||
|
use_cluster_spec=True):
|
||||||
|
"""Generate configuration strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_workers: number of worker jobs.
|
||||||
|
num_param_servers: number of ps server jobs.
|
||||||
|
port: GRPC server port.
|
||||||
|
request_load_balancer: request worker0 to be exposed on a public IP
|
||||||
|
address via an external load balancer.
|
||||||
|
docker_image: docker image to use.
|
||||||
|
name_prefix: name to prepend to pod job names.
|
||||||
|
env_vars: dictionary of environment variables to set.
|
||||||
|
use_shared_volume: whether to add hostPath to /shared directory
|
||||||
|
to the kubernetes config.
|
||||||
|
use_cluster_spec: if true, pass --cluster_spec to worker and ps jobs.
|
||||||
|
If false, pass --worker_hosts and --ps_hosts to worker and ps jobs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Kubernetes yaml config.
|
||||||
|
"""
|
||||||
|
if env_vars is None:
|
||||||
|
env_vars = {}
|
||||||
|
env_str = ', '.join([_ENV_VAR_TEMPLATE % (name, value)
|
||||||
|
for name, value in env_vars.items()])
|
||||||
|
config = ''
|
||||||
|
common_args = GetCommonArgs(
|
||||||
|
num_workers, num_param_servers, port, name_prefix, use_cluster_spec)
|
||||||
|
for worker in range(num_workers):
|
||||||
|
worker_args = {
|
||||||
|
'job_name': 'worker',
|
||||||
|
'task_id': worker
|
||||||
|
}
|
||||||
|
worker_args.update(common_args)
|
||||||
|
arg_str = ', '.join([_ARG_TEMPLATE % (name, value)
|
||||||
|
for name, value in worker_args.items()])
|
||||||
|
config += WORKER_RC.format(
|
||||||
|
port=port,
|
||||||
|
worker_id=worker,
|
||||||
|
docker_image=docker_image,
|
||||||
|
name_prefix=name_prefix,
|
||||||
|
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
|
||||||
|
volumes=VOLUMES if use_shared_volume else '',
|
||||||
|
args=arg_str,
|
||||||
|
env_vars=env_str)
|
||||||
|
config += '---\n'
|
||||||
|
if request_load_balancer:
|
||||||
|
config += WORKER_LB_SVC.format(port=port,
|
||||||
|
worker_id=worker,
|
||||||
|
name_prefix=name_prefix)
|
||||||
|
else:
|
||||||
|
config += WORKER_SVC.format(port=port,
|
||||||
|
worker_id=worker,
|
||||||
|
name_prefix=name_prefix)
|
||||||
|
config += '---\n'
|
||||||
|
|
||||||
|
for param_server in range(num_param_servers):
|
||||||
|
ps_args = {
|
||||||
|
'job_name': 'ps',
|
||||||
|
'task_id': param_server
|
||||||
|
}
|
||||||
|
ps_args.update(common_args)
|
||||||
|
arg_str = ', '.join([_ARG_TEMPLATE % (name, value)
|
||||||
|
for name, value in ps_args.items()])
|
||||||
|
config += PARAM_SERVER_RC.format(
|
||||||
|
port=port,
|
||||||
|
param_server_id=param_server,
|
||||||
|
docker_image=docker_image,
|
||||||
|
name_prefix=name_prefix,
|
||||||
|
volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
|
||||||
|
volumes=VOLUMES if use_shared_volume else '',
|
||||||
|
args=arg_str,
|
||||||
|
env_vars=env_str)
|
||||||
|
config += '---\n'
|
||||||
|
if request_load_balancer:
|
||||||
|
config += PARAM_LB_SVC.format(
|
||||||
|
port=port, param_server_id=param_server, name_prefix=name_prefix)
|
||||||
|
else:
|
||||||
|
config += PARAM_SERVER_SVC.format(
|
||||||
|
port=port, param_server_id=param_server, name_prefix=name_prefix)
|
||||||
|
config += '---\n'
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def WorkerClusterSpecString(num_workers,
|
||||||
|
num_param_servers,
|
||||||
|
port,
|
||||||
|
name_prefix):
|
||||||
|
"""Generates worker cluster spec."""
|
||||||
|
return ClusterSpecString(num_workers, num_param_servers, port, name_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
def ParamServerClusterSpecString(num_workers,
|
||||||
|
num_param_servers,
|
||||||
|
port,
|
||||||
|
name_prefix):
|
||||||
|
"""Generates parameter server spec."""
|
||||||
|
return ClusterSpecString(num_workers, num_param_servers, port,
|
||||||
|
name_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
def ClusterSpecString(num_workers,
|
||||||
|
num_param_servers,
|
||||||
|
port,
|
||||||
|
name_prefix):
|
||||||
|
"""Generates general cluster spec."""
|
||||||
|
spec = 'worker|'
|
||||||
|
for worker in range(num_workers):
|
||||||
|
spec += '%s-worker%d:%d' % (name_prefix, worker, port)
|
||||||
|
if worker != num_workers-1:
|
||||||
|
spec += ';'
|
||||||
|
|
||||||
|
spec += ',ps|'
|
||||||
|
for param_server in range(num_param_servers):
|
||||||
|
spec += '%s-ps%d:%d' % (name_prefix, param_server, port)
|
||||||
|
if param_server != num_param_servers-1:
|
||||||
|
spec += ';'
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def GetCommonArgs(num_workers,
|
||||||
|
num_param_servers,
|
||||||
|
port,
|
||||||
|
name_prefix,
|
||||||
|
use_cluster_spec):
|
||||||
|
"""Get arguments common to both worker and ps jobs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_workers: number of workers.
|
||||||
|
num_param_servers: number of ps servers.
|
||||||
|
port: worker and ps port number.
|
||||||
|
name_prefix: prefix to prepend to job names.
|
||||||
|
use_cluster_spec: if true, pass --cluster_spec argument.
|
||||||
|
If false, parse --worker_hosts and --ps_hosts arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of argument names mapping to argument values.
|
||||||
|
"""
|
||||||
|
common_args = {}
|
||||||
|
if use_cluster_spec:
|
||||||
|
common_args['cluster_spec'] = WorkerClusterSpecString(
|
||||||
|
num_workers,
|
||||||
|
num_param_servers,
|
||||||
|
port,
|
||||||
|
name_prefix)
|
||||||
|
else:
|
||||||
|
common_args['worker_hosts'] = WorkerHosts(num_workers, port, name_prefix)
|
||||||
|
common_args['ps_hosts'] = PsHosts(num_param_servers, port, name_prefix)
|
||||||
|
return common_args
|
||||||
|
|
||||||
|
|
||||||
|
def WorkerHosts(num_workers, port, name_prefix):
|
||||||
|
worker_hosts = ['%s-worker%d:%d' % (name_prefix, i, port)
|
||||||
|
for i in range(num_workers)]
|
||||||
|
return ','.join(worker_hosts)
|
||||||
|
|
||||||
|
|
||||||
|
def PsHosts(num_ps, port, name_prefix):
|
||||||
|
ps_hosts = ['%s-ps%d:%d' % (name_prefix, i, port)
|
||||||
|
for i in range(num_ps)]
|
||||||
|
return ','.join(ps_hosts)
|
132
tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py
Normal file
132
tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.tools.dist_test.scripts.k8s_tensorflow_lib."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.platform import googletest
|
||||||
|
from tensorflow.tools.dist_test.scripts import k8s_tensorflow_lib
|
||||||
|
|
||||||
|
|
||||||
|
class K8sTensorflowTest(googletest.TestCase):
|
||||||
|
|
||||||
|
def testGenerateConfig_LoadBalancer(self):
|
||||||
|
# Use loadbalancer
|
||||||
|
config = k8s_tensorflow_lib.GenerateConfig(
|
||||||
|
num_workers=1,
|
||||||
|
num_param_servers=1,
|
||||||
|
port=5000,
|
||||||
|
request_load_balancer=True,
|
||||||
|
docker_image='test_image',
|
||||||
|
name_prefix='abc',
|
||||||
|
use_shared_volume=False)
|
||||||
|
self.assertTrue('LoadBalancer' in config)
|
||||||
|
|
||||||
|
# Don't use loadbalancer
|
||||||
|
config = k8s_tensorflow_lib.GenerateConfig(
|
||||||
|
num_workers=1,
|
||||||
|
num_param_servers=1,
|
||||||
|
port=5000,
|
||||||
|
request_load_balancer=False,
|
||||||
|
docker_image='test_image',
|
||||||
|
name_prefix='abc',
|
||||||
|
use_shared_volume=False)
|
||||||
|
self.assertFalse('LoadBalancer' in config)
|
||||||
|
|
||||||
|
def testGenerateConfig_SharedVolume(self):
|
||||||
|
# Use shared directory
|
||||||
|
config = k8s_tensorflow_lib.GenerateConfig(
|
||||||
|
num_workers=1,
|
||||||
|
num_param_servers=1,
|
||||||
|
port=5000,
|
||||||
|
request_load_balancer=False,
|
||||||
|
docker_image='test_image',
|
||||||
|
name_prefix='abc',
|
||||||
|
use_shared_volume=True)
|
||||||
|
self.assertTrue('/shared' in config)
|
||||||
|
|
||||||
|
# Don't use shared directory
|
||||||
|
config = k8s_tensorflow_lib.GenerateConfig(
|
||||||
|
num_workers=1,
|
||||||
|
num_param_servers=1,
|
||||||
|
port=5000,
|
||||||
|
request_load_balancer=False,
|
||||||
|
docker_image='test_image',
|
||||||
|
name_prefix='abc',
|
||||||
|
use_shared_volume=False)
|
||||||
|
self.assertFalse('/shared' in config)
|
||||||
|
|
||||||
|
def testEnvVar(self):
|
||||||
|
# Use loadbalancer
|
||||||
|
config = k8s_tensorflow_lib.GenerateConfig(
|
||||||
|
num_workers=1,
|
||||||
|
num_param_servers=1,
|
||||||
|
port=5000,
|
||||||
|
request_load_balancer=True,
|
||||||
|
docker_image='test_image',
|
||||||
|
name_prefix='abc',
|
||||||
|
use_shared_volume=False,
|
||||||
|
env_vars={'test1': 'test1_value', 'test2': 'test2_value'})
|
||||||
|
self.assertTrue('{name: "test1", value: "test1_value"}' in config)
|
||||||
|
self.assertTrue('{name: "test2", value: "test2_value"}' in config)
|
||||||
|
|
||||||
|
def testClusterSpec(self):
|
||||||
|
# Use cluster_spec
|
||||||
|
config = k8s_tensorflow_lib.GenerateConfig(
|
||||||
|
num_workers=1,
|
||||||
|
num_param_servers=1,
|
||||||
|
port=5000,
|
||||||
|
request_load_balancer=True,
|
||||||
|
docker_image='test_image',
|
||||||
|
name_prefix='abc',
|
||||||
|
use_shared_volume=False,
|
||||||
|
use_cluster_spec=True)
|
||||||
|
self.assertFalse('worker_hosts' in config)
|
||||||
|
self.assertFalse('ps_hosts' in config)
|
||||||
|
self.assertTrue(
|
||||||
|
'"--cluster_spec=worker|abc-worker0:5000,ps|abc-ps0:5000"' in config)
|
||||||
|
|
||||||
|
# Don't use cluster_spec
|
||||||
|
config = k8s_tensorflow_lib.GenerateConfig(
|
||||||
|
num_workers=1,
|
||||||
|
num_param_servers=1,
|
||||||
|
port=5000,
|
||||||
|
request_load_balancer=True,
|
||||||
|
docker_image='test_image',
|
||||||
|
name_prefix='abc',
|
||||||
|
use_shared_volume=False,
|
||||||
|
use_cluster_spec=False)
|
||||||
|
self.assertFalse('cluster_spec' in config)
|
||||||
|
self.assertTrue('"--worker_hosts=abc-worker0:5000"' in config)
|
||||||
|
self.assertTrue('"--ps_hosts=abc-ps0:5000"' in config)
|
||||||
|
|
||||||
|
def testWorkerHosts(self):
|
||||||
|
self.assertEquals(
|
||||||
|
'test_prefix-worker0:1234',
|
||||||
|
k8s_tensorflow_lib.WorkerHosts(1, 1234, 'test_prefix'))
|
||||||
|
self.assertEquals(
|
||||||
|
'test_prefix-worker0:1234,test_prefix-worker1:1234',
|
||||||
|
k8s_tensorflow_lib.WorkerHosts(2, 1234, 'test_prefix'))
|
||||||
|
|
||||||
|
def testPsHosts(self):
|
||||||
|
self.assertEquals(
|
||||||
|
'test_prefix-ps0:1234,test_prefix-ps1:1234',
|
||||||
|
k8s_tensorflow_lib.PsHosts(2, 1234, 'test_prefix'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
googletest.main()
|
Loading…
x
Reference in New Issue
Block a user