Refactor TpuPlatformInterface::GetRegisteredPlatform() to make tries_left configurable.
This also bumps the logging down to INFO and doesn't log about waiting 1 second on the last try. This is to allow for disabling retries and not having excess logging when it's expected to fail. PiperOrigin-RevId: 332990478 Change-Id: If90a79e62049a45928e38e4e06d11fc5cae7788d
This commit is contained in:
parent
05f4c0cba4
commit
127d50358a
tensorflow/stream_executor/tpu
@ -26,13 +26,8 @@ namespace tpu {
|
||||
|
||||
namespace {
|
||||
TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform,
|
||||
int tries_left = 5) {
|
||||
if (tries_left <= 0) {
|
||||
LOG(ERROR) << "Unable to find a TPU platform after exhausting all tries. "
|
||||
"Returning nullptr...";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int tries_left) {
|
||||
DCHECK_GT(tries_left, 0);
|
||||
// Prefer TpuPlatform if it's registered.
|
||||
auto status_or_tpu_platform =
|
||||
stream_executor::MultiPlatformManager::PlatformWithName(
|
||||
@ -65,7 +60,8 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform,
|
||||
}
|
||||
|
||||
// If we find at least one thing, we return the first thing we see.
|
||||
if (status_or_other_tpu_platforms.ok()) {
|
||||
if (status_or_other_tpu_platforms.ok() &&
|
||||
!status_or_other_tpu_platforms->empty()) {
|
||||
auto other_tpu_platforms = status_or_other_tpu_platforms.ValueOrDie();
|
||||
LOG(WARNING) << other_tpu_platforms.size()
|
||||
<< " TPU platforms registered, selecting "
|
||||
@ -73,26 +69,26 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform,
|
||||
return static_cast<TpuPlatformInterface*>(other_tpu_platforms[0]);
|
||||
}
|
||||
|
||||
LOG(WARNING)
|
||||
--tries_left;
|
||||
if (tries_left <= 0) {
|
||||
LOG(INFO) << "No TPU platform found.";
|
||||
return nullptr;
|
||||
}
|
||||
LOG(INFO)
|
||||
<< "No TPU platform registered. Waiting 1 second and trying again... ("
|
||||
<< (tries_left - 1) << " tries left)";
|
||||
<< tries_left << " tries left)";
|
||||
Env::Default()->SleepForMicroseconds(1000000); // 1 second
|
||||
return GetRegisteredPlatformStatic(initialize_platform, --tries_left);
|
||||
return GetRegisteredPlatformStatic(initialize_platform, tries_left);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* static */
|
||||
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
|
||||
return GetRegisteredPlatform(/*initialize_platform=*/true);
|
||||
}
|
||||
|
||||
/* static */
|
||||
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform(
|
||||
bool initialize_platform) {
|
||||
bool initialize_platform, int num_tries) {
|
||||
static auto* mu = new mutex;
|
||||
static bool requested_initialize_platform = initialize_platform;
|
||||
static TpuPlatformInterface* tpu_registered_platform =
|
||||
GetRegisteredPlatformStatic(initialize_platform);
|
||||
GetRegisteredPlatformStatic(initialize_platform, num_tries);
|
||||
|
||||
mutex_lock lock(*mu);
|
||||
if (!requested_initialize_platform && initialize_platform) {
|
||||
@ -100,7 +96,8 @@ TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform(
|
||||
// initializing the platform, but the next caller wants the platform
|
||||
// initialized, we will call GetRegisteredPlatformStatic again to initialize
|
||||
// the platform.
|
||||
tpu_registered_platform = GetRegisteredPlatformStatic(initialize_platform);
|
||||
tpu_registered_platform =
|
||||
GetRegisteredPlatformStatic(initialize_platform, num_tries);
|
||||
requested_initialize_platform = true;
|
||||
}
|
||||
|
||||
|
@ -35,10 +35,13 @@ class TpuPlatformInterface : public stream_executor::Platform {
|
||||
// Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are
|
||||
// registered, finds the most suitable one. Returns nullptr if no TPU platform
|
||||
// is registered or an error occurred.
|
||||
static TpuPlatformInterface* GetRegisteredPlatform();
|
||||
|
||||
// Option to not initialize a platform if not necessary.
|
||||
static TpuPlatformInterface* GetRegisteredPlatform(bool initialize_platform);
|
||||
//
|
||||
// 'initialize_platform' can be set to false to not initialize a platform if
|
||||
// not necessary. 'num_tries' specifies the number of tries if the TPU
|
||||
// platform isn't initialized yet, with a 1-second delay between each try
|
||||
// (num_tries == 1 means try once with no retries).
|
||||
static TpuPlatformInterface* GetRegisteredPlatform(
|
||||
bool initialize_platform = true, int num_tries = 5);
|
||||
|
||||
virtual Status Reset(bool only_tear_down, absl::string_view reason) = 0;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user