Cleanup PlatformUtil and MultiPlatformManager. NFC.

* Move implicit platform initialization logic in PlatformUtil to MulitPlatformManager so that latter is the only implicit initialization site.
* Remove unused methods.

PiperOrigin-RevId: 279422441
Change-Id: I37161feb4b96f839438e95a3b04f118279a77e8b
This commit is contained in:
A. Unique TensorFlower 2019-11-08 17:18:45 -08:00 committed by TensorFlower Gardener
parent c36229facd
commit ecd935d643
6 changed files with 44 additions and 113 deletions

View File

@ -63,53 +63,26 @@ string CanonicalPlatformName(const string& platform_name) {
return lowercase_platform_name;
}
StatusOr<std::vector<se::Platform*>> GetSupportedPlatforms() {
return se::MultiPlatformManager::PlatformsWithFilter(
[](const se::Platform* platform) {
auto compiler_status = Compiler::GetForPlatform(platform);
bool supported = compiler_status.ok();
if (!supported) {
LOG(INFO) << "platform " << platform->Name() << " present but no "
<< "XLA compiler available: "
<< compiler_status.status().error_message();
}
return supported;
});
}
} // namespace
/* static */ StatusOr<std::vector<se::Platform*>>
PlatformUtil::GetSupportedPlatforms() {
std::vector<se::Platform*> all_platforms =
se::MultiPlatformManager::AllPlatforms();
if (all_platforms.empty()) {
LOG(WARNING) << "no executor platforms available: platform map is empty";
}
// Gather all platforms which have an XLA compiler.
std::vector<se::Platform*> platforms;
for (se::Platform* platform : all_platforms) {
auto compiler_status = Compiler::GetForPlatform(platform);
if (compiler_status.ok()) {
if (!platform->Initialized()) {
TF_RETURN_IF_ERROR(platform->Initialize({}));
}
platforms.push_back(platform);
} else {
LOG(INFO) << "platform " << platform->Name() << " present but no "
<< "XLA compiler available: "
<< compiler_status.status().error_message();
}
}
return platforms;
}
/* static */ StatusOr<se::Platform*> PlatformUtil::GetSolePlatform() {
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
if (platforms.empty()) {
return NotFound("no platforms found");
} else if (platforms.size() == 1) {
se::Platform* platform = platforms[0];
if (!platform->Initialized()) {
TF_RETURN_IF_ERROR(platform->Initialize({}));
}
return platform;
}
// Multiple platforms present and we can't pick a reasonable default.
string platforms_string = absl::StrJoin(
platforms, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"must specify platform because more than one platform found: %s",
platforms_string);
return xla::GetSupportedPlatforms();
}
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
@ -130,9 +103,6 @@ PlatformUtil::GetSupportedPlatforms() {
}
}
if (platform != nullptr) {
if (!platform->Initialized()) {
TF_RETURN_IF_ERROR(platform->Initialize({}));
}
return platform;
}
@ -148,47 +118,11 @@ PlatformUtil::GetSupportedPlatforms() {
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
const string& platform_name) {
string platform_str = CanonicalPlatformName(platform_name);
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
for (se::Platform* platform : platforms) {
if (absl::AsciiStrToLower(platform->Name()) == platform_str) {
if (!platform->Initialized()) {
TF_RETURN_IF_ERROR(platform->Initialize({}));
}
return platform;
}
}
return InvalidArgument("platform %s not found", platform_name);
}
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
const string& platform_name) {
string platform_str = CanonicalPlatformName(platform_name);
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
std::vector<se::Platform*> matched;
for (se::Platform* platform : platforms) {
if (absl::AsciiStrToLower(platform->Name()) != platform_name) {
matched.push_back(platform);
}
}
if (matched.empty()) {
return InvalidArgument("unable to find platform that is not %s",
platform_name);
}
if (matched.size() == 1) {
auto platform = matched[0];
if (!platform->Initialized()) {
TF_RETURN_IF_ERROR(platform->Initialize({}));
}
return platform;
}
string matched_string = absl::StrJoin(
matched, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"found multiple platforms %s, but expected one platform except for %s",
matched_string, platform_name);
TF_ASSIGN_OR_RETURN(se::Platform * platform,
se::MultiPlatformManager::PlatformWithName(
CanonicalPlatformName(platform_name)));
TF_RETURN_IF_ERROR(Compiler::GetForPlatform(platform).status());
return platform;
}
// Returns whether the device underlying the given StreamExecutor is supported

View File

