Cleanup PlatformUtil and MultiPlatformManager. NFC.
* Move implicit platform initialization logic in PlatformUtil to MulitPlatformManager so that latter is the only implicit initialization site. * Remove unused methods. PiperOrigin-RevId: 279422441 Change-Id: I37161feb4b96f839438e95a3b04f118279a77e8b
This commit is contained in:
parent
c36229facd
commit
ecd935d643
@ -63,53 +63,26 @@ string CanonicalPlatformName(const string& platform_name) {
|
||||
return lowercase_platform_name;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<se::Platform*>> GetSupportedPlatforms() {
|
||||
return se::MultiPlatformManager::PlatformsWithFilter(
|
||||
[](const se::Platform* platform) {
|
||||
auto compiler_status = Compiler::GetForPlatform(platform);
|
||||
bool supported = compiler_status.ok();
|
||||
if (!supported) {
|
||||
LOG(INFO) << "platform " << platform->Name() << " present but no "
|
||||
<< "XLA compiler available: "
|
||||
<< compiler_status.status().error_message();
|
||||
}
|
||||
return supported;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ StatusOr<std::vector<se::Platform*>>
|
||||
PlatformUtil::GetSupportedPlatforms() {
|
||||
std::vector<se::Platform*> all_platforms =
|
||||
se::MultiPlatformManager::AllPlatforms();
|
||||
if (all_platforms.empty()) {
|
||||
LOG(WARNING) << "no executor platforms available: platform map is empty";
|
||||
}
|
||||
|
||||
// Gather all platforms which have an XLA compiler.
|
||||
std::vector<se::Platform*> platforms;
|
||||
for (se::Platform* platform : all_platforms) {
|
||||
auto compiler_status = Compiler::GetForPlatform(platform);
|
||||
if (compiler_status.ok()) {
|
||||
if (!platform->Initialized()) {
|
||||
TF_RETURN_IF_ERROR(platform->Initialize({}));
|
||||
}
|
||||
platforms.push_back(platform);
|
||||
} else {
|
||||
LOG(INFO) << "platform " << platform->Name() << " present but no "
|
||||
<< "XLA compiler available: "
|
||||
<< compiler_status.status().error_message();
|
||||
}
|
||||
}
|
||||
return platforms;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<se::Platform*> PlatformUtil::GetSolePlatform() {
|
||||
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
|
||||
if (platforms.empty()) {
|
||||
return NotFound("no platforms found");
|
||||
} else if (platforms.size() == 1) {
|
||||
se::Platform* platform = platforms[0];
|
||||
if (!platform->Initialized()) {
|
||||
TF_RETURN_IF_ERROR(platform->Initialize({}));
|
||||
}
|
||||
return platform;
|
||||
}
|
||||
|
||||
// Multiple platforms present and we can't pick a reasonable default.
|
||||
string platforms_string = absl::StrJoin(
|
||||
platforms, ", ",
|
||||
[](string* out, const se::Platform* p) { out->append(p->Name()); });
|
||||
return InvalidArgument(
|
||||
"must specify platform because more than one platform found: %s",
|
||||
platforms_string);
|
||||
return xla::GetSupportedPlatforms();
|
||||
}
|
||||
|
||||
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
|
||||
@ -130,9 +103,6 @@ PlatformUtil::GetSupportedPlatforms() {
|
||||
}
|
||||
}
|
||||
if (platform != nullptr) {
|
||||
if (!platform->Initialized()) {
|
||||
TF_RETURN_IF_ERROR(platform->Initialize({}));
|
||||
}
|
||||
return platform;
|
||||
}
|
||||
|
||||
@ -148,47 +118,11 @@ PlatformUtil::GetSupportedPlatforms() {
|
||||
|
||||
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
|
||||
const string& platform_name) {
|
||||
string platform_str = CanonicalPlatformName(platform_name);
|
||||
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
|
||||
for (se::Platform* platform : platforms) {
|
||||
if (absl::AsciiStrToLower(platform->Name()) == platform_str) {
|
||||
if (!platform->Initialized()) {
|
||||
TF_RETURN_IF_ERROR(platform->Initialize({}));
|
||||
}
|
||||
return platform;
|
||||
}
|
||||
}
|
||||
return InvalidArgument("platform %s not found", platform_name);
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
|
||||
const string& platform_name) {
|
||||
string platform_str = CanonicalPlatformName(platform_name);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
|
||||
std::vector<se::Platform*> matched;
|
||||
for (se::Platform* platform : platforms) {
|
||||
if (absl::AsciiStrToLower(platform->Name()) != platform_name) {
|
||||
matched.push_back(platform);
|
||||
}
|
||||
}
|
||||
if (matched.empty()) {
|
||||
return InvalidArgument("unable to find platform that is not %s",
|
||||
platform_name);
|
||||
}
|
||||
if (matched.size() == 1) {
|
||||
auto platform = matched[0];
|
||||
if (!platform->Initialized()) {
|
||||
TF_RETURN_IF_ERROR(platform->Initialize({}));
|
||||
}
|
||||
return platform;
|
||||
}
|
||||
string matched_string = absl::StrJoin(
|
||||
matched, ", ",
|
||||
[](string* out, const se::Platform* p) { out->append(p->Name()); });
|
||||
return InvalidArgument(
|
||||
"found multiple platforms %s, but expected one platform except for %s",
|
||||
matched_string, platform_name);
|
||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||
se::MultiPlatformManager::PlatformWithName(
|
||||
CanonicalPlatformName(platform_name)));
|
||||
TF_RETURN_IF_ERROR(Compiler::GetForPlatform(platform).status());
|
||||
return platform;
|
||||
}
|
||||
|
||||
// Returns whether the device underlying the given StreamExecutor is supported
|
||||
|
||||
@ -44,20 +44,10 @@ class PlatformUtil {
|
||||
// platform. Otherwise returns an error.
|
||||
static StatusOr<se::Platform*> GetDefaultPlatform();
|
||||
|
||||
// Convenience function which returns the sole supported platform. If
|
||||
// exactly one supported platform is present, then this platform is the
|
||||
// default platform. Otherwise returns an error.
|
||||
static StatusOr<se::Platform*> GetSolePlatform();
|
||||
|
||||
// Returns the platform according to the given name. Returns error if there is
|
||||
// no such platform.
|
||||
static StatusOr<se::Platform*> GetPlatform(const string& platform_name);
|
||||
|
||||
// Returns exactly one platform that does not have given name. Returns error
|
||||
// if there is no such platform, or there are multiple such platforms.
|
||||
static StatusOr<se::Platform*> GetPlatformExceptFor(
|
||||
const string& platform_name);
|
||||
|
||||
// Returns a vector of StreamExecutors for the given platform.
|
||||
// If populated, only the devices in allowed_devices will have
|
||||
// their StreamExecutors initialized, otherwise all StreamExecutors will be
|
||||
|
||||
@ -28,10 +28,8 @@ namespace xla {
|
||||
namespace {
|
||||
|
||||
TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto platforms,
|
||||
xla::PlatformUtil::GetSupportedPlatforms());
|
||||
ASSERT_FALSE(platforms.empty());
|
||||
auto* platform = platforms[0];
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto* platform,
|
||||
xla::PlatformUtil::GetDefaultPlatform());
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto executors,
|
||||
xla::PlatformUtil::GetStreamExecutors(platform));
|
||||
xla::se::StreamExecutorMemoryAllocator allocator(platform, executors);
|
||||
|
||||
@ -110,13 +110,8 @@ class LLVMCompilerTest : public ::testing::Test {
|
||||
|
||||
private:
|
||||
Platform *FindPlatform() {
|
||||
for (Platform *platform :
|
||||
PlatformUtil::GetSupportedPlatforms().ConsumeValueOrDie()) {
|
||||
if (platform->Name() == platform_name_) {
|
||||
return platform;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
auto status_or_platform = PlatformUtil::GetPlatform(platform_name_);
|
||||
return status_or_platform.ok() ? status_or_platform.ValueOrDie() : nullptr;
|
||||
}
|
||||
|
||||
string platform_name_;
|
||||
|
||||
@ -45,7 +45,8 @@ class MultiPlatformManagerImpl {
|
||||
const Platform::Id& id, const std::map<string, string>& options)
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
|
||||
std::vector<Platform*> AllPlatforms() LOCKS_EXCLUDED(mu_);
|
||||
port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter) LOCKS_EXCLUDED(mu_);
|
||||
|
||||
using Listener = MultiPlatformManager::Listener;
|
||||
port::Status RegisterListener(std::unique_ptr<Listener> listener)
|
||||
@ -157,13 +158,21 @@ port::Status MultiPlatformManagerImpl::RegisterListener(
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
std::vector<Platform*> MultiPlatformManagerImpl::AllPlatforms() {
|
||||
port::StatusOr<std::vector<Platform*>>
|
||||
MultiPlatformManagerImpl::PlatformsWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
CHECK_EQ(id_map_.size(), name_map_.size());
|
||||
std::vector<Platform*> platforms;
|
||||
platforms.reserve(id_map_.size());
|
||||
for (const auto& entry : id_map_) {
|
||||
platforms.push_back(entry.second);
|
||||
Platform* platform = entry.second;
|
||||
if (filter(platform)) {
|
||||
if (!platform->Initialized()) {
|
||||
SE_RETURN_IF_ERROR(platform->Initialize({}));
|
||||
}
|
||||
platforms.push_back(platform);
|
||||
}
|
||||
}
|
||||
return platforms;
|
||||
}
|
||||
@ -230,8 +239,10 @@ MultiPlatformManager::InitializePlatformWithId(
|
||||
return Impl().RegisterListener(std::move(listener));
|
||||
}
|
||||
|
||||
/*static*/ std::vector<Platform*> MultiPlatformManager::AllPlatforms() {
|
||||
return Impl().AllPlatforms();
|
||||
/*static*/ port::StatusOr<std::vector<Platform*>>
|
||||
MultiPlatformManager::PlatformsWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter) {
|
||||
return Impl().PlatformsWithFilter(filter);
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
||||
@ -116,7 +116,10 @@ class MultiPlatformManager {
|
||||
static port::StatusOr<Platform*> InitializePlatformWithId(
|
||||
const Platform::Id& id, const std::map<string, string>& options);
|
||||
|
||||
static std::vector<Platform*> AllPlatforms();
|
||||
// Retrives the platforms satisfying the given filter, i.e. returns true.
|
||||
// Returned Platforms are always initialized.
|
||||
static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter);
|
||||
|
||||
// Although the MultiPlatformManager "owns" its platforms, it holds them as
|
||||
// undecorated pointers to prevent races during program exit (between this
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user