[PJRT] Extract utils for getting MaxParallelism level.
PiperOrigin-RevId: 360060466 Change-Id: I5dcb4ba1d3ef82f2094561afa63138d1ac8055c2
This commit is contained in:
parent
90078eb344
commit
59ff97c07b
@ -187,19 +187,6 @@ class CpuAllocator : public tensorflow::Allocator {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static int DefaultThreadPoolSize() {
|
|
||||||
// Google's CI system exposes an environment variable NPROC that describes
|
|
||||||
// a CPU reservation for tests.
|
|
||||||
// TODO(phawkins): expose a better thought-out set of knobs to control
|
|
||||||
// parallelism.
|
|
||||||
const char* nproc_str = std::getenv("NPROC");
|
|
||||||
int nproc = 0;
|
|
||||||
if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
|
|
||||||
return std::max(0, nproc);
|
|
||||||
}
|
|
||||||
return tensorflow::port::MaxParallelism();
|
|
||||||
}
|
|
||||||
|
|
||||||
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||||
std::string platform_name, LocalClient* client,
|
std::string platform_name, LocalClient* client,
|
||||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id,
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id,
|
||||||
|
@ -260,4 +260,17 @@ StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
|||||||
return parameters_to_donate;
|
return parameters_to_donate;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int DefaultThreadPoolSize() {
|
||||||
|
// Google's CI system exposes an environment variable NPROC that describes
|
||||||
|
// a CPU reservation for tests.
|
||||||
|
// TODO(phawkins): expose a better thought-out set of knobs to control
|
||||||
|
// parallelism.
|
||||||
|
const char* nproc_str = std::getenv("NPROC");
|
||||||
|
int nproc = 0;
|
||||||
|
if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
|
||||||
|
return std::max(0, nproc);
|
||||||
|
}
|
||||||
|
return tensorflow::port::MaxParallelism();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -52,6 +52,9 @@ Status DetermineArgumentLayoutsFromCompileOptions(
|
|||||||
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||||
const HloModule& module, bool tuple_inputs);
|
const HloModule& module, bool tuple_inputs);
|
||||||
|
|
||||||
|
// Return max parallelism level.
|
||||||
|
int DefaultThreadPoolSize();
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
|
#endif // TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user