@ -44,20 +44,10 @@ class PlatformUtil {
// platform. Otherwise returns an error.
static StatusOr<se::Platform*> GetDefaultPlatform();
// Convenience function which returns the sole supported platform. If
// exactly one supported platform is present, then this platform is the
// default platform. Otherwise returns an error.
static StatusOr<se::Platform*> GetSolePlatform();
// Returns the platform according to the given name. Returns error if there is
// no such platform.
static StatusOr<se::Platform*> GetPlatform(const string& platform_name);
// Returns exactly one platform that does not have given name. Returns error
// if there is no such platform, or there are multiple such platforms.
static StatusOr<se::Platform*> GetPlatformExceptFor(
const string& platform_name);
// Returns a vector of StreamExecutors for the given platform.
// If populated, only the devices in allowed_devices will have
// their StreamExecutors initialized, otherwise all StreamExecutors will be

View File

@ -28,10 +28,8 @@ namespace xla {
namespace {
TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) {
TF_ASSERT_OK_AND_ASSIGN(auto platforms,
xla::PlatformUtil::GetSupportedPlatforms());
ASSERT_FALSE(platforms.empty());
auto* platform = platforms[0];
TF_ASSERT_OK_AND_ASSIGN(auto* platform,
xla::PlatformUtil::GetDefaultPlatform());
TF_ASSERT_OK_AND_ASSIGN(auto executors,
xla::PlatformUtil::GetStreamExecutors(platform));
xla::se::StreamExecutorMemoryAllocator allocator(platform, executors);

View File

@ -110,13 +110,8 @@ class LLVMCompilerTest : public ::testing::Test {
private:
Platform *FindPlatform() {
for (Platform *platform :
PlatformUtil::GetSupportedPlatforms().ConsumeValueOrDie()) {
if (platform->Name() == platform_name_) {
return platform;
}
}
return nullptr;
auto status_or_platform = PlatformUtil::GetPlatform(platform_name_);
return status_or_platform.ok() ? status_or_platform.ValueOrDie() : nullptr;
}
string platform_name_;

View File

@ -45,7 +45,8 @@ class MultiPlatformManagerImpl {
const Platform::Id& id, const std::map<string, string>& options)
LOCKS_EXCLUDED(mu_);
std::vector<Platform*> AllPlatforms() LOCKS_EXCLUDED(mu_);
port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
const std::function<bool(const Platform*)>& filter) LOCKS_EXCLUDED(mu_);
using Listener = MultiPlatformManager::Listener;
port::Status RegisterListener(std::unique_ptr<Listener> listener)
@ -157,13 +158,21 @@ port::Status MultiPlatformManagerImpl::RegisterListener(
return port::Status::OK();
}
std::vector<Platform*> MultiPlatformManagerImpl::AllPlatforms() {
port::StatusOr<std::vector<Platform*>>
MultiPlatformManagerImpl::PlatformsWithFilter(
const std::function<bool(const Platform*)>& filter) {
absl::MutexLock lock(&mu_);
CHECK_EQ(id_map_.size(), name_map_.size());
std::vector<Platform*> platforms;
platforms.reserve(id_map_.size());
for (const auto& entry : id_map_) {
platforms.push_back(entry.second);
Platform* platform = entry.second;
if (filter(platform)) {
if (!platform->Initialized()) {
SE_RETURN_IF_ERROR(platform->Initialize({}));
}
platforms.push_back(platform);
}
}
return platforms;
}
@ -230,8 +239,10 @@ MultiPlatformManager::InitializePlatformWithId(
return Impl().RegisterListener(std::move(listener));
}
/*static*/ std::vector<Platform*> MultiPlatformManager::AllPlatforms() {
return Impl().AllPlatforms();
/*static*/ port::StatusOr<std::vector<Platform*>>
MultiPlatformManager::PlatformsWithFilter(
const std::function<bool(const Platform*)>& filter) {
return Impl().PlatformsWithFilter(filter);
}
} // namespace stream_executor

View File

@ -116,7 +116,10 @@ class MultiPlatformManager {
static port::StatusOr<Platform*> InitializePlatformWithId(
const Platform::Id& id, const std::map<string, string>& options);
static std::vector<Platform*> AllPlatforms();
// Retrives the platforms satisfying the given filter, i.e. returns true.
// Returned Platforms are always initialized.
static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
const std::function<bool(const Platform*)>& filter);
// Although the MultiPlatformManager "owns" its platforms, it holds them as
// undecorated pointers to prevent races during program exit (between this