Gets functional ops tests to pass in tf2 on gpu
PiperOrigin-RevId: 236394682
This commit is contained in:
parent
02a9fdc083
commit
ffe13572ce
@ -475,6 +475,7 @@ class FunctionalOpsTest(test.TestCase):
|
||||
mul = self.evaluate(remote_op)
|
||||
self.assertEqual(mul, [6])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRemoteFunctionCPUGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -499,6 +500,7 @@ class FunctionalOpsTest(test.TestCase):
|
||||
mul = self.evaluate(remote_op)
|
||||
self.assertEqual(mul, 9.0)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRemoteFunctionGPUCPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -523,6 +525,7 @@ class FunctionalOpsTest(test.TestCase):
|
||||
mul = self.evaluate(remote_op)
|
||||
self.assertEqual(mul, 9.0)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRemoteFunctionGPUCPUStrings(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -981,6 +984,7 @@ class PartitionedCallTest(test.TestCase):
|
||||
constant_op.constant(2.)], f=Body)
|
||||
self.assertEqual(output.eval(), 12.)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasicMultiDeviceGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
return
|
||||
@ -1061,6 +1065,7 @@ class PartitionedCallTest(test.TestCase):
|
||||
value = self.evaluate(v.read_value())
|
||||
self.assertEqual(value, 2.0)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFunctionWithResourcesOnDifferentDevices(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPUs available.")
|
||||
|
Loading…
Reference in New Issue
Block a user