Flip the ordering of test context managers so that parameterized testcases can be correctly processed.

Also renamed one of the `GpuToRemoteCopy` tests to `GpuToRemoteOps` to avoid confusion.

PiperOrigin-RevId: 299439193
Change-Id: I82318f485e85805f24fe6ec1553926fc22b32373
This commit is contained in:
Haoyu Zhang 2020-03-06 14:36:11 -08:00 committed by TensorFlower Gardener
parent abf241bc86
commit 78e0404953

View File

@ -93,9 +93,19 @@ class RemoteExecutionTest(test.TestCase, parameterized.TestCase):
ops.device(None).__enter__()
context._reset_context()
@test_util.run_gpu_only
@test_util.run_in_async_and_sync_mode
@test_util.run_gpu_only
def testGpuToRemoteCopy(self):
"""Tests that the remote copy happens satisfactorily."""
x1 = array_ops.ones([2, 2]).gpu()
with ops.device("/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
x2 = x1._copy() # pylint: disable=protected-access
np.testing.assert_array_equal(x1.numpy(), x2.numpy())
@test_util.run_in_async_and_sync_mode
@test_util.run_gpu_only
def testGpuToRemoteOp(self):
with ops.device("gpu:0"):
x = array_ops.ones([2, 2])
with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
@ -222,17 +232,6 @@ class RemoteExecutionTest(test.TestCase, parameterized.TestCase):
self.assertEqual(y.device,
"/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
@test_util.run_gpu_only
@test_util.run_in_async_and_sync_mode
def testGPUToRemoteCopy(self):
"""Tests that the remote copy happens satisfactorily."""
x1 = array_ops.ones([2, 2]).gpu()
with ops.device("/job:remote_device/replica:0/task:1/device:CPU:0"):
x2 = x1._copy() # pylint: disable=protected-access
np.testing.assert_array_equal(x1.numpy(), x2.numpy())
class RemoteExecutionWithoutLazyRemoteInputsCopyTest(RemoteExecutionTest):