Expose each device's incarnation via Session.list_devices().

PiperOrigin-RevId: 205273020
This commit is contained in:
Derek Murray 2018-07-19 11:27:39 -07:00 committed by TensorFlower Gardener
parent de6be2cbb1
commit 109ae67a7e
6 changed files with 44 additions and 12 deletions

View File

@ -963,6 +963,7 @@ TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
nullptr);
TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
#undef TF_DEVICELIST_METHOD

View File

@ -1521,6 +1521,13 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
const TF_DeviceList* list, int index, TF_Status* status);
// Retrieve the incarnation number of a given device.
//
// If index is out of bounds, an error code will be set in the status object,
// and 0 will be returned.
TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation(
const TF_DeviceList* list, int index, TF_Status* status);
// --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels

View File

@ -540,10 +540,11 @@ class _DeviceAttributes(object):
(in bytes).
"""
def __init__(self, name, device_type, memory_limit_bytes):
def __init__(self, name, device_type, memory_limit_bytes, incarnation):
self._name = device.canonical_name(name)
self._device_type = device_type
self._memory_limit_bytes = memory_limit_bytes
self._incarnation = incarnation
@property
def name(self):
@ -557,11 +558,16 @@ class _DeviceAttributes(object):
def memory_limit_bytes(self):
return self._memory_limit_bytes
@property
def incarnation(self):
return self._incarnation
def __repr__(self):
return '_DeviceAttributes(%s, %s, %d)' % (
return '_DeviceAttributes(%s, %s, %d, %d)' % (
self.name,
self.device_type,
self.memory_limit_bytes,
self.incarnation,
)
@ -658,7 +664,9 @@ class BaseSession(SessionInterface):
name = tf_session.TF_DeviceListName(raw_device_list, i)
device_type = tf_session.TF_DeviceListType(raw_device_list, i)
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
device_list.append(_DeviceAttributes(name, device_type, memory))
incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i)
device_list.append(
_DeviceAttributes(name, device_type, memory, incarnation))
tf_session.TF_DeleteDeviceList(raw_device_list)
return device_list

View File

@ -37,6 +37,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
devices = sess.list_devices()
self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
[d.name for d in devices]), devices)
# All valid device incarnations must be non-zero.
self.assertTrue(all(d.incarnation != 0 for d in devices))
def testInvalidDeviceNumber(self):
opts = tf_session.TF_NewSessionOptions()
@ -54,6 +56,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
devices = sess.list_devices()
self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
[d.name for d in devices]), devices)
# All valid device incarnations must be non-zero.
self.assertTrue(all(d.incarnation != 0 for d in devices))
def testListDevicesClusterSpecPropagation(self):
server1 = server_lib.Server.create_local_server()
@ -67,11 +71,13 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
config = config_pb2.ConfigProto(cluster_def=cluster_def)
with session.Session(server1.target, config=config) as sess:
devices = sess.list_devices()
device_names = set([d.name for d in devices])
device_names = set(d.name for d in devices)
self.assertTrue(
'/job:worker/replica:0/task:0/device:CPU:0' in device_names)
self.assertTrue(
'/job:worker/replica:0/task:1/device:CPU:0' in device_names)
# All valid device incarnations must be non-zero.
self.assertTrue(all(d.incarnation != 0 for d in devices))
if __name__ == '__main__':

View File

@ -35,6 +35,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as framework_device_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
@ -104,18 +105,20 @@ class SessionTest(test_util.TensorFlowTestCase):
copy_val)
def testManyCPUs(self):
# TODO(keveman): Implement ListDevices and test for the number of
# devices returned by ListDevices.
with session.Session(
config=config_pb2.ConfigProto(device_count={
'CPU': 2
})):
'CPU': 2, 'GPU': 0
})) as sess:
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
devices = sess.list_devices()
self.assertEqual(2, len(devices))
for device in devices:
self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string(
device.name).device_type)
def testPerSessionThreads(self):
# TODO(keveman): Implement ListDevices and test for the number of
# devices returned by ListDevices.
with session.Session(
config=config_pb2.ConfigProto(use_per_session_threads=True)):
inp = constant_op.constant(10.0, name='W1')
@ -1868,19 +1871,21 @@ class SessionTest(test_util.TensorFlowTestCase):
def testDeviceAttributes(self):
attrs = session._DeviceAttributes(
'/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337)
'/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000)
self.assertEqual(1337, attrs.memory_limit_bytes)
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
self.assertEqual('TYPE', attrs.device_type)
self.assertEqual(1000000, attrs.incarnation)
str_repr = '%s' % attrs
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
def testDeviceAttributesCanonicalization(self):
attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
'TYPE', 1337)
'TYPE', 1337, 1000000)
self.assertEqual(1337, attrs.memory_limit_bytes)
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
self.assertEqual('TYPE', attrs.device_type)
self.assertEqual(1000000, attrs.incarnation)
str_repr = '%s' % attrs
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)

View File

@ -138,6 +138,11 @@ tensorflow::ImportNumpy();
$result = PyLong_FromLongLong($1);
}
// Convert TF_DeviceListIncarnation uint64_t output to Python integer
%typemap(out) uint64_t {
$result = PyLong_FromUnsignedLongLong($1);
}
// We use TF_OperationGetControlInputs_wrapper instead of
// TF_OperationGetControlInputs
%ignore TF_OperationGetControlInputs;