Display the list of the available platforms in the error message.
PiperOrigin-RevId: 308781122 Change-Id: I80a0cba925eab878ad0f0c366032fbd9523fedbc
This commit is contained in:
parent
428cdeda09
commit
a0d47ceed2
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
@ -64,6 +65,13 @@ class MultiPlatformManagerImpl {
|
||||
port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Returns the names of the initialied platforms satisfying the given filter.
|
||||
// By default, it will return all initialized platform names.
|
||||
std::vector<std::string> InitializedPlatformNamesWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter = [](const Platform*) {
|
||||
return true;
|
||||
}) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
absl::Mutex mu_;
|
||||
std::vector<std::unique_ptr<Listener>> listeners_ TF_GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<Platform::Id, Platform*> id_map_ TF_GUARDED_BY(mu_);
|
||||
@ -179,6 +187,23 @@ MultiPlatformManagerImpl::PlatformsWithFilter(
|
||||
return platforms;
|
||||
}
|
||||
|
||||
std::vector<std::string>
|
||||
MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter) {
|
||||
CHECK_EQ(id_map_.size(), name_map_.size());
|
||||
std::vector<std::string> initialized_platforms_names;
|
||||
initialized_platforms_names.reserve(id_map_.size());
|
||||
for (const auto& entry : id_map_) {
|
||||
Platform* platform = entry.second;
|
||||
if (filter(platform)) {
|
||||
if (platform->Initialized()) {
|
||||
initialized_platforms_names.push_back(platform->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
return initialized_platforms_names;
|
||||
}
|
||||
|
||||
port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
|
||||
absl::string_view target) {
|
||||
auto it = name_map_.find(absl::AsciiStrToLower(target));
|
||||
@ -186,7 +211,8 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
|
||||
return port::Status(
|
||||
port::error::NOT_FOUND,
|
||||
absl::StrCat("Could not find registered platform with name: \"", target,
|
||||
"\""));
|
||||
"\". Available platform names are: ",
|
||||
absl::StrJoin(InitializedPlatformNamesWithFilter(), " ")));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user