From a0d47ceed2a3ed1953306e8c82e5358f21f13ce7 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 28 Apr 2020 01:19:52 -0700
Subject: [PATCH] Display the list of the available platforms in the error
 message.

PiperOrigin-RevId: 308781122
Change-Id: I80a0cba925eab878ad0f0c366032fbd9523fedbc
---
 .../stream_executor/multi_platform_manager.cc | 28 ++++++++++++++++++-
 1 file changed, 27 insertions(+), 1 deletion(-)

diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc
index dfee2152165..64543a8ae4d 100644
--- a/tensorflow/stream_executor/multi_platform_manager.cc
+++ b/tensorflow/stream_executor/multi_platform_manager.cc
@@ -18,6 +18,7 @@ limitations under the License.
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/ascii.h"
 #include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
 #include "absl/strings/string_view.h"
 #include "absl/synchronization/mutex.h"
 #include "tensorflow/core/platform/thread_annotations.h"
@@ -64,6 +65,13 @@ class MultiPlatformManagerImpl {
   port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
+  // Returns the names of the initialied platforms satisfying the given filter.
+  // By default, it will return all initialized platform names.
+  std::vector<std::string> InitializedPlatformNamesWithFilter(
+      const std::function<bool(const Platform*)>& filter = [](const Platform*) {
+        return true;
+      }) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
   absl::Mutex mu_;
   std::vector<std::unique_ptr<Listener>> listeners_ TF_GUARDED_BY(mu_);
   absl::flat_hash_map<Platform::Id, Platform*> id_map_ TF_GUARDED_BY(mu_);
@@ -179,6 +187,23 @@ MultiPlatformManagerImpl::PlatformsWithFilter(
   return platforms;
 }
 
+std::vector<std::string>
+MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter(
+    const std::function<bool(const Platform*)>& filter) {
+  CHECK_EQ(id_map_.size(), name_map_.size());
+  std::vector<std::string> initialized_platforms_names;
+  initialized_platforms_names.reserve(id_map_.size());
+  for (const auto& entry : id_map_) {
+    Platform* platform = entry.second;
+    if (filter(platform)) {
+      if (platform->Initialized()) {
+        initialized_platforms_names.push_back(platform->Name());
+      }
+    }
+  }
+  return initialized_platforms_names;
+}
+
 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
     absl::string_view target) {
   auto it = name_map_.find(absl::AsciiStrToLower(target));
@@ -186,7 +211,8 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
     return port::Status(
         port::error::NOT_FOUND,
         absl::StrCat("Could not find registered platform with name: \"", target,
-                     "\""));
+                     "\". Available platform names are: ",
+                     absl::StrJoin(InitializedPlatformNamesWithFilter(), " ")));
   }
   return it->second;
 }