diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 76dc22e17f7..6f444d46821 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -868,6 +868,43 @@ cuda_py_test( ], ) +cuda_py_test( + name = "remote_execution_test", + srcs = ["remote_execution_test.py"], + grpc_enabled = True, + python_version = "PY3", + shard_count = 2, + tags = [ + "no_oss", # This test launches local server + ], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:remote", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "remote_cluster_test", + srcs = ["remote_cluster_test.py"], + grpc_enabled = True, + python_version = "PY3", + shard_count = 16, + tags = [ + "no_oss", # This test launches local server + ], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "@absl_py//absl/testing:parameterized", + ], +) + tpu_py_test( name = "remote_cloud_tpu_test", srcs = ["remote_cloud_tpu_test.py"], diff --git a/tensorflow/python/eager/remote_cluster_test.py b/tensorflow/python/eager/remote_cluster_test.py new file mode 100644 index 00000000000..e26b99a8aa0 --- /dev/null +++ b/tensorflow/python/eager/remote_cluster_test.py @@ -0,0 +1,511 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for remote eager execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from absl.testing import parameterized +import numpy as np + +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import tensorflow_server_pb2 +from tensorflow.python import pywrap_tfe +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.training import coordinator +from tensorflow.python.training import server_lib + +JOB_NAME = "remote_device" + + +def get_server_def(job_name, local_server_port, remote_server_addresses, + task_index): + """Returns a server def with a single job + multiple tasks.""" + cluster_def = cluster_pb2.ClusterDef() + job_def = cluster_def.job.add() + job_def.name = job_name + job_def.tasks[0] = "localhost:%d" % local_server_port + + for i, remote_server_address in enumerate(remote_server_addresses, start=1): + job_def.tasks[i] = remote_server_address + + server_def = tensorflow_server_pb2.ServerDef( + cluster=cluster_def, + job_name=job_name, + task_index=task_index, + protocol="grpc") + + return server_def + + +class DynamicClusterTest(test.TestCase, parameterized.TestCase): + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(DynamicClusterTest, self).__init__(methodName) + self._cached_server1 = server_lib.Server.create_local_server() + self._cached_server2 = server_lib.Server.create_local_server() + self._cached_server3 = server_lib.Server.create_local_server() + self._cached_server4 = server_lib.Server.create_local_server() + + self._cached_server1_target = self._cached_server1.target[len("grpc://"):] + self._cached_server2_target = self._cached_server2.target[len("grpc://"):] + self._cached_server3_target = self._cached_server3.target[len("grpc://"):] + self._cached_server4_target = self._cached_server4.target[len("grpc://"):] + + self.server_def_s1 = get_server_def( + JOB_NAME, + local_server_port=0, + remote_server_addresses=[self._cached_server1_target], + task_index=0) + self.server_def_s1_s2 = get_server_def( + JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0) + self.server_def_s1_s3 = get_server_def( + JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server3_target + ], + task_index=0) + self.server_def_s4_s3 = get_server_def( + JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server4_target, self._cached_server3_target + ], + task_index=0) + self.server_def_s1_s2_s3 = get_server_def( + JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target, + self._cached_server3_target + ], + task_index=0) + + self.device_local = "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME + self.device_t1 = "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME + self.device_t2 = "/job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME + self.device_t3 = "/job:%s/replica:0/task:3/device:CPU:0" % JOB_NAME + + def setUp(self): + super(DynamicClusterTest, self).setUp() + local_port = pywrap_tfe.TF_PickUnusedPortOrDie() + context.set_server_def( + server_def=get_server_def( + JOB_NAME, + local_server_port=local_port, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0)) + + def tearDown(self): + super(DynamicClusterTest, self).tearDown() + context._reset_context() + + @test_util.run_in_async_and_sync_mode + def testServerAdded(self): + """Add a server to cluster, and run remote ops on it.""" + with ops.device(self.device_t1): + x1 = array_ops.ones([2, 2]) + + context.update_server_def(server_def=self.server_def_s1_s2_s3) + with ops.device(self.device_t3): + x2 = array_ops.ones([2, 2]) + + # Test new server accessing resources on old server + with ops.device(self.device_t3): + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + # Test old server accessing resources on new server + with ops.device(self.device_t2): + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testServerRemoved(self): + """Remove a server from cluster, and run ops on cluster.""" + with ops.device(self.device_t1): + x1 = array_ops.ones([2, 2]) + with ops.device(self.device_t2): + x2 = array_ops.ones([2, 2]) + + with ops.device(self.device_t1): + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + context.update_server_def(server_def=self.server_def_s1) + with ops.device(self.device_t1): + y = math_ops.matmul(x1, x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + # Running ops on removed server s2 throws an exception + with self.assertRaises(errors.InvalidArgumentError) as cm: + with ops.device(self.device_t2): + y = math_ops.matmul(x1, x2) + self.assertIn("unknown device", cm.exception.message) + + # TODO(haoyuzhang): raise and catch exception when accessing tensors on + # the removed servers. + + @test_util.run_in_async_and_sync_mode + def testServerReplaced(self): + """Replace remote host_port for a task, and run ops on cluster.""" + with ops.device(self.device_t1): + x1 = array_ops.ones([2, 2]) + + context.update_server_def(server_def=self.server_def_s1_s3) + with ops.device(self.device_t2): + y = math_ops.matmul(x1, x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testFunctionServerAdded(self): + """Add a server to cluster, and run remote function on it.""" + with ops.device(self.device_t1): + x1 = array_ops.ones([2, 2]) + + @def_function.function + def worker_fn(i): + return math_ops.matmul(i, i) + + # Forces function tracing and registration + worker_fn.get_concrete_function(x1) + + context.update_server_def(server_def=self.server_def_s1_s2_s3) + with ops.device(self.device_t3): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + with ops.device(self.device_t3): + x2 = array_ops.ones([2, 2]) + with ops.device(self.device_t1): + y = worker_fn(x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testFunctionServerRemoved(self): + """Remove a server from cluster, and run ops on cluster.""" + + @def_function.function + def worker_fn(i): + return math_ops.matmul(i, i) + + with ops.device(self.device_t1): + x1 = array_ops.ones([2, 2]) + + # Forces function tracing and registration + worker_fn.get_concrete_function(x1) + + context.update_server_def(server_def=self.server_def_s1) + + with ops.device(self.device_t1): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + # Running functions on removed server s2 throws an exception + with self.assertRaises(errors.InvalidArgumentError) as cm: + with ops.device(self.device_t2): + y = worker_fn(x1) + self.assertIn(" unknown device", cm.exception.message) + + # TODO(haoyuzhang): raise and catch exception when accessing tensors on + # the removed servers. + + @test_util.run_in_async_and_sync_mode + def testFunctionServerRemovedAddedBack(self): + """Add and remove a server, and run functions on cluster.""" + with ops.device(self.device_t1): + x1 = array_ops.ones([2, 2]) + + @def_function.function + def worker_fn(i): + return math_ops.matmul(i, i) + + # Forces function tracing and registration + worker_fn.get_concrete_function(x1) + + context.update_server_def(server_def=self.server_def_s1_s2_s3) + with ops.device(self.device_t3): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + context.update_server_def(server_def=self.server_def_s1_s2) + with ops.device(self.device_t2): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + context.update_server_def(server_def=self.server_def_s1_s2_s3) + with ops.device(self.device_t3): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testFunctionServerReplaced(self): + """Replace remote host_port for a task, and run functions on cluster.""" + with ops.device(self.device_t1): + x1 = array_ops.ones([2, 2]) + + @def_function.function + def worker_fn(i): + return math_ops.matmul(i, i) + + # Forces function tracing and registration + worker_fn.get_concrete_function(x1) + + context.update_server_def(server_def=self.server_def_s1_s3) + with ops.device(self.device_t2): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testPendingNodesServerReplaced(self): + """Update cluster when nodes are still pending on remote workers.""" + with ops.device(self.device_local): + x1 = array_ops.ones([2, 2]) + + @def_function.function + def worker_fn(i): + return math_ops.matmul(i, i) + + # Forces function tracing and registration + worker_fn.get_concrete_function(x1) + + # Add enough ops so they are pending when changing the cluster + num_nodes = 10 + ret = [None] * num_nodes + for i in range(num_nodes): + with ops.device(self.device_t1): + ret[i] = worker_fn(x1) + # While nodes are still pending on worker s1, replace worker s2 with s3. + context.update_server_def(server_def=self.server_def_s1_s3) + with ops.device(self.device_t2): + y = worker_fn(x1) + for i in range(num_nodes): + np.testing.assert_array_equal([[2, 2], [2, 2]], ret[i].numpy()) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testMultiThreadPendingNodesServerReplaced(self): + """Update cluster when other remote function calls are being launched.""" + with ops.device(self.device_local): + x1 = array_ops.ones([2, 2]) + + num_calls = 10 + lock = threading.Lock() + + @def_function.function + def worker_fn(i): + return math_ops.matmul(i, i) + + def thread_fn(device, results): + for i in range(num_calls): + lock.acquire() + with ops.device(device): + y = worker_fn(x1) + results[i] = y.numpy() + lock.release() + + def update_server_def_fn(): + for i in range(num_calls): + lock.acquire() + context.update_server_def( + server_def=(self.server_def_s1_s2 if i % + 2 == 0 else self.server_def_s1_s3)) + lock.release() + + t1_results = [None] * num_calls + t2_results = [None] * num_calls + threads = [] + threads.append(threading.Thread(target=thread_fn, + args=(self.device_t1, t1_results))) + threads.append(threading.Thread(target=thread_fn, + args=(self.device_t2, t2_results))) + threads.append(threading.Thread(target=update_server_def_fn)) + for t in threads: + t.start() + for t in threads: + t.join() + for result in t1_results + t2_results: + np.testing.assert_array_equal([[2, 2], [2, 2]], result) + + @test_util.run_in_async_and_sync_mode + def testMultiThreadPendingNodesLockFree(self): + """Update cluster when other remote function calls are being launched.""" + with ops.device(self.device_t1): + x1 = array_ops.ones([2, 2]) + + num_calls = 10 + self._coord = coordinator.Coordinator() + + @def_function.function + def worker_fn(i): + return math_ops.matmul(i, i) + + def thread_fn(device, results): + for i in range(num_calls): + with self._coord.stop_on_exception(): + with ops.device(device): + results[i] = worker_fn(x1).numpy() + + def update_server_def_fn(): + for _ in range(30): + with self._coord.stop_on_exception(): + context.update_server_def(self.server_def_s1_s2) + + t1_results = [None] * num_calls + t2_results = [None] * num_calls + threads = [] + threads.append( + threading.Thread(target=thread_fn, args=(self.device_t1, t1_results))) + threads.append( + threading.Thread(target=thread_fn, args=(self.device_t2, t2_results))) + threads.append(threading.Thread(target=update_server_def_fn)) + for t in threads: + t.start() + self._coord.join(threads) + for result in t1_results + t2_results: + np.testing.assert_array_equal([[2, 2], [2, 2]], result) + + @test_util.run_in_async_and_sync_mode + def testDistributedFunctionServerAdded(self): + """Add a server to cluster, and run distributed function on it.""" + with ops.device(self.device_t1): + x1 = array_ops.ones([2, 2]) + + @def_function.function + def worker_fn(i): + with ops.device(self.device_t2): + mul = math_ops.matmul(i, i) + return mul - array_ops.zeros_like(mul) + + # Forces function tracing and registration + worker_fn.get_concrete_function(x1) + + context.update_server_def(server_def=self.server_def_s1_s2_s3) + with ops.device(self.device_t3): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testDistributedFunctionServerRemovedAddedBack(self): + """Add then remove a server, and run distributed function on cluster.""" + with ops.device(self.device_local): + x1 = array_ops.ones([2, 2]) + + @def_function.function + def worker_fn(i): + with ops.device(self.device_t1): + mul = math_ops.matmul(i, i) + return mul - array_ops.zeros_like(mul) + + # Forces function tracing and registration + worker_fn.get_concrete_function(x1) + + context.update_server_def(server_def=self.server_def_s1) + with ops.device(self.device_t1): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + context.update_server_def(server_def=self.server_def_s1_s2) + with ops.device(self.device_t2): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testDistributedFunctionBothServersReplaced(self): + """Tests that replacing servers works correctly. + + We create two servers, t1 and t2. We first replace t2, then we replace t1. + + Among other things, this ensures that both already existing, and + restarted workers have the context view IDs correctly updated. + """ + with ops.device(self.device_local): + x1 = array_ops.ones([2, 2]) + + @def_function.function + def worker_fn(i): + with ops.device(self.device_t1): + mul = math_ops.matmul(i, i) + with ops.device(self.device_t2): + add = mul + i + return add - i + + # Forces function tracing and registration + worker_fn.get_concrete_function(x1) + + # Replace task2 + context.update_server_def(server_def=self.server_def_s1_s3) + for device in (self.device_t1, self.device_t2): + with ops.device(device): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + # Then replace task1 + context.update_server_def(server_def=self.server_def_s4_s3) + for device in (self.device_t1, self.device_t2): + with ops.device(device): + y = worker_fn(x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + def testCheckAlive(self): + with self.assertRaisesRegexp(ValueError, "Context is not initialized."): + context.check_alive("/job:remote_device/task:0") + context.context().ensure_initialized() + + self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0")) + self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1")) + + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Client for target /job:remote_device/replica:0/task:10 not found."): + context.check_alive("/job:remote_device/replica:0/task:10") + + +class DynamicClusterWithoutLazyRemoteInputsCopyTest(DynamicClusterTest): + + @classmethod + def setUpClass(cls): + super(DynamicClusterWithoutLazyRemoteInputsCopyTest, cls).setUpClass() + context._reset_context() + context.context().lazy_remote_inputs_copy = False + + @classmethod + def tearDownClass(cls): + super(DynamicClusterWithoutLazyRemoteInputsCopyTest, cls).tearDownClass() + context._reset_context() + context.context().lazy_remote_inputs_copy = True + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/eager/remote_execution_test.py b/tensorflow/python/eager/remote_execution_test.py new file mode 100644 index 00000000000..4ddd451a40a --- /dev/null +++ b/tensorflow/python/eager/remote_execution_test.py @@ -0,0 +1,254 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for remote eager execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import tensorflow_server_pb2 +from tensorflow.python import pywrap_tfe +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.eager import remote +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import script_ops +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib + +JOB_NAME = "remote_device" +ALT_JOB_NAME = "alt_remote_device" + + +def get_server_def(job_name, local_server_port, remote_server_addresses, + task_index): + """Returns a server def with a single job + multiple tasks.""" + cluster_def = cluster_pb2.ClusterDef() + job_def = cluster_def.job.add() + job_def.name = job_name + job_def.tasks[0] = "localhost:%d" % local_server_port + + for i, remote_server_address in enumerate(remote_server_addresses, start=1): + job_def.tasks[i] = remote_server_address + + server_def = tensorflow_server_pb2.ServerDef( + cluster=cluster_def, + job_name=job_name, + task_index=task_index, + protocol="grpc") + + return server_def + + +class RemoteExecutionTest(test.TestCase, parameterized.TestCase): + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(RemoteExecutionTest, self).__init__(methodName) + self._cached_server1 = server_lib.Server.create_local_server() + self._cached_server2 = server_lib.Server.create_local_server() + + self._cached_server1_target = self._cached_server1.target[len("grpc://"):] + self._cached_server2_target = self._cached_server2.target[len("grpc://"):] + + def setUp(self): + super(RemoteExecutionTest, self).setUp() + local_port = pywrap_tfe.TF_PickUnusedPortOrDie() + context.set_server_def( + server_def=get_server_def( + JOB_NAME, + local_server_port=local_port, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0)) + + def tearDown(self): + super(RemoteExecutionTest, self).tearDown() + + # Clear the current device scope and reset the context to avoid polluting + # other test cases. + ops.device(None).__enter__() + context._reset_context() + + @test_util.run_gpu_only + @test_util.run_in_async_and_sync_mode + def testGpuToRemoteCopy(self): + with ops.device("gpu:0"): + x = array_ops.ones([2, 2]) + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + y = math_ops.matmul(x, x) + + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testDefunMatmul(self): + """Basic remote eager execution with defun.""" + + mm_defun = function.defun(math_ops.matmul) + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + x1 = array_ops.ones([2, 2]) + with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME): + x2 = array_ops.ones([2, 2]) + y = mm_defun(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testSimpleMatmul(self): + """Basic remote eager execution.""" + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + x1 = array_ops.ones([2, 2]) + with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME): + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + def testEagerPyFuncPlacement(self): + if not ops.executing_eagerly_outside_functions(): + return + + def f(x): + return math_ops.square(x) + + with ops.device("/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + const_op = constant_op.constant(3.0, dtype=dtypes.float32) + # PyFuncOp should be placed on the localhost's address space. + py_func_op = script_ops.eager_py_func( + func=f, inp=[const_op], Tout=dtypes.float32) + self.assertEqual(py_func_op.device, + "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME) + self.assertEqual(self.evaluate(py_func_op), 9.0) + + @test_util.run_in_async_and_sync_mode + def testSimpleWeightRead(self): + """Basic remote eager weight read.""" + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + w = resource_variable_ops.ResourceVariable([[2.0]]) + loss = w * w + np.testing.assert_array_equal([[4.0]], loss.numpy()) + + @test_util.run_in_async_and_sync_mode + def testTapeWeightRead(self): + """Remote eager weight read in a tape.""" + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + w = resource_variable_ops.ResourceVariable([[3.0]]) + with backprop.GradientTape() as tape: + loss = w * w + + grad = tape.gradient(loss, w) + np.testing.assert_array_equal([[9.0]], loss.numpy()) + np.testing.assert_array_equal([[6.0]], grad.numpy()) + + @test_util.run_in_async_and_sync_mode + def testServerDefChanged(self): + """Update server def, and run ops on new cluster.""" + context.set_server_def( + server_def=get_server_def( + ALT_JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0)) + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % ALT_JOB_NAME): + x1 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + # Set the server def back to JOB_NAME + context.set_server_def( + server_def=get_server_def( + JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0)) + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + x1 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testConnectToRemoteServer(self): + """Basic server connection.""" + context._reset_context() + remote.connect_to_remote_host(self._cached_server1_target) + + with ops.device("job:worker/replica:0/task:0/device:CPU:0"): + x1 = array_ops.ones([2, 2]) + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @test_util.run_in_async_and_sync_mode + def testContextDeviceUpdated(self): + """Tests that the context device is correctly updated.""" + + with ops.device("cpu:0"): + x1 = array_ops.ones([2, 2]) + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + # `y` is placed on the local CPU as expected. + self.assertEqual(y.device, + "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME) + + @test_util.run_gpu_only + @test_util.run_in_async_and_sync_mode + def testGPUToRemoteCopy(self): + """Tests that the remote copy happens satisfactorily.""" + x1 = array_ops.ones([2, 2]).gpu() + + with ops.device("/job:remote_device/replica:0/task:1/device:CPU:0"): + x2 = x1._copy() # pylint: disable=protected-access + + np.testing.assert_array_equal(x1.numpy(), x2.numpy()) + + +class RemoteExecutionWithoutLazyRemoteInputsCopyTest(RemoteExecutionTest): + + @classmethod + def setUpClass(cls): + super(RemoteExecutionWithoutLazyRemoteInputsCopyTest, cls).setUpClass() + context._reset_context() + context.context().lazy_remote_inputs_copy = False + + @classmethod + def tearDownClass(cls): + super(RemoteExecutionWithoutLazyRemoteInputsCopyTest, cls).tearDownClass() + context._reset_context() + context.context().lazy_remote_inputs_copy = True + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 84651dec152..1f6ac5b654f 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1017,6 +1017,21 @@ def build_as_function_and_v1_graph(func=None): return decorator +def run_in_async_and_sync_mode(f): + """Execute the test in async mode and sync mode.""" + + @parameterized.named_parameters([("Async", True), ("", False)]) + @functools.wraps(f) + def decorator(self, async_mode, *args, **kwargs): + if async_mode: + with context.execution_mode(context.ASYNC): + f(self, *args, **kwargs) + else: + with context.execution_mode(context.SYNC): + f(self, *args, **kwargs) + return decorator + + def eager_lazy_remote_copy_on_and_off(f): """Execute the test method w/o lazy tensor copy for function remote inputs."""