Cleanup PlatformUtil and MultiPlatformManager. NFC.
* Move implicit platform initialization logic in PlatformUtil to MulitPlatformManager so that latter is the only implicit initialization site. * Remove unused methods. PiperOrigin-RevId: 279422441 Change-Id: I37161feb4b96f839438e95a3b04f118279a77e8b
This commit is contained in:
parent
c36229facd
commit
ecd935d643
@ -63,53 +63,26 @@ string CanonicalPlatformName(const string& platform_name) {
|
|||||||
return lowercase_platform_name;
|
return lowercase_platform_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<se::Platform*>> GetSupportedPlatforms() {
|
||||||
|
return se::MultiPlatformManager::PlatformsWithFilter(
|
||||||
|
[](const se::Platform* platform) {
|
||||||
|
auto compiler_status = Compiler::GetForPlatform(platform);
|
||||||
|
bool supported = compiler_status.ok();
|
||||||
|
if (!supported) {
|
||||||
|
LOG(INFO) << "platform " << platform->Name() << " present but no "
|
||||||
|
<< "XLA compiler available: "
|
||||||
|
<< compiler_status.status().error_message();
|
||||||
|
}
|
||||||
|
return supported;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/* static */ StatusOr<std::vector<se::Platform*>>
|
/* static */ StatusOr<std::vector<se::Platform*>>
|
||||||
PlatformUtil::GetSupportedPlatforms() {
|
PlatformUtil::GetSupportedPlatforms() {
|
||||||
std::vector<se::Platform*> all_platforms =
|
|
||||||
se::MultiPlatformManager::AllPlatforms();
|
|
||||||
if (all_platforms.empty()) {
|
|
||||||
LOG(WARNING) << "no executor platforms available: platform map is empty";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gather all platforms which have an XLA compiler.
|
// Gather all platforms which have an XLA compiler.
|
||||||
std::vector<se::Platform*> platforms;
|
return xla::GetSupportedPlatforms();
|
||||||
for (se::Platform* platform : all_platforms) {
|
|
||||||
auto compiler_status = Compiler::GetForPlatform(platform);
|
|
||||||
if (compiler_status.ok()) {
|
|
||||||
if (!platform->Initialized()) {
|
|
||||||
TF_RETURN_IF_ERROR(platform->Initialize({}));
|
|
||||||
}
|
|
||||||
platforms.push_back(platform);
|
|
||||||
} else {
|
|
||||||
LOG(INFO) << "platform " << platform->Name() << " present but no "
|
|
||||||
<< "XLA compiler available: "
|
|
||||||
<< compiler_status.status().error_message();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return platforms;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */ StatusOr<se::Platform*> PlatformUtil::GetSolePlatform() {
|
|
||||||
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
|
|
||||||
if (platforms.empty()) {
|
|
||||||
return NotFound("no platforms found");
|
|
||||||
} else if (platforms.size() == 1) {
|
|
||||||
se::Platform* platform = platforms[0];
|
|
||||||
if (!platform->Initialized()) {
|
|
||||||
TF_RETURN_IF_ERROR(platform->Initialize({}));
|
|
||||||
}
|
|
||||||
return platform;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiple platforms present and we can't pick a reasonable default.
|
|
||||||
string platforms_string = absl::StrJoin(
|
|
||||||
platforms, ", ",
|
|
||||||
[](string* out, const se::Platform* p) { out->append(p->Name()); });
|
|
||||||
return InvalidArgument(
|
|
||||||
"must specify platform because more than one platform found: %s",
|
|
||||||
platforms_string);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
|
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
|
||||||
@ -130,9 +103,6 @@ PlatformUtil::GetSupportedPlatforms() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (platform != nullptr) {
|
if (platform != nullptr) {
|
||||||
if (!platform->Initialized()) {
|
|
||||||
TF_RETURN_IF_ERROR(platform->Initialize({}));
|
|
||||||
}
|
|
||||||
return platform;
|
return platform;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,47 +118,11 @@ PlatformUtil::GetSupportedPlatforms() {
|
|||||||
|
|
||||||
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
|
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
|
||||||
const string& platform_name) {
|
const string& platform_name) {
|
||||||
string platform_str = CanonicalPlatformName(platform_name);
|
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||||
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
|
se::MultiPlatformManager::PlatformWithName(
|
||||||
for (se::Platform* platform : platforms) {
|
CanonicalPlatformName(platform_name)));
|
||||||
if (absl::AsciiStrToLower(platform->Name()) == platform_str) {
|
TF_RETURN_IF_ERROR(Compiler::GetForPlatform(platform).status());
|
||||||
if (!platform->Initialized()) {
|
return platform;
|
||||||
TF_RETURN_IF_ERROR(platform->Initialize({}));
|
|
||||||
}
|
|
||||||
return platform;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return InvalidArgument("platform %s not found", platform_name);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
|
|
||||||
const string& platform_name) {
|
|
||||||
string platform_str = CanonicalPlatformName(platform_name);
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
|
|
||||||
std::vector<se::Platform*> matched;
|
|
||||||
for (se::Platform* platform : platforms) {
|
|
||||||
if (absl::AsciiStrToLower(platform->Name()) != platform_name) {
|
|
||||||
matched.push_back(platform);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (matched.empty()) {
|
|
||||||
return InvalidArgument("unable to find platform that is not %s",
|
|
||||||
platform_name);
|
|
||||||
}
|
|
||||||
if (matched.size() == 1) {
|
|
||||||
auto platform = matched[0];
|
|
||||||
if (!platform->Initialized()) {
|
|
||||||
TF_RETURN_IF_ERROR(platform->Initialize({}));
|
|
||||||
}
|
|
||||||
return platform;
|
|
||||||
}
|
|
||||||
string matched_string = absl::StrJoin(
|
|
||||||
matched, ", ",
|
|
||||||
[](string* out, const se::Platform* p) { out->append(p->Name()); });
|
|
||||||
return InvalidArgument(
|
|
||||||
"found multiple platforms %s, but expected one platform except for %s",
|
|
||||||
matched_string, platform_name);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether the device underlying the given StreamExecutor is supported
|
// Returns whether the device underlying the given StreamExecutor is supported
|
||||||
|
|||||||
@ -44,20 +44,10 @@ class PlatformUtil {
|
|||||||
// platform. Otherwise returns an error.
|
// platform. Otherwise returns an error.
|
||||||
static StatusOr<se::Platform*> GetDefaultPlatform();
|
static StatusOr<se::Platform*> GetDefaultPlatform();
|
||||||
|
|
||||||
// Convenience function which returns the sole supported platform. If
|
|
||||||
// exactly one supported platform is present, then this platform is the
|
|
||||||
// default platform. Otherwise returns an error.
|
|
||||||
static StatusOr<se::Platform*> GetSolePlatform();
|
|
||||||
|
|
||||||
// Returns the platform according to the given name. Returns error if there is
|
// Returns the platform according to the given name. Returns error if there is
|
||||||
// no such platform.
|
// no such platform.
|
||||||
static StatusOr<se::Platform*> GetPlatform(const string& platform_name);
|
static StatusOr<se::Platform*> GetPlatform(const string& platform_name);
|
||||||
|
|
||||||
// Returns exactly one platform that does not have given name. Returns error
|
|
||||||
// if there is no such platform, or there are multiple such platforms.
|
|
||||||
static StatusOr<se::Platform*> GetPlatformExceptFor(
|
|
||||||
const string& platform_name);
|
|
||||||
|
|
||||||
// Returns a vector of StreamExecutors for the given platform.
|
// Returns a vector of StreamExecutors for the given platform.
|
||||||
// If populated, only the devices in allowed_devices will have
|
// If populated, only the devices in allowed_devices will have
|
||||||
// their StreamExecutors initialized, otherwise all StreamExecutors will be
|
// their StreamExecutors initialized, otherwise all StreamExecutors will be
|
||||||
|
|||||||
@ -28,10 +28,8 @@ namespace xla {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) {
|
TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto platforms,
|
TF_ASSERT_OK_AND_ASSIGN(auto* platform,
|
||||||
xla::PlatformUtil::GetSupportedPlatforms());
|
xla::PlatformUtil::GetDefaultPlatform());
|
||||||
ASSERT_FALSE(platforms.empty());
|
|
||||||
auto* platform = platforms[0];
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto executors,
|
TF_ASSERT_OK_AND_ASSIGN(auto executors,
|
||||||
xla::PlatformUtil::GetStreamExecutors(platform));
|
xla::PlatformUtil::GetStreamExecutors(platform));
|
||||||
xla::se::StreamExecutorMemoryAllocator allocator(platform, executors);
|
xla::se::StreamExecutorMemoryAllocator allocator(platform, executors);
|
||||||
|
|||||||
@ -110,13 +110,8 @@ class LLVMCompilerTest : public ::testing::Test {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Platform *FindPlatform() {
|
Platform *FindPlatform() {
|
||||||
for (Platform *platform :
|
auto status_or_platform = PlatformUtil::GetPlatform(platform_name_);
|
||||||
PlatformUtil::GetSupportedPlatforms().ConsumeValueOrDie()) {
|
return status_or_platform.ok() ? status_or_platform.ValueOrDie() : nullptr;
|
||||||
if (platform->Name() == platform_name_) {
|
|
||||||
return platform;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string platform_name_;
|
string platform_name_;
|
||||||
|
|||||||
@ -45,7 +45,8 @@ class MultiPlatformManagerImpl {
|
|||||||
const Platform::Id& id, const std::map<string, string>& options)
|
const Platform::Id& id, const std::map<string, string>& options)
|
||||||
LOCKS_EXCLUDED(mu_);
|
LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
std::vector<Platform*> AllPlatforms() LOCKS_EXCLUDED(mu_);
|
port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
|
||||||
|
const std::function<bool(const Platform*)>& filter) LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
using Listener = MultiPlatformManager::Listener;
|
using Listener = MultiPlatformManager::Listener;
|
||||||
port::Status RegisterListener(std::unique_ptr<Listener> listener)
|
port::Status RegisterListener(std::unique_ptr<Listener> listener)
|
||||||
@ -157,13 +158,21 @@ port::Status MultiPlatformManagerImpl::RegisterListener(
|
|||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Platform*> MultiPlatformManagerImpl::AllPlatforms() {
|
port::StatusOr<std::vector<Platform*>>
|
||||||
|
MultiPlatformManagerImpl::PlatformsWithFilter(
|
||||||
|
const std::function<bool(const Platform*)>& filter) {
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
CHECK_EQ(id_map_.size(), name_map_.size());
|
CHECK_EQ(id_map_.size(), name_map_.size());
|
||||||
std::vector<Platform*> platforms;
|
std::vector<Platform*> platforms;
|
||||||
platforms.reserve(id_map_.size());
|
platforms.reserve(id_map_.size());
|
||||||
for (const auto& entry : id_map_) {
|
for (const auto& entry : id_map_) {
|
||||||
platforms.push_back(entry.second);
|
Platform* platform = entry.second;
|
||||||
|
if (filter(platform)) {
|
||||||
|
if (!platform->Initialized()) {
|
||||||
|
SE_RETURN_IF_ERROR(platform->Initialize({}));
|
||||||
|
}
|
||||||
|
platforms.push_back(platform);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return platforms;
|
return platforms;
|
||||||
}
|
}
|
||||||
@ -230,8 +239,10 @@ MultiPlatformManager::InitializePlatformWithId(
|
|||||||
return Impl().RegisterListener(std::move(listener));
|
return Impl().RegisterListener(std::move(listener));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ std::vector<Platform*> MultiPlatformManager::AllPlatforms() {
|
/*static*/ port::StatusOr<std::vector<Platform*>>
|
||||||
return Impl().AllPlatforms();
|
MultiPlatformManager::PlatformsWithFilter(
|
||||||
|
const std::function<bool(const Platform*)>& filter) {
|
||||||
|
return Impl().PlatformsWithFilter(filter);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
|||||||
@ -116,7 +116,10 @@ class MultiPlatformManager {
|
|||||||
static port::StatusOr<Platform*> InitializePlatformWithId(
|
static port::StatusOr<Platform*> InitializePlatformWithId(
|
||||||
const Platform::Id& id, const std::map<string, string>& options);
|
const Platform::Id& id, const std::map<string, string>& options);
|
||||||
|
|
||||||
static std::vector<Platform*> AllPlatforms();
|
// Retrives the platforms satisfying the given filter, i.e. returns true.
|
||||||
|
// Returned Platforms are always initialized.
|
||||||
|
static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
|
||||||
|
const std::function<bool(const Platform*)>& filter);
|
||||||
|
|
||||||
// Although the MultiPlatformManager "owns" its platforms, it holds them as
|
// Although the MultiPlatformManager "owns" its platforms, it holds them as
|
||||||
// undecorated pointers to prevent races during program exit (between this
|
// undecorated pointers to prevent races during program exit (between this
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user