Distributed runtime test cleanup.

(1) Moved most contrib/eager/python/remote_test.py out of contrib
(2) Fixed asan / msan / tsan flakiness for DynamicClusterTest
(3) Improved coverage of async+sync | lazy_remote_input_copy on+off combinations

PiperOrigin-RevId: 299390689
Change-Id: I9e24c646fca3feedec56b16a2e891d357398272e
This commit is contained in:
Haoyu Zhang 2020-03-06 10:55:53 -08:00 committed by TensorFlower Gardener
parent 0531deb3ad
commit 950b054440
4 changed files with 817 additions and 0 deletions

View File

@ -868,6 +868,43 @@ cuda_py_test(
],
)
cuda_py_test(
name = "remote_execution_test",
srcs = ["remote_execution_test.py"],
grpc_enabled = True,
python_version = "PY3",
shard_count = 2,
tags = [
"no_oss", # This test launches local server
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python/eager:remote",
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "remote_cluster_test",
srcs = ["remote_cluster_test.py"],
grpc_enabled = True,
python_version = "PY3",
shard_count = 16,
tags = [
"no_oss", # This test launches local server
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"@absl_py//absl/testing:parameterized",
],
)
tpu_py_test(
name = "remote_cloud_tpu_test",
srcs = ["remote_cloud_tpu_test.py"],

View File

@ -0,0 +1,511 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for remote eager execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
from tensorflow.python.training import server_lib
JOB_NAME = "remote_device"
def get_server_def(job_name, local_server_port, remote_server_addresses,
task_index):
"""Returns a server def with a single job + multiple tasks."""
cluster_def = cluster_pb2.ClusterDef()
job_def = cluster_def.job.add()
job_def.name = job_name
job_def.tasks[0] = "localhost:%d" % local_server_port
for i, remote_server_address in enumerate(remote_server_addresses, start=1):
job_def.tasks[i] = remote_server_address
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def,
job_name=job_name,
task_index=task_index,
protocol="grpc")
return server_def
class DynamicClusterTest(test.TestCase, parameterized.TestCase):
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
super(DynamicClusterTest, self).__init__(methodName)
self._cached_server1 = server_lib.Server.create_local_server()
self._cached_server2 = server_lib.Server.create_local_server()
self._cached_server3 = server_lib.Server.create_local_server()
self._cached_server4 = server_lib.Server.create_local_server()
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
self._cached_server3_target = self._cached_server3.target[len("grpc://"):]
self._cached_server4_target = self._cached_server4.target[len("grpc://"):]
self.server_def_s1 = get_server_def(
JOB_NAME,
local_server_port=0,
remote_server_addresses=[self._cached_server1_target],
task_index=0)
self.server_def_s1_s2 = get_server_def(
JOB_NAME,
local_server_port=0,
remote_server_addresses=[
self._cached_server1_target, self._cached_server2_target
],
task_index=0)
self.server_def_s1_s3 = get_server_def(
JOB_NAME,
local_server_port=0,
remote_server_addresses=[
self._cached_server1_target, self._cached_server3_target
],
task_index=0)
self.server_def_s4_s3 = get_server_def(
JOB_NAME,
local_server_port=0,
remote_server_addresses=[
self._cached_server4_target, self._cached_server3_target
],
task_index=0)
self.server_def_s1_s2_s3 = get_server_def(
JOB_NAME,
local_server_port=0,
remote_server_addresses=[
self._cached_server1_target, self._cached_server2_target,
self._cached_server3_target
],
task_index=0)
self.device_local = "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME
self.device_t1 = "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME
self.device_t2 = "/job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME
self.device_t3 = "/job:%s/replica:0/task:3/device:CPU:0" % JOB_NAME
def setUp(self):
super(DynamicClusterTest, self).setUp()
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
context.set_server_def(
server_def=get_server_def(
JOB_NAME,
local_server_port=local_port,
remote_server_addresses=[
self._cached_server1_target, self._cached_server2_target
],
task_index=0))
def tearDown(self):
super(DynamicClusterTest, self).tearDown()
context._reset_context()
@test_util.run_in_async_and_sync_mode
def testServerAdded(self):
"""Add a server to cluster, and run remote ops on it."""
with ops.device(self.device_t1):
x1 = array_ops.ones([2, 2])
context.update_server_def(server_def=self.server_def_s1_s2_s3)
with ops.device(self.device_t3):
x2 = array_ops.ones([2, 2])
# Test new server accessing resources on old server
with ops.device(self.device_t3):
y = math_ops.matmul(x1, x2)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
# Test old server accessing resources on new server
with ops.device(self.device_t2):
y = math_ops.matmul(x1, x2)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testServerRemoved(self):
"""Remove a server from cluster, and run ops on cluster."""
with ops.device(self.device_t1):
x1 = array_ops.ones([2, 2])
with ops.device(self.device_t2):
x2 = array_ops.ones([2, 2])
with ops.device(self.device_t1):
y = math_ops.matmul(x1, x2)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
context.update_server_def(server_def=self.server_def_s1)
with ops.device(self.device_t1):
y = math_ops.matmul(x1, x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
# Running ops on removed server s2 throws an exception
with self.assertRaises(errors.InvalidArgumentError) as cm:
with ops.device(self.device_t2):
y = math_ops.matmul(x1, x2)
self.assertIn("unknown device", cm.exception.message)
# TODO(haoyuzhang): raise and catch exception when accessing tensors on
# the removed servers.
@test_util.run_in_async_and_sync_mode
def testServerReplaced(self):
"""Replace remote host_port for a task, and run ops on cluster."""
with ops.device(self.device_t1):
x1 = array_ops.ones([2, 2])
context.update_server_def(server_def=self.server_def_s1_s3)
with ops.device(self.device_t2):
y = math_ops.matmul(x1, x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testFunctionServerAdded(self):
"""Add a server to cluster, and run remote function on it."""
with ops.device(self.device_t1):
x1 = array_ops.ones([2, 2])
@def_function.function
def worker_fn(i):
return math_ops.matmul(i, i)
# Forces function tracing and registration
worker_fn.get_concrete_function(x1)
context.update_server_def(server_def=self.server_def_s1_s2_s3)
with ops.device(self.device_t3):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
with ops.device(self.device_t3):
x2 = array_ops.ones([2, 2])
with ops.device(self.device_t1):
y = worker_fn(x2)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testFunctionServerRemoved(self):
"""Remove a server from cluster, and run ops on cluster."""
@def_function.function
def worker_fn(i):
return math_ops.matmul(i, i)
with ops.device(self.device_t1):
x1 = array_ops.ones([2, 2])
# Forces function tracing and registration
worker_fn.get_concrete_function(x1)
context.update_server_def(server_def=self.server_def_s1)
with ops.device(self.device_t1):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
# Running functions on removed server s2 throws an exception
with self.assertRaises(errors.InvalidArgumentError) as cm:
with ops.device(self.device_t2):
y = worker_fn(x1)
self.assertIn(" unknown device", cm.exception.message)
# TODO(haoyuzhang): raise and catch exception when accessing tensors on
# the removed servers.
@test_util.run_in_async_and_sync_mode
def testFunctionServerRemovedAddedBack(self):
"""Add and remove a server, and run functions on cluster."""
with ops.device(self.device_t1):
x1 = array_ops.ones([2, 2])
@def_function.function
def worker_fn(i):
return math_ops.matmul(i, i)
# Forces function tracing and registration
worker_fn.get_concrete_function(x1)
context.update_server_def(server_def=self.server_def_s1_s2_s3)
with ops.device(self.device_t3):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
context.update_server_def(server_def=self.server_def_s1_s2)
with ops.device(self.device_t2):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
context.update_server_def(server_def=self.server_def_s1_s2_s3)
with ops.device(self.device_t3):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testFunctionServerReplaced(self):
"""Replace remote host_port for a task, and run functions on cluster."""
with ops.device(self.device_t1):
x1 = array_ops.ones([2, 2])
@def_function.function
def worker_fn(i):
return math_ops.matmul(i, i)
# Forces function tracing and registration
worker_fn.get_concrete_function(x1)
context.update_server_def(server_def=self.server_def_s1_s3)
with ops.device(self.device_t2):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testPendingNodesServerReplaced(self):
"""Update cluster when nodes are still pending on remote workers."""
with ops.device(self.device_local):
x1 = array_ops.ones([2, 2])
@def_function.function
def worker_fn(i):
return math_ops.matmul(i, i)
# Forces function tracing and registration
worker_fn.get_concrete_function(x1)
# Add enough ops so they are pending when changing the cluster
num_nodes = 10
ret = [None] * num_nodes
for i in range(num_nodes):
with ops.device(self.device_t1):
ret[i] = worker_fn(x1)
# While nodes are still pending on worker s1, replace worker s2 with s3.
context.update_server_def(server_def=self.server_def_s1_s3)
with ops.device(self.device_t2):
y = worker_fn(x1)
for i in range(num_nodes):
np.testing.assert_array_equal([[2, 2], [2, 2]], ret[i].numpy())
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testMultiThreadPendingNodesServerReplaced(self):
"""Update cluster when other remote function calls are being launched."""
with ops.device(self.device_local):
x1 = array_ops.ones([2, 2])
num_calls = 10
lock = threading.Lock()
@def_function.function
def worker_fn(i):
return math_ops.matmul(i, i)
def thread_fn(device, results):
for i in range(num_calls):
lock.acquire()
with ops.device(device):
y = worker_fn(x1)
results[i] = y.numpy()
lock.release()
def update_server_def_fn():
for i in range(num_calls):
lock.acquire()
context.update_server_def(
server_def=(self.server_def_s1_s2 if i %
2 == 0 else self.server_def_s1_s3))
lock.release()
t1_results = [None] * num_calls
t2_results = [None] * num_calls
threads = []
threads.append(threading.Thread(target=thread_fn,
args=(self.device_t1, t1_results)))
threads.append(threading.Thread(target=thread_fn,
args=(self.device_t2, t2_results)))
threads.append(threading.Thread(target=update_server_def_fn))
for t in threads:
t.start()
for t in threads:
t.join()
for result in t1_results + t2_results:
np.testing.assert_array_equal([[2, 2], [2, 2]], result)
@test_util.run_in_async_and_sync_mode
def testMultiThreadPendingNodesLockFree(self):
"""Update cluster when other remote function calls are being launched."""
with ops.device(self.device_t1):
x1 = array_ops.ones([2, 2])
num_calls = 10
self._coord = coordinator.Coordinator()
@def_function.function
def worker_fn(i):
return math_ops.matmul(i, i)
def thread_fn(device, results):
for i in range(num_calls):
with self._coord.stop_on_exception():
with ops.device(device):
results[i] = worker_fn(x1).numpy()
def update_server_def_fn():
for _ in range(30):
with self._coord.stop_on_exception():
context.update_server_def(self.server_def_s1_s2)
t1_results = [None] * num_calls
t2_results = [None] * num_calls
threads = []
threads.append(
threading.Thread(target=thread_fn, args=(self.device_t1, t1_results)))
threads.append(
threading.Thread(target=thread_fn, args=(self.device_t2, t2_results)))
threads.append(threading.Thread(target=update_server_def_fn))
for t in threads:
t.start()
self._coord.join(threads)
for result in t1_results + t2_results:
np.testing.assert_array_equal([[2, 2], [2, 2]], result)
@test_util.run_in_async_and_sync_mode
def testDistributedFunctionServerAdded(self):
"""Add a server to cluster, and run distributed function on it."""
with ops.device(self.device_t1):
x1 = array_ops.ones([2, 2])
@def_function.function
def worker_fn(i):
with ops.device(self.device_t2):
mul = math_ops.matmul(i, i)
return mul - array_ops.zeros_like(mul)
# Forces function tracing and registration
worker_fn.get_concrete_function(x1)
context.update_server_def(server_def=self.server_def_s1_s2_s3)
with ops.device(self.device_t3):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testDistributedFunctionServerRemovedAddedBack(self):
"""Add then remove a server, and run distributed function on cluster."""
with ops.device(self.device_local):
x1 = array_ops.ones([2, 2])
@def_function.function
def worker_fn(i):
with ops.device(self.device_t1):
mul = math_ops.matmul(i, i)
return mul - array_ops.zeros_like(mul)
# Forces function tracing and registration
worker_fn.get_concrete_function(x1)
context.update_server_def(server_def=self.server_def_s1)
with ops.device(self.device_t1):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
context.update_server_def(server_def=self.server_def_s1_s2)
with ops.device(self.device_t2):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testDistributedFunctionBothServersReplaced(self):
"""Tests that replacing servers works correctly.
We create two servers, t1 and t2. We first replace t2, then we replace t1.
Among other things, this ensures that both already existing, and
restarted workers have the context view IDs correctly updated.
"""
with ops.device(self.device_local):
x1 = array_ops.ones([2, 2])
@def_function.function
def worker_fn(i):
with ops.device(self.device_t1):
mul = math_ops.matmul(i, i)
with ops.device(self.device_t2):
add = mul + i
return add - i
# Forces function tracing and registration
worker_fn.get_concrete_function(x1)
# Replace task2
context.update_server_def(server_def=self.server_def_s1_s3)
for device in (self.device_t1, self.device_t2):
with ops.device(device):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
# Then replace task1
context.update_server_def(server_def=self.server_def_s4_s3)
for device in (self.device_t1, self.device_t2):
with ops.device(device):
y = worker_fn(x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
def testCheckAlive(self):
with self.assertRaisesRegexp(ValueError, "Context is not initialized."):
context.check_alive("/job:remote_device/task:0")
context.context().ensure_initialized()
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0"))
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1"))
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Client for target /job:remote_device/replica:0/task:10 not found."):
context.check_alive("/job:remote_device/replica:0/task:10")
class DynamicClusterWithoutLazyRemoteInputsCopyTest(DynamicClusterTest):
@classmethod
def setUpClass(cls):
super(DynamicClusterWithoutLazyRemoteInputsCopyTest, cls).setUpClass()
context._reset_context()
context.context().lazy_remote_inputs_copy = False
@classmethod
def tearDownClass(cls):
super(DynamicClusterWithoutLazyRemoteInputsCopyTest, cls).tearDownClass()
context._reset_context()
context.context().lazy_remote_inputs_copy = True
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()

View File

@ -0,0 +1,254 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for remote eager execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import remote
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
JOB_NAME = "remote_device"
ALT_JOB_NAME = "alt_remote_device"
def get_server_def(job_name, local_server_port, remote_server_addresses,
task_index):
"""Returns a server def with a single job + multiple tasks."""
cluster_def = cluster_pb2.ClusterDef()
job_def = cluster_def.job.add()
job_def.name = job_name
job_def.tasks[0] = "localhost:%d" % local_server_port
for i, remote_server_address in enumerate(remote_server_addresses, start=1):
job_def.tasks[i] = remote_server_address
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def,
job_name=job_name,
task_index=task_index,
protocol="grpc")
return server_def
class RemoteExecutionTest(test.TestCase, parameterized.TestCase):
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
super(RemoteExecutionTest, self).__init__(methodName)
self._cached_server1 = server_lib.Server.create_local_server()
self._cached_server2 = server_lib.Server.create_local_server()
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
def setUp(self):
super(RemoteExecutionTest, self).setUp()
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
context.set_server_def(
server_def=get_server_def(
JOB_NAME,
local_server_port=local_port,
remote_server_addresses=[
self._cached_server1_target, self._cached_server2_target
],
task_index=0))
def tearDown(self):
super(RemoteExecutionTest, self).tearDown()
# Clear the current device scope and reset the context to avoid polluting
# other test cases.
ops.device(None).__enter__()
context._reset_context()
@test_util.run_gpu_only
@test_util.run_in_async_and_sync_mode
def testGpuToRemoteCopy(self):
with ops.device("gpu:0"):
x = array_ops.ones([2, 2])
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
y = math_ops.matmul(x, x)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testDefunMatmul(self):
"""Basic remote eager execution with defun."""
mm_defun = function.defun(math_ops.matmul)
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
x1 = array_ops.ones([2, 2])
with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME):
x2 = array_ops.ones([2, 2])
y = mm_defun(x1, x2)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testSimpleMatmul(self):
"""Basic remote eager execution."""
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
x1 = array_ops.ones([2, 2])
with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME):
x2 = array_ops.ones([2, 2])
y = math_ops.matmul(x1, x2)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
def testEagerPyFuncPlacement(self):
if not ops.executing_eagerly_outside_functions():
return
def f(x):
return math_ops.square(x)
with ops.device("/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
const_op = constant_op.constant(3.0, dtype=dtypes.float32)
# PyFuncOp should be placed on the localhost's address space.
py_func_op = script_ops.eager_py_func(
func=f, inp=[const_op], Tout=dtypes.float32)
self.assertEqual(py_func_op.device,
"/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
self.assertEqual(self.evaluate(py_func_op), 9.0)
@test_util.run_in_async_and_sync_mode
def testSimpleWeightRead(self):
"""Basic remote eager weight read."""
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
w = resource_variable_ops.ResourceVariable([[2.0]])
loss = w * w
np.testing.assert_array_equal([[4.0]], loss.numpy())
@test_util.run_in_async_and_sync_mode
def testTapeWeightRead(self):
"""Remote eager weight read in a tape."""
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
w = resource_variable_ops.ResourceVariable([[3.0]])
with backprop.GradientTape() as tape:
loss = w * w
grad = tape.gradient(loss, w)
np.testing.assert_array_equal([[9.0]], loss.numpy())
np.testing.assert_array_equal([[6.0]], grad.numpy())
@test_util.run_in_async_and_sync_mode
def testServerDefChanged(self):
"""Update server def, and run ops on new cluster."""
context.set_server_def(
server_def=get_server_def(
ALT_JOB_NAME,
local_server_port=0,
remote_server_addresses=[
self._cached_server1_target, self._cached_server2_target
],
task_index=0))
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % ALT_JOB_NAME):
x1 = array_ops.ones([2, 2])
y = math_ops.matmul(x1, x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
# Set the server def back to JOB_NAME
context.set_server_def(
server_def=get_server_def(
JOB_NAME,
local_server_port=0,
remote_server_addresses=[
self._cached_server1_target, self._cached_server2_target
],
task_index=0))
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
x1 = array_ops.ones([2, 2])
y = math_ops.matmul(x1, x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testConnectToRemoteServer(self):
"""Basic server connection."""
context._reset_context()
remote.connect_to_remote_host(self._cached_server1_target)
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
x1 = array_ops.ones([2, 2])
x2 = array_ops.ones([2, 2])
y = math_ops.matmul(x1, x2)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
@test_util.run_in_async_and_sync_mode
def testContextDeviceUpdated(self):
"""Tests that the context device is correctly updated."""
with ops.device("cpu:0"):
x1 = array_ops.ones([2, 2])
x2 = array_ops.ones([2, 2])
y = math_ops.matmul(x1, x2)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
# `y` is placed on the local CPU as expected.
self.assertEqual(y.device,
"/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
@test_util.run_gpu_only
@test_util.run_in_async_and_sync_mode
def testGPUToRemoteCopy(self):
"""Tests that the remote copy happens satisfactorily."""
x1 = array_ops.ones([2, 2]).gpu()
with ops.device("/job:remote_device/replica:0/task:1/device:CPU:0"):
x2 = x1._copy() # pylint: disable=protected-access
np.testing.assert_array_equal(x1.numpy(), x2.numpy())
class RemoteExecutionWithoutLazyRemoteInputsCopyTest(RemoteExecutionTest):
@classmethod
def setUpClass(cls):
super(RemoteExecutionWithoutLazyRemoteInputsCopyTest, cls).setUpClass()
context._reset_context()
context.context().lazy_remote_inputs_copy = False
@classmethod
def tearDownClass(cls):
super(RemoteExecutionWithoutLazyRemoteInputsCopyTest, cls).tearDownClass()
context._reset_context()
context.context().lazy_remote_inputs_copy = True
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()

View File

@ -1017,6 +1017,21 @@ def build_as_function_and_v1_graph(func=None):
return decorator
def run_in_async_and_sync_mode(f):
"""Execute the test in async mode and sync mode."""
@parameterized.named_parameters([("Async", True), ("", False)])
@functools.wraps(f)
def decorator(self, async_mode, *args, **kwargs):
if async_mode:
with context.execution_mode(context.ASYNC):
f(self, *args, **kwargs)
else:
with context.execution_mode(context.SYNC):
f(self, *args, **kwargs)
return decorator
def eager_lazy_remote_copy_on_and_off(f):
"""Execute the test method w/o lazy tensor copy for function remote inputs."""