Distributed runtime test cleanup.
(1) Moved most contrib/eager/python/remote_test.py out of contrib (2) Fixed asan / msan / tsan flakiness for DynamicClusterTest (3) Improved coverage of async+sync | lazy_remote_input_copy on+off combinations PiperOrigin-RevId: 299390689 Change-Id: I9e24c646fca3feedec56b16a2e891d357398272e
This commit is contained in:
parent
0531deb3ad
commit
950b054440
@ -868,6 +868,43 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "remote_execution_test",
|
||||
srcs = ["remote_execution_test.py"],
|
||||
grpc_enabled = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 2,
|
||||
tags = [
|
||||
"no_oss", # This test launches local server
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/eager:remote",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "remote_cluster_test",
|
||||
srcs = ["remote_cluster_test.py"],
|
||||
grpc_enabled = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 16,
|
||||
tags = [
|
||||
"no_oss", # This test launches local server
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tpu_py_test(
|
||||
name = "remote_cloud_tpu_test",
|
||||
srcs = ["remote_cloud_tpu_test.py"],
|
||||
|
511
tensorflow/python/eager/remote_cluster_test.py
Normal file
511
tensorflow/python/eager/remote_cluster_test.py
Normal file
@ -0,0 +1,511 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for remote eager execution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
JOB_NAME = "remote_device"
|
||||
|
||||
|
||||
def get_server_def(job_name, local_server_port, remote_server_addresses,
|
||||
task_index):
|
||||
"""Returns a server def with a single job + multiple tasks."""
|
||||
cluster_def = cluster_pb2.ClusterDef()
|
||||
job_def = cluster_def.job.add()
|
||||
job_def.name = job_name
|
||||
job_def.tasks[0] = "localhost:%d" % local_server_port
|
||||
|
||||
for i, remote_server_address in enumerate(remote_server_addresses, start=1):
|
||||
job_def.tasks[i] = remote_server_address
|
||||
|
||||
server_def = tensorflow_server_pb2.ServerDef(
|
||||
cluster=cluster_def,
|
||||
job_name=job_name,
|
||||
task_index=task_index,
|
||||
protocol="grpc")
|
||||
|
||||
return server_def
|
||||
|
||||
|
||||
class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
|
||||
super(DynamicClusterTest, self).__init__(methodName)
|
||||
self._cached_server1 = server_lib.Server.create_local_server()
|
||||
self._cached_server2 = server_lib.Server.create_local_server()
|
||||
self._cached_server3 = server_lib.Server.create_local_server()
|
||||
self._cached_server4 = server_lib.Server.create_local_server()
|
||||
|
||||
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
|
||||
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
|
||||
self._cached_server3_target = self._cached_server3.target[len("grpc://"):]
|
||||
self._cached_server4_target = self._cached_server4.target[len("grpc://"):]
|
||||
|
||||
self.server_def_s1 = get_server_def(
|
||||
JOB_NAME,
|
||||
local_server_port=0,
|
||||
remote_server_addresses=[self._cached_server1_target],
|
||||
task_index=0)
|
||||
self.server_def_s1_s2 = get_server_def(
|
||||
JOB_NAME,
|
||||
local_server_port=0,
|
||||
remote_server_addresses=[
|
||||
self._cached_server1_target, self._cached_server2_target
|
||||
],
|
||||
task_index=0)
|
||||
self.server_def_s1_s3 = get_server_def(
|
||||
JOB_NAME,
|
||||
local_server_port=0,
|
||||
remote_server_addresses=[
|
||||
self._cached_server1_target, self._cached_server3_target
|
||||
],
|
||||
task_index=0)
|
||||
self.server_def_s4_s3 = get_server_def(
|
||||
JOB_NAME,
|
||||
local_server_port=0,
|
||||
remote_server_addresses=[
|
||||
self._cached_server4_target, self._cached_server3_target
|
||||
],
|
||||
task_index=0)
|
||||
self.server_def_s1_s2_s3 = get_server_def(
|
||||
JOB_NAME,
|
||||
local_server_port=0,
|
||||
remote_server_addresses=[
|
||||
self._cached_server1_target, self._cached_server2_target,
|
||||
self._cached_server3_target
|
||||
],
|
||||
task_index=0)
|
||||
|
||||
self.device_local = "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME
|
||||
self.device_t1 = "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME
|
||||
self.device_t2 = "/job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME
|
||||
self.device_t3 = "/job:%s/replica:0/task:3/device:CPU:0" % JOB_NAME
|
||||
|
||||
def setUp(self):
|
||||
super(DynamicClusterTest, self).setUp()
|
||||
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
|
||||
context.set_server_def(
|
||||
server_def=get_server_def(
|
||||
JOB_NAME,
|
||||
local_server_port=local_port,
|
||||
remote_server_addresses=[
|
||||
self._cached_server1_target, self._cached_server2_target
|
||||
],
|
||||
task_index=0))
|
||||
|
||||
def tearDown(self):
|
||||
super(DynamicClusterTest, self).tearDown()
|
||||
context._reset_context()
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testServerAdded(self):
|
||||
"""Add a server to cluster, and run remote ops on it."""
|
||||
with ops.device(self.device_t1):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1_s2_s3)
|
||||
with ops.device(self.device_t3):
|
||||
x2 = array_ops.ones([2, 2])
|
||||
|
||||
# Test new server accessing resources on old server
|
||||
with ops.device(self.device_t3):
|
||||
y = math_ops.matmul(x1, x2)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
# Test old server accessing resources on new server
|
||||
with ops.device(self.device_t2):
|
||||
y = math_ops.matmul(x1, x2)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testServerRemoved(self):
|
||||
"""Remove a server from cluster, and run ops on cluster."""
|
||||
with ops.device(self.device_t1):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
with ops.device(self.device_t2):
|
||||
x2 = array_ops.ones([2, 2])
|
||||
|
||||
with ops.device(self.device_t1):
|
||||
y = math_ops.matmul(x1, x2)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1)
|
||||
with ops.device(self.device_t1):
|
||||
y = math_ops.matmul(x1, x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
# Running ops on removed server s2 throws an exception
|
||||
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
||||
with ops.device(self.device_t2):
|
||||
y = math_ops.matmul(x1, x2)
|
||||
self.assertIn("unknown device", cm.exception.message)
|
||||
|
||||
# TODO(haoyuzhang): raise and catch exception when accessing tensors on
|
||||
# the removed servers.
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testServerReplaced(self):
|
||||
"""Replace remote host_port for a task, and run ops on cluster."""
|
||||
with ops.device(self.device_t1):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1_s3)
|
||||
with ops.device(self.device_t2):
|
||||
y = math_ops.matmul(x1, x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testFunctionServerAdded(self):
|
||||
"""Add a server to cluster, and run remote function on it."""
|
||||
with ops.device(self.device_t1):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
return math_ops.matmul(i, i)
|
||||
|
||||
# Forces function tracing and registration
|
||||
worker_fn.get_concrete_function(x1)
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1_s2_s3)
|
||||
with ops.device(self.device_t3):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
with ops.device(self.device_t3):
|
||||
x2 = array_ops.ones([2, 2])
|
||||
with ops.device(self.device_t1):
|
||||
y = worker_fn(x2)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testFunctionServerRemoved(self):
|
||||
"""Remove a server from cluster, and run ops on cluster."""
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
return math_ops.matmul(i, i)
|
||||
|
||||
with ops.device(self.device_t1):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
# Forces function tracing and registration
|
||||
worker_fn.get_concrete_function(x1)
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1)
|
||||
|
||||
with ops.device(self.device_t1):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
# Running functions on removed server s2 throws an exception
|
||||
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
||||
with ops.device(self.device_t2):
|
||||
y = worker_fn(x1)
|
||||
self.assertIn(" unknown device", cm.exception.message)
|
||||
|
||||
# TODO(haoyuzhang): raise and catch exception when accessing tensors on
|
||||
# the removed servers.
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testFunctionServerRemovedAddedBack(self):
|
||||
"""Add and remove a server, and run functions on cluster."""
|
||||
with ops.device(self.device_t1):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
return math_ops.matmul(i, i)
|
||||
|
||||
# Forces function tracing and registration
|
||||
worker_fn.get_concrete_function(x1)
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1_s2_s3)
|
||||
with ops.device(self.device_t3):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1_s2)
|
||||
with ops.device(self.device_t2):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1_s2_s3)
|
||||
with ops.device(self.device_t3):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testFunctionServerReplaced(self):
|
||||
"""Replace remote host_port for a task, and run functions on cluster."""
|
||||
with ops.device(self.device_t1):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
return math_ops.matmul(i, i)
|
||||
|
||||
# Forces function tracing and registration
|
||||
worker_fn.get_concrete_function(x1)
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1_s3)
|
||||
with ops.device(self.device_t2):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testPendingNodesServerReplaced(self):
|
||||
"""Update cluster when nodes are still pending on remote workers."""
|
||||
with ops.device(self.device_local):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
return math_ops.matmul(i, i)
|
||||
|
||||
# Forces function tracing and registration
|
||||
worker_fn.get_concrete_function(x1)
|
||||
|
||||
# Add enough ops so they are pending when changing the cluster
|
||||
num_nodes = 10
|
||||
ret = [None] * num_nodes
|
||||
for i in range(num_nodes):
|
||||
with ops.device(self.device_t1):
|
||||
ret[i] = worker_fn(x1)
|
||||
# While nodes are still pending on worker s1, replace worker s2 with s3.
|
||||
context.update_server_def(server_def=self.server_def_s1_s3)
|
||||
with ops.device(self.device_t2):
|
||||
y = worker_fn(x1)
|
||||
for i in range(num_nodes):
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], ret[i].numpy())
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testMultiThreadPendingNodesServerReplaced(self):
|
||||
"""Update cluster when other remote function calls are being launched."""
|
||||
with ops.device(self.device_local):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
num_calls = 10
|
||||
lock = threading.Lock()
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
return math_ops.matmul(i, i)
|
||||
|
||||
def thread_fn(device, results):
|
||||
for i in range(num_calls):
|
||||
lock.acquire()
|
||||
with ops.device(device):
|
||||
y = worker_fn(x1)
|
||||
results[i] = y.numpy()
|
||||
lock.release()
|
||||
|
||||
def update_server_def_fn():
|
||||
for i in range(num_calls):
|
||||
lock.acquire()
|
||||
context.update_server_def(
|
||||
server_def=(self.server_def_s1_s2 if i %
|
||||
2 == 0 else self.server_def_s1_s3))
|
||||
lock.release()
|
||||
|
||||
t1_results = [None] * num_calls
|
||||
t2_results = [None] * num_calls
|
||||
threads = []
|
||||
threads.append(threading.Thread(target=thread_fn,
|
||||
args=(self.device_t1, t1_results)))
|
||||
threads.append(threading.Thread(target=thread_fn,
|
||||
args=(self.device_t2, t2_results)))
|
||||
threads.append(threading.Thread(target=update_server_def_fn))
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
for result in t1_results + t2_results:
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], result)
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testMultiThreadPendingNodesLockFree(self):
|
||||
"""Update cluster when other remote function calls are being launched."""
|
||||
with ops.device(self.device_t1):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
num_calls = 10
|
||||
self._coord = coordinator.Coordinator()
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
return math_ops.matmul(i, i)
|
||||
|
||||
def thread_fn(device, results):
|
||||
for i in range(num_calls):
|
||||
with self._coord.stop_on_exception():
|
||||
with ops.device(device):
|
||||
results[i] = worker_fn(x1).numpy()
|
||||
|
||||
def update_server_def_fn():
|
||||
for _ in range(30):
|
||||
with self._coord.stop_on_exception():
|
||||
context.update_server_def(self.server_def_s1_s2)
|
||||
|
||||
t1_results = [None] * num_calls
|
||||
t2_results = [None] * num_calls
|
||||
threads = []
|
||||
threads.append(
|
||||
threading.Thread(target=thread_fn, args=(self.device_t1, t1_results)))
|
||||
threads.append(
|
||||
threading.Thread(target=thread_fn, args=(self.device_t2, t2_results)))
|
||||
threads.append(threading.Thread(target=update_server_def_fn))
|
||||
for t in threads:
|
||||
t.start()
|
||||
self._coord.join(threads)
|
||||
for result in t1_results + t2_results:
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], result)
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testDistributedFunctionServerAdded(self):
|
||||
"""Add a server to cluster, and run distributed function on it."""
|
||||
with ops.device(self.device_t1):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
with ops.device(self.device_t2):
|
||||
mul = math_ops.matmul(i, i)
|
||||
return mul - array_ops.zeros_like(mul)
|
||||
|
||||
# Forces function tracing and registration
|
||||
worker_fn.get_concrete_function(x1)
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1_s2_s3)
|
||||
with ops.device(self.device_t3):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testDistributedFunctionServerRemovedAddedBack(self):
|
||||
"""Add then remove a server, and run distributed function on cluster."""
|
||||
with ops.device(self.device_local):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
with ops.device(self.device_t1):
|
||||
mul = math_ops.matmul(i, i)
|
||||
return mul - array_ops.zeros_like(mul)
|
||||
|
||||
# Forces function tracing and registration
|
||||
worker_fn.get_concrete_function(x1)
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1)
|
||||
with ops.device(self.device_t1):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
context.update_server_def(server_def=self.server_def_s1_s2)
|
||||
with ops.device(self.device_t2):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testDistributedFunctionBothServersReplaced(self):
|
||||
"""Tests that replacing servers works correctly.
|
||||
|
||||
We create two servers, t1 and t2. We first replace t2, then we replace t1.
|
||||
|
||||
Among other things, this ensures that both already existing, and
|
||||
restarted workers have the context view IDs correctly updated.
|
||||
"""
|
||||
with ops.device(self.device_local):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
with ops.device(self.device_t1):
|
||||
mul = math_ops.matmul(i, i)
|
||||
with ops.device(self.device_t2):
|
||||
add = mul + i
|
||||
return add - i
|
||||
|
||||
# Forces function tracing and registration
|
||||
worker_fn.get_concrete_function(x1)
|
||||
|
||||
# Replace task2
|
||||
context.update_server_def(server_def=self.server_def_s1_s3)
|
||||
for device in (self.device_t1, self.device_t2):
|
||||
with ops.device(device):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
# Then replace task1
|
||||
context.update_server_def(server_def=self.server_def_s4_s3)
|
||||
for device in (self.device_t1, self.device_t2):
|
||||
with ops.device(device):
|
||||
y = worker_fn(x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
def testCheckAlive(self):
|
||||
with self.assertRaisesRegexp(ValueError, "Context is not initialized."):
|
||||
context.check_alive("/job:remote_device/task:0")
|
||||
context.context().ensure_initialized()
|
||||
|
||||
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0"))
|
||||
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1"))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Client for target /job:remote_device/replica:0/task:10 not found."):
|
||||
context.check_alive("/job:remote_device/replica:0/task:10")
|
||||
|
||||
|
||||
class DynamicClusterWithoutLazyRemoteInputsCopyTest(DynamicClusterTest):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(DynamicClusterWithoutLazyRemoteInputsCopyTest, cls).setUpClass()
|
||||
context._reset_context()
|
||||
context.context().lazy_remote_inputs_copy = False
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(DynamicClusterWithoutLazyRemoteInputsCopyTest, cls).tearDownClass()
|
||||
context._reset_context()
|
||||
context.context().lazy_remote_inputs_copy = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
254
tensorflow/python/eager/remote_execution_test.py
Normal file
254
tensorflow/python/eager/remote_execution_test.py
Normal file
@ -0,0 +1,254 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for remote eager execution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
JOB_NAME = "remote_device"
|
||||
ALT_JOB_NAME = "alt_remote_device"
|
||||
|
||||
|
||||
def get_server_def(job_name, local_server_port, remote_server_addresses,
|
||||
task_index):
|
||||
"""Returns a server def with a single job + multiple tasks."""
|
||||
cluster_def = cluster_pb2.ClusterDef()
|
||||
job_def = cluster_def.job.add()
|
||||
job_def.name = job_name
|
||||
job_def.tasks[0] = "localhost:%d" % local_server_port
|
||||
|
||||
for i, remote_server_address in enumerate(remote_server_addresses, start=1):
|
||||
job_def.tasks[i] = remote_server_address
|
||||
|
||||
server_def = tensorflow_server_pb2.ServerDef(
|
||||
cluster=cluster_def,
|
||||
job_name=job_name,
|
||||
task_index=task_index,
|
||||
protocol="grpc")
|
||||
|
||||
return server_def
|
||||
|
||||
|
||||
class RemoteExecutionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
|
||||
super(RemoteExecutionTest, self).__init__(methodName)
|
||||
self._cached_server1 = server_lib.Server.create_local_server()
|
||||
self._cached_server2 = server_lib.Server.create_local_server()
|
||||
|
||||
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
|
||||
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
|
||||
|
||||
def setUp(self):
|
||||
super(RemoteExecutionTest, self).setUp()
|
||||
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
|
||||
context.set_server_def(
|
||||
server_def=get_server_def(
|
||||
JOB_NAME,
|
||||
local_server_port=local_port,
|
||||
remote_server_addresses=[
|
||||
self._cached_server1_target, self._cached_server2_target
|
||||
],
|
||||
task_index=0))
|
||||
|
||||
def tearDown(self):
|
||||
super(RemoteExecutionTest, self).tearDown()
|
||||
|
||||
# Clear the current device scope and reset the context to avoid polluting
|
||||
# other test cases.
|
||||
ops.device(None).__enter__()
|
||||
context._reset_context()
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testGpuToRemoteCopy(self):
|
||||
with ops.device("gpu:0"):
|
||||
x = array_ops.ones([2, 2])
|
||||
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
|
||||
y = math_ops.matmul(x, x)
|
||||
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testDefunMatmul(self):
|
||||
"""Basic remote eager execution with defun."""
|
||||
|
||||
mm_defun = function.defun(math_ops.matmul)
|
||||
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME):
|
||||
x2 = array_ops.ones([2, 2])
|
||||
y = mm_defun(x1, x2)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testSimpleMatmul(self):
|
||||
"""Basic remote eager execution."""
|
||||
|
||||
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME):
|
||||
x2 = array_ops.ones([2, 2])
|
||||
y = math_ops.matmul(x1, x2)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
def testEagerPyFuncPlacement(self):
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
return
|
||||
|
||||
def f(x):
|
||||
return math_ops.square(x)
|
||||
|
||||
with ops.device("/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
|
||||
const_op = constant_op.constant(3.0, dtype=dtypes.float32)
|
||||
# PyFuncOp should be placed on the localhost's address space.
|
||||
py_func_op = script_ops.eager_py_func(
|
||||
func=f, inp=[const_op], Tout=dtypes.float32)
|
||||
self.assertEqual(py_func_op.device,
|
||||
"/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
|
||||
self.assertEqual(self.evaluate(py_func_op), 9.0)
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testSimpleWeightRead(self):
|
||||
"""Basic remote eager weight read."""
|
||||
|
||||
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
|
||||
w = resource_variable_ops.ResourceVariable([[2.0]])
|
||||
loss = w * w
|
||||
np.testing.assert_array_equal([[4.0]], loss.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testTapeWeightRead(self):
|
||||
"""Remote eager weight read in a tape."""
|
||||
|
||||
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
|
||||
w = resource_variable_ops.ResourceVariable([[3.0]])
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = w * w
|
||||
|
||||
grad = tape.gradient(loss, w)
|
||||
np.testing.assert_array_equal([[9.0]], loss.numpy())
|
||||
np.testing.assert_array_equal([[6.0]], grad.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testServerDefChanged(self):
|
||||
"""Update server def, and run ops on new cluster."""
|
||||
context.set_server_def(
|
||||
server_def=get_server_def(
|
||||
ALT_JOB_NAME,
|
||||
local_server_port=0,
|
||||
remote_server_addresses=[
|
||||
self._cached_server1_target, self._cached_server2_target
|
||||
],
|
||||
task_index=0))
|
||||
|
||||
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % ALT_JOB_NAME):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
y = math_ops.matmul(x1, x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
# Set the server def back to JOB_NAME
|
||||
context.set_server_def(
|
||||
server_def=get_server_def(
|
||||
JOB_NAME,
|
||||
local_server_port=0,
|
||||
remote_server_addresses=[
|
||||
self._cached_server1_target, self._cached_server2_target
|
||||
],
|
||||
task_index=0))
|
||||
|
||||
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
y = math_ops.matmul(x1, x1)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testConnectToRemoteServer(self):
|
||||
"""Basic server connection."""
|
||||
context._reset_context()
|
||||
remote.connect_to_remote_host(self._cached_server1_target)
|
||||
|
||||
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
x2 = array_ops.ones([2, 2])
|
||||
y = math_ops.matmul(x1, x2)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testContextDeviceUpdated(self):
|
||||
"""Tests that the context device is correctly updated."""
|
||||
|
||||
with ops.device("cpu:0"):
|
||||
x1 = array_ops.ones([2, 2])
|
||||
x2 = array_ops.ones([2, 2])
|
||||
y = math_ops.matmul(x1, x2)
|
||||
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
|
||||
|
||||
# `y` is placed on the local CPU as expected.
|
||||
self.assertEqual(y.device,
|
||||
"/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testGPUToRemoteCopy(self):
|
||||
"""Tests that the remote copy happens satisfactorily."""
|
||||
x1 = array_ops.ones([2, 2]).gpu()
|
||||
|
||||
with ops.device("/job:remote_device/replica:0/task:1/device:CPU:0"):
|
||||
x2 = x1._copy() # pylint: disable=protected-access
|
||||
|
||||
np.testing.assert_array_equal(x1.numpy(), x2.numpy())
|
||||
|
||||
|
||||
class RemoteExecutionWithoutLazyRemoteInputsCopyTest(RemoteExecutionTest):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(RemoteExecutionWithoutLazyRemoteInputsCopyTest, cls).setUpClass()
|
||||
context._reset_context()
|
||||
context.context().lazy_remote_inputs_copy = False
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(RemoteExecutionWithoutLazyRemoteInputsCopyTest, cls).tearDownClass()
|
||||
context._reset_context()
|
||||
context.context().lazy_remote_inputs_copy = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
@ -1017,6 +1017,21 @@ def build_as_function_and_v1_graph(func=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def run_in_async_and_sync_mode(f):
|
||||
"""Execute the test in async mode and sync mode."""
|
||||
|
||||
@parameterized.named_parameters([("Async", True), ("", False)])
|
||||
@functools.wraps(f)
|
||||
def decorator(self, async_mode, *args, **kwargs):
|
||||
if async_mode:
|
||||
with context.execution_mode(context.ASYNC):
|
||||
f(self, *args, **kwargs)
|
||||
else:
|
||||
with context.execution_mode(context.SYNC):
|
||||
f(self, *args, **kwargs)
|
||||
return decorator
|
||||
|
||||
|
||||
def eager_lazy_remote_copy_on_and_off(f):
|
||||
"""Execute the test method w/o lazy tensor copy for function remote inputs."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user