Added short-curcuitting in Context.get_*_devices if device_type is None

This change also replaces a forloop with a list comprehension in all of
the methods.

PiperOrigin-RevId: 273539713
This commit is contained in:
Sergei Lebedev 2019-10-08 09:40:32 -07:00 committed by TensorFlower Gardener
parent 0477de8448
commit 822b7fff41

View File

@ -1083,13 +1083,10 @@ class Context(object):
""" """
self._initialize_physical_devices() self._initialize_physical_devices()
if device_type is not None: if device_type is None:
return [ return list(self._physical_devices)
d for d in self._physical_devices
if device_type is None or device_type == d.device_type
]
return self._physical_devices return [d for d in self._physical_devices if d.device_type == device_type]
def _import_config(self): def _import_config(self):
"""Import config if passed in during construction. """Import config if passed in during construction.
@ -1140,26 +1137,21 @@ class Context(object):
def list_logical_devices(self, device_type=None): def list_logical_devices(self, device_type=None):
"""Return logical devices.""" """Return logical devices."""
self.ensure_initialized() self.ensure_initialized()
if device_type is None:
return list(self._logical_devices)
devices = [] return [d for d in self._logical_devices if d.device_type == device_type]
for dev in self._logical_devices:
if device_type is not None and device_type != dev.device_type:
continue
devices.append(dev)
return devices
def get_visible_devices(self, device_type=None): def get_visible_devices(self, device_type=None):
"""Get the list of visible devices.""" """Get the list of visible devices."""
self._initialize_physical_devices() self._initialize_physical_devices()
if device_type is None: if device_type is None:
return self._visible_device_list return list(self._visible_device_list)
else:
return [ return [
d for d in self._visible_device_list if d.device_type == device_type d for d in self._visible_device_list if d.device_type == device_type
] ]
def set_visible_devices(self, devices, device_type=None): def set_visible_devices(self, devices, device_type=None):
"""Set the list of visible devices.""" """Set the list of visible devices."""