From a0d47ceed2a3ed1953306e8c82e5358f21f13ce7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Tue, 28 Apr 2020 01:19:52 -0700 Subject: [PATCH] Display the list of the available platforms in the error message. PiperOrigin-RevId: 308781122 Change-Id: I80a0cba925eab878ad0f0c366032fbd9523fedbc --- .../stream_executor/multi_platform_manager.cc | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc index dfee2152165..64543a8ae4d 100644 --- a/tensorflow/stream_executor/multi_platform_manager.cc +++ b/tensorflow/stream_executor/multi_platform_manager.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/ascii.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -64,6 +65,13 @@ class MultiPlatformManagerImpl { port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Returns the names of the initialied platforms satisfying the given filter. + // By default, it will return all initialized platform names. + std::vector<std::string> InitializedPlatformNamesWithFilter( + const std::function<bool(const Platform*)>& filter = [](const Platform*) { + return true; + }) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Mutex mu_; std::vector<std::unique_ptr<Listener>> listeners_ TF_GUARDED_BY(mu_); absl::flat_hash_map<Platform::Id, Platform*> id_map_ TF_GUARDED_BY(mu_); @@ -179,6 +187,23 @@ MultiPlatformManagerImpl::PlatformsWithFilter( return platforms; } +std::vector<std::string> +MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter( + const std::function<bool(const Platform*)>& filter) { + CHECK_EQ(id_map_.size(), name_map_.size()); + std::vector<std::string> initialized_platforms_names; + initialized_platforms_names.reserve(id_map_.size()); + for (const auto& entry : id_map_) { + Platform* platform = entry.second; + if (filter(platform)) { + if (platform->Initialized()) { + initialized_platforms_names.push_back(platform->Name()); + } + } + } + return initialized_platforms_names; +} + port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked( absl::string_view target) { auto it = name_map_.find(absl::AsciiStrToLower(target)); @@ -186,7 +211,8 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked( return port::Status( port::error::NOT_FOUND, absl::StrCat("Could not find registered platform with name: \"", target, - "\"")); + "\". Available platform names are: ", + absl::StrJoin(InitializedPlatformNamesWithFilter(), " "))); } return it->second; }