Introduce initialize_platform parameter so that platforms can optionally be returned without forced initialization
PiperOrigin-RevId: 322233924 Change-Id: I70a7d44887544d5b3030f4938d8d7fb0efe72bce
This commit is contained in:
parent
cf3a2e3c5d
commit
5cb1025a17
@ -39,6 +39,14 @@ class MultiPlatformManagerImpl {
|
||||
port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id)
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
port::StatusOr<Platform*> PlatformWithName(absl::string_view target,
|
||||
bool initialize_platform)
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id,
|
||||
bool initialize_platform)
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
port::StatusOr<Platform*> InitializePlatformWithName(
|
||||
absl::string_view target,
|
||||
const std::map<std::string, std::string>& options) TF_LOCKS_EXCLUDED(mu_);
|
||||
@ -104,10 +112,20 @@ port::Status MultiPlatformManagerImpl::RegisterPlatform(
|
||||
|
||||
port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
|
||||
absl::string_view target) {
|
||||
return PlatformWithName(target, /*initialize_platform=*/true);
|
||||
}
|
||||
|
||||
port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
|
||||
const Platform::Id& id) {
|
||||
return PlatformWithId(id, /*initialize_platform=*/true);
|
||||
}
|
||||
|
||||
port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
|
||||
absl::string_view target, bool initialize_platform) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
|
||||
if (!platform->Initialized()) {
|
||||
if (initialize_platform && !platform->Initialized()) {
|
||||
SE_RETURN_IF_ERROR(platform->Initialize({}));
|
||||
}
|
||||
|
||||
@ -115,11 +133,11 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
|
||||
}
|
||||
|
||||
port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
|
||||
const Platform::Id& id) {
|
||||
const Platform::Id& id, bool initialize_platform) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
|
||||
if (!platform->Initialized()) {
|
||||
if (initialize_platform && !platform->Initialized()) {
|
||||
SE_RETURN_IF_ERROR(platform->Initialize({}));
|
||||
}
|
||||
|
||||
@ -250,6 +268,16 @@ MultiPlatformManagerImpl& Impl() {
|
||||
return Impl().PlatformWithId(id);
|
||||
}
|
||||
|
||||
/*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
|
||||
const Platform::Id& id, bool initialize_platform) {
|
||||
return Impl().PlatformWithId(id, initialize_platform);
|
||||
}
|
||||
|
||||
/*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
|
||||
absl::string_view target, bool initialize_platform) {
|
||||
return Impl().PlatformWithName(target, initialize_platform);
|
||||
}
|
||||
|
||||
/*static*/ port::StatusOr<Platform*>
|
||||
MultiPlatformManager::InitializePlatformWithName(
|
||||
absl::string_view target,
|
||||
|
@ -100,6 +100,13 @@ class MultiPlatformManager {
|
||||
static port::StatusOr<Platform*> PlatformWithName(absl::string_view target);
|
||||
static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id);
|
||||
|
||||
// Same functions as above, but allows platforms to be returned without
|
||||
// initialization if initialize_platform == false.
|
||||
static port::StatusOr<Platform*> PlatformWithName(absl::string_view target,
|
||||
bool initialize_platform);
|
||||
static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id,
|
||||
bool initialize_platform);
|
||||
|
||||
// Retrieves the platform registered with the given platform name (e.g.
|
||||
// "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the
|
||||
// Platform's Id() method).
|
||||
|
Loading…
x
Reference in New Issue
Block a user