Introduce initialize_platform parameter so that platforms can optionally be returned without forced initialization

PiperOrigin-RevId: 322233924
Change-Id: I70a7d44887544d5b3030f4938d8d7fb0efe72bce
This commit is contained in:
Frank Chen 2020-07-20 15:06:53 -07:00 committed by TensorFlower Gardener
parent cf3a2e3c5d
commit 5cb1025a17
2 changed files with 38 additions and 3 deletions

View File

@ -39,6 +39,14 @@ class MultiPlatformManagerImpl {
port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id)
TF_LOCKS_EXCLUDED(mu_);
port::StatusOr<Platform*> PlatformWithName(absl::string_view target,
bool initialize_platform)
TF_LOCKS_EXCLUDED(mu_);
port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id,
bool initialize_platform)
TF_LOCKS_EXCLUDED(mu_);
port::StatusOr<Platform*> InitializePlatformWithName(
absl::string_view target,
const std::map<std::string, std::string>& options) TF_LOCKS_EXCLUDED(mu_);
@ -104,10 +112,20 @@ port::Status MultiPlatformManagerImpl::RegisterPlatform(
port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
absl::string_view target) {
return PlatformWithName(target, /*initialize_platform=*/true);
}
port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
const Platform::Id& id) {
return PlatformWithId(id, /*initialize_platform=*/true);
}
port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
absl::string_view target, bool initialize_platform) {
absl::MutexLock lock(&mu_);
SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
if (!platform->Initialized()) {
if (initialize_platform && !platform->Initialized()) {
SE_RETURN_IF_ERROR(platform->Initialize({}));
}
@ -115,11 +133,11 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
}
port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
const Platform::Id& id) {
const Platform::Id& id, bool initialize_platform) {
absl::MutexLock lock(&mu_);
SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
if (!platform->Initialized()) {
if (initialize_platform && !platform->Initialized()) {
SE_RETURN_IF_ERROR(platform->Initialize({}));
}
@ -250,6 +268,16 @@ MultiPlatformManagerImpl& Impl() {
return Impl().PlatformWithId(id);
}
/*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
const Platform::Id& id, bool initialize_platform) {
return Impl().PlatformWithId(id, initialize_platform);
}
/*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
absl::string_view target, bool initialize_platform) {
return Impl().PlatformWithName(target, initialize_platform);
}
/*static*/ port::StatusOr<Platform*>
MultiPlatformManager::InitializePlatformWithName(
absl::string_view target,

View File

@ -100,6 +100,13 @@ class MultiPlatformManager {
static port::StatusOr<Platform*> PlatformWithName(absl::string_view target);
static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id);
// Same functions as above, but allows platforms to be returned without
// initialization if initialize_platform == false.
static port::StatusOr<Platform*> PlatformWithName(absl::string_view target,
bool initialize_platform);
static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id,
bool initialize_platform);
// Retrieves the platform registered with the given platform name (e.g.
// "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the
// Platform's Id() method).