From bef97131885096a15402e33051053240c3beda6a Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Thu, 11 Jun 2020 11:39:18 -0700 Subject: [PATCH] Add `get_device_details` API. This API can get details about physical devices. Right now, only GPUs are supported, and the only fields are "name" and "compute_capability". The primary motivation is to determine whether mixed precision will run well, as it only results in significant speedups on GPUs with compute capability 7.0 and greater. In general, it's rare that querying device details is necessary, as TensorFlow runs most ops well on all devices, but mixed precision is an exception. PiperOrigin-RevId: 315943445 Change-Id: I077fdc8f87a713ace74037fd2d82eede48740067 --- .../core/common_runtime/device_factory.cc | 42 +++++++++++++++ .../core/common_runtime/device_factory.h | 13 +++++ .../core/common_runtime/gpu/gpu_device.cc | 51 +++++++++++++++---- .../core/common_runtime/gpu/gpu_device.h | 13 +++++ .../common_runtime/gpu/gpu_device_test.cc | 17 +++++++ tensorflow/python/eager/context.py | 36 +++++++++++++ tensorflow/python/framework/config.py | 45 ++++++++++++++++ tensorflow/python/framework/config_test.py | 43 ++++++++++++++++ tensorflow/python/tfe_wrapper.cc | 11 ++++ .../v1/tensorflow.config.experimental.pbtxt | 4 ++ .../v2/tensorflow.config.experimental.pbtxt | 4 ++ 11 files changed, 270 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/common_runtime/device_factory.cc b/tensorflow/core/common_runtime/device_factory.cc index aed7a3c9dd7..2872da15c26 100644 --- a/tensorflow/core/common_runtime/device_factory.cc +++ b/tensorflow/core/common_runtime/device_factory.cc @@ -116,6 +116,48 @@ Status DeviceFactory::ListAllPhysicalDevices(std::vector* devices) { return Status::OK(); } +Status DeviceFactory::GetAnyDeviceDetails( + int device_index, std::unordered_map* details) { + if (device_index < 0) { + return errors::InvalidArgument("Device index out of bounds: ", + device_index); + } + const int orig_device_index = device_index; + + // Iterate over devices in the same way as in ListAllPhysicalDevices. + auto cpu_factory = GetFactory("CPU"); + if (!cpu_factory) { + return errors::NotFound( + "CPU Factory not registered. Did you link in threadpool_device?"); + } + + std::vector devices; + TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(&devices)); + if (device_index < devices.size()) { + return cpu_factory->GetDeviceDetails(device_index, details); + } + device_index -= devices.size(); + + // Then the rest (including GPU). + tf_shared_lock l(*get_device_factory_lock()); + for (auto& p : device_factories()) { + auto factory = p.second.factory.get(); + if (factory != cpu_factory) { + devices.clear(); + // TODO(b/146009447): Find the factory size without having to allocate a + // vector with all the physical devices. + TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(&devices)); + if (device_index < devices.size()) { + return factory->GetDeviceDetails(device_index, details); + } + device_index -= devices.size(); + } + } + + return errors::InvalidArgument("Device index out of bounds: ", + orig_device_index); +} + Status DeviceFactory::AddDevices( const SessionOptions& options, const string& name_prefix, std::vector>* devices) { diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h index e18bb7f4834..c026a188f5e 100644 --- a/tensorflow/core/common_runtime/device_factory.h +++ b/tensorflow/core/common_runtime/device_factory.h @@ -55,9 +55,22 @@ class DeviceFactory { // CPU is are added first. static Status ListAllPhysicalDevices(std::vector* devices); + // Get details for a specific device among all device factories. + // 'device_index' indexes into devices from ListAllPhysicalDevices. + static Status GetAnyDeviceDetails( + int device_index, std::unordered_map* details); + // For a specific device factory list all possible physical devices. virtual Status ListPhysicalDevices(std::vector* devices) = 0; + // Get details for a specific device for a specific factory. Subclasses + // can store arbitrary device information in the map. 'device_index' indexes + // into devices from ListPhysicalDevices. + virtual Status GetDeviceDetails(int device_index, + std::unordered_map* details) { + return Status::OK(); + } + // Most clients should call AddDevices() instead. virtual Status CreateDevices( const SessionOptions& options, const string& name_prefix, diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 04b7f9d6082..57af898ecd2 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -1034,7 +1034,11 @@ Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr, const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000; const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1; -Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector* devices) { +Status BaseGPUDeviceFactory::CacheDeviceIds() { + if (!cached_device_ids_.empty()) { + return Status::OK(); + } + TF_RETURN_IF_ERROR(ValidateGPUMachineManager()); se::Platform* gpu_manager = GPUMachineManager(); if (gpu_manager == nullptr) { @@ -1047,15 +1051,14 @@ Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector* devices) { } std::vector visible_gpu_order(device_count); - int deviceNo = 0; - std::generate(visible_gpu_order.begin(), visible_gpu_order.end(), - [&deviceNo] { return deviceNo++; }); + std::iota(visible_gpu_order.begin(), visible_gpu_order.end(), 0); + TF_RETURN_IF_ERROR(GetValidDeviceIds(visible_gpu_order, &cached_device_ids_)); + return Status::OK(); +} - std::vector valid_platform_gpu_ids; - TF_RETURN_IF_ERROR( - GetValidDeviceIds(visible_gpu_order, &valid_platform_gpu_ids)); - - for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) { +Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector* devices) { + TF_RETURN_IF_ERROR(CacheDeviceIds()); + for (PlatformGpuId platform_gpu_id : cached_device_ids_) { const string device_name = strings::StrCat("/physical_device:GPU:", platform_gpu_id.value()); devices->push_back(device_name); @@ -1064,6 +1067,36 @@ Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector* devices) { return Status::OK(); } +Status BaseGPUDeviceFactory::GetDeviceDetails( + int device_index, std::unordered_map* details) { + TF_RETURN_IF_ERROR(CacheDeviceIds()); + + if (device_index < 0 || device_index > cached_device_ids_.size()) { + return errors::Internal("Invalid device index: ", device_index); + } + PlatformGpuId platform_gpu_id = cached_device_ids_[device_index]; + + TF_RETURN_IF_ERROR(ValidateGPUMachineManager()); + se::Platform* gpu_manager = GPUMachineManager(); + if (gpu_manager == nullptr) { + return errors::Internal("Cannot get GPUMachineManager"); + } + auto desc_status = gpu_manager->DescriptionForDevice(platform_gpu_id.value()); + if (!desc_status.ok()) { + return desc_status.status(); + } + + auto desc = desc_status.ConsumeValueOrDie(); + (*details)["device_name"] = desc->name(); +#if GOOGLE_CUDA + int cc_major, cc_minor; + if (desc->cuda_compute_capability(&cc_major, &cc_minor)) { + (*details)["compute_capability"] = strings::StrCat(cc_major, ".", cc_minor); + } +#endif // GOOGLE_CUDA + return Status::OK(); +} + Status BaseGPUDeviceFactory::CreateDevices( const SessionOptions& options, const string& name_prefix, std::vector>* devices) { diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h index 32c7738d916..5609334ce9c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.h +++ b/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -311,6 +311,8 @@ class BaseGPUDeviceFactory : public DeviceFactory { Status ListPhysicalDevices(std::vector* devices) override; Status CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector>* devices) override; + Status GetDeviceDetails(int device_index, + std::unordered_map* details) override; struct InterconnectMap { // Name of interconnect technology, if known. @@ -369,9 +371,20 @@ class BaseGPUDeviceFactory : public DeviceFactory { Status GetValidDeviceIds(const std::vector& visible_gpu_order, std::vector* ids); + // Cache the valid device IDs if not already cached. Cached IDs are stored in + // field cached_device_ids_. Passes {0, 1, ..., num_devices-1} to + // GetValidDeviceIds, so this should only be used in functions where all + // devices should be treated as visible, like ListPhysicalDevices. + Status CacheDeviceIds(); + // visible_gpu_initialized_[platform_gpu_id] is true if visible GPU // platform_gpu_id has been initialized by the process. std::unordered_map visible_gpu_initialized_; + + // Cached device IDs, as returned by GetValidDeviceIds when every physical + // device is visible. Cache should not be used if some devices are not + // visible. + std::vector cached_device_ids_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc index 1703d926f9f..dae744380e9 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc @@ -464,6 +464,23 @@ TEST_F(GPUDeviceTest, CopyTensorInSameDevice) { } } +TEST_F(GPUDeviceTest, DeviceDetails) { + DeviceFactory* factory = DeviceFactory::GetFactory("GPU"); + std::vector devices; + TF_ASSERT_OK(factory->ListPhysicalDevices(&devices)); + EXPECT_GE(devices.size(), 1); + for (int i = 0; i < devices.size(); i++) { + std::unordered_map details; + TF_ASSERT_OK(factory->GetDeviceDetails(i, &details)); + EXPECT_NE(details["device_name"], ""); +#if TENSORFLOW_USE_ROCM + EXPECT_EQ(details.count("compute_capability"), 0); +#else + EXPECT_NE(details["compute_capability"], ""); +#endif + } +} + class GPUKernelTrackerTest : public ::testing::Test { protected: void Init(const GPUKernelTracker::Params& params) { diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index aa760583800..1c083ffe294 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -440,6 +440,7 @@ class Context(object): self._device_lock = threading.Lock() self._physical_devices = None + self._physical_device_to_index = None self._visible_device_list = [] self._memory_growth_map = None self._virtual_device_map = {} @@ -1226,6 +1227,10 @@ class Context(object): self._physical_devices = [ PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1]) for d in devs] + self._physical_device_to_index = { + p: i for i, p in enumerate(self._physical_devices) + } + # Construct the visible device list from all physical devices but ignore # XLA devices self._visible_device_list = [ @@ -1259,6 +1264,37 @@ class Context(object): return [d for d in self._physical_devices if d.device_type == device_type] + def get_device_details(self, device): # pylint: disable=redefined-outer-name + """Returns details about a physical devices. + + Args: + device: A `tf.config.PhysicalDevice` returned by + `tf.config.list_physical_devices` or `tf.config.get_visible_devices`. + + Returns: + A dict with string keys. + """ + if not isinstance(device, PhysicalDevice): + raise ValueError("device must be a tf.config.PhysicalDevice, but got: " + "%s" % (device,)) + if (self._physical_device_to_index is None or + device not in self._physical_device_to_index): + raise ValueError("The PhysicalDevice must be one obtained from " + "calling `tf.config.list_physical_devices`, but got: " + "%s" % (device,)) + index = self._physical_device_to_index[device] + details = pywrap_tfe.TF_GetDeviceDetails(index) + + # Change compute_capability from a string to a tuple + if "compute_capability" in details: + try: + major, minor = details["compute_capability"].split(".") + details["compute_capability"] = (int(major), int(minor)) + except ValueError: + raise RuntimeError("Device returned compute capability an in invalid " + "format: %s" % details["compute_capability"]) + return details + def _import_config(self): """Import config if passed in during construction. diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index 5361d7290e8..9ff16f2a327 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -500,6 +500,51 @@ def set_memory_growth(device, enable): context.context().set_memory_growth(device, enable) +@tf_export('config.experimental.get_device_details') +def get_device_details(device): + """Returns details about a physical devices. + + This API takes in a `tf.config.PhysicalDevice` returned by + `tf.config.list_physical_devices`. It returns a dict with string keys + containing various details about the device. Each key is only supported by a + subset of devices, so you should not assume the returned dict will have any + particular key. + + >>> gpu_devices = tf.config.list_physical_devices('GPU') + >>> if gpu_devices: + ... details = tf.config.experimental.get_device_details(gpu_devices[0]) + ... details.get('device_name', 'Unknown GPU') + + Currently, details are only returned for GPUs. This function returns an + empty dict if passed a non-GPU device. + + The returned dict may have the following keys: + * `'device_name'`: A human-readable name of the device as a string, e.g. + "Titan V". Unlike `tf.config.PhysicalDevice.name`, this will be the same for + multiple devices if each device is the same model. Currently only available + for GPUs. + * `'compute_capability'`: The + [compute capability](https://developer.nvidia.com/cuda-gpus) of the device + as a tuple of two ints, in the form `(major_version, minor_version)`. Only + available for NVIDIA GPUs + + Note: This is similar to `tf.sysconfig.get_build_info` in that both functions + can return information relating to GPUs. However, this function returns + run-time information about a specific device (such as a GPU's compute + capability), while `tf.sysconfig.get_build_info` returns compile-time + information about how TensorFlow was built (such as what version of CUDA + TensorFlow was built for). + + Args: + device: A `tf.config.PhysicalDevice` returned by + `tf.config.list_physical_devices` or `tf.config.get_visible_devices`. + + Returns: + A dict with string keys. + """ + return context.context().get_device_details(device) + + @tf_export('config.get_logical_device_configuration', 'config.experimental.get_virtual_device_configuration') @deprecation.deprecated_endpoints( diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index 3051f1d0623..65845535ea7 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -482,6 +482,49 @@ class DeviceTest(test.TestCase): a = constant_op.constant(1.0) self.evaluate(a) + @reset_eager + def testDeviceDetails(self): + (cpu,) = config.list_physical_devices('CPU') + details = config.get_device_details(cpu) + self.assertEqual(details, {}) + + if not test_util.is_gpu_available(): + return + + gpus = config.list_physical_devices('GPU') + details = config.get_device_details(gpus[0]) + self.assertIsInstance(details['device_name'], str) + self.assertNotEmpty(details['device_name']) + if test.is_built_with_rocm(): + # AMD GPUs do not have a compute capability + self.assertNotIn('compute_capability', details) + else: + cc = details['compute_capability'] + self.assertIsInstance(cc, tuple) + major, minor = cc + self.assertGreater(major, 0) + self.assertGreaterEqual(minor, 0) + + # Test GPU returned from get_visible_devices + if len(gpus) > 2: + config.set_visible_devices(gpus[1], 'GPU') + (visible_gpu,) = config.get_visible_devices('GPU') + details = config.get_device_details(visible_gpu) + self.assertIsInstance(details['device_name'], str) + + @reset_eager + def testDeviceDetailsErrors(self): + logical_devices = config.list_logical_devices() + with self.assertRaisesRegexp(ValueError, + 'must be a tf.config.PhysicalDevice'): + config.get_device_details(logical_devices[0]) + + phys_dev = context.PhysicalDevice('/physical_device:CPU:100', 'CPU') + with self.assertRaisesRegexp( + ValueError, 'The PhysicalDevice must be one obtained from ' + 'calling `tf.config.list_physical_devices`'): + config.get_device_details(phys_dev) + @test_util.run_gpu_only @reset_eager def testVirtualGpu(self): diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index efcd912f430..2901a63c829 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -272,6 +272,16 @@ static py::object TF_ListPhysicalDevices() { return tensorflow::PyoOrThrow(result); } +static std::unordered_map TF_GetDeviceDetails(int index) { + tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); + std::unordered_map device_details; + tensorflow::Status s = + tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details); + tensorflow::Set_TF_Status_from_Status(status.get(), s); + MaybeRaiseRegisteredFromTFStatus(status.get()); + return device_details; +} + static py::object TFE_ClearScalarCache() { tensorflow::TFE_TensorHandleCache::Get()->Clear(); return py::none(); @@ -812,6 +822,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) { tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); }); m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices); + m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails); m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList, py::return_value_policy::reference); m.def("TF_DeviceListCount", &TF_DeviceListCount); diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt index 0f3558e844e..7397719e656 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt @@ -24,6 +24,10 @@ tf_module { name: "enable_mlir_graph_optimization" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_device_details" + argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_device_policy" argspec: "args=[], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt index 0f3558e844e..7397719e656 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt @@ -24,6 +24,10 @@ tf_module { name: "enable_mlir_graph_optimization" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_device_details" + argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_device_policy" argspec: "args=[], varargs=None, keywords=None, defaults=None"