Expose each device's incarnation via Session.list_devices()
.
PiperOrigin-RevId: 205273020
This commit is contained in:
parent
de6be2cbb1
commit
109ae67a7e
@ -963,6 +963,7 @@ TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
|
||||
TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
|
||||
nullptr);
|
||||
TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
|
||||
TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
|
||||
|
||||
#undef TF_DEVICELIST_METHOD
|
||||
|
||||
|
@ -1521,6 +1521,13 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
|
||||
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
|
||||
const TF_DeviceList* list, int index, TF_Status* status);
|
||||
|
||||
// Retrieve the incarnation number of a given device.
|
||||
//
|
||||
// If index is out of bounds, an error code will be set in the status object,
|
||||
// and 0 will be returned.
|
||||
TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation(
|
||||
const TF_DeviceList* list, int index, TF_Status* status);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Load plugins containing custom ops and kernels
|
||||
|
||||
|
@ -540,10 +540,11 @@ class _DeviceAttributes(object):
|
||||
(in bytes).
|
||||
"""
|
||||
|
||||
def __init__(self, name, device_type, memory_limit_bytes):
|
||||
def __init__(self, name, device_type, memory_limit_bytes, incarnation):
|
||||
self._name = device.canonical_name(name)
|
||||
self._device_type = device_type
|
||||
self._memory_limit_bytes = memory_limit_bytes
|
||||
self._incarnation = incarnation
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -557,11 +558,16 @@ class _DeviceAttributes(object):
|
||||
def memory_limit_bytes(self):
|
||||
return self._memory_limit_bytes
|
||||
|
||||
@property
|
||||
def incarnation(self):
|
||||
return self._incarnation
|
||||
|
||||
def __repr__(self):
|
||||
return '_DeviceAttributes(%s, %s, %d)' % (
|
||||
return '_DeviceAttributes(%s, %s, %d, %d)' % (
|
||||
self.name,
|
||||
self.device_type,
|
||||
self.memory_limit_bytes,
|
||||
self.incarnation,
|
||||
)
|
||||
|
||||
|
||||
@ -658,7 +664,9 @@ class BaseSession(SessionInterface):
|
||||
name = tf_session.TF_DeviceListName(raw_device_list, i)
|
||||
device_type = tf_session.TF_DeviceListType(raw_device_list, i)
|
||||
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
|
||||
device_list.append(_DeviceAttributes(name, device_type, memory))
|
||||
incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i)
|
||||
device_list.append(
|
||||
_DeviceAttributes(name, device_type, memory, incarnation))
|
||||
tf_session.TF_DeleteDeviceList(raw_device_list)
|
||||
return device_list
|
||||
|
||||
|
@ -37,6 +37,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
|
||||
devices = sess.list_devices()
|
||||
self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
|
||||
[d.name for d in devices]), devices)
|
||||
# All valid device incarnations must be non-zero.
|
||||
self.assertTrue(all(d.incarnation != 0 for d in devices))
|
||||
|
||||
def testInvalidDeviceNumber(self):
|
||||
opts = tf_session.TF_NewSessionOptions()
|
||||
@ -54,6 +56,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
|
||||
devices = sess.list_devices()
|
||||
self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
|
||||
[d.name for d in devices]), devices)
|
||||
# All valid device incarnations must be non-zero.
|
||||
self.assertTrue(all(d.incarnation != 0 for d in devices))
|
||||
|
||||
def testListDevicesClusterSpecPropagation(self):
|
||||
server1 = server_lib.Server.create_local_server()
|
||||
@ -67,11 +71,13 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
|
||||
config = config_pb2.ConfigProto(cluster_def=cluster_def)
|
||||
with session.Session(server1.target, config=config) as sess:
|
||||
devices = sess.list_devices()
|
||||
device_names = set([d.name for d in devices])
|
||||
device_names = set(d.name for d in devices)
|
||||
self.assertTrue(
|
||||
'/job:worker/replica:0/task:0/device:CPU:0' in device_names)
|
||||
self.assertTrue(
|
||||
'/job:worker/replica:0/task:1/device:CPU:0' in device_names)
|
||||
# All valid device incarnations must be non-zero.
|
||||
self.assertTrue(all(d.incarnation != 0 for d in devices))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as framework_device_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
@ -104,18 +105,20 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
copy_val)
|
||||
|
||||
def testManyCPUs(self):
|
||||
# TODO(keveman): Implement ListDevices and test for the number of
|
||||
# devices returned by ListDevices.
|
||||
with session.Session(
|
||||
config=config_pb2.ConfigProto(device_count={
|
||||
'CPU': 2
|
||||
})):
|
||||
'CPU': 2, 'GPU': 0
|
||||
})) as sess:
|
||||
inp = constant_op.constant(10.0, name='W1')
|
||||
self.assertAllEqual(inp.eval(), 10.0)
|
||||
|
||||
devices = sess.list_devices()
|
||||
self.assertEqual(2, len(devices))
|
||||
for device in devices:
|
||||
self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string(
|
||||
device.name).device_type)
|
||||
|
||||
def testPerSessionThreads(self):
|
||||
# TODO(keveman): Implement ListDevices and test for the number of
|
||||
# devices returned by ListDevices.
|
||||
with session.Session(
|
||||
config=config_pb2.ConfigProto(use_per_session_threads=True)):
|
||||
inp = constant_op.constant(10.0, name='W1')
|
||||
@ -1868,19 +1871,21 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testDeviceAttributes(self):
|
||||
attrs = session._DeviceAttributes(
|
||||
'/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337)
|
||||
'/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000)
|
||||
self.assertEqual(1337, attrs.memory_limit_bytes)
|
||||
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
|
||||
self.assertEqual('TYPE', attrs.device_type)
|
||||
self.assertEqual(1000000, attrs.incarnation)
|
||||
str_repr = '%s' % attrs
|
||||
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
|
||||
|
||||
def testDeviceAttributesCanonicalization(self):
|
||||
attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
|
||||
'TYPE', 1337)
|
||||
'TYPE', 1337, 1000000)
|
||||
self.assertEqual(1337, attrs.memory_limit_bytes)
|
||||
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
|
||||
self.assertEqual('TYPE', attrs.device_type)
|
||||
self.assertEqual(1000000, attrs.incarnation)
|
||||
str_repr = '%s' % attrs
|
||||
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
|
||||
|
||||
|
@ -138,6 +138,11 @@ tensorflow::ImportNumpy();
|
||||
$result = PyLong_FromLongLong($1);
|
||||
}
|
||||
|
||||
// Convert TF_DeviceListIncarnation uint64_t output to Python integer
|
||||
%typemap(out) uint64_t {
|
||||
$result = PyLong_FromUnsignedLongLong($1);
|
||||
}
|
||||
|
||||
// We use TF_OperationGetControlInputs_wrapper instead of
|
||||
// TF_OperationGetControlInputs
|
||||
%ignore TF_OperationGetControlInputs;
|
||||
|
Loading…
Reference in New Issue
Block a user