Display the list of the available platforms in the error message.

PiperOrigin-RevId: 308781122
Change-Id: I80a0cba925eab878ad0f0c366032fbd9523fedbc
This commit is contained in:
A. Unique TensorFlower 2020-04-28 01:19:52 -07:00 committed by TensorFlower Gardener
parent 428cdeda09
commit a0d47ceed2

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@ -64,6 +65,13 @@ class MultiPlatformManagerImpl {
port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns the names of the initialied platforms satisfying the given filter.
// By default, it will return all initialized platform names.
std::vector<std::string> InitializedPlatformNamesWithFilter(
const std::function<bool(const Platform*)>& filter = [](const Platform*) {
return true;
}) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
absl::Mutex mu_;
std::vector<std::unique_ptr<Listener>> listeners_ TF_GUARDED_BY(mu_);
absl::flat_hash_map<Platform::Id, Platform*> id_map_ TF_GUARDED_BY(mu_);
@ -179,6 +187,23 @@ MultiPlatformManagerImpl::PlatformsWithFilter(
return platforms;
}
std::vector<std::string>
MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter(
const std::function<bool(const Platform*)>& filter) {
CHECK_EQ(id_map_.size(), name_map_.size());
std::vector<std::string> initialized_platforms_names;
initialized_platforms_names.reserve(id_map_.size());
for (const auto& entry : id_map_) {
Platform* platform = entry.second;
if (filter(platform)) {
if (platform->Initialized()) {
initialized_platforms_names.push_back(platform->Name());
}
}
}
return initialized_platforms_names;
}
port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
absl::string_view target) {
auto it = name_map_.find(absl::AsciiStrToLower(target));
@ -186,7 +211,8 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
return port::Status(
port::error::NOT_FOUND,
absl::StrCat("Could not find registered platform with name: \"", target,
"\""));
"\". Available platform names are: ",
absl::StrJoin(InitializedPlatformNamesWithFilter(), " ")));
}
return it->second;
}