Expose APIs to list and configure devices

The new APIs allow listing devices as PhysicalDevice prior to the
initialization of the runtime and LogicalDevice once the runtime
is initialized. We also allow configuring virtual devices,
memory limits as well as visibility.

PiperOrigin-RevId: 243330844
This commit is contained in:
Gaurav Jain 2019-04-12 14:15:08 -07:00 committed by TensorFlower Gardener
parent 6c6e798a15
commit 1e96adb967
15 changed files with 745 additions and 3 deletions

View File

@ -30,10 +30,17 @@ namespace tensorflow {
class XlaCpuDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
return Status::OK();
}
Status XlaCpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {

View File

@ -55,10 +55,32 @@ static xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
class XlaGpuDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
return Status::OK();
}
int device_count = platform.ValueOrDie()->VisibleDeviceCount();
if (device_count <= 0) {
return Status::OK();
}
for (int i = 0; i < device_count; ++i) {
devices->push_back(
absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i));
}
return Status::OK();
}
Status XlaGpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {

View File

@ -32,10 +32,19 @@ constexpr std::array<DataType, 10> kExecAllTypes = {
class XlaInterpreterDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaInterpreterDeviceFactory::ListPhysicalDevices(
std::vector<string>* devices) {
devices->push_back(
absl::StrCat("/physical_device:", DEVICE_XLA_INTERPRETER, ":0"));
return Status::OK();
}
Status XlaInterpreterDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {

View File

@ -90,6 +90,32 @@ DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
return it->second.factory.get();
}
Status DeviceFactory::ListAllPhysicalDevices(std::vector<string>* devices) {
// CPU first. A CPU device is required.
auto cpu_factory = GetFactory("CPU");
if (!cpu_factory) {
return errors::NotFound(
"CPU Factory not registered. Did you link in threadpool_device?");
}
size_t init_size = devices->size();
TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(devices));
if (devices->size() == init_size) {
return errors::NotFound("No CPU devices are available in this process");
}
// Then the rest (including GPU).
mutex_lock l(*get_device_factory_lock());
for (auto& p : device_factories()) {
auto factory = p.second.factory.get();
if (factory != cpu_factory) {
TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices));
}
}
return Status::OK();
}
Status DeviceFactory::AddDevices(
const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {

View File

@ -49,6 +49,15 @@ class DeviceFactory {
const SessionOptions& options,
const string& name_prefix);
// Iterate through all device factories and build a list of all of the
// possible physical devices.
//
// CPU is are added first.
static Status ListAllPhysicalDevices(std::vector<string>* devices);
// For a specific device factory list all possible physical devices.
virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0;
// Most clients should call AddDevices() instead.
virtual Status CreateDevices(
const SessionOptions& options, const string& name_prefix,

View File

@ -56,6 +56,9 @@ class DeviceSetTest : public ::testing::Test {
class DummyFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override {
return Status::OK();
}
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override {
return Status::OK();

View File

@ -983,6 +983,36 @@ Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
se::Platform* gpu_manager = GPUMachineManager();
if (gpu_manager == nullptr) {
return Status::OK();
}
int device_count = gpu_manager->VisibleDeviceCount();
if (device_count <= 0) {
return Status::OK();
}
std::vector<PlatformGpuId> visible_gpu_order(device_count);
int deviceNo = 0;
std::generate(visible_gpu_order.begin(), visible_gpu_order.end(),
[&deviceNo] { return deviceNo++; });
std::vector<PlatformGpuId> valid_platform_gpu_ids;
TF_RETURN_IF_ERROR(
GetValidDeviceIds(visible_gpu_order, &valid_platform_gpu_ids));
for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
const string device_name =
strings::StrCat("/physical_device:GPU:", platform_gpu_id.value());
devices->push_back(device_name);
}
return Status::OK();
}
Status BaseGPUDeviceFactory::CreateDevices(
const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {

View File

@ -263,6 +263,7 @@ class GPUKernelTracker {
class BaseGPUDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;

View File

@ -109,6 +109,12 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice {
// The associated factory.
class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override {
devices->push_back("/physical_device:CPU:0");
return Status::OK();
}
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override {
int n = 1;

View File

@ -106,6 +106,9 @@ class FakeDevice : public Device {
class DummyFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override {
return Status::OK();
}
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override {
return Status::OK();

View File

@ -29,6 +29,12 @@ namespace tensorflow {
// TODO(zhifengc/tucker): Figure out the bytes of available RAM.
class ThreadPoolDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override {
devices->push_back("/physical_device:CPU:0");
return Status::OK();
}
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override {
int num_numa_nodes = port::NUMANumNodes();

View File

@ -214,6 +214,57 @@ class _ContextSwitchStack(threading.local):
self.stack.pop()
class LogicalDevice(
collections.namedtuple("LogicalDevice", ["name", "device_type"])):
"""Abstraction for a device initialized by the runtime.
A LogicalDevice corresponds to a initialized instance on a PhysicalDevice or a
remote device available in the cluster. Tensors and operations can be placed
on a specific LogicalDevice by calling `tf.device()` with the `name` of the
LogicalDevice.
Fields:
name: The fully qualified name of the device. Can be used for Op or function
placement.
device_type: String declaring the type of device such as "CPU" or "GPU".
"""
pass
class VirtualDeviceConfiguration(
collections.namedtuple("VirtualDeviceConfiguration", ["memory_limit"])):
"""Configuration class for virtual devices for a PhysicalDevice.
Fields:
memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual
device. Currently only supported for GPUs.
"""
def __new__(cls, memory_limit=None):
return super(VirtualDeviceConfiguration, cls).__new__(cls, memory_limit)
class PhysicalDevice(
collections.namedtuple("PhysicalDevice", ["name", "device_type"])):
"""Abstraction for a locally visible physical device.
TensorFlow can utilize various devices such as the CPU or multiple GPUs
for computation. Before initializing a local device for use, the user can
customize certain properties of the device such as it's visibility or memory
configuration.
Once a PhysicalDevice is initialized one or many LogicalDevice objects are
created. Use tf.config.set_virtual_device_configuration() to create multiple
LogicalDevice objects for a PhysicalDevice. This is useful when separation
between models is needed.
Fields:
name: Unique identifier for device.
device_type: String declaring the type of device such as "CPU" or "GPU".
"""
pass
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
# TODO(agarwal): consider keeping the corresponding Graph here.
class Context(object):
@ -284,6 +335,11 @@ class Context(object):
self._server_def = server_def
self._collective_ops_server_def = None
self._physical_devices = None
self._visible_device_list = []
self._memory_growth_map = None
self._virtual_device_map = {}
# Values set after construction
self._optimizer_jit = None
self._intra_op_parallelism_threads = None
@ -317,6 +373,7 @@ class Context(object):
def _initialize_devices(self):
"""Helper to initialize devices."""
# Store list of devices
self._logical_devices = []
self._context_devices = []
device_list = pywrap_tensorflow.TFE_ContextListDevices(
self._context_handle)
@ -325,6 +382,9 @@ class Context(object):
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
self._context_devices.append(pydev.canonical_name(dev_name))
spec = pydev.DeviceSpec.from_string(dev_name)
self._logical_devices.append(
LogicalDevice(name=dev_name, device_type=spec.device_type))
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
if dev_type == "GPU":
self._num_gpus += 1
@ -664,8 +724,64 @@ class Context(object):
if nodes is not None:
config.graph_options.rewrite_options.min_graph_nodes = nodes
# Compute device counts
config.device_count["CPU"] = 0
config.device_count["GPU"] = 0
for dev in self.list_physical_devices():
if dev not in self._visible_device_list:
continue
virtual_devices = self._virtual_device_map.get(dev)
if virtual_devices is None:
config.device_count[dev.device_type] += 1
else:
config.device_count[dev.device_type] += len(virtual_devices)
# Configure gpu_options
gpu_options = self._compute_gpu_options()
config.gpu_options.MergeFrom(gpu_options)
return config
def _compute_gpu_options(self):
"""Build the GPUOptions proto."""
visible_device_list = []
virtual_devices = []
gpu_index = -1
for dev in self.list_physical_devices("GPU"):
gpu_index += 1
if dev not in self._visible_device_list:
continue
visible_device_list.append(str(gpu_index))
if self._virtual_device_map:
vdevs = self._virtual_device_map.get(dev, [])
device_limits = []
for virt_dev in vdevs:
device_limits.append(virt_dev.memory_limit)
virtual_devices.append(
config_pb2.GPUOptions.Experimental.VirtualDevices(
memory_limit_mb=device_limits))
# Only compute growth if virtual devices have not been configured and we
# have GPUs
if not virtual_devices and self._memory_growth_map:
memory_growths = set(self._memory_growth_map.values())
if len(memory_growths) > 1:
raise ValueError("Memory growth cannot differ between GPU devices")
allow_growth = memory_growths.pop()
else:
allow_growth = None
return config_pb2.GPUOptions(
allow_growth=allow_growth,
visible_device_list=",".join(visible_device_list),
experimental=config_pb2.GPUOptions.Experimental(
virtual_devices=virtual_devices))
@property
def function_call_options(self):
"""Returns function call options for current thread.
@ -769,6 +885,173 @@ class Context(object):
"""Get the list of post-execution callbacks added to the context."""
return self._post_execution_callbacks
def list_physical_devices(self, device_type=None):
"""List local devices visible to the system.
This API allows a client to query the devices before they have been
initialized by the eager runtime. Additionally a user can filter by device
type, to get only CPUs or GPUs.
Args:
device_type: Optional device type to limit results to
Returns:
List of PhysicalDevice objects.
"""
# We lazy initialize self._physical_devices since we do not want to do this
# the constructor since the backend may not be initialized yet.
if self._physical_devices is None:
devs = pywrap_tensorflow.TF_ListPhysicalDevices()
self._physical_devices = [
PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1])
for d in devs
]
# Construct the visible device list from all physical devices but ignore
# XLA devices
self._visible_device_list = [
d for d in self._physical_devices
if not d.device_type.startswith("XLA")
]
self._memory_growth_map = {
d: None for d in self._physical_devices if d.device_type == "GPU"
}
# Import device settings that may have been passed into the constructor
self._import_config()
if device_type is not None:
return [
d for d in self._physical_devices
if device_type is None or device_type == d.device_type
]
return self._physical_devices
def _import_config(self):
"""Import config if passed in during construction.
If Context was created with a ConfigProto such as when calling
tf.compat.v1.enable_eager_execution(), then we need to pull out the
various pieces we might be replacing and import then into our internal
class representation.
"""
if self._config is None:
return
num_cpus = self._config.device_count.get("CPU", 1)
if num_cpus != 1:
cpus = [d for d in self._physical_devices if d.device_type == "CPU"]
if num_cpus == 0:
self.set_visible_devices([], "CPU")
elif num_cpus > 1:
self.set_virtual_device_configuration(
cpus[0], [VirtualDeviceConfiguration() for _ in range(num_cpus)])
gpus = [d for d in self._physical_devices if d.device_type == "GPU"]
gpu_count = self._config.device_count.get("GPU", None)
if gpu_count == 0:
self.set_visible_devices([], "GPU")
elif gpu_count is not None:
# TODO(gjn): Handle importing existing virtual GPU configuration
self.set_visible_devices(gpus[:gpu_count], "GPU")
def list_logical_devices(self, device_type=None):
"""Return logical devices."""
self.ensure_initialized()
devices = []
for dev in self._logical_devices:
if device_type is not None and device_type != dev.device_type:
continue
devices.append(dev)
return devices
def get_visible_devices(self, device_type=None):
"""Get the list of visible devices."""
if device_type is None:
return self._visible_device_list
else:
return [
d for d in self._visible_device_list if d.device_type == device_type
]
def set_visible_devices(self, devices, device_type=None):
"""Set the list of visible devices."""
if self._context_handle is not None:
raise RuntimeError("Visible devices must be set at program startup")
if not isinstance(devices, list):
devices = [devices]
for d in devices:
if d not in self._physical_devices:
raise ValueError("Unrecognized device: %s" % repr(d))
if device_type is not None and d.device_type != device_type:
raise ValueError("Unrecognized device: %s" % repr(d))
if device_type is None:
self._visible_device_list = []
else:
self._visible_device_list = [
d for d in self._visible_device_list if d.device_type != device_type
]
self._visible_device_list += devices
def get_memory_growth(self, dev):
"""Get if memory growth is enabled for a PhysicalDevice."""
if dev not in self._physical_devices:
raise ValueError("Unrecognized device: %s" % repr(dev))
return self._memory_growth_map[dev]
def set_memory_growth(self, dev, enable):
"""Set if memory growth should be enabled for a PhysicalDevice."""
if self._context_handle is not None:
raise RuntimeError("Memory growth must be set at program startup")
if dev not in self._physical_devices:
raise ValueError("Unrecognized device: %s" % repr(dev))
if dev in self._virtual_device_map:
raise ValueError(
"Cannot set memory growth on device when virtual devices configured")
self._memory_growth_map[dev] = enable
def get_virtual_device_configuration(self, dev):
"""Get the virtual device configuration for a PhysicalDevice."""
if dev not in self._physical_devices:
raise ValueError("Unrecognized device: %s" % repr(dev))
return self._virtual_device_map.get(dev)
def set_virtual_device_configuration(self, dev, virtual_devices):
"""Set the virtual device configuration for a PhysicalDevice."""
if self._context_handle is not None:
raise RuntimeError("Virtual devices must be set at program startup")
if dev not in self._physical_devices:
raise ValueError("Unrecognized device: %s" % repr(dev))
if dev.device_type == "CPU":
for vdev in virtual_devices:
if vdev.memory_limit is not None:
raise ValueError("Setting memory limit on CPU virtual devices is "
"currently not supported")
elif dev.device_type == "GPU":
for vdev in virtual_devices:
if vdev.memory_limit is None:
raise ValueError(
"Setting memory limit is required for GPU virtual devices is")
else:
raise ValueError("Virtual devices are not supported for %s" %
dev.device_type())
self._virtual_device_map[dev] = virtual_devices
@property
def optimizer_jit(self):
level = self.config.graph_options.optimizer_options.global_jit_level

View File

@ -54,7 +54,7 @@ def set_intra_op_parallelism_threads(num_threads):
def get_inter_op_parallelism_threads():
"""Get number of threads used for parallelism between independent operations.
Determines the number of threads used by independent non-blokcing operations.
Determines the number of threads used by independent non-blocking operations.
0 means the system picks an appropriate number.
Returns:
@ -67,7 +67,7 @@ def get_inter_op_parallelism_threads():
def set_inter_op_parallelism_threads(num_threads):
"""Set number of threads used for parallelism between independent operations.
Determines the number of threads used by independent non-blokcing operations.
Determines the number of threads used by independent non-blocking operations.
0 means the system picks an appropriate number.
Args:
@ -287,3 +287,126 @@ def set_synchronous_execution(enable):
context.context().execution_mode = context.SYNC
else:
context.context().execution_mode = context.ASYNC
def list_physical_devices(device_type=None):
"""Return a list of physical devices visible to the runtime.
Physical devices are hardware devices that are locally present on the current
machine. By default all discovered CPU and GPU devices are considered visible.
Args:
device_type: (optional) Device type to filter by such as "CPU" or "GPU"
Returns:
List of PhysicalDevice objects
"""
return context.context().list_physical_devices(device_type)
def list_logical_devices(device_type=None):
"""Return a list of logical devices created by runtime.
Logical devices may correspond to physical devices or remote devices in the
cluster. Operations and tensors may be placed on these devices by using the
`name` of the LogicalDevice in `tf.device(logical_device.name)`.
Args:
device_type: (optional) Device type to filter by such as "CPU" or "GPU"
Returns:
List of LogicalDevice objects
"""
return context.context().list_logical_devices(device_type=device_type)
def get_visible_devices(device_type=None):
"""Get the list of visible physical devices.
Returns a list of PhysicalDevice objects that are current marked as visible to
the runtime. Any visible devices will have LogicalDevices assigned to them
once the runtime is initialized.
Args:
device_type: (optional) Device types to limit query to.
Returns:
List of PhysicalDevice objects
"""
return context.context().get_visible_devices(device_type)
def set_visible_devices(devices, device_type=None):
"""Set the list of visible devices.
Sets the list of PhysicalDevices to be marked as visible to the runtime. Any
devices that are not marked as visible means TensorFlow will not allocate
memory on it and will not be able to place any operations on it as no
LogicalDevice will be created on it. By default all discovered devices are
marked as visible.
Args:
devices: (optional) List of PhysicalDevice objects to make visible
device_type: (optional) Device types to limit visibility configuration to.
Other device types will be left unaltered.
"""
context.context().set_visible_devices(devices, device_type)
def get_memory_growth(device):
"""Get if memory growth is enabled for a PhysicalDevice.
A PhysicalDevice with memory growth set will not allocate all memory on the
device upfront.
Args:
device: PhysicalDevice to query
Returns:
Current memory growth setting.
"""
return context.context().get_memory_growth(device)
def set_memory_growth(device, enable):
"""Set if memory growth should be enabled for a PhysicalDevice.
A PhysicalDevice with memory growth set will not allocate all memory on the
device upfront. Memory growth cannot be configured on a PhysicalDevice with
virtual devices configured.
Args:
device: PhysicalDevice to configure
enable: Whether to enable or disable memory growth
"""
context.context().set_memory_growth(device, enable)
def get_virtual_device_configuration(device):
"""Get the virtual device configuration for a PhysicalDevice.
Returns the list of VirtualDeviceConfiguration objects previously configured
by a call to `set_virtual_device_configuration()``.
Args:
device: PhysicalDevice to query
Returns:
List of VirtualDeviceConfiguration objects
"""
return context.context().get_virtual_device_configuration(device)
def set_virtual_device_configuration(device, virtual_devices):
"""Set the virtual device configuration for a PhysicalDevice.
A PhysicalDevice marked as visible will by default have a single LogicalDevice
allocated to it once the runtime is configured. Specifying a list of
VirtualDeviceConfiguration objects allows multiple devices to be configured
that utilize the same PhysicalDevice.
Args:
device: (optional) Need to update
virtual_devices: (optional) Need to update
"""
context.context().set_virtual_device_configuration(device, virtual_devices)

View File

@ -20,6 +20,9 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
@ -324,6 +327,196 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
self.assertIn(compat.as_bytes('GPU'), gpu2)
class DeviceTest(test.TestCase):
@reset_eager
def testPhysicalDevices(self):
cpus = config.list_physical_devices('CPU')
self.assertGreater(len(cpus), 0)
if test_util.is_gpu_available():
gpus = config.list_physical_devices('GPU')
self.assertGreater(len(gpus), 0)
@reset_eager
def testCpuMultiple(self):
cpus = config.list_physical_devices('CPU')
self.assertEqual(len(cpus), 1)
config.set_virtual_device_configuration(cpus[0], [
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration()
])
context.ensure_initialized()
cpus = config.list_logical_devices('CPU')
self.assertEqual(len(cpus), 2)
with ops.device('/device:CPU:0'):
a = constant_op.constant(1.0)
self.evaluate(a)
with ops.device('/device:CPU:1'):
b = constant_op.constant(1.0)
self.evaluate(b)
with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
with ops.device('/device:CPU:2'):
c = constant_op.constant(1.0)
self.evaluate(c)
# Ensure we can place ops on each of the device names
for cpu in cpus:
with ops.device(cpu.name):
d = constant_op.constant(1.0)
self.evaluate(d)
@test_util.run_gpu_only
@reset_eager
def testGpuNone(self):
gpus = config.list_physical_devices('GPU')
self.assertGreater(len(gpus), 0)
cpus = config.list_physical_devices('CPU')
self.assertEqual(len(cpus), 1)
self.assertEqual(len(config.get_visible_devices('CPU')), 1)
self.assertGreater(len(config.get_visible_devices('GPU')), 0)
config.set_visible_devices(cpus[0])
self.assertEqual(len(config.get_visible_devices('CPU')), 1)
self.assertEqual(len(config.get_visible_devices('GPU')), 0)
with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
with ops.device('/device:GPU:0'):
a = constant_op.constant(1.0)
self.evaluate(a)
@reset_eager
def testGpuMultiple(self):
gpus = config.list_physical_devices('GPU')
if len(gpus) < 2:
self.skipTest('Need at least 2 GPUs')
context.ensure_initialized()
for i in range(0, len(gpus)):
with ops.device('/device:GPU:' + str(i)):
a = constant_op.constant(1.0)
self.evaluate(a)
with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
with ops.device('/device:GPU:' + str(len(gpus))):
a = constant_op.constant(1.0)
self.evaluate(a)
@test_util.run_gpu_only
@reset_eager
def testVirtualGpu(self):
gpus = config.list_physical_devices('GPU')
self.assertNotEqual(len(gpus), 0)
self.assertIsNone(config.get_virtual_device_configuration(gpus[-1]))
config.set_virtual_device_configuration(gpus[-1], [
context.VirtualDeviceConfiguration(memory_limit=10),
context.VirtualDeviceConfiguration(memory_limit=10)
])
self.assertEqual(len(config.get_virtual_device_configuration(gpus[-1])), 2)
logical_gpus = config.list_logical_devices('GPU')
self.assertTrue(len(logical_gpus), len(gpus) + 1)
for i in range(0, len(logical_gpus)):
with ops.device('/device:GPU:' + str(i)):
a = constant_op.constant(1.0)
self.evaluate(a)
with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
with ops.device('/device:GPU:' + str(len(logical_gpus))):
a = constant_op.constant(1.0)
self.evaluate(a)
@test_util.run_gpu_only
@reset_eager
def testGpuInvalidConfig(self):
gpus = config.list_physical_devices('GPU')
self.assertNotEqual(len(gpus), 0)
for gpu in gpus:
config.set_memory_growth(gpu, True)
c = context.context().config
self.assertTrue(c.gpu_options.allow_growth)
with self.assertRaisesRegexp(ValueError, 'memory limit'):
config.set_virtual_device_configuration(gpus[-1], [
context.VirtualDeviceConfiguration(),
context.VirtualDeviceConfiguration()
])
self.assertIsNone(config.get_virtual_device_configuration(gpus[-1]))
config.set_virtual_device_configuration(gpus[-1], [
context.VirtualDeviceConfiguration(memory_limit=10),
context.VirtualDeviceConfiguration(memory_limit=10)
])
c = context.context().config
self.assertFalse(c.gpu_options.allow_growth)
with self.assertRaisesRegexp(ValueError, 'virtual devices'):
config.set_memory_growth(gpus[-1], False)
@test_util.run_gpu_only
@reset_eager
def testRemote(self):
gpus = config.list_logical_devices('GPU')
self.assertNotEqual(len(gpus), 0)
context.ensure_initialized()
gpus = config.list_logical_devices('GPU')
self.assertNotEqual(len(gpus), 0)
for gpu in gpus:
self.assertIsNotNone(gpu.name)
context.ensure_initialized()
job_name = 'test'
cluster_def = cluster_pb2.ClusterDef()
job_def = cluster_def.job.add()
job_def.name = job_name
job_def.tasks[0] = 'localhost:0'
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_def, job_name=job_name, task_index=0, protocol='grpc')
context.set_server_def(server_def)
gpus = config.list_logical_devices('GPU')
for gpu in gpus:
self.assertIsNotNone(gpu.name)
@test_util.run_gpu_only
@reset_eager
def testV1Compatibility(self):
# Ensure we set 1 CPU by default
context.context()._config = config_pb2.ConfigProto()
new_config = context.context().config
self.assertEqual(new_config.device_count['CPU'], 1)
context.context()._physical_devices = None
# Ensure CPU is split
context.context()._config = config_pb2.ConfigProto(device_count={'CPU': 2},)
new_config = context.context().config
self.assertEqual(new_config.device_count['CPU'], 2)
context.context()._physical_devices = None
# Ensure Handle visible device list parsing
context.context()._config = config_pb2.ConfigProto(
gpu_options=config_pb2.GPUOptions(visible_device_list='',),)
gpus = config.list_physical_devices('GPU')
new_config = context.context().config
self.assertEqual(new_config.gpu_options.visible_device_list,
','.join(str(i) for i in range(len(gpus))))
context.context()._physical_devices = None
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -82,6 +82,7 @@ limitations under the License.
%rename("%s") TFE_Py_RegisterVSpace;
%rename("%s") TFE_Py_EncodeArg;
%rename("%s") TFE_EnableCollectiveOps;
%rename("%s") TF_ListPhysicalDevices;
%rename("%s") TF_PickUnusedPortOrDie;
%rename("%s") TFE_MonitoringSetGauge;
%rename("%s") TFE_MonitoringAddCounter;
@ -90,8 +91,28 @@ limitations under the License.
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/common_runtime/device_factory.h"
static PyObject* TF_ListPhysicalDevices(TF_Status* status) {
std::vector<string> devices;
tensorflow::Status s = tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
tensorflow::Set_TF_Status_from_Status(status, s);
if (!s.ok()) {
Py_RETURN_NONE;
};
PyObject* result = PyList_New(devices.size());
int i = 0;
for (auto& dev : devices) {
PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
PyList_SetItem(result, i, dev_obj);
++i;
}
return result;
}
%}
static PyObject* TF_ListPhysicalDevices(TF_Status* status);
%typemap(in) (const void* proto) {
char* c_string;