Add unit test + benchmarks for grpc miniclusters created on localhost.
Also added client/net_lib to get access to PickUnusedPortOrDie. Change: 129250139
This commit is contained in:
parent
a0812ee71d
commit
9a137975be
@ -173,6 +173,7 @@ cc_library(
|
||||
"platform/logging.h",
|
||||
"platform/macros.h",
|
||||
"platform/mem.h",
|
||||
"platform/net.h",
|
||||
"platform/mutex.h",
|
||||
"platform/protobuf.h", # TODO(josh11b): make internal
|
||||
"platform/regexp.h",
|
||||
@ -1039,6 +1040,7 @@ filegroup(
|
||||
"platform/macros.h",
|
||||
"platform/mem.h",
|
||||
"platform/mutex.h",
|
||||
"platform/net.h",
|
||||
"platform/platform.h",
|
||||
"platform/protobuf.h",
|
||||
"platform/strong_hash.h",
|
||||
@ -1319,6 +1321,7 @@ tf_cc_tests(
|
||||
"platform/fingerprint_test.cc",
|
||||
"platform/integral_types_test.cc",
|
||||
"platform/logging_test.cc",
|
||||
"platform/net_test.cc",
|
||||
"platform/port_test.cc",
|
||||
],
|
||||
deps = [
|
||||
|
27
tensorflow/core/platform/net.h
Normal file
27
tensorflow/core/platform/net.h
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_PLATFORM_NET_H_
|
||||
#define TENSORFLOW_PLATFORM_NET_H_
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
int PickUnusedPortOrDie();
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PLATFORM_NET_H_
|
34
tensorflow/core/platform/net_test.cc
Normal file
34
tensorflow/core/platform/net_test.cc
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
TEST(Net, PickUnusedPortOrDie) {
|
||||
int port0 = PickUnusedPortOrDie();
|
||||
int port1 = PickUnusedPortOrDie();
|
||||
CHECK_GE(port0, 0);
|
||||
CHECK_LT(port0, 65536);
|
||||
CHECK_GE(port1, 0);
|
||||
CHECK_LT(port1, 65536);
|
||||
CHECK_NE(port0, port1);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
129
tensorflow/core/platform/posix/net.cc
Normal file
129
tensorflow/core/platform/posix/net.cc
Normal file
@ -0,0 +1,129 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
namespace {
|
||||
bool IsPortAvailable(int* port, bool is_tcp) {
|
||||
const int protocol = is_tcp ? IPPROTO_TCP : 0;
|
||||
const int fd = socket(AF_INET, is_tcp ? SOCK_STREAM : SOCK_DGRAM, protocol);
|
||||
|
||||
struct sockaddr_in addr;
|
||||
socklen_t addr_len = sizeof(addr);
|
||||
int actual_port;
|
||||
|
||||
CHECK_GE(*port, 0);
|
||||
CHECK_LE(*port, 65535);
|
||||
if (fd < 0) {
|
||||
LOG(ERROR) << "socket() failed: " << strerror(errno);
|
||||
return false;
|
||||
}
|
||||
|
||||
// SO_REUSEADDR lets us start up a server immediately after it exists.
|
||||
int one = 1;
|
||||
if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0) {
|
||||
LOG(ERROR) << "setsockopt() failed: " << strerror(errno);
|
||||
close(fd);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try binding to port.
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_addr.s_addr = INADDR_ANY;
|
||||
addr.sin_port = htons((uint16_t)*port);
|
||||
if (bind(fd, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
|
||||
LOG(WARNING) << "bind(port=" << *port << ") failed: " << strerror(errno);
|
||||
close(fd);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get the bound port number.
|
||||
if (getsockname(fd, (struct sockaddr*)&addr, &addr_len) < 0) {
|
||||
LOG(WARNING) << "getsockname() failed: " << strerror(errno);
|
||||
close(fd);
|
||||
return false;
|
||||
}
|
||||
CHECK_LE(addr_len, sizeof(addr));
|
||||
actual_port = ntohs(addr.sin_port);
|
||||
CHECK_GT(actual_port, 0);
|
||||
if (*port == 0) {
|
||||
*port = actual_port;
|
||||
} else {
|
||||
CHECK_EQ(*port, actual_port);
|
||||
}
|
||||
close(fd);
|
||||
return true;
|
||||
}
|
||||
|
||||
const int kNumRandomPortsToPick = 100;
|
||||
const int kMaximumTrials = 1000;
|
||||
|
||||
} // namespace
|
||||
|
||||
int PickUnusedPortOrDie() {
|
||||
static std::unordered_set<int> chosen_ports;
|
||||
|
||||
// Type of port to first pick in the next iteration.
|
||||
bool is_tcp = true;
|
||||
int trial = 0;
|
||||
while (true) {
|
||||
int port;
|
||||
trial++;
|
||||
CHECK_LE(trial, kMaximumTrials)
|
||||
<< "Failed to pick an unused port for testing.";
|
||||
if (trial == 1) {
|
||||
port = getpid() % (65536 - 30000) + 30000;
|
||||
} else if (trial <= kNumRandomPortsToPick) {
|
||||
port = rand() % (65536 - 30000) + 30000;
|
||||
} else {
|
||||
port = 0;
|
||||
}
|
||||
|
||||
if (chosen_ports.find(port) != chosen_ports.end()) {
|
||||
continue;
|
||||
}
|
||||
if (!IsPortAvailable(&port, is_tcp)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
CHECK_GT(port, 0);
|
||||
if (!IsPortAvailable(&port, !is_tcp)) {
|
||||
is_tcp = !is_tcp;
|
||||
continue;
|
||||
}
|
||||
|
||||
chosen_ports.insert(port);
|
||||
return port;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
@ -13,16 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <netinet/in.h>
|
||||
#include <signal.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -84,101 +78,7 @@ std::unique_ptr<SubProcess> CreateSubProcess(const std::vector<string>& argv) {
|
||||
return std::unique_ptr<SubProcess>(new PosixSubProcess(argv));
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool IsPortAvailable(int* port, bool is_tcp) {
|
||||
const int protocol = is_tcp ? IPPROTO_TCP : 0;
|
||||
const int fd = socket(AF_INET, is_tcp ? SOCK_STREAM : SOCK_DGRAM, protocol);
|
||||
|
||||
struct sockaddr_in addr;
|
||||
socklen_t addr_len = sizeof(addr);
|
||||
int actual_port;
|
||||
|
||||
CHECK_GE(*port, 0);
|
||||
CHECK_LE(*port, 65535);
|
||||
if (fd < 0) {
|
||||
LOG(ERROR) << "socket() failed: " << strerror(errno);
|
||||
return false;
|
||||
}
|
||||
|
||||
// SO_REUSEADDR lets us start up a server immediately after it exists.
|
||||
int one = 1;
|
||||
if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0) {
|
||||
LOG(ERROR) << "setsockopt() failed: " << strerror(errno);
|
||||
close(fd);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try binding to port.
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_addr.s_addr = INADDR_ANY;
|
||||
addr.sin_port = htons((uint16_t)*port);
|
||||
if (bind(fd, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
|
||||
LOG(WARNING) << "bind(port=" << *port << ") failed: " << strerror(errno);
|
||||
close(fd);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get the bound port number.
|
||||
if (getsockname(fd, (struct sockaddr*)&addr, &addr_len) < 0) {
|
||||
LOG(WARNING) << "getsockname() failed: " << strerror(errno);
|
||||
close(fd);
|
||||
return false;
|
||||
}
|
||||
CHECK_LE(addr_len, sizeof(addr));
|
||||
actual_port = ntohs(addr.sin_port);
|
||||
CHECK_GT(actual_port, 0);
|
||||
if (*port == 0) {
|
||||
*port = actual_port;
|
||||
} else {
|
||||
CHECK_EQ(*port, actual_port);
|
||||
}
|
||||
close(fd);
|
||||
return true;
|
||||
}
|
||||
|
||||
const int kNumRandomPortsToPick = 100;
|
||||
const int kMaximumTrials = 1000;
|
||||
|
||||
} // namespace
|
||||
|
||||
int PickUnusedPortOrDie() {
|
||||
static std::unordered_set<int> chosen_ports;
|
||||
|
||||
// Type of port to first pick in the next iteration.
|
||||
bool is_tcp = true;
|
||||
int trial = 0;
|
||||
while (true) {
|
||||
int port;
|
||||
trial++;
|
||||
CHECK_LE(trial, kMaximumTrials)
|
||||
<< "Failed to pick an unused port for testing.";
|
||||
if (trial == 1) {
|
||||
port = getpid() % (65536 - 30000) + 30000;
|
||||
} else if (trial <= kNumRandomPortsToPick) {
|
||||
port = rand() % (65536 - 30000) + 30000;
|
||||
} else {
|
||||
port = 0;
|
||||
}
|
||||
|
||||
if (chosen_ports.find(port) != chosen_ports.end()) {
|
||||
continue;
|
||||
}
|
||||
if (!IsPortAvailable(&port, is_tcp)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
CHECK_GT(port, 0);
|
||||
if (!IsPortAvailable(&port, !is_tcp)) {
|
||||
is_tcp = !is_tcp;
|
||||
continue;
|
||||
}
|
||||
|
||||
chosen_ports.insert(port);
|
||||
return port;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
int PickUnusedPortOrDie() { return internal::PickUnusedPortOrDie(); }
|
||||
|
||||
string TensorFlowSrcRoot() {
|
||||
// 'bazel test' sets TEST_SRCDIR, and also TEST_WORKSPACE if a new
|
||||
|
@ -1057,6 +1057,27 @@ cuda_py_tests(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "net_lib",
|
||||
testonly = 1,
|
||||
srcs = ["util/net_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pywrap_tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
py_tests(
|
||||
name = "net_lib_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"util/net_lib_test.py",
|
||||
],
|
||||
additional_deps = [
|
||||
":net_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "tf_session_helper",
|
||||
srcs = ["client/tf_session_helper.cc"],
|
||||
@ -1083,6 +1104,7 @@ tf_py_wrap_cc(
|
||||
swig_includes = [
|
||||
"client/device_lib.i",
|
||||
"client/events_writer.i",
|
||||
"client/net_lib.i",
|
||||
"client/quantize_training.i",
|
||||
"client/tf_session.i",
|
||||
"framework/python_op_gen.i",
|
||||
@ -1148,6 +1170,14 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "localhost_cluster_performance_test",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"training/localhost_cluster_performance_test.py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "timeline",
|
||||
srcs = ["client/timeline.py"],
|
||||
@ -1283,6 +1313,7 @@ cuda_py_tests(
|
||||
"training/session_manager_test.py",
|
||||
"training/supervisor_test.py",
|
||||
"training/saver_large_variable_test.py",
|
||||
"training/localhost_cluster_performance_test.py",
|
||||
],
|
||||
),
|
||||
additional_deps = [
|
||||
|
30
tensorflow/python/client/net_lib.i
Normal file
30
tensorflow/python/client/net_lib.i
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::internal;
|
||||
%unignore tensorflow::internal::PickUnusedPortOrDie;
|
||||
|
||||
%include "tensorflow/core/platform/net.h"
|
||||
|
||||
%unignoreall
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
|
||||
%include "tensorflow/python/client/tf_session.i"
|
||||
%include "tensorflow/python/client/device_lib.i"
|
||||
%include "tensorflow/python/client/net_lib.i"
|
||||
%include "tensorflow/python/client/quantize_training.i"
|
||||
|
||||
%include "tensorflow/python/lib/io/file_io.i"
|
||||
|
133
tensorflow/python/training/localhost_cluster_performance_test.py
Normal file
133
tensorflow/python/training/localhost_cluster_performance_test.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests and benchmarks for creating RPC clusters on localhost."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.util import net_lib
|
||||
|
||||
|
||||
def create_local_cluster(num_workers, num_ps, protocol="grpc"):
|
||||
"""Create local GRPC servers and return their servers."""
|
||||
worker_ports = [net_lib.pick_unused_port_or_die() for _ in range(num_workers)]
|
||||
ps_ports = [net_lib.pick_unused_port_or_die() for _ in range(num_ps)]
|
||||
cluster_dict = {
|
||||
"worker": ["localhost:%s" % port for port in worker_ports],
|
||||
"ps": ["localhost:%s" % port for port in ps_ports]}
|
||||
cs = tf.train.ClusterSpec(cluster_dict)
|
||||
|
||||
workers = [
|
||||
tf.train.Server(
|
||||
cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
|
||||
for ix in range(num_workers)]
|
||||
ps_servers = [
|
||||
tf.train.Server(
|
||||
cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
|
||||
for ix in range(num_ps)]
|
||||
|
||||
return workers, ps_servers
|
||||
|
||||
|
||||
class CreateLocalClusterTest(tf.test.TestCase):
|
||||
|
||||
def testCreateLocalCluster(self):
|
||||
workers, _ = create_local_cluster(num_workers=2, num_ps=2)
|
||||
worker_sessions = [tf.Session(w.target) for w in workers]
|
||||
with tf.device("/job:ps/task:0"):
|
||||
var0 = tf.Variable(0.0)
|
||||
with tf.device("/job:ps/task:1"):
|
||||
var1 = tf.Variable(1.0)
|
||||
worker_sessions[0].run([var0.initializer, var1.initializer])
|
||||
with tf.device("/job:ps/task:0"):
|
||||
var2 = tf.Variable(2.0)
|
||||
with tf.device("/job:ps/task:1"):
|
||||
var3 = tf.Variable(3.0)
|
||||
worker_sessions[1].run([var2.initializer, var3.initializer])
|
||||
|
||||
# Read values back in the opposite session
|
||||
self.assertAllEqual(0.0, var0.eval(session=worker_sessions[1]))
|
||||
self.assertAllEqual(1.0, var1.eval(session=worker_sessions[1]))
|
||||
self.assertAllEqual(2.0, var2.eval(session=worker_sessions[0]))
|
||||
self.assertAllEqual(3.0, var3.eval(session=worker_sessions[0]))
|
||||
|
||||
|
||||
class CreateLocalClusterBenchmark(tf.test.Benchmark):
|
||||
|
||||
def benchmarkCreateLocalCluster(self):
|
||||
deltas = []
|
||||
iters = 50
|
||||
for _ in range(iters):
|
||||
start_time = time.time()
|
||||
create_local_cluster(num_workers=1, num_ps=10)
|
||||
end_time = time.time()
|
||||
deltas.append(end_time - start_time)
|
||||
|
||||
median_deltas = np.median(deltas)
|
||||
print(
|
||||
"\n\nbenchmark_create_local_cluster_1_worker_10_ps. "
|
||||
"iterations: %d, median wall time: %g\n\n" % (iters, median_deltas))
|
||||
self.report_benchmark(
|
||||
iters=iters,
|
||||
wall_time=median_deltas,
|
||||
name="benchmark_create_local_cluster_1_worker_10_ps")
|
||||
|
||||
|
||||
class PartitionedVariablesBenchmark(tf.test.Benchmark):
|
||||
|
||||
def benchmark_create_1000_partitions_with_100_parameter_servers(self):
|
||||
workers, _ = create_local_cluster(num_workers=1, num_ps=100)
|
||||
worker_sessions = [tf.Session(w.target) for w in workers]
|
||||
worker = worker_sessions[0]
|
||||
partition_sizes = (1, 512, 1024*32, 1024*128)
|
||||
|
||||
partitioned = []
|
||||
|
||||
for partition_size in partition_sizes:
|
||||
# max_shard_bytes is 4, shape is 1000*partition_size float32s which should
|
||||
# partition into 1000 shards, each containing partition_size float32s.
|
||||
print("Building partitioned variable with %d floats per partition"
|
||||
% partition_size)
|
||||
with tf.device(tf.train.replica_device_setter(ps_tasks=100)):
|
||||
partitioned_ix = tf.get_variable(
|
||||
"partitioned_%d" % partition_size,
|
||||
shape=[1000 * partition_size],
|
||||
dtype=tf.float32,
|
||||
# Each partition to have exactly N float32s
|
||||
partitioner=tf.variable_axis_size_partitioner(
|
||||
max_shard_bytes=4 * partition_size))
|
||||
# Concatenates along axis 0
|
||||
partitioned.append(tf.convert_to_tensor(partitioned_ix))
|
||||
|
||||
tf.initialize_all_variables().run(session=worker)
|
||||
|
||||
for ix, partition_size in enumerate(partition_sizes):
|
||||
print("Running benchmark having partitions with %d floats"
|
||||
% partition_size)
|
||||
self.run_op_benchmark(
|
||||
worker,
|
||||
partitioned[ix],
|
||||
name=("read_concat_1000_partitions_from_"
|
||||
"100_parameter_servers_partsize_%d_floats" % partition_size))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
28
tensorflow/python/util/net_lib.py
Normal file
28
tensorflow/python/util/net_lib.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A Python interface for creating TensorFlow tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six # pylint: disable=unused-import
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
|
||||
|
||||
def pick_unused_port_or_die():
|
||||
"""Find an unused port on localhost."""
|
||||
return pywrap_tensorflow.PickUnusedPortOrDie()
|
39
tensorflow/python/util/net_lib_test.py
Normal file
39
tensorflow/python/util/net_lib_test.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for the SWIG-wrapped test lib."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.util import net_lib
|
||||
|
||||
|
||||
class TestLibTest(tf.test.TestCase):
|
||||
|
||||
def testPickUnusedPortOrDie(self):
|
||||
port0 = net_lib.pick_unused_port_or_die()
|
||||
port1 = net_lib.pick_unused_port_or_die()
|
||||
self.assertGreater(port0, 0)
|
||||
self.assertLess(port0, 65536)
|
||||
self.assertGreater(port1, 0)
|
||||
self.assertLess(port1, 65536)
|
||||
self.assertNotEqual(port0, port1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
Loading…
Reference in New Issue
Block a user