From 5cb1025a172c23f2fc60d9396aa62ba8c2d13669 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Mon, 20 Jul 2020 15:06:53 -0700 Subject: [PATCH] Introduce initialize_platform parameter so that platforms can optionally be returned without forced initialization PiperOrigin-RevId: 322233924 Change-Id: I70a7d44887544d5b3030f4938d8d7fb0efe72bce --- .../stream_executor/multi_platform_manager.cc | 34 +++++++++++++++++-- .../stream_executor/multi_platform_manager.h | 7 ++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc index 64543a8ae4d..6c767d1d66e 100644 --- a/tensorflow/stream_executor/multi_platform_manager.cc +++ b/tensorflow/stream_executor/multi_platform_manager.cc @@ -39,6 +39,14 @@ class MultiPlatformManagerImpl { port::StatusOr PlatformWithId(const Platform::Id& id) TF_LOCKS_EXCLUDED(mu_); + port::StatusOr PlatformWithName(absl::string_view target, + bool initialize_platform) + TF_LOCKS_EXCLUDED(mu_); + + port::StatusOr PlatformWithId(const Platform::Id& id, + bool initialize_platform) + TF_LOCKS_EXCLUDED(mu_); + port::StatusOr InitializePlatformWithName( absl::string_view target, const std::map& options) TF_LOCKS_EXCLUDED(mu_); @@ -104,10 +112,20 @@ port::Status MultiPlatformManagerImpl::RegisterPlatform( port::StatusOr MultiPlatformManagerImpl::PlatformWithName( absl::string_view target) { + return PlatformWithName(target, /*initialize_platform=*/true); +} + +port::StatusOr MultiPlatformManagerImpl::PlatformWithId( + const Platform::Id& id) { + return PlatformWithId(id, /*initialize_platform=*/true); +} + +port::StatusOr MultiPlatformManagerImpl::PlatformWithName( + absl::string_view target, bool initialize_platform) { absl::MutexLock lock(&mu_); SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); - if (!platform->Initialized()) { + if (initialize_platform && !platform->Initialized()) { SE_RETURN_IF_ERROR(platform->Initialize({})); } @@ -115,11 +133,11 @@ port::StatusOr MultiPlatformManagerImpl::PlatformWithName( } port::StatusOr MultiPlatformManagerImpl::PlatformWithId( - const Platform::Id& id) { + const Platform::Id& id, bool initialize_platform) { absl::MutexLock lock(&mu_); SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); - if (!platform->Initialized()) { + if (initialize_platform && !platform->Initialized()) { SE_RETURN_IF_ERROR(platform->Initialize({})); } @@ -250,6 +268,16 @@ MultiPlatformManagerImpl& Impl() { return Impl().PlatformWithId(id); } +/*static*/ port::StatusOr MultiPlatformManager::PlatformWithId( + const Platform::Id& id, bool initialize_platform) { + return Impl().PlatformWithId(id, initialize_platform); +} + +/*static*/ port::StatusOr MultiPlatformManager::PlatformWithName( + absl::string_view target, bool initialize_platform) { + return Impl().PlatformWithName(target, initialize_platform); +} + /*static*/ port::StatusOr MultiPlatformManager::InitializePlatformWithName( absl::string_view target, diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h index 556015de790..fbb6effdf83 100644 --- a/tensorflow/stream_executor/multi_platform_manager.h +++ b/tensorflow/stream_executor/multi_platform_manager.h @@ -100,6 +100,13 @@ class MultiPlatformManager { static port::StatusOr PlatformWithName(absl::string_view target); static port::StatusOr PlatformWithId(const Platform::Id& id); + // Same functions as above, but allows platforms to be returned without + // initialization if initialize_platform == false. + static port::StatusOr PlatformWithName(absl::string_view target, + bool initialize_platform); + static port::StatusOr PlatformWithId(const Platform::Id& id, + bool initialize_platform); + // Retrieves the platform registered with the given platform name (e.g. // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the // Platform's Id() method).