diff --git a/tensorflow/core/common_runtime/device_factory.cc b/tensorflow/core/common_runtime/device_factory.cc
index aed7a3c9dd7..2872da15c26 100644
--- a/tensorflow/core/common_runtime/device_factory.cc
+++ b/tensorflow/core/common_runtime/device_factory.cc
@@ -116,6 +116,48 @@ Status DeviceFactory::ListAllPhysicalDevices(std::vector<string>* devices) {
   return Status::OK();
 }
 
+Status DeviceFactory::GetAnyDeviceDetails(
+    int device_index, std::unordered_map<string, string>* details) {
+  if (device_index < 0) {
+    return errors::InvalidArgument("Device index out of bounds: ",
+                                   device_index);
+  }
+  const int orig_device_index = device_index;
+
+  // Iterate over devices in the same way as in ListAllPhysicalDevices.
+  auto cpu_factory = GetFactory("CPU");
+  if (!cpu_factory) {
+    return errors::NotFound(
+        "CPU Factory not registered. Did you link in threadpool_device?");
+  }
+
+  std::vector<string> devices;
+  TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(&devices));
+  if (device_index < devices.size()) {
+    return cpu_factory->GetDeviceDetails(device_index, details);
+  }
+  device_index -= devices.size();
+
+  // Then the rest (including GPU).
+  tf_shared_lock l(*get_device_factory_lock());
+  for (auto& p : device_factories()) {
+    auto factory = p.second.factory.get();
+    if (factory != cpu_factory) {
+      devices.clear();
+      // TODO(b/146009447): Find the factory size without having to allocate a
+      // vector with all the physical devices.
+      TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(&devices));
+      if (device_index < devices.size()) {
+        return factory->GetDeviceDetails(device_index, details);
+      }
+      device_index -= devices.size();
+    }
+  }
+
+  return errors::InvalidArgument("Device index out of bounds: ",
+                                 orig_device_index);
+}
+
 Status DeviceFactory::AddDevices(
     const SessionOptions& options, const string& name_prefix,
     std::vector<std::unique_ptr<Device>>* devices) {
diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h
index e18bb7f4834..c026a188f5e 100644
--- a/tensorflow/core/common_runtime/device_factory.h
+++ b/tensorflow/core/common_runtime/device_factory.h
@@ -55,9 +55,22 @@ class DeviceFactory {
   // CPU is are added first.
   static Status ListAllPhysicalDevices(std::vector<string>* devices);
 
+  // Get details for a specific device among all device factories.
+  // 'device_index' indexes into devices from ListAllPhysicalDevices.
+  static Status GetAnyDeviceDetails(
+      int device_index, std::unordered_map<string, string>* details);
+
   // For a specific device factory list all possible physical devices.
   virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0;
 
+  // Get details for a specific device for a specific factory. Subclasses
+  // can store arbitrary device information in the map. 'device_index' indexes
+  // into devices from ListPhysicalDevices.
+  virtual Status GetDeviceDetails(int device_index,
+                                  std::unordered_map<string, string>* details) {
+    return Status::OK();
+  }
+
   // Most clients should call AddDevices() instead.
   virtual Status CreateDevices(
       const SessionOptions& options, const string& name_prefix,
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 04b7f9d6082..57af898ecd2 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -1034,7 +1034,11 @@ Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
 const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
 const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
 
-Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
+Status BaseGPUDeviceFactory::CacheDeviceIds() {
+  if (!cached_device_ids_.empty()) {
+    return Status::OK();
+  }
+
   TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
   se::Platform* gpu_manager = GPUMachineManager();
   if (gpu_manager == nullptr) {
@@ -1047,15 +1051,14 @@ Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
   }
 
   std::vector<PlatformGpuId> visible_gpu_order(device_count);
-  int deviceNo = 0;
-  std::generate(visible_gpu_order.begin(), visible_gpu_order.end(),
-                [&deviceNo] { return deviceNo++; });
+  std::iota(visible_gpu_order.begin(), visible_gpu_order.end(), 0);
+  TF_RETURN_IF_ERROR(GetValidDeviceIds(visible_gpu_order, &cached_device_ids_));
+  return Status::OK();
+}
 
-  std::vector<PlatformGpuId> valid_platform_gpu_ids;
-  TF_RETURN_IF_ERROR(
-      GetValidDeviceIds(visible_gpu_order, &valid_platform_gpu_ids));
-
-  for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
+Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
+  TF_RETURN_IF_ERROR(CacheDeviceIds());
+  for (PlatformGpuId platform_gpu_id : cached_device_ids_) {
     const string device_name =
         strings::StrCat("/physical_device:GPU:", platform_gpu_id.value());
     devices->push_back(device_name);
@@ -1064,6 +1067,36 @@ Status BaseGPUDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
   return Status::OK();
 }
 
+Status BaseGPUDeviceFactory::GetDeviceDetails(
+    int device_index, std::unordered_map<string, string>* details) {
+  TF_RETURN_IF_ERROR(CacheDeviceIds());
+
+  if (device_index < 0 || device_index > cached_device_ids_.size()) {
+    return errors::Internal("Invalid device index: ", device_index);
+  }
+  PlatformGpuId platform_gpu_id = cached_device_ids_[device_index];
+
+  TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
+  se::Platform* gpu_manager = GPUMachineManager();
+  if (gpu_manager == nullptr) {
+    return errors::Internal("Cannot get GPUMachineManager");
+  }
+  auto desc_status = gpu_manager->DescriptionForDevice(platform_gpu_id.value());
+  if (!desc_status.ok()) {
+    return desc_status.status();
+  }
+
+  auto desc = desc_status.ConsumeValueOrDie();
+  (*details)["device_name"] = desc->name();
+#if GOOGLE_CUDA
+  int cc_major, cc_minor;
+  if (desc->cuda_compute_capability(&cc_major, &cc_minor)) {
+    (*details)["compute_capability"] = strings::StrCat(cc_major, ".", cc_minor);
+  }
+#endif  // GOOGLE_CUDA
+  return Status::OK();
+}
+
 Status BaseGPUDeviceFactory::CreateDevices(
     const SessionOptions& options, const string& name_prefix,
     std::vector<std::unique_ptr<Device>>* devices) {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 32c7738d916..5609334ce9c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -311,6 +311,8 @@ class BaseGPUDeviceFactory : public DeviceFactory {
   Status ListPhysicalDevices(std::vector<string>* devices) override;
   Status CreateDevices(const SessionOptions& options, const string& name_prefix,
                        std::vector<std::unique_ptr<Device>>* devices) override;
+  Status GetDeviceDetails(int device_index,
+                          std::unordered_map<string, string>* details) override;
 
   struct InterconnectMap {
     // Name of interconnect technology, if known.
@@ -369,9 +371,20 @@ class BaseGPUDeviceFactory : public DeviceFactory {
   Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order,
                            std::vector<PlatformGpuId>* ids);
 
+  // Cache the valid device IDs if not already cached. Cached IDs are stored in
+  // field cached_device_ids_. Passes {0, 1, ..., num_devices-1} to
+  // GetValidDeviceIds, so this should only be used in functions where all
+  // devices should be treated as visible, like ListPhysicalDevices.
+  Status CacheDeviceIds();
+
   // visible_gpu_initialized_[platform_gpu_id] is true if visible GPU
   // platform_gpu_id has been initialized by the process.
   std::unordered_map<int, bool> visible_gpu_initialized_;
+
+  // Cached device IDs, as returned by GetValidDeviceIds when every physical
+  // device is visible. Cache should not be used if some devices are not
+  // visible.
+  std::vector<PlatformGpuId> cached_device_ids_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index 1703d926f9f..dae744380e9 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -464,6 +464,23 @@ TEST_F(GPUDeviceTest, CopyTensorInSameDevice) {
   }
 }
 
+TEST_F(GPUDeviceTest, DeviceDetails) {
+  DeviceFactory* factory = DeviceFactory::GetFactory("GPU");
+  std::vector<string> devices;
+  TF_ASSERT_OK(factory->ListPhysicalDevices(&devices));
+  EXPECT_GE(devices.size(), 1);
+  for (int i = 0; i < devices.size(); i++) {
+    std::unordered_map<string, string> details;
+    TF_ASSERT_OK(factory->GetDeviceDetails(i, &details));
+    EXPECT_NE(details["device_name"], "");
+#if TENSORFLOW_USE_ROCM
+    EXPECT_EQ(details.count("compute_capability"), 0);
+#else
+    EXPECT_NE(details["compute_capability"], "");
+#endif
+  }
+}
+
 class GPUKernelTrackerTest : public ::testing::Test {
  protected:
   void Init(const GPUKernelTracker::Params& params) {
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index aa760583800..1c083ffe294 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -440,6 +440,7 @@ class Context(object):
 
     self._device_lock = threading.Lock()
     self._physical_devices = None
+    self._physical_device_to_index = None
     self._visible_device_list = []
     self._memory_growth_map = None
     self._virtual_device_map = {}
@@ -1226,6 +1227,10 @@ class Context(object):
       self._physical_devices = [
           PhysicalDevice(name=d.decode(),
                          device_type=d.decode().split(":")[1]) for d in devs]
+      self._physical_device_to_index = {
+          p: i for i, p in enumerate(self._physical_devices)
+      }
+
       # Construct the visible device list from all physical devices but ignore
       # XLA devices
       self._visible_device_list = [
@@ -1259,6 +1264,37 @@ class Context(object):
 
     return [d for d in self._physical_devices if d.device_type == device_type]
 
+  def get_device_details(self, device):  # pylint: disable=redefined-outer-name
+    """Returns details about a physical devices.
+
+    Args:
+      device: A `tf.config.PhysicalDevice` returned by
+        `tf.config.list_physical_devices` or `tf.config.get_visible_devices`.
+
+    Returns:
+      A dict with string keys.
+    """
+    if not isinstance(device, PhysicalDevice):
+      raise ValueError("device must be a tf.config.PhysicalDevice, but got: "
+                       "%s" % (device,))
+    if (self._physical_device_to_index is None or
+        device not in self._physical_device_to_index):
+      raise ValueError("The PhysicalDevice must be one obtained from "
+                       "calling `tf.config.list_physical_devices`, but got: "
+                       "%s" % (device,))
+    index = self._physical_device_to_index[device]
+    details = pywrap_tfe.TF_GetDeviceDetails(index)
+
+    # Change compute_capability from a string to a tuple
+    if "compute_capability" in details:
+      try:
+        major, minor = details["compute_capability"].split(".")
+        details["compute_capability"] = (int(major), int(minor))
+      except ValueError:
+        raise RuntimeError("Device returned compute capability an in invalid "
+                           "format: %s" % details["compute_capability"])
+    return details
+
   def _import_config(self):
     """Import config if passed in during construction.
 
diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py
index 5361d7290e8..9ff16f2a327 100644
--- a/tensorflow/python/framework/config.py
+++ b/tensorflow/python/framework/config.py
@@ -500,6 +500,51 @@ def set_memory_growth(device, enable):
   context.context().set_memory_growth(device, enable)
 
 
+@tf_export('config.experimental.get_device_details')
+def get_device_details(device):
+  """Returns details about a physical devices.
+
+  This API takes in a `tf.config.PhysicalDevice` returned by
+  `tf.config.list_physical_devices`. It returns a dict with string keys
+  containing various details about the device. Each key is only supported by a
+  subset of devices, so you should not assume the returned dict will have any
+  particular key.
+
+  >>> gpu_devices = tf.config.list_physical_devices('GPU')
+  >>> if gpu_devices:
+  ...   details = tf.config.experimental.get_device_details(gpu_devices[0])
+  ...   details.get('device_name', 'Unknown GPU')
+
+  Currently, details are only returned for GPUs. This function returns an
+  empty dict if passed a non-GPU device.
+
+  The returned dict may have the following keys:
+  * `'device_name'`: A human-readable name of the device as a string, e.g.
+    "Titan V". Unlike `tf.config.PhysicalDevice.name`, this will be the same for
+    multiple devices if each device is the same model. Currently only available
+    for GPUs.
+  * `'compute_capability'`: The
+    [compute capability](https://developer.nvidia.com/cuda-gpus) of the device
+    as a tuple of two ints, in the form `(major_version, minor_version)`. Only
+    available for NVIDIA GPUs
+
+  Note: This is similar to `tf.sysconfig.get_build_info` in that both functions
+  can return information relating to GPUs. However, this function returns
+  run-time information about a specific device (such as a GPU's compute
+  capability), while `tf.sysconfig.get_build_info` returns compile-time
+  information about how TensorFlow was built (such as what version of CUDA
+  TensorFlow was built for).
+
+  Args:
+    device: A `tf.config.PhysicalDevice` returned by
+      `tf.config.list_physical_devices` or `tf.config.get_visible_devices`.
+
+  Returns:
+    A dict with string keys.
+  """
+  return context.context().get_device_details(device)
+
+
 @tf_export('config.get_logical_device_configuration',
            'config.experimental.get_virtual_device_configuration')
 @deprecation.deprecated_endpoints(
diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py
index 3051f1d0623..65845535ea7 100644
--- a/tensorflow/python/framework/config_test.py
+++ b/tensorflow/python/framework/config_test.py
@@ -482,6 +482,49 @@ class DeviceTest(test.TestCase):
         a = constant_op.constant(1.0)
         self.evaluate(a)
 
+  @reset_eager
+  def testDeviceDetails(self):
+    (cpu,) = config.list_physical_devices('CPU')
+    details = config.get_device_details(cpu)
+    self.assertEqual(details, {})
+
+    if not test_util.is_gpu_available():
+      return
+
+    gpus = config.list_physical_devices('GPU')
+    details = config.get_device_details(gpus[0])
+    self.assertIsInstance(details['device_name'], str)
+    self.assertNotEmpty(details['device_name'])
+    if test.is_built_with_rocm():
+      # AMD GPUs do not have a compute capability
+      self.assertNotIn('compute_capability', details)
+    else:
+      cc = details['compute_capability']
+      self.assertIsInstance(cc, tuple)
+      major, minor = cc
+      self.assertGreater(major, 0)
+      self.assertGreaterEqual(minor, 0)
+
+    # Test GPU returned from get_visible_devices
+    if len(gpus) > 2:
+      config.set_visible_devices(gpus[1], 'GPU')
+      (visible_gpu,) = config.get_visible_devices('GPU')
+      details = config.get_device_details(visible_gpu)
+      self.assertIsInstance(details['device_name'], str)
+
+  @reset_eager
+  def testDeviceDetailsErrors(self):
+    logical_devices = config.list_logical_devices()
+    with self.assertRaisesRegexp(ValueError,
+                                 'must be a tf.config.PhysicalDevice'):
+      config.get_device_details(logical_devices[0])
+
+    phys_dev = context.PhysicalDevice('/physical_device:CPU:100', 'CPU')
+    with self.assertRaisesRegexp(
+        ValueError, 'The PhysicalDevice must be one obtained from '
+        'calling `tf.config.list_physical_devices`'):
+      config.get_device_details(phys_dev)
+
   @test_util.run_gpu_only
   @reset_eager
   def testVirtualGpu(self):
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index efcd912f430..2901a63c829 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -272,6 +272,16 @@ static py::object TF_ListPhysicalDevices() {
   return tensorflow::PyoOrThrow(result);
 }
 
+static std::unordered_map<string, string> TF_GetDeviceDetails(int index) {
+  tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
+  std::unordered_map<string, string> device_details;
+  tensorflow::Status s =
+      tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details);
+  tensorflow::Set_TF_Status_from_Status(status.get(), s);
+  MaybeRaiseRegisteredFromTFStatus(status.get());
+  return device_details;
+}
+
 static py::object TFE_ClearScalarCache() {
   tensorflow::TFE_TensorHandleCache::Get()->Clear();
   return py::none();
@@ -812,6 +822,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
   });
   m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
+  m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
   m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
         py::return_value_policy::reference);
   m.def("TF_DeviceListCount", &TF_DeviceListCount);
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt
index 0f3558e844e..7397719e656 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt
@@ -24,6 +24,10 @@ tf_module {
     name: "enable_mlir_graph_optimization"
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
+  member_method {
+    name: "get_device_details"
+    argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
+  }
   member_method {
     name: "get_device_policy"
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt
index 0f3558e844e..7397719e656 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt
@@ -24,6 +24,10 @@ tf_module {
     name: "enable_mlir_graph_optimization"
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
+  member_method {
+    name: "get_device_details"
+    argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
+  }
   member_method {
     name: "get_device_policy"
     argspec: "args=[], varargs=None, keywords=None, defaults=None"