[PJRT] Extract utils for getting MaxParallelism level.
PiperOrigin-RevId: 360060466 Change-Id: I5dcb4ba1d3ef82f2094561afa63138d1ac8055c2
This commit is contained in:
parent
90078eb344
commit
59ff97c07b
tensorflow/compiler/xla/pjrt
@ -187,19 +187,6 @@ class CpuAllocator : public tensorflow::Allocator {
|
||||
}
|
||||
};
|
||||
|
||||
static int DefaultThreadPoolSize() {
|
||||
// Google's CI system exposes an environment variable NPROC that describes
|
||||
// a CPU reservation for tests.
|
||||
// TODO(phawkins): expose a better thought-out set of knobs to control
|
||||
// parallelism.
|
||||
const char* nproc_str = std::getenv("NPROC");
|
||||
int nproc = 0;
|
||||
if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
|
||||
return std::max(0, nproc);
|
||||
}
|
||||
return tensorflow::port::MaxParallelism();
|
||||
}
|
||||
|
||||
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id,
|
||||
|
@ -260,4 +260,17 @@ StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
return parameters_to_donate;
|
||||
}
|
||||
|
||||
int DefaultThreadPoolSize() {
|
||||
// Google's CI system exposes an environment variable NPROC that describes
|
||||
// a CPU reservation for tests.
|
||||
// TODO(phawkins): expose a better thought-out set of knobs to control
|
||||
// parallelism.
|
||||
const char* nproc_str = std::getenv("NPROC");
|
||||
int nproc = 0;
|
||||
if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
|
||||
return std::max(0, nproc);
|
||||
}
|
||||
return tensorflow::port::MaxParallelism();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -52,6 +52,9 @@ Status DetermineArgumentLayoutsFromCompileOptions(
|
||||
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
const HloModule& module, bool tuple_inputs);
|
||||
|
||||
// Return max parallelism level.
|
||||
int DefaultThreadPoolSize();
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user