[PJRT] Extract utils for getting MaxParallelism level.

PiperOrigin-RevId: 360060466
Change-Id: I5dcb4ba1d3ef82f2094561afa63138d1ac8055c2
This commit is contained in:
Qiao Zhang 2021-02-28 12:38:15 -08:00 committed by TensorFlower Gardener
parent 90078eb344
commit 59ff97c07b
3 changed files with 16 additions and 13 deletions

View File

@ -187,19 +187,6 @@ class CpuAllocator : public tensorflow::Allocator {
}
};
static int DefaultThreadPoolSize() {
// Google's CI system exposes an environment variable NPROC that describes
// a CPU reservation for tests.
// TODO(phawkins): expose a better thought-out set of knobs to control
// parallelism.
const char* nproc_str = std::getenv("NPROC");
int nproc = 0;
if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
return std::max(0, nproc);
}
return tensorflow::port::MaxParallelism();
}
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id,

View File

@ -260,4 +260,17 @@ StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
return parameters_to_donate;
}
int DefaultThreadPoolSize() {
// Google's CI system exposes an environment variable NPROC that describes
// a CPU reservation for tests.
// TODO(phawkins): expose a better thought-out set of knobs to control
// parallelism.
const char* nproc_str = std::getenv("NPROC");
int nproc = 0;
if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
return std::max(0, nproc);
}
return tensorflow::port::MaxParallelism();
}
} // namespace xla

View File

@ -52,6 +52,9 @@ Status DetermineArgumentLayoutsFromCompileOptions(
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
const HloModule& module, bool tuple_inputs);
// Return max parallelism level.
int DefaultThreadPoolSize();
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_