TFE: expose tfe.num_gpus()

PiperOrigin-RevId: 168154345
This commit is contained in:
Shanqing Cai 2017-09-10 07:54:32 -07:00 committed by TensorFlower Gardener
parent 67a7cbc283
commit 3bce4f9a0d
3 changed files with 16 additions and 1 deletions

View File

@ -20,6 +20,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@device
@@list_devices
@@num_gpus
@@defun
@@implicit_gradients
@ -60,6 +61,7 @@ from tensorflow.python.eager import function
from tensorflow.python.eager.context import device
from tensorflow.python.eager.context import enable_eager_execution
from tensorflow.python.eager.context import list_devices
from tensorflow.python.eager.context import num_gpus
from tensorflow.python.eager.context import run
from tensorflow.python.eager.core import enable_tracing
from tensorflow.python.eager.execution_callbacks import add_execution_callback

View File

@ -27,6 +27,10 @@ class TFETest(test.TestCase):
# Expect at least one device.
self.assertTrue(tfe.list_devices())
def testNumGPUs(self):
devices = tfe.list_devices()
self.assertEqual(len(devices) - 1, tfe.num_gpus())
if __name__ == "__main__":
test.main()

View File

@ -414,9 +414,18 @@ def enable_eager_execution():
def list_devices():
"""List the names of the devices available to the default context.
"""List the names of the available devices.
Returns:
Names of the available devices, as a `list`.
"""
return context().devices()
def num_gpus():
"""Get the number of available GPU devices.
Returns:
The number of available GPU devices.
"""
return context().num_gpus()