Expose each device's incarnation via Session.list_devices()
.
PiperOrigin-RevId: 205273020
This commit is contained in:
parent
de6be2cbb1
commit
109ae67a7e
@ -963,6 +963,7 @@ TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
|
|||||||
TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
|
TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
|
||||||
nullptr);
|
nullptr);
|
||||||
TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
|
TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
|
||||||
|
TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
|
||||||
|
|
||||||
#undef TF_DEVICELIST_METHOD
|
#undef TF_DEVICELIST_METHOD
|
||||||
|
|
||||||
|
@ -1521,6 +1521,13 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
|
|||||||
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
|
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
|
||||||
const TF_DeviceList* list, int index, TF_Status* status);
|
const TF_DeviceList* list, int index, TF_Status* status);
|
||||||
|
|
||||||
|
// Retrieve the incarnation number of a given device.
|
||||||
|
//
|
||||||
|
// If index is out of bounds, an error code will be set in the status object,
|
||||||
|
// and 0 will be returned.
|
||||||
|
TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation(
|
||||||
|
const TF_DeviceList* list, int index, TF_Status* status);
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// Load plugins containing custom ops and kernels
|
// Load plugins containing custom ops and kernels
|
||||||
|
|
||||||
|
@ -540,10 +540,11 @@ class _DeviceAttributes(object):
|
|||||||
(in bytes).
|
(in bytes).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name, device_type, memory_limit_bytes):
|
def __init__(self, name, device_type, memory_limit_bytes, incarnation):
|
||||||
self._name = device.canonical_name(name)
|
self._name = device.canonical_name(name)
|
||||||
self._device_type = device_type
|
self._device_type = device_type
|
||||||
self._memory_limit_bytes = memory_limit_bytes
|
self._memory_limit_bytes = memory_limit_bytes
|
||||||
|
self._incarnation = incarnation
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -557,11 +558,16 @@ class _DeviceAttributes(object):
|
|||||||
def memory_limit_bytes(self):
|
def memory_limit_bytes(self):
|
||||||
return self._memory_limit_bytes
|
return self._memory_limit_bytes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def incarnation(self):
|
||||||
|
return self._incarnation
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '_DeviceAttributes(%s, %s, %d)' % (
|
return '_DeviceAttributes(%s, %s, %d, %d)' % (
|
||||||
self.name,
|
self.name,
|
||||||
self.device_type,
|
self.device_type,
|
||||||
self.memory_limit_bytes,
|
self.memory_limit_bytes,
|
||||||
|
self.incarnation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -658,7 +664,9 @@ class BaseSession(SessionInterface):
|
|||||||
name = tf_session.TF_DeviceListName(raw_device_list, i)
|
name = tf_session.TF_DeviceListName(raw_device_list, i)
|
||||||
device_type = tf_session.TF_DeviceListType(raw_device_list, i)
|
device_type = tf_session.TF_DeviceListType(raw_device_list, i)
|
||||||
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
|
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
|
||||||
device_list.append(_DeviceAttributes(name, device_type, memory))
|
incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i)
|
||||||
|
device_list.append(
|
||||||
|
_DeviceAttributes(name, device_type, memory, incarnation))
|
||||||
tf_session.TF_DeleteDeviceList(raw_device_list)
|
tf_session.TF_DeleteDeviceList(raw_device_list)
|
||||||
return device_list
|
return device_list
|
||||||
|
|
||||||
|
@ -37,6 +37,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
|
|||||||
devices = sess.list_devices()
|
devices = sess.list_devices()
|
||||||
self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
|
self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
|
||||||
[d.name for d in devices]), devices)
|
[d.name for d in devices]), devices)
|
||||||
|
# All valid device incarnations must be non-zero.
|
||||||
|
self.assertTrue(all(d.incarnation != 0 for d in devices))
|
||||||
|
|
||||||
def testInvalidDeviceNumber(self):
|
def testInvalidDeviceNumber(self):
|
||||||
opts = tf_session.TF_NewSessionOptions()
|
opts = tf_session.TF_NewSessionOptions()
|
||||||
@ -54,6 +56,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
|
|||||||
devices = sess.list_devices()
|
devices = sess.list_devices()
|
||||||
self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
|
self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
|
||||||
[d.name for d in devices]), devices)
|
[d.name for d in devices]), devices)
|
||||||
|
# All valid device incarnations must be non-zero.
|
||||||
|
self.assertTrue(all(d.incarnation != 0 for d in devices))
|
||||||
|
|
||||||
def testListDevicesClusterSpecPropagation(self):
|
def testListDevicesClusterSpecPropagation(self):
|
||||||
server1 = server_lib.Server.create_local_server()
|
server1 = server_lib.Server.create_local_server()
|
||||||
@ -67,11 +71,13 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
|
|||||||
config = config_pb2.ConfigProto(cluster_def=cluster_def)
|
config = config_pb2.ConfigProto(cluster_def=cluster_def)
|
||||||
with session.Session(server1.target, config=config) as sess:
|
with session.Session(server1.target, config=config) as sess:
|
||||||
devices = sess.list_devices()
|
devices = sess.list_devices()
|
||||||
device_names = set([d.name for d in devices])
|
device_names = set(d.name for d in devices)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
'/job:worker/replica:0/task:0/device:CPU:0' in device_names)
|
'/job:worker/replica:0/task:0/device:CPU:0' in device_names)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
'/job:worker/replica:0/task:1/device:CPU:0' in device_names)
|
'/job:worker/replica:0/task:1/device:CPU:0' in device_names)
|
||||||
|
# All valid device incarnations must be non-zero.
|
||||||
|
self.assertTrue(all(d.incarnation != 0 for d in devices))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.core.protobuf import config_pb2
|
|||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import common_shapes
|
from tensorflow.python.framework import common_shapes
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import device as framework_device_lib
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
@ -104,18 +105,20 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
copy_val)
|
copy_val)
|
||||||
|
|
||||||
def testManyCPUs(self):
|
def testManyCPUs(self):
|
||||||
# TODO(keveman): Implement ListDevices and test for the number of
|
|
||||||
# devices returned by ListDevices.
|
|
||||||
with session.Session(
|
with session.Session(
|
||||||
config=config_pb2.ConfigProto(device_count={
|
config=config_pb2.ConfigProto(device_count={
|
||||||
'CPU': 2
|
'CPU': 2, 'GPU': 0
|
||||||
})):
|
})) as sess:
|
||||||
inp = constant_op.constant(10.0, name='W1')
|
inp = constant_op.constant(10.0, name='W1')
|
||||||
self.assertAllEqual(inp.eval(), 10.0)
|
self.assertAllEqual(inp.eval(), 10.0)
|
||||||
|
|
||||||
|
devices = sess.list_devices()
|
||||||
|
self.assertEqual(2, len(devices))
|
||||||
|
for device in devices:
|
||||||
|
self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string(
|
||||||
|
device.name).device_type)
|
||||||
|
|
||||||
def testPerSessionThreads(self):
|
def testPerSessionThreads(self):
|
||||||
# TODO(keveman): Implement ListDevices and test for the number of
|
|
||||||
# devices returned by ListDevices.
|
|
||||||
with session.Session(
|
with session.Session(
|
||||||
config=config_pb2.ConfigProto(use_per_session_threads=True)):
|
config=config_pb2.ConfigProto(use_per_session_threads=True)):
|
||||||
inp = constant_op.constant(10.0, name='W1')
|
inp = constant_op.constant(10.0, name='W1')
|
||||||
@ -1868,19 +1871,21 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testDeviceAttributes(self):
|
def testDeviceAttributes(self):
|
||||||
attrs = session._DeviceAttributes(
|
attrs = session._DeviceAttributes(
|
||||||
'/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337)
|
'/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000)
|
||||||
self.assertEqual(1337, attrs.memory_limit_bytes)
|
self.assertEqual(1337, attrs.memory_limit_bytes)
|
||||||
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
|
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
|
||||||
self.assertEqual('TYPE', attrs.device_type)
|
self.assertEqual('TYPE', attrs.device_type)
|
||||||
|
self.assertEqual(1000000, attrs.incarnation)
|
||||||
str_repr = '%s' % attrs
|
str_repr = '%s' % attrs
|
||||||
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
|
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
|
||||||
|
|
||||||
def testDeviceAttributesCanonicalization(self):
|
def testDeviceAttributesCanonicalization(self):
|
||||||
attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
|
attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
|
||||||
'TYPE', 1337)
|
'TYPE', 1337, 1000000)
|
||||||
self.assertEqual(1337, attrs.memory_limit_bytes)
|
self.assertEqual(1337, attrs.memory_limit_bytes)
|
||||||
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
|
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
|
||||||
self.assertEqual('TYPE', attrs.device_type)
|
self.assertEqual('TYPE', attrs.device_type)
|
||||||
|
self.assertEqual(1000000, attrs.incarnation)
|
||||||
str_repr = '%s' % attrs
|
str_repr = '%s' % attrs
|
||||||
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
|
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
|
||||||
|
|
||||||
|
@ -138,6 +138,11 @@ tensorflow::ImportNumpy();
|
|||||||
$result = PyLong_FromLongLong($1);
|
$result = PyLong_FromLongLong($1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert TF_DeviceListIncarnation uint64_t output to Python integer
|
||||||
|
%typemap(out) uint64_t {
|
||||||
|
$result = PyLong_FromUnsignedLongLong($1);
|
||||||
|
}
|
||||||
|
|
||||||
// We use TF_OperationGetControlInputs_wrapper instead of
|
// We use TF_OperationGetControlInputs_wrapper instead of
|
||||||
// TF_OperationGetControlInputs
|
// TF_OperationGetControlInputs
|
||||||
%ignore TF_OperationGetControlInputs;
|
%ignore TF_OperationGetControlInputs;
|
||||||
|
Loading…
Reference in New Issue
Block a user