Add get_device_details
API.
This API can get details about physical devices. Right now, only GPUs are supported, and the only fields are "name" and "compute_capability". The primary motivation is to determine whether mixed precision will run well, as it only results in significant speedups on GPUs with compute capability 7.0 and greater. In general, it's rare that querying device details is necessary, as TensorFlow runs most ops well on all devices, but mixed precision is an exception. PiperOrigin-RevId: 315943445 Change-Id: I077fdc8f87a713ace74037fd2d82eede48740067
This commit is contained in:
parent
ee67c83f72
commit
bef9713188
@ -116,6 +116,48 @@ Status DeviceFactory::ListAllPhysicalDevices(std::vector<string>* devices) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status DeviceFactory::GetAnyDeviceDetails(
|
||||||
|
int device_index, std::unordered_map<string, string>* details) {
|
||||||
|
if (device_index < 0) {
|
||||||
|
return errors::InvalidArgument("Device index out of bounds: ",
|
||||||
|
device_index);
|
||||||
|
}
|
||||||
|
const int orig_device_index = device_index;
|
||||||
|
|
||||||
|
// Iterate over devices in the same way as in ListAllPhysicalDevices.
|
||||||
|
auto cpu_factory = GetFactory("CPU");
|
||||||
|
if (!cpu_factory) {
|
||||||
|
return errors::NotFound(
|
||||||
|
"CPU Factory not registered. Did you link in threadpool_device?");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> devices;
|
||||||
|
TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(&devices));
|
||||||
|
if (device_index < devices.size()) {
|
||||||
|
return cpu_factory->GetDeviceDetails(device_index, details);
|
||||||
|
}
|
||||||
|
device_index -= devices.size();
|
||||||
|
|
||||||
|
// Then the rest (including GPU).
|
||||||
|
tf_shared_lock l(*get_device_factory_lock());
|
||||||
|
for (auto& p : device_factories()) {
|
||||||
|
auto factory = p.second.factory.get();
|
||||||
|
if (factory != cpu_factory) {
|
||||||
|
devices.clear();
|
||||||
|
// TODO(b/146009447): Find the factory size without having to allocate a
|
||||||
|
// vector with all the physical devices.
|
||||||
|
TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(&devices));
|
||||||
|
if (device_index < devices.size()) {
|
||||||
|
return factory->GetDeviceDetails(device_index, details);
|
||||||
|
}
|
||||||
|
device_index -= devices.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors::InvalidArgument("Device index out of bounds: ",
|
||||||
|
orig_device_index);
|
||||||
|
}
|
||||||
|
|
||||||
Status DeviceFactory::AddDevices(
|
Status DeviceFactory::AddDevices(
|
||||||
const SessionOptions& options, const string& name_prefix,
|
const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<std::unique_ptr<Device>>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
|
@ -55,9 +55,22 @@ class DeviceFactory {
|
|||||||
// CPU is are added first.
|
// CPU is are added first.
|
||||||
static Status ListAllPhysicalDevices(std::vector<string>* devices);
|
static Status ListAllPhysicalDevices(std::vector<string>* devices);
|
||||||
|
|
||||||
|
// Get details for a specific device among all device factories.
|
||||||
|
// 'device_index' indexes into devices from ListAllPhysicalDevices.
|
||||||
|
static Status GetAnyDeviceDetails(
|
||||||
|
int device_index, std::unordered_map<string, string>* details);
|
||||||
|
|
||||||
// For a specific device factory list all possible physical devices.
|
// For a specific device factory list all possible physical devices.
|
||||||
virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0;
|
virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0;
|
||||||
|
|
||||||
|
// Get details for a specific device for a specific factory. Subclasses
|
||||||
|
// can store arbitrary device information in the map. 'device_index' indexes
|
||||||
|
// into devices from ListPhysicalDevices.
|
||||||
|
virtual Status GetDeviceDetails(int device_index,
|
||||||
|
std::unordered_map<string, string>* details) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Most clients should call AddDevices() instead.
|
// Most clients should call AddDevices() instead.
|
||||||
virtual Status CreateDevices(
|
virtual Status CreateDevices(
|
||||||
const SessionOptions& options, const string& name_prefix,
|
const SessionOptions& options, const string& name_prefix,
|
||||||
|
@ -1034,7 +1034,11 @@ Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
|
|||||||
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
|
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
|
||||||
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
|
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
|
||||||
|
|
||||||
Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
Status BaseGPUDeviceFactory::CacheDeviceIds() {
|
||||||
|
if (!cached_device_ids_.empty()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
|
TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
|
||||||
se::Platform* gpu_manager = GPUMachineManager();
|
se::Platform* gpu_manager = GPUMachineManager();
|
||||||
if (gpu_manager == nullptr) {
|
if (gpu_manager == nullptr) {
|
||||||
@ -1047,15 +1051,14 @@ Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<PlatformGpuId> visible_gpu_order(device_count);
|
std::vector<PlatformGpuId> visible_gpu_order(device_count);
|
||||||
int deviceNo = 0;
|
std::iota(visible_gpu_order.begin(), visible_gpu_order.end(), 0);
|
||||||
std::generate(visible_gpu_order.begin(), visible_gpu_order.end(),
|
TF_RETURN_IF_ERROR(GetValidDeviceIds(visible_gpu_order, &cached_device_ids_));
|
||||||
[&deviceNo] { return deviceNo++; });
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<PlatformGpuId> valid_platform_gpu_ids;
|
Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(CacheDeviceIds());
|
||||||
GetValidDeviceIds(visible_gpu_order, &valid_platform_gpu_ids));
|
for (PlatformGpuId platform_gpu_id : cached_device_ids_) {
|
||||||
|
|
||||||
for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
|
|
||||||
const string device_name =
|
const string device_name =
|
||||||
strings::StrCat("/physical_device:GPU:", platform_gpu_id.value());
|
strings::StrCat("/physical_device:GPU:", platform_gpu_id.value());
|
||||||
devices->push_back(device_name);
|
devices->push_back(device_name);
|
||||||
@ -1064,6 +1067,36 @@ Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status BaseGPUDeviceFactory::GetDeviceDetails(
|
||||||
|
int device_index, std::unordered_map<string, string>* details) {
|
||||||
|
TF_RETURN_IF_ERROR(CacheDeviceIds());
|
||||||
|
|
||||||
|
if (device_index < 0 || device_index > cached_device_ids_.size()) {
|
||||||
|
return errors::Internal("Invalid device index: ", device_index);
|
||||||
|
}
|
||||||
|
PlatformGpuId platform_gpu_id = cached_device_ids_[device_index];
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
|
||||||
|
se::Platform* gpu_manager = GPUMachineManager();
|
||||||
|
if (gpu_manager == nullptr) {
|
||||||
|
return errors::Internal("Cannot get GPUMachineManager");
|
||||||
|
}
|
||||||
|
auto desc_status = gpu_manager->DescriptionForDevice(platform_gpu_id.value());
|
||||||
|
if (!desc_status.ok()) {
|
||||||
|
return desc_status.status();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto desc = desc_status.ConsumeValueOrDie();
|
||||||
|
(*details)["device_name"] = desc->name();
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
int cc_major, cc_minor;
|
||||||
|
if (desc->cuda_compute_capability(&cc_major, &cc_minor)) {
|
||||||
|
(*details)["compute_capability"] = strings::StrCat(cc_major, ".", cc_minor);
|
||||||
|
}
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status BaseGPUDeviceFactory::CreateDevices(
|
Status BaseGPUDeviceFactory::CreateDevices(
|
||||||
const SessionOptions& options, const string& name_prefix,
|
const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<std::unique_ptr<Device>>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
|
@ -311,6 +311,8 @@ class BaseGPUDeviceFactory : public DeviceFactory {
|
|||||||
Status ListPhysicalDevices(std::vector<string>* devices) override;
|
Status ListPhysicalDevices(std::vector<string>* devices) override;
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<std::unique_ptr<Device>>* devices) override;
|
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||||
|
Status GetDeviceDetails(int device_index,
|
||||||
|
std::unordered_map<string, string>* details) override;
|
||||||
|
|
||||||
struct InterconnectMap {
|
struct InterconnectMap {
|
||||||
// Name of interconnect technology, if known.
|
// Name of interconnect technology, if known.
|
||||||
@ -369,9 +371,20 @@ class BaseGPUDeviceFactory : public DeviceFactory {
|
|||||||
Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order,
|
Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order,
|
||||||
std::vector<PlatformGpuId>* ids);
|
std::vector<PlatformGpuId>* ids);
|
||||||
|
|
||||||
|
// Cache the valid device IDs if not already cached. Cached IDs are stored in
|
||||||
|
// field cached_device_ids_. Passes {0, 1, ..., num_devices-1} to
|
||||||
|
// GetValidDeviceIds, so this should only be used in functions where all
|
||||||
|
// devices should be treated as visible, like ListPhysicalDevices.
|
||||||
|
Status CacheDeviceIds();
|
||||||
|
|
||||||
// visible_gpu_initialized_[platform_gpu_id] is true if visible GPU
|
// visible_gpu_initialized_[platform_gpu_id] is true if visible GPU
|
||||||
// platform_gpu_id has been initialized by the process.
|
// platform_gpu_id has been initialized by the process.
|
||||||
std::unordered_map<int, bool> visible_gpu_initialized_;
|
std::unordered_map<int, bool> visible_gpu_initialized_;
|
||||||
|
|
||||||
|
// Cached device IDs, as returned by GetValidDeviceIds when every physical
|
||||||
|
// device is visible. Cache should not be used if some devices are not
|
||||||
|
// visible.
|
||||||
|
std::vector<PlatformGpuId> cached_device_ids_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -464,6 +464,23 @@ TEST_F(GPUDeviceTest, CopyTensorInSameDevice) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GPUDeviceTest, DeviceDetails) {
|
||||||
|
DeviceFactory* factory = DeviceFactory::GetFactory("GPU");
|
||||||
|
std::vector<string> devices;
|
||||||
|
TF_ASSERT_OK(factory->ListPhysicalDevices(&devices));
|
||||||
|
EXPECT_GE(devices.size(), 1);
|
||||||
|
for (int i = 0; i < devices.size(); i++) {
|
||||||
|
std::unordered_map<string, string> details;
|
||||||
|
TF_ASSERT_OK(factory->GetDeviceDetails(i, &details));
|
||||||
|
EXPECT_NE(details["device_name"], "");
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
EXPECT_EQ(details.count("compute_capability"), 0);
|
||||||
|
#else
|
||||||
|
EXPECT_NE(details["compute_capability"], "");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class GPUKernelTrackerTest : public ::testing::Test {
|
class GPUKernelTrackerTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void Init(const GPUKernelTracker::Params& params) {
|
void Init(const GPUKernelTracker::Params& params) {
|
||||||
|
@ -440,6 +440,7 @@ class Context(object):
|
|||||||
|
|
||||||
self._device_lock = threading.Lock()
|
self._device_lock = threading.Lock()
|
||||||
self._physical_devices = None
|
self._physical_devices = None
|
||||||
|
self._physical_device_to_index = None
|
||||||
self._visible_device_list = []
|
self._visible_device_list = []
|
||||||
self._memory_growth_map = None
|
self._memory_growth_map = None
|
||||||
self._virtual_device_map = {}
|
self._virtual_device_map = {}
|
||||||
@ -1226,6 +1227,10 @@ class Context(object):
|
|||||||
self._physical_devices = [
|
self._physical_devices = [
|
||||||
PhysicalDevice(name=d.decode(),
|
PhysicalDevice(name=d.decode(),
|
||||||
device_type=d.decode().split(":")[1]) for d in devs]
|
device_type=d.decode().split(":")[1]) for d in devs]
|
||||||
|
self._physical_device_to_index = {
|
||||||
|
p: i for i, p in enumerate(self._physical_devices)
|
||||||
|
}
|
||||||
|
|
||||||
# Construct the visible device list from all physical devices but ignore
|
# Construct the visible device list from all physical devices but ignore
|
||||||
# XLA devices
|
# XLA devices
|
||||||
self._visible_device_list = [
|
self._visible_device_list = [
|
||||||
@ -1259,6 +1264,37 @@ class Context(object):
|
|||||||
|
|
||||||
return [d for d in self._physical_devices if d.device_type == device_type]
|
return [d for d in self._physical_devices if d.device_type == device_type]
|
||||||
|
|
||||||
|
def get_device_details(self, device): # pylint: disable=redefined-outer-name
|
||||||
|
"""Returns details about a physical devices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: A `tf.config.PhysicalDevice` returned by
|
||||||
|
`tf.config.list_physical_devices` or `tf.config.get_visible_devices`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict with string keys.
|
||||||
|
"""
|
||||||
|
if not isinstance(device, PhysicalDevice):
|
||||||
|
raise ValueError("device must be a tf.config.PhysicalDevice, but got: "
|
||||||
|
"%s" % (device,))
|
||||||
|
if (self._physical_device_to_index is None or
|
||||||
|
device not in self._physical_device_to_index):
|
||||||
|
raise ValueError("The PhysicalDevice must be one obtained from "
|
||||||
|
"calling `tf.config.list_physical_devices`, but got: "
|
||||||
|
"%s" % (device,))
|
||||||
|
index = self._physical_device_to_index[device]
|
||||||
|
details = pywrap_tfe.TF_GetDeviceDetails(index)
|
||||||
|
|
||||||
|
# Change compute_capability from a string to a tuple
|
||||||
|
if "compute_capability" in details:
|
||||||
|
try:
|
||||||
|
major, minor = details["compute_capability"].split(".")
|
||||||
|
details["compute_capability"] = (int(major), int(minor))
|
||||||
|
except ValueError:
|
||||||
|
raise RuntimeError("Device returned compute capability an in invalid "
|
||||||
|
"format: %s" % details["compute_capability"])
|
||||||
|
return details
|
||||||
|
|
||||||
def _import_config(self):
|
def _import_config(self):
|
||||||
"""Import config if passed in during construction.
|
"""Import config if passed in during construction.
|
||||||
|
|
||||||
|
@ -500,6 +500,51 @@ def set_memory_growth(device, enable):
|
|||||||
context.context().set_memory_growth(device, enable)
|
context.context().set_memory_growth(device, enable)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('config.experimental.get_device_details')
|
||||||
|
def get_device_details(device):
|
||||||
|
"""Returns details about a physical devices.
|
||||||
|
|
||||||
|
This API takes in a `tf.config.PhysicalDevice` returned by
|
||||||
|
`tf.config.list_physical_devices`. It returns a dict with string keys
|
||||||
|
containing various details about the device. Each key is only supported by a
|
||||||
|
subset of devices, so you should not assume the returned dict will have any
|
||||||
|
particular key.
|
||||||
|
|
||||||
|
>>> gpu_devices = tf.config.list_physical_devices('GPU')
|
||||||
|
>>> if gpu_devices:
|
||||||
|
... details = tf.config.experimental.get_device_details(gpu_devices[0])
|
||||||
|
... details.get('device_name', 'Unknown GPU')
|
||||||
|
|
||||||
|
Currently, details are only returned for GPUs. This function returns an
|
||||||
|
empty dict if passed a non-GPU device.
|
||||||
|
|
||||||
|
The returned dict may have the following keys:
|
||||||
|
* `'device_name'`: A human-readable name of the device as a string, e.g.
|
||||||
|
"Titan V". Unlike `tf.config.PhysicalDevice.name`, this will be the same for
|
||||||
|
multiple devices if each device is the same model. Currently only available
|
||||||
|
for GPUs.
|
||||||
|
* `'compute_capability'`: The
|
||||||
|
[compute capability](https://developer.nvidia.com/cuda-gpus) of the device
|
||||||
|
as a tuple of two ints, in the form `(major_version, minor_version)`. Only
|
||||||
|
available for NVIDIA GPUs
|
||||||
|
|
||||||
|
Note: This is similar to `tf.sysconfig.get_build_info` in that both functions
|
||||||
|
can return information relating to GPUs. However, this function returns
|
||||||
|
run-time information about a specific device (such as a GPU's compute
|
||||||
|
capability), while `tf.sysconfig.get_build_info` returns compile-time
|
||||||
|
information about how TensorFlow was built (such as what version of CUDA
|
||||||
|
TensorFlow was built for).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: A `tf.config.PhysicalDevice` returned by
|
||||||
|
`tf.config.list_physical_devices` or `tf.config.get_visible_devices`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict with string keys.
|
||||||
|
"""
|
||||||
|
return context.context().get_device_details(device)
|
||||||
|
|
||||||
|
|
||||||
@tf_export('config.get_logical_device_configuration',
|
@tf_export('config.get_logical_device_configuration',
|
||||||
'config.experimental.get_virtual_device_configuration')
|
'config.experimental.get_virtual_device_configuration')
|
||||||
@deprecation.deprecated_endpoints(
|
@deprecation.deprecated_endpoints(
|
||||||
|
@ -482,6 +482,49 @@ class DeviceTest(test.TestCase):
|
|||||||
a = constant_op.constant(1.0)
|
a = constant_op.constant(1.0)
|
||||||
self.evaluate(a)
|
self.evaluate(a)
|
||||||
|
|
||||||
|
@reset_eager
|
||||||
|
def testDeviceDetails(self):
|
||||||
|
(cpu,) = config.list_physical_devices('CPU')
|
||||||
|
details = config.get_device_details(cpu)
|
||||||
|
self.assertEqual(details, {})
|
||||||
|
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
gpus = config.list_physical_devices('GPU')
|
||||||
|
details = config.get_device_details(gpus[0])
|
||||||
|
self.assertIsInstance(details['device_name'], str)
|
||||||
|
self.assertNotEmpty(details['device_name'])
|
||||||
|
if test.is_built_with_rocm():
|
||||||
|
# AMD GPUs do not have a compute capability
|
||||||
|
self.assertNotIn('compute_capability', details)
|
||||||
|
else:
|
||||||
|
cc = details['compute_capability']
|
||||||
|
self.assertIsInstance(cc, tuple)
|
||||||
|
major, minor = cc
|
||||||
|
self.assertGreater(major, 0)
|
||||||
|
self.assertGreaterEqual(minor, 0)
|
||||||
|
|
||||||
|
# Test GPU returned from get_visible_devices
|
||||||
|
if len(gpus) > 2:
|
||||||
|
config.set_visible_devices(gpus[1], 'GPU')
|
||||||
|
(visible_gpu,) = config.get_visible_devices('GPU')
|
||||||
|
details = config.get_device_details(visible_gpu)
|
||||||
|
self.assertIsInstance(details['device_name'], str)
|
||||||
|
|
||||||
|
@reset_eager
|
||||||
|
def testDeviceDetailsErrors(self):
|
||||||
|
logical_devices = config.list_logical_devices()
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
'must be a tf.config.PhysicalDevice'):
|
||||||
|
config.get_device_details(logical_devices[0])
|
||||||
|
|
||||||
|
phys_dev = context.PhysicalDevice('/physical_device:CPU:100', 'CPU')
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'The PhysicalDevice must be one obtained from '
|
||||||
|
'calling `tf.config.list_physical_devices`'):
|
||||||
|
config.get_device_details(phys_dev)
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testVirtualGpu(self):
|
def testVirtualGpu(self):
|
||||||
|
@ -272,6 +272,16 @@ static py::object TF_ListPhysicalDevices() {
|
|||||||
return tensorflow::PyoOrThrow(result);
|
return tensorflow::PyoOrThrow(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::unordered_map<string, string> TF_GetDeviceDetails(int index) {
|
||||||
|
tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
|
||||||
|
std::unordered_map<string, string> device_details;
|
||||||
|
tensorflow::Status s =
|
||||||
|
tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details);
|
||||||
|
tensorflow::Set_TF_Status_from_Status(status.get(), s);
|
||||||
|
MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
return device_details;
|
||||||
|
}
|
||||||
|
|
||||||
static py::object TFE_ClearScalarCache() {
|
static py::object TFE_ClearScalarCache() {
|
||||||
tensorflow::TFE_TensorHandleCache::Get()->Clear();
|
tensorflow::TFE_TensorHandleCache::Get()->Clear();
|
||||||
return py::none();
|
return py::none();
|
||||||
@ -812,6 +822,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
});
|
});
|
||||||
m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
|
m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
|
||||||
|
m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
|
||||||
m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
|
m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
|
||||||
py::return_value_policy::reference);
|
py::return_value_policy::reference);
|
||||||
m.def("TF_DeviceListCount", &TF_DeviceListCount);
|
m.def("TF_DeviceListCount", &TF_DeviceListCount);
|
||||||
|
@ -24,6 +24,10 @@ tf_module {
|
|||||||
name: "enable_mlir_graph_optimization"
|
name: "enable_mlir_graph_optimization"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_device_details"
|
||||||
|
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_device_policy"
|
name: "get_device_policy"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -24,6 +24,10 @@ tf_module {
|
|||||||
name: "enable_mlir_graph_optimization"
|
name: "enable_mlir_graph_optimization"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_device_details"
|
||||||
|
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_device_policy"
|
name: "get_device_policy"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
Reference in New Issue
Block a user