STT-tensorflow/tensorflow/python/training/server_lib_test.py
Mark Daoust 18b680216e Apply tf1->tf2 name replaces to doc-strings and comments in tensorflow.
No code changes, only doc-strings and comments.

PiperOrigin-RevId: 244275767
2019-04-18 17:19:27 -07:00

586 lines
22 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.GrpcServer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import input as input_ops
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import server_lib
class GrpcServerTest(test.TestCase):
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
super(GrpcServerTest, self).__init__(methodName)
self._cached_server = server_lib.Server.create_local_server()
def testRunStep(self):
server = self._cached_server
with session.Session(server.target) as sess:
c = constant_op.constant([[2, 1]])
d = constant_op.constant([[1], [2]])
e = math_ops.matmul(c, d)
self.assertAllEqual([[4]], sess.run(e))
# TODO(mrry): Add `server.stop()` and `server.join()` when these work.
@test_util.run_v1_only("b/120545219")
def testMultipleSessions(self):
server = self._cached_server
c = constant_op.constant([[2, 1]])
d = constant_op.constant([[1], [2]])
e = math_ops.matmul(c, d)
sess_1 = session.Session(server.target)
sess_2 = session.Session(server.target)
self.assertAllEqual([[4]], sess_1.run(e))
self.assertAllEqual([[4]], sess_2.run(e))
sess_1.close()
sess_2.close()
# TODO(mrry): Add `server.stop()` and `server.join()` when these work.
# Verifies various reset failures.
@test_util.run_v1_only("b/120545219")
def testResetFails(self):
# Creates variable with container name.
with ops.container("test0"):
v0 = variables.VariableV1(1.0, name="v0")
# Creates variable with default container.
v1 = variables.VariableV1(2.0, name="v1")
# Verifies resetting the non-existent target returns error.
with self.assertRaises(errors_impl.NotFoundError):
session.Session.reset("nonexistent", ["test0"])
# Verifies resetting with config.
# Verifies that resetting target with no server times out.
with self.assertRaises(errors_impl.DeadlineExceededError):
session.Session.reset(
"grpc://localhost:0", ["test0"],
config=config_pb2.ConfigProto(operation_timeout_in_ms=5))
# Verifies no containers are reset with non-existent container.
server = self._cached_server
sess = session.Session(server.target)
sess.run(variables.global_variables_initializer())
self.assertAllEqual(1.0, sess.run(v0))
self.assertAllEqual(2.0, sess.run(v1))
# No container is reset, but the server is reset.
session.Session.reset(server.target, ["test1"])
# Verifies that both variables are still valid.
sess = session.Session(server.target)
self.assertAllEqual(1.0, sess.run(v0))
self.assertAllEqual(2.0, sess.run(v1))
def _useRPCConfig(self):
"""Return a `tf.compat.v1.ConfigProto` that ensures we use the RPC stack for tests.
This configuration ensures that we continue to exercise the gRPC
stack when testing, rather than using the in-process optimization,
which avoids using gRPC as the transport between a client and
master in the same process.
Returns:
A `tf.compat.v1.ConfigProto`.
"""
return config_pb2.ConfigProto(rpc_options=config_pb2.RPCOptions(
use_rpc_for_inprocess_master=True))
def testLargeConstant(self):
server = self._cached_server
with session.Session(server.target, config=self._useRPCConfig()) as sess:
const_val = np.empty([10000, 3000], dtype=np.float32)
const_val.fill(0.5)
c = constant_op.constant(const_val)
shape_t = array_ops.shape(c)
self.assertAllEqual([10000, 3000], sess.run(shape_t))
def testLargeFetch(self):
server = self._cached_server
with session.Session(server.target, config=self._useRPCConfig()) as sess:
c = array_ops.fill([10000, 3000], 0.5)
expected_val = np.empty([10000, 3000], dtype=np.float32)
expected_val.fill(0.5)
self.assertAllEqual(expected_val, sess.run(c))
def testLargeFeed(self):
server = self._cached_server
with session.Session(server.target, config=self._useRPCConfig()) as sess:
feed_val = np.empty([10000, 3000], dtype=np.float32)
feed_val.fill(0.5)
p = array_ops.placeholder(dtypes.float32, shape=[10000, 3000])
min_t = math_ops.reduce_min(p)
max_t = math_ops.reduce_max(p)
min_val, max_val = sess.run([min_t, max_t], feed_dict={p: feed_val})
self.assertEqual(0.5, min_val)
self.assertEqual(0.5, max_val)
@test_util.run_v1_only("b/120545219")
def testCloseCancelsBlockingOperation(self):
server = self._cached_server
sess = session.Session(server.target, config=self._useRPCConfig())
q = data_flow_ops.FIFOQueue(10, [dtypes.float32])
enqueue_op = q.enqueue(37.0)
dequeue_t = q.dequeue()
sess.run(enqueue_op)
sess.run(dequeue_t)
def blocking_dequeue():
with self.assertRaisesRegexp(errors_impl.CancelledError,
"Session::Close"):
sess.run(dequeue_t)
blocking_thread = self.checkedThread(blocking_dequeue)
blocking_thread.start()
time.sleep(0.5)
sess.close()
blocking_thread.join()
def testInteractiveSession(self):
server = self._cached_server
# Session creation will warn (in C++) that the place_pruned_graph option
# is not supported, but it should successfully ignore it.
sess = session.InteractiveSession(server.target)
c = constant_op.constant(42.0)
self.assertEqual(42.0, self.evaluate(c))
sess.close()
def testSetConfiguration(self):
config = config_pb2.ConfigProto(
gpu_options=config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.1))
# Configure a server using the default local server options.
server = server_lib.Server.create_local_server(config=config, start=False)
self.assertEqual(0.1, server.server_def.default_session_config.gpu_options.
per_process_gpu_memory_fraction)
# Configure a server using an explicit ServerDefd with an
# overridden config.
cluster_def = server_lib.ClusterSpec({
"localhost": ["localhost:0"]
}).as_cluster_def()
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def,
job_name="localhost",
task_index=0,
protocol="grpc")
server = server_lib.Server(server_def, config=config, start=False)
self.assertEqual(0.1, server.server_def.default_session_config.gpu_options.
per_process_gpu_memory_fraction)
def testInvalidHostname(self):
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "port"):
_ = server_lib.Server(
{
"local": ["localhost"]
}, job_name="local", task_index=0)
@test_util.run_v1_only("b/120545219")
def testTimeoutRaisesException(self):
server = self._cached_server
q = data_flow_ops.FIFOQueue(1, [dtypes.float32])
blocking_t = q.dequeue()
with session.Session(server.target) as sess:
with self.assertRaises(errors_impl.DeadlineExceededError):
sess.run(blocking_t, options=config_pb2.RunOptions(timeout_in_ms=1000))
with session.Session(server.target, config=self._useRPCConfig()) as sess:
with self.assertRaises(errors_impl.DeadlineExceededError):
sess.run(blocking_t, options=config_pb2.RunOptions(timeout_in_ms=1000))
def testTwoServersSamePort(self):
# Starting a server with the same target as the cached server should fail.
server = self._cached_server
with self.assertRaises(errors_impl.UnknownError):
_ = server_lib.Server(
{"local_2": [server.target[len("grpc://"):]]})
def testExtendAfterQueueRunners(self):
server = self._cached_server
with session.Session(server.target) as sess:
input_queue = input_ops.input_producer(constant_op.constant(
[0.], dtype=dtypes.float32))
self.assertIsNotNone(input_queue)
var = variables.VariableV1(1., dtype=dtypes.float32, trainable=False,
name="var")
sess.run(variables.global_variables_initializer())
queue_runner_impl.start_queue_runners(sess)
sess.run(var.assign(3.0))
@test_util.run_v1_only("b/120545219")
def testIsolateSessionState(self):
server = self._cached_server
init_value = array_ops.placeholder(dtypes.int32)
v = variables.VariableV1(init_value, validate_shape=False, name="v")
sharing_config = config_pb2.ConfigProto(isolate_session_state=False)
sharing_sess_0 = session.Session(server.target, config=sharing_config)
sharing_sess_1 = session.Session(server.target, config=sharing_config)
isolate_config = config_pb2.ConfigProto(isolate_session_state=True)
isolate_sess_0 = session.Session(server.target, config=isolate_config)
isolate_sess_1 = session.Session(server.target, config=isolate_config)
# Initially all variables are initialized.
for sess in [sharing_sess_0, sharing_sess_1,
isolate_sess_0, isolate_sess_1]:
with self.assertRaises(errors_impl.FailedPreconditionError):
sess.run(v)
# Shared sessions will see each other's updates, but isolated sessions
# will not.
sharing_sess_0.run(v.initializer, feed_dict={init_value: 86})
self.assertAllEqual(86, sharing_sess_0.run(v))
self.assertAllEqual(86, sharing_sess_1.run(v))
with self.assertRaises(errors_impl.FailedPreconditionError):
isolate_sess_0.run(v)
with self.assertRaises(errors_impl.FailedPreconditionError):
isolate_sess_1.run(v)
# Changing the shape works because `validate_shape` is False.
sharing_sess_1.run(v.initializer, feed_dict={init_value: [86, 99]})
self.assertAllEqual([86, 99], sharing_sess_0.run(v))
self.assertAllEqual([86, 99], sharing_sess_1.run(v))
with self.assertRaises(errors_impl.FailedPreconditionError):
isolate_sess_0.run(v)
with self.assertRaises(errors_impl.FailedPreconditionError):
isolate_sess_1.run(v)
# Initializing in an isolated session will only affect the state in that
# session.
isolate_sess_0.run(v.initializer, feed_dict={init_value: 37})
self.assertAllEqual([86, 99], sharing_sess_0.run(v))
self.assertAllEqual([86, 99], sharing_sess_1.run(v))
self.assertAllEqual(37, isolate_sess_0.run(v))
with self.assertRaises(errors_impl.FailedPreconditionError):
isolate_sess_1.run(v)
# Isolated sessions can have different shapes for the same variable.
isolate_sess_1.run(v.initializer, feed_dict={init_value: [19, 86]})
self.assertAllEqual([86, 99], sharing_sess_0.run(v))
self.assertAllEqual([86, 99], sharing_sess_1.run(v))
self.assertAllEqual(37, isolate_sess_0.run(v))
self.assertAllEqual([19, 86], isolate_sess_1.run(v))
@test_util.run_v1_only("b/120545219")
def testShapeChangingIsolateState(self):
server = self._cached_server
sharing_config = config_pb2.ConfigProto(isolate_session_state=False)
isolate_config = config_pb2.ConfigProto(isolate_session_state=True)
with ops.Graph().as_default():
w_vector = variables.VariableV1([1, 2, 3], name="w")
with session.Session(server.target, config=sharing_config) as sess:
with self.assertRaises(errors_impl.FailedPreconditionError):
sess.run(w_vector)
sess.run(w_vector.initializer)
self.assertAllEqual([1, 2, 3], sess.run(w_vector))
with ops.Graph().as_default():
w_vector = variables.VariableV1([4, 5, 6], name="w")
with session.Session(server.target, config=sharing_config) as sess:
self.assertAllEqual([1, 2, 3], sess.run(w_vector))
sess.run(w_vector.initializer)
self.assertAllEqual([4, 5, 6], sess.run(w_vector))
with ops.Graph().as_default():
w_scalar = variables.VariableV1(86, name="w")
with session.Session(server.target, config=sharing_config) as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(w_scalar.initializer)
with ops.Graph().as_default():
w_scalar = variables.VariableV1(37, name="w")
with session.Session(server.target, config=isolate_config) as sess:
with self.assertRaises(errors_impl.FailedPreconditionError):
sess.run(w_scalar)
sess.run(w_scalar.initializer)
self.assertAllEqual(37, sess.run(w_scalar))
class ServerDefTest(test.TestCase):
def testLocalServer(self):
cluster_def = server_lib.ClusterSpec({
"local": ["localhost:2222"]
}).as_cluster_def()
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def, job_name="local", task_index=0, protocol="grpc")
self.assertProtoEquals("""
cluster {
job { name: 'local' tasks { key: 0 value: 'localhost:2222' } }
}
job_name: 'local' task_index: 0 protocol: 'grpc'
""", server_def)
# Verifies round trip from Proto->Spec->Proto is correct.
cluster_spec = server_lib.ClusterSpec(cluster_def)
self.assertProtoEquals(cluster_def, cluster_spec.as_cluster_def())
def testTwoProcesses(self):
cluster_def = server_lib.ClusterSpec({
"local": ["localhost:2222", "localhost:2223"]
}).as_cluster_def()
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def, job_name="local", task_index=1, protocol="grpc")
self.assertProtoEquals("""
cluster {
job { name: 'local' tasks { key: 0 value: 'localhost:2222' }
tasks { key: 1 value: 'localhost:2223' } }
}
job_name: 'local' task_index: 1 protocol: 'grpc'
""", server_def)
# Verifies round trip from Proto->Spec->Proto is correct.
cluster_spec = server_lib.ClusterSpec(cluster_def)
self.assertProtoEquals(cluster_def, cluster_spec.as_cluster_def())
def testTwoJobs(self):
cluster_def = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
}).as_cluster_def()
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def, job_name="worker", task_index=2, protocol="grpc")
self.assertProtoEquals("""
cluster {
job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
tasks { key: 1 value: 'ps1:2222' } }
job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
tasks { key: 1 value: 'worker1:2222' }
tasks { key: 2 value: 'worker2:2222' } }
}
job_name: 'worker' task_index: 2 protocol: 'grpc'
""", server_def)
# Verifies round trip from Proto->Spec->Proto is correct.
cluster_spec = server_lib.ClusterSpec(cluster_def)
self.assertProtoEquals(cluster_def, cluster_spec.as_cluster_def())
def testDenseAndSparseJobs(self):
cluster_def = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": {
0: "worker0:2222",
2: "worker2:2222"
}
}).as_cluster_def()
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def, job_name="worker", task_index=2, protocol="grpc")
self.assertProtoEquals("""
cluster {
job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
tasks { key: 1 value: 'ps1:2222' } }
job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
tasks { key: 2 value: 'worker2:2222' } }
}
job_name: 'worker' task_index: 2 protocol: 'grpc'
""", server_def)
# Verifies round trip from Proto->Spec->Proto is correct.
cluster_spec = server_lib.ClusterSpec(cluster_def)
self.assertProtoEquals(cluster_def, cluster_spec.as_cluster_def())
class ClusterSpecTest(test.TestCase):
def testStringConversion(self):
cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:1111"],
"worker": ["worker0:3333", "worker1:4444"]
})
expected_str = (
"ClusterSpec({'ps': ['ps0:1111'], 'worker': ['worker0:3333', "
"'worker1:4444']})")
self.assertEqual(expected_str, str(cluster_spec))
def testProtoDictDefEquivalences(self):
cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
expected_proto = """
job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
tasks { key: 1 value: 'ps1:2222' } }
job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
tasks { key: 1 value: 'worker1:2222' }
tasks { key: 2 value: 'worker2:2222' } }
"""
self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec).as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
def testProtoDictDefEquivalencesWithZeroWorker(self):
cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": []
})
expected_proto = """
job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
tasks { key: 1 value: 'ps1:2222' } }
job { name: 'worker' }
"""
self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
self.assertProtoEquals(
expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
def testClusterSpecAccessors(self):
original_dict = {
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"],
"sparse": {
0: "sparse0:2222",
3: "sparse3:2222"
}
}
cluster_spec = server_lib.ClusterSpec(original_dict)
self.assertEqual(original_dict, cluster_spec.as_dict())
self.assertEqual(2, cluster_spec.num_tasks("ps"))
self.assertEqual(3, cluster_spec.num_tasks("worker"))
self.assertEqual(2, cluster_spec.num_tasks("sparse"))
with self.assertRaises(ValueError):
cluster_spec.num_tasks("unknown")
self.assertEqual("ps0:2222", cluster_spec.task_address("ps", 0))
self.assertEqual("sparse0:2222", cluster_spec.task_address("sparse", 0))
with self.assertRaises(ValueError):
cluster_spec.task_address("unknown", 0)
with self.assertRaises(ValueError):
cluster_spec.task_address("sparse", 2)
self.assertEqual([0, 1], cluster_spec.task_indices("ps"))
self.assertEqual([0, 1, 2], cluster_spec.task_indices("worker"))
self.assertEqual([0, 3], cluster_spec.task_indices("sparse"))
with self.assertRaises(ValueError):
cluster_spec.task_indices("unknown")
# NOTE(mrry): `ClusterSpec.job_tasks()` is not recommended for use
# with sparse jobs.
self.assertEqual(["ps0:2222", "ps1:2222"], cluster_spec.job_tasks("ps"))
self.assertEqual(["worker0:2222", "worker1:2222", "worker2:2222"],
cluster_spec.job_tasks("worker"))
self.assertEqual(["sparse0:2222", None, None, "sparse3:2222"],
cluster_spec.job_tasks("sparse"))
with self.assertRaises(ValueError):
cluster_spec.job_tasks("unknown")
def testEmptyClusterSpecIsFalse(self):
self.assertFalse(server_lib.ClusterSpec({}))
def testNonEmptyClusterSpecIsTrue(self):
self.assertTrue(server_lib.ClusterSpec({"job": ["host:port"]}))
def testEq(self):
self.assertEquals(server_lib.ClusterSpec({}), server_lib.ClusterSpec({}))
self.assertEquals(
server_lib.ClusterSpec({
"job": ["host:2222"]
}),
server_lib.ClusterSpec({
"job": ["host:2222"]
}),)
self.assertEquals(
server_lib.ClusterSpec({
"job": {
0: "host:2222"
}
}), server_lib.ClusterSpec({
"job": ["host:2222"]
}))
def testNe(self):
self.assertNotEquals(
server_lib.ClusterSpec({}),
server_lib.ClusterSpec({
"job": ["host:2223"]
}),)
self.assertNotEquals(
server_lib.ClusterSpec({
"job1": ["host:2222"]
}),
server_lib.ClusterSpec({
"job2": ["host:2222"]
}),)
self.assertNotEquals(
server_lib.ClusterSpec({
"job": ["host:2222"]
}),
server_lib.ClusterSpec({
"job": ["host:2223"]
}),)
self.assertNotEquals(
server_lib.ClusterSpec({
"job": ["host:2222", "host:2223"]
}),
server_lib.ClusterSpec({
"job": ["host:2223", "host:2222"]
}),)
if __name__ == "__main__":
test.main()