Refactor TpuPlatformInterface::GetRegisteredPlatform() to make tries_left configurable.

This also bumps the logging down to INFO and doesn't log about waiting
1 second on the last try. This is to allow for disabling retries and
not having excess logging when it's expected to fail.

PiperOrigin-RevId: 332990478
Change-Id: If90a79e62049a45928e38e4e06d11fc5cae7788d
This commit is contained in:
Skye Wanderman-Milne 2020-09-21 19:47:14 -07:00 committed by TensorFlower Gardener
parent 05f4c0cba4
commit 127d50358a
2 changed files with 23 additions and 23 deletions

View File

@ -26,13 +26,8 @@ namespace tpu {
namespace {
TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform,
int tries_left = 5) {
if (tries_left <= 0) {
LOG(ERROR) << "Unable to find a TPU platform after exhausting all tries. "
"Returning nullptr...";
return nullptr;
}
int tries_left) {
DCHECK_GT(tries_left, 0);
// Prefer TpuPlatform if it's registered.
auto status_or_tpu_platform =
stream_executor::MultiPlatformManager::PlatformWithName(
@ -65,7 +60,8 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform,
}
// If we find at least one thing, we return the first thing we see.
if (status_or_other_tpu_platforms.ok()) {
if (status_or_other_tpu_platforms.ok() &&
!status_or_other_tpu_platforms->empty()) {
auto other_tpu_platforms = status_or_other_tpu_platforms.ValueOrDie();
LOG(WARNING) << other_tpu_platforms.size()
<< " TPU platforms registered, selecting "
@ -73,26 +69,26 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform,
return static_cast<TpuPlatformInterface*>(other_tpu_platforms[0]);
}
LOG(WARNING)
--tries_left;
if (tries_left <= 0) {
LOG(INFO) << "No TPU platform found.";
return nullptr;
}
LOG(INFO)
<< "No TPU platform registered. Waiting 1 second and trying again... ("
<< (tries_left - 1) << " tries left)";
<< tries_left << " tries left)";
Env::Default()->SleepForMicroseconds(1000000); // 1 second
return GetRegisteredPlatformStatic(initialize_platform, --tries_left);
return GetRegisteredPlatformStatic(initialize_platform, tries_left);
}
} // namespace
/* static */
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
return GetRegisteredPlatform(/*initialize_platform=*/true);
}
/* static */
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform(
bool initialize_platform) {
bool initialize_platform, int num_tries) {
static auto* mu = new mutex;
static bool requested_initialize_platform = initialize_platform;
static TpuPlatformInterface* tpu_registered_platform =
GetRegisteredPlatformStatic(initialize_platform);
GetRegisteredPlatformStatic(initialize_platform, num_tries);
mutex_lock lock(*mu);
if (!requested_initialize_platform && initialize_platform) {
@ -100,7 +96,8 @@ TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform(
// initializing the platform, but the next caller wants the platform
// initialized, we will call GetRegisteredPlatformStatic again to initialize
// the platform.
tpu_registered_platform = GetRegisteredPlatformStatic(initialize_platform);
tpu_registered_platform =
GetRegisteredPlatformStatic(initialize_platform, num_tries);
requested_initialize_platform = true;
}

View File

@ -35,10 +35,13 @@ class TpuPlatformInterface : public stream_executor::Platform {
// Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are
// registered, finds the most suitable one. Returns nullptr if no TPU platform
// is registered or an error occurred.
static TpuPlatformInterface* GetRegisteredPlatform();
// Option to not initialize a platform if not necessary.
static TpuPlatformInterface* GetRegisteredPlatform(bool initialize_platform);
//
// 'initialize_platform' can be set to false to not initialize a platform if
// not necessary. 'num_tries' specifies the number of tries if the TPU
// platform isn't initialized yet, with a 1-second delay between each try
// (num_tries == 1 means try once with no retries).
static TpuPlatformInterface* GetRegisteredPlatform(
bool initialize_platform = true, int num_tries = 5);
virtual Status Reset(bool only_tear_down, absl::string_view reason) = 0;