TFE: expose tfe.num_gpus()
PiperOrigin-RevId: 168154345
This commit is contained in:
parent
67a7cbc283
commit
3bce4f9a0d
@ -20,6 +20,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
||||
|
||||
@@device
|
||||
@@list_devices
|
||||
@@num_gpus
|
||||
|
||||
@@defun
|
||||
@@implicit_gradients
|
||||
@ -60,6 +61,7 @@ from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager.context import device
|
||||
from tensorflow.python.eager.context import enable_eager_execution
|
||||
from tensorflow.python.eager.context import list_devices
|
||||
from tensorflow.python.eager.context import num_gpus
|
||||
from tensorflow.python.eager.context import run
|
||||
from tensorflow.python.eager.core import enable_tracing
|
||||
from tensorflow.python.eager.execution_callbacks import add_execution_callback
|
||||
|
@ -27,6 +27,10 @@ class TFETest(test.TestCase):
|
||||
# Expect at least one device.
|
||||
self.assertTrue(tfe.list_devices())
|
||||
|
||||
def testNumGPUs(self):
|
||||
devices = tfe.list_devices()
|
||||
self.assertEqual(len(devices) - 1, tfe.num_gpus())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -414,9 +414,18 @@ def enable_eager_execution():
|
||||
|
||||
|
||||
def list_devices():
|
||||
"""List the names of the devices available to the default context.
|
||||
"""List the names of the available devices.
|
||||
|
||||
Returns:
|
||||
Names of the available devices, as a `list`.
|
||||
"""
|
||||
return context().devices()
|
||||
|
||||
|
||||
def num_gpus():
|
||||
"""Get the number of available GPU devices.
|
||||
|
||||
Returns:
|
||||
The number of available GPU devices.
|
||||
"""
|
||||
return context().num_gpus()
|
||||
|
Loading…
Reference in New Issue
Block a user