623 lines
20 KiB
Python
623 lines
20 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for remote execution."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import random
|
|
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
import six
|
|
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import remote
|
|
from tensorflow.python.eager import test
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import functional_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.training import server_lib
|
|
from tensorflow.python.training.server_lib import ClusterSpec
|
|
|
|
|
|
class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
super(SingleWorkerTest, self).setUp()
|
|
|
|
workers, _ = test_util.create_local_cluster(1, 0)
|
|
remote.connect_to_remote_host(workers[0].target)
|
|
|
|
def tearDown(self):
|
|
super(SingleWorkerTest, self).tearDown()
|
|
|
|
# Clear the current device scope to avoid polluting other test cases.
|
|
ops.device(None).__enter__()
|
|
# Reset the context to avoid polluting other test cases.
|
|
context._reset_context()
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testMultiDeviceFunctionBasic(self):
|
|
|
|
@def_function.function
|
|
def basic(i):
|
|
with ops.device('/job:localhost/replica:0/task:0/cpu:0'):
|
|
a = constant_op.constant([2]) + i
|
|
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
|
b = constant_op.constant([1])
|
|
|
|
return a + b
|
|
|
|
self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5])
|
|
self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4])
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testMultiDeviceFunctionVariable(self):
|
|
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
|
variable_b = variables.Variable(1)
|
|
|
|
@def_function.function
|
|
def with_variable(i):
|
|
return i + variable_b
|
|
|
|
self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testMultiDeviceFunctionRemoteOutput(self):
|
|
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
|
variable_b = variables.Variable(1)
|
|
|
|
@def_function.function
|
|
def remote_output(i):
|
|
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
|
c = variable_b + 1
|
|
return c, i + variable_b
|
|
|
|
self.assertAllEqual(
|
|
remote_output(constant_op.constant([1]))[0].numpy(), 2)
|
|
|
|
def testMultiDeviceFunctionAmbiguousDevice(self):
|
|
|
|
@def_function.function
|
|
def ambiguous_device(i):
|
|
with ops.device('cpu:0'):
|
|
return i + constant_op.constant([2])
|
|
|
|
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
|
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
|
ambiguous_device(constant_op.constant([2])).numpy()
|
|
|
|
self.assertIn('the output node must match exactly one device',
|
|
cm.exception.message)
|
|
|
|
def testStreaming(self):
|
|
"""A mini stress test for streaming - issuing many RPCs back to back."""
|
|
with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
|
|
x = array_ops.ones([2, 2])
|
|
y = array_ops.zeros([2, 2])
|
|
num_iters = 200
|
|
for _ in range(num_iters):
|
|
y = x + y
|
|
# Ask for y's shape after every 10 additions on average.
|
|
# This exercises waiting for remote shape logic in TensorHandle.
|
|
if random.randint(1, 10) == 1:
|
|
_ = y.shape
|
|
np.testing.assert_array_equal(
|
|
[[num_iters, num_iters], [num_iters, num_iters]], y.numpy())
|
|
|
|
def testShapeError_OpByOp(self):
|
|
with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
|
|
x = array_ops.ones([2, 3])
|
|
y = array_ops.zeros([2, 2])
|
|
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
|
math_ops.matmul(x, y)
|
|
|
|
self.assertIn('Dimensions must be equal', cm.exception.message)
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testShapeError_Function(self):
|
|
|
|
@def_function.function
|
|
def matmul_func(x, y):
|
|
return math_ops.matmul(x, y)
|
|
|
|
x = array_ops.ones([2, 3])
|
|
y = array_ops.zeros([2, 2])
|
|
|
|
with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
|
|
with self.assertRaises(ValueError) as cm:
|
|
matmul_func(x, y)
|
|
|
|
if six.PY2:
|
|
self.assertIn('Dimensions must be equal', cm.exception.message)
|
|
else:
|
|
self.assertIn('Dimensions must be equal', cm.exception.args[0])
|
|
|
|
def testClientVarible(self):
|
|
var = variables.Variable(initial_value=0)
|
|
|
|
@def_function.function
|
|
def func():
|
|
with ops.device('/job:localhost/task:0'):
|
|
read = var.read_value()
|
|
return read + 1
|
|
|
|
with ops.device('/job:worker/task:0'):
|
|
self.assertAllEqual(func(), 1)
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testRemoteCall(self):
|
|
|
|
@def_function.function(
|
|
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
|
|
def _remote_fn(x):
|
|
return constant_op.constant(1) + x
|
|
|
|
remote_fn = _remote_fn.get_concrete_function()
|
|
|
|
@def_function.function
|
|
def func(x):
|
|
return functional_ops.remote_call(
|
|
args=[x],
|
|
Tout=[dtypes.int32],
|
|
f=remote_fn,
|
|
target='/job:worker/task:0')
|
|
|
|
with ops.device('/job:localhost/task:0'):
|
|
self.assertAllEqual(func(constant_op.constant(1)), [2])
|
|
|
|
|
|
class RemoteAsyncTest(test.TestCase):
|
|
|
|
def setUp(self):
|
|
super(RemoteAsyncTest, self).setUp()
|
|
|
|
workers, _ = test_util.create_local_cluster(1, 0)
|
|
remote.connect_to_remote_host(workers[0].target)
|
|
|
|
def tearDown(self):
|
|
super(RemoteAsyncTest, self).tearDown()
|
|
|
|
# Reset the context to avoid polluting other test cases.
|
|
context._reset_context()
|
|
|
|
def test_out_of_range_with_while_loop(self):
|
|
|
|
with ops.device('/job:worker/task:0'):
|
|
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
|
|
dataset = dataset.batch(1, drop_remainder=False)
|
|
iterator = iter(dataset)
|
|
v = variables.Variable(1.0)
|
|
|
|
@def_function.function
|
|
def train_step(iterator):
|
|
i = next(iterator)
|
|
v.assign_add(math_ops.reduce_mean(i))
|
|
|
|
while True:
|
|
try:
|
|
with ops.device('/job:worker/task:0'):
|
|
train_step(iterator)
|
|
except (errors.OutOfRangeError, errors.InternalError):
|
|
context.async_clear_error()
|
|
break
|
|
|
|
self.assertAllEqual(v.numpy(), 4.0)
|
|
|
|
def test_out_of_range_with_for_loop(self):
|
|
|
|
with ops.device('/job:worker/task:0'):
|
|
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
|
|
dataset = dataset.batch(1, drop_remainder=False)
|
|
iterator = iter(dataset)
|
|
v = variables.Variable(1.0)
|
|
|
|
@def_function.function
|
|
def train_step(iterator):
|
|
i = next(iterator)
|
|
v.assign_add(math_ops.reduce_mean(i))
|
|
|
|
num_steps = 3
|
|
for i in range(num_steps):
|
|
try:
|
|
with ops.device('/job:worker/task:0'):
|
|
train_step(iterator)
|
|
if i == num_steps - 1:
|
|
context.async_wait()
|
|
except errors.OutOfRangeError:
|
|
context.async_clear_error()
|
|
break
|
|
|
|
self.assertAllEqual(v.numpy(), 4.0)
|
|
|
|
def test_out_of_range_with_async_scope(self):
|
|
|
|
with ops.device('/job:worker/task:0'):
|
|
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
|
|
dataset = dataset.batch(1, drop_remainder=False)
|
|
iterator = iter(dataset)
|
|
v = variables.Variable(1.0)
|
|
|
|
@def_function.function
|
|
def train_step(iterator):
|
|
i = next(iterator)
|
|
v.assign_add(math_ops.reduce_mean(i))
|
|
|
|
num_steps = 3
|
|
try:
|
|
with context.async_scope():
|
|
for _ in range(num_steps):
|
|
with ops.device('/job:worker/task:0'):
|
|
train_step(iterator)
|
|
except errors.OutOfRangeError:
|
|
context.async_clear_error()
|
|
|
|
self.assertAllEqual(v.numpy(), 4.0)
|
|
|
|
|
|
class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
super(MultiWorkersTest, self).setUp()
|
|
|
|
workers, _ = test_util.create_local_cluster(3, 0)
|
|
remote.connect_to_remote_host(
|
|
[workers[0].target, workers[1].target, workers[2].target])
|
|
|
|
def tearDown(self):
|
|
super(MultiWorkersTest, self).tearDown()
|
|
|
|
# Clear the current device scope to avoid polluting other test cases.
|
|
ops.device(None).__enter__()
|
|
# Reset the context to avoid polluting other test cases.
|
|
context._reset_context()
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testReturnRemoteArgument(self):
|
|
|
|
@def_function.function
|
|
def local_func(i):
|
|
return i
|
|
|
|
with ops.device('/job:worker/replica:0/task:0'):
|
|
x = constant_op.constant([2, 1])
|
|
|
|
with ops.device('/job:worker/replica:0/task:1'):
|
|
self.assertAllEqual(local_func(x), [2, 1])
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testMultiDeviceFunctionOnLocalDevice(self):
|
|
with ops.device('/job:worker/replica:0/task:1'):
|
|
variable_b = variables.Variable(1.0)
|
|
|
|
@def_function.function
|
|
def remote_function(i):
|
|
with ops.device('/job:worker/replica:0/task:0'):
|
|
a = i + variable_b
|
|
c = a + 1.0
|
|
return c
|
|
|
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
|
|
|
def testMultiDeviceFunctionWithPackedVariable(self):
|
|
with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
|
|
var0 = resource_variable_ops.ResourceVariable(1.0)
|
|
with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
|
|
var1 = resource_variable_ops.ResourceVariable(2.0)
|
|
|
|
packed_var = ops.pack_eager_tensors([var0.handle, var1.handle])
|
|
self.assertEqual(packed_var.device,
|
|
'/job:localhost/replica:0/task:0/device:COMPOSITE:0')
|
|
self.assertEqual(packed_var.backing_device,
|
|
'/job:localhost/replica:0/task:0/device:COMPOSITE:0')
|
|
|
|
@def_function.function
|
|
def add_variables():
|
|
with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
|
|
read0 = resource_variable_ops.read_variable_op(
|
|
packed_var, dtype=dtypes.float32)
|
|
with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
|
|
read1 = resource_variable_ops.read_variable_op(
|
|
packed_var, dtype=dtypes.float32)
|
|
|
|
return read0 + read1
|
|
|
|
# Run the function on a remote device
|
|
with ops.device('/job:worker/replica:0/task:0'):
|
|
self.assertAllEqual(add_variables().numpy(), 3.0)
|
|
|
|
# Run the function on a local worker
|
|
self.assertAllEqual(add_variables().numpy(), 3.0)
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testMultiDeviceFunctionOnRemoteDeviceWithWait(self):
|
|
with ops.device('/job:worker/replica:0/task:1'):
|
|
variable_b = variables.Variable([1.0])
|
|
|
|
@def_function.function
|
|
def remote_function(i):
|
|
x = array_ops.ones([1000, 1000])
|
|
for _ in range(1, 1000):
|
|
x = x * x
|
|
variable_b.assign_add(i)
|
|
a = 1.0 + variable_b
|
|
return a
|
|
|
|
@def_function.function
|
|
def remote_function2(i):
|
|
variable_b.assign_add(i)
|
|
a = 1.0 + variable_b
|
|
return a
|
|
|
|
# Runs first function:
|
|
# - on remote device
|
|
# - needs remote input
|
|
# - is side impacting
|
|
# - runs much slower
|
|
with ops.device('/job:worker/replica:0/task:0'):
|
|
remote_function(constant_op.constant([2.0]))
|
|
|
|
# Runs second function:
|
|
# - on remote device
|
|
# - is side impacting
|
|
# There should be a sync point here and the next function will be executed
|
|
# only after the first function has completed.
|
|
with ops.device('/job:worker/replica:0/task:2'):
|
|
self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0])
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testMultiDeviceFunctionOnRemoteDevice(self):
|
|
with ops.device('/job:worker/replica:0/task:1'):
|
|
variable_b = variables.Variable(1.0)
|
|
|
|
@def_function.function
|
|
def remote_function(i):
|
|
with ops.device('/job:worker/replica:0/task:0'):
|
|
a = i + variable_b
|
|
c = a + 1.0
|
|
return c
|
|
|
|
context.context().mirroring_policy = context.MIRRORING_NONE
|
|
|
|
with ops.device('/job:worker/replica:0/task:0'):
|
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
|
|
|
if test_util.is_gpu_available():
|
|
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
|
|
|
context.context().mirroring_policy = context.MIRRORING_ALL
|
|
|
|
with ops.device('/job:worker/replica:0/task:0'):
|
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
|
|
|
if test_util.is_gpu_available():
|
|
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testMultiDeviceWhileLoopOnRemoteDevice(self):
|
|
with ops.device('/job:worker/replica:0/task:1'):
|
|
variable_b = variables.Variable(1.0)
|
|
|
|
@def_function.function
|
|
def remote_function(i):
|
|
|
|
def body(i, _):
|
|
with ops.device('/job:worker/replica:0/task:0'):
|
|
a = i + variable_b
|
|
return a + 1.0, 1
|
|
|
|
return control_flow_ops.while_loop_v2(lambda _, d: d < 1, body, [i, 0])[0]
|
|
|
|
context.context().mirroring_policy = context.MIRRORING_NONE
|
|
|
|
with ops.device('/job:worker/replica:0/task:0'):
|
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
|
|
|
if test_util.is_gpu_available():
|
|
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
|
|
|
context.context().mirroring_policy = context.MIRRORING_ALL
|
|
|
|
with ops.device('/job:worker/replica:0/task:0'):
|
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
|
|
|
if test_util.is_gpu_available():
|
|
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testSimpleParameterServer(self):
|
|
|
|
with ops.device('/job:worker/task:2/device:CPU:0'):
|
|
v1 = variables.Variable(initial_value=0)
|
|
v2 = variables.Variable(initial_value=10)
|
|
|
|
@def_function.function
|
|
def worker_fn():
|
|
v1.assign_add(1)
|
|
v2.assign_sub(2)
|
|
return v1.read_value() + v2.read_value()
|
|
|
|
with ops.device('/job:worker/task:0/device:CPU:0'):
|
|
self.assertAllEqual(worker_fn(), 9)
|
|
|
|
with ops.device('/job:worker/task:1/device:CPU:0'):
|
|
self.assertAllEqual(worker_fn(), 8)
|
|
|
|
|
|
_GRPC_PREFIX = 'grpc://'
|
|
|
|
|
|
class MultiJobsTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
super(MultiJobsTest, self).setUp()
|
|
|
|
workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2)
|
|
cluster = {
|
|
'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers],
|
|
'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps],
|
|
}
|
|
self._cluster = server_lib.ClusterSpec(cluster)
|
|
self._cluster_resolver = SimpleClusterResolver(
|
|
cluster_spec=self._cluster, master=ps[0].target)
|
|
|
|
def tearDown(self):
|
|
super(MultiJobsTest, self).tearDown()
|
|
|
|
# Clear the current device scope to avoid polluting other test cases.
|
|
ops.device(None).__enter__()
|
|
# Reset the context to avoid polluting other test cases.
|
|
context._reset_context()
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testSimpleParameterServer(self):
|
|
remote.connect_to_cluster(self._cluster)
|
|
|
|
with ops.device('/job:my_ps/task:0/device:CPU:0'):
|
|
v1 = variables.Variable(initial_value=0)
|
|
v2 = variables.Variable(initial_value=10)
|
|
|
|
@def_function.function
|
|
def worker_fn():
|
|
v1.assign_add(1)
|
|
v2.assign_sub(2)
|
|
return v1.read_value() + v2.read_value()
|
|
|
|
with ops.device('/job:my_worker/task:0/device:CPU:0'):
|
|
self.assertAllEqual(worker_fn(), 9)
|
|
|
|
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
|
self.assertAllEqual(worker_fn(), 8)
|
|
|
|
# TODO(b/152224115): Re-enable this test.
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def DISABLED_testSimpleParameterServerWithDeviceFilters(self):
|
|
cluster_device_filters = server_lib.ClusterDeviceFilters()
|
|
for i in range(2):
|
|
cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps'])
|
|
cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker'])
|
|
remote.connect_to_cluster(
|
|
self._cluster, cluster_device_filters=cluster_device_filters)
|
|
|
|
with ops.device('/job:my_ps/task:0/device:CPU:0'):
|
|
v1 = variables.Variable(initial_value=0)
|
|
with ops.device('/job:my_ps/task:1/device:CPU:0'):
|
|
v2 = variables.Variable(initial_value=10)
|
|
|
|
@def_function.function
|
|
def worker_fn():
|
|
v1.assign_add(1)
|
|
v2.assign_sub(2)
|
|
return v1.read_value() + v2.read_value()
|
|
|
|
with ops.device('/job:my_worker/task:0/device:CPU:0'):
|
|
self.assertAllEqual(worker_fn(), 9)
|
|
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
|
self.assertAllEqual(worker_fn(), 8)
|
|
|
|
# The following remote call would fail because the ps nodes cannot see each
|
|
# other due to the device filters.
|
|
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
|
with ops.device('/job:my_ps/task:0/device:CPU:0'):
|
|
worker_fn().numpy()
|
|
self.assertIn('/job:my_ps/replica:0/task:1/device:CPU:0 unknown device',
|
|
cm.exception.message)
|
|
|
|
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
|
with ops.device('/job:my_ps/task:1/device:CPU:0'):
|
|
worker_fn().numpy()
|
|
self.assertIn('/job:my_ps/replica:0/task:0/device:CPU:0 unknown device',
|
|
cm.exception.message)
|
|
|
|
with ops.device('/job:my_worker/task:0/device:CPU:0'):
|
|
self.assertAllEqual(worker_fn(), 7)
|
|
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
|
self.assertAllEqual(worker_fn(), 6)
|
|
# Explicitly delete variables to avoid triggering errors when being GC'ed in
|
|
# subsequent tests.
|
|
del v1, v2
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testConnectWithClusterResolver(self):
|
|
remote.connect_to_cluster(self._cluster_resolver)
|
|
|
|
v1 = variables.Variable(initial_value=0)
|
|
v2 = variables.Variable(initial_value=10)
|
|
|
|
@def_function.function
|
|
def worker_fn():
|
|
v1.assign_add(1)
|
|
v2.assign_sub(2)
|
|
return v1.read_value() + v2.read_value()
|
|
|
|
with ops.device('/job:my_worker/task:0/device:CPU:0'):
|
|
self.assertAllEqual(worker_fn(), 9)
|
|
|
|
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
|
self.assertAllEqual(worker_fn(), 8)
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testConnectToClusterTwiceOk(self):
|
|
remote.connect_to_cluster(self._cluster_resolver)
|
|
remote.connect_to_cluster(self._cluster_resolver)
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testConnectToClusterOnMismatchedDevice(self):
|
|
remote.connect_to_cluster(self._cluster_resolver)
|
|
|
|
# enter into another device scope.
|
|
ops.device('/job:my_worker/task:0/device:CPU:0').__enter__()
|
|
|
|
with self.assertRaises(ValueError):
|
|
remote.connect_to_cluster(self._cluster_resolver)
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testConnectToClusterWithLocalMaster(self):
|
|
local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
|
|
remote.connect_to_cluster(local_resolver)
|
|
|
|
@test_util.eager_lazy_remote_copy_on_and_off
|
|
def testConnectToClusterInGraphModeWillFail(self):
|
|
ops.disable_eager_execution()
|
|
with self.assertRaises(ValueError):
|
|
remote.connect_to_cluster(self._cluster_resolver)
|
|
ops.enable_eager_execution()
|
|
|
|
|
|
def _strip_prefix(s, prefix):
|
|
return s[len(prefix):] if s.startswith(prefix) else s
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test.main()
|