TFE: expose tfe.num_gpus()
PiperOrigin-RevId: 168154345
This commit is contained in:
parent
67a7cbc283
commit
3bce4f9a0d
@ -20,6 +20,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
|||||||
|
|
||||||
@@device
|
@@device
|
||||||
@@list_devices
|
@@list_devices
|
||||||
|
@@num_gpus
|
||||||
|
|
||||||
@@defun
|
@@defun
|
||||||
@@implicit_gradients
|
@@implicit_gradients
|
||||||
@ -60,6 +61,7 @@ from tensorflow.python.eager import function
|
|||||||
from tensorflow.python.eager.context import device
|
from tensorflow.python.eager.context import device
|
||||||
from tensorflow.python.eager.context import enable_eager_execution
|
from tensorflow.python.eager.context import enable_eager_execution
|
||||||
from tensorflow.python.eager.context import list_devices
|
from tensorflow.python.eager.context import list_devices
|
||||||
|
from tensorflow.python.eager.context import num_gpus
|
||||||
from tensorflow.python.eager.context import run
|
from tensorflow.python.eager.context import run
|
||||||
from tensorflow.python.eager.core import enable_tracing
|
from tensorflow.python.eager.core import enable_tracing
|
||||||
from tensorflow.python.eager.execution_callbacks import add_execution_callback
|
from tensorflow.python.eager.execution_callbacks import add_execution_callback
|
||||||
|
@ -27,6 +27,10 @@ class TFETest(test.TestCase):
|
|||||||
# Expect at least one device.
|
# Expect at least one device.
|
||||||
self.assertTrue(tfe.list_devices())
|
self.assertTrue(tfe.list_devices())
|
||||||
|
|
||||||
|
def testNumGPUs(self):
|
||||||
|
devices = tfe.list_devices()
|
||||||
|
self.assertEqual(len(devices) - 1, tfe.num_gpus())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -414,9 +414,18 @@ def enable_eager_execution():
|
|||||||
|
|
||||||
|
|
||||||
def list_devices():
|
def list_devices():
|
||||||
"""List the names of the devices available to the default context.
|
"""List the names of the available devices.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Names of the available devices, as a `list`.
|
Names of the available devices, as a `list`.
|
||||||
"""
|
"""
|
||||||
return context().devices()
|
return context().devices()
|
||||||
|
|
||||||
|
|
||||||
|
def num_gpus():
|
||||||
|
"""Get the number of available GPU devices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of available GPU devices.
|
||||||
|
"""
|
||||||
|
return context().num_gpus()
|
||||||
|
Loading…
Reference in New Issue
Block a user