Display the list of the available platforms in the error message.
PiperOrigin-RevId: 308781122 Change-Id: I80a0cba925eab878ad0f0c366032fbd9523fedbc
This commit is contained in:
parent
428cdeda09
commit
a0d47ceed2
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/ascii.h"
|
#include "absl/strings/ascii.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
@ -64,6 +65,13 @@ class MultiPlatformManagerImpl {
|
|||||||
port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
|
port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// Returns the names of the initialied platforms satisfying the given filter.
|
||||||
|
// By default, it will return all initialized platform names.
|
||||||
|
std::vector<std::string> InitializedPlatformNamesWithFilter(
|
||||||
|
const std::function<bool(const Platform*)>& filter = [](const Platform*) {
|
||||||
|
return true;
|
||||||
|
}) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
absl::Mutex mu_;
|
absl::Mutex mu_;
|
||||||
std::vector<std::unique_ptr<Listener>> listeners_ TF_GUARDED_BY(mu_);
|
std::vector<std::unique_ptr<Listener>> listeners_ TF_GUARDED_BY(mu_);
|
||||||
absl::flat_hash_map<Platform::Id, Platform*> id_map_ TF_GUARDED_BY(mu_);
|
absl::flat_hash_map<Platform::Id, Platform*> id_map_ TF_GUARDED_BY(mu_);
|
||||||
@ -179,6 +187,23 @@ MultiPlatformManagerImpl::PlatformsWithFilter(
|
|||||||
return platforms;
|
return platforms;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::string>
|
||||||
|
MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter(
|
||||||
|
const std::function<bool(const Platform*)>& filter) {
|
||||||
|
CHECK_EQ(id_map_.size(), name_map_.size());
|
||||||
|
std::vector<std::string> initialized_platforms_names;
|
||||||
|
initialized_platforms_names.reserve(id_map_.size());
|
||||||
|
for (const auto& entry : id_map_) {
|
||||||
|
Platform* platform = entry.second;
|
||||||
|
if (filter(platform)) {
|
||||||
|
if (platform->Initialized()) {
|
||||||
|
initialized_platforms_names.push_back(platform->Name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return initialized_platforms_names;
|
||||||
|
}
|
||||||
|
|
||||||
port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
|
port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
|
||||||
absl::string_view target) {
|
absl::string_view target) {
|
||||||
auto it = name_map_.find(absl::AsciiStrToLower(target));
|
auto it = name_map_.find(absl::AsciiStrToLower(target));
|
||||||
@ -186,7 +211,8 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
|
|||||||
return port::Status(
|
return port::Status(
|
||||||
port::error::NOT_FOUND,
|
port::error::NOT_FOUND,
|
||||||
absl::StrCat("Could not find registered platform with name: \"", target,
|
absl::StrCat("Could not find registered platform with name: \"", target,
|
||||||
"\""));
|
"\". Available platform names are: ",
|
||||||
|
absl::StrJoin(InitializedPlatformNamesWithFilter(), " ")));
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user