diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index ade95ca67c5..b5bf839a89b 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -20,6 +20,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@device @@list_devices +@@num_gpus @@defun @@implicit_gradients @@ -60,6 +61,7 @@ from tensorflow.python.eager import function from tensorflow.python.eager.context import device from tensorflow.python.eager.context import enable_eager_execution from tensorflow.python.eager.context import list_devices +from tensorflow.python.eager.context import num_gpus from tensorflow.python.eager.context import run from tensorflow.python.eager.core import enable_tracing from tensorflow.python.eager.execution_callbacks import add_execution_callback diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index a59fcb8d6da..2a9d7589d3a 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -27,6 +27,10 @@ class TFETest(test.TestCase): # Expect at least one device. self.assertTrue(tfe.list_devices()) + def testNumGPUs(self): + devices = tfe.list_devices() + self.assertEqual(len(devices) - 1, tfe.num_gpus()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 8496a02947f..79374d2bb55 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -414,9 +414,18 @@ def enable_eager_execution(): def list_devices(): - """List the names of the devices available to the default context. + """List the names of the available devices. Returns: Names of the available devices, as a `list`. """ return context().devices() + + +def num_gpus(): + """Get the number of available GPU devices. + + Returns: + The number of available GPU devices. + """ + return context().num_gpus()