From 127d50358ab1dad3b673ac8ac8c2ab87982ed5fd Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 21 Sep 2020 19:47:14 -0700 Subject: [PATCH] Refactor TpuPlatformInterface::GetRegisteredPlatform() to make tries_left configurable. This also bumps the logging down to INFO and doesn't log about waiting 1 second on the last try. This is to allow for disabling retries and not having excess logging when it's expected to fail. PiperOrigin-RevId: 332990478 Change-Id: If90a79e62049a45928e38e4e06d11fc5cae7788d --- .../tpu/tpu_platform_interface.cc | 35 +++++++++---------- .../tpu/tpu_platform_interface.h | 11 +++--- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc index 9b8b9cd8ed5..5330dd49d2c 100644 --- a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc +++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc @@ -26,13 +26,8 @@ namespace tpu { namespace { TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, - int tries_left = 5) { - if (tries_left <= 0) { - LOG(ERROR) << "Unable to find a TPU platform after exhausting all tries. " - "Returning nullptr..."; - return nullptr; - } - + int tries_left) { + DCHECK_GT(tries_left, 0); // Prefer TpuPlatform if it's registered. auto status_or_tpu_platform = stream_executor::MultiPlatformManager::PlatformWithName( @@ -65,7 +60,8 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, } // If we find at least one thing, we return the first thing we see. - if (status_or_other_tpu_platforms.ok()) { + if (status_or_other_tpu_platforms.ok() && + !status_or_other_tpu_platforms->empty()) { auto other_tpu_platforms = status_or_other_tpu_platforms.ValueOrDie(); LOG(WARNING) << other_tpu_platforms.size() << " TPU platforms registered, selecting " @@ -73,26 +69,26 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, return static_cast(other_tpu_platforms[0]); } - LOG(WARNING) + --tries_left; + if (tries_left <= 0) { + LOG(INFO) << "No TPU platform found."; + return nullptr; + } + LOG(INFO) << "No TPU platform registered. Waiting 1 second and trying again... (" - << (tries_left - 1) << " tries left)"; + << tries_left << " tries left)"; Env::Default()->SleepForMicroseconds(1000000); // 1 second - return GetRegisteredPlatformStatic(initialize_platform, --tries_left); + return GetRegisteredPlatformStatic(initialize_platform, tries_left); } } // namespace -/* static */ -TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() { - return GetRegisteredPlatform(/*initialize_platform=*/true); -} - /* static */ TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform( - bool initialize_platform) { + bool initialize_platform, int num_tries) { static auto* mu = new mutex; static bool requested_initialize_platform = initialize_platform; static TpuPlatformInterface* tpu_registered_platform = - GetRegisteredPlatformStatic(initialize_platform); + GetRegisteredPlatformStatic(initialize_platform, num_tries); mutex_lock lock(*mu); if (!requested_initialize_platform && initialize_platform) { @@ -100,7 +96,8 @@ TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform( // initializing the platform, but the next caller wants the platform // initialized, we will call GetRegisteredPlatformStatic again to initialize // the platform. - tpu_registered_platform = GetRegisteredPlatformStatic(initialize_platform); + tpu_registered_platform = + GetRegisteredPlatformStatic(initialize_platform, num_tries); requested_initialize_platform = true; } diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.h b/tensorflow/stream_executor/tpu/tpu_platform_interface.h index 1aa30581d29..fee9d92b42d 100644 --- a/tensorflow/stream_executor/tpu/tpu_platform_interface.h +++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.h @@ -35,10 +35,13 @@ class TpuPlatformInterface : public stream_executor::Platform { // Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are // registered, finds the most suitable one. Returns nullptr if no TPU platform // is registered or an error occurred. - static TpuPlatformInterface* GetRegisteredPlatform(); - - // Option to not initialize a platform if not necessary. - static TpuPlatformInterface* GetRegisteredPlatform(bool initialize_platform); + // + // 'initialize_platform' can be set to false to not initialize a platform if + // not necessary. 'num_tries' specifies the number of tries if the TPU + // platform isn't initialized yet, with a 1-second delay between each try + // (num_tries == 1 means try once with no retries). + static TpuPlatformInterface* GetRegisteredPlatform( + bool initialize_platform = true, int num_tries = 5); virtual Status Reset(bool only_tear_down, absl::string_view reason) = 0;