Flip the ordering of test context managers so that parameterized testcases can be correctly processed.
Also renamed one of the `GpuToRemoteCopy` tests to `GpuToRemoteOps` to avoid confusion. PiperOrigin-RevId: 299439193 Change-Id: I82318f485e85805f24fe6ec1553926fc22b32373
This commit is contained in:
parent
abf241bc86
commit
78e0404953
@ -93,9 +93,19 @@ class RemoteExecutionTest(test.TestCase, parameterized.TestCase):
|
||||
ops.device(None).__enter__()
|
||||
context._reset_context()
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
@test_util.run_gpu_only
|
||||
def testGpuToRemoteCopy(self):
|
||||
"""Tests that the remote copy happens satisfactorily."""
|
||||
x1 = array_ops.ones([2, 2]).gpu()
|
||||
with ops.device("/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
|
||||
x2 = x1._copy() # pylint: disable=protected-access
|
||||
|
||||
np.testing.assert_array_equal(x1.numpy(), x2.numpy())
|
||||
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
@test_util.run_gpu_only
|
||||
def testGpuToRemoteOp(self):
|
||||
with ops.device("gpu:0"):
|
||||
x = array_ops.ones([2, 2])
|
||||
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
|
||||
@ -222,17 +232,6 @@ class RemoteExecutionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(y.device,
|
||||
"/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@test_util.run_in_async_and_sync_mode
|
||||
def testGPUToRemoteCopy(self):
|
||||
"""Tests that the remote copy happens satisfactorily."""
|
||||
x1 = array_ops.ones([2, 2]).gpu()
|
||||
|
||||
with ops.device("/job:remote_device/replica:0/task:1/device:CPU:0"):
|
||||
x2 = x1._copy() # pylint: disable=protected-access
|
||||
|
||||
np.testing.assert_array_equal(x1.numpy(), x2.numpy())
|
||||
|
||||
|
||||
class RemoteExecutionWithoutLazyRemoteInputsCopyTest(RemoteExecutionTest):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user