Replace NumSchedulableCPUs() with MaxParallelism()
PiperOrigin-RevId: 255652572
This commit is contained in:
parent
c1ad328572
commit
870fb5c489
tensorflow/compiler
@ -59,7 +59,7 @@ void zero_buffers(XlaCompiledCpuFunction* computation) {
|
|||||||
|
|
||||||
// Trivial test that runs the generated function to ensure it doesn't crash.
|
// Trivial test that runs the generated function to ensure it doesn't crash.
|
||||||
TEST(TEST_NAME, NoCrash) {
|
TEST(TEST_NAME, NoCrash) {
|
||||||
Eigen::ThreadPool pool(port::NumSchedulableCPUs());
|
Eigen::ThreadPool pool(port::MaxParallelism());
|
||||||
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
|
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
|
||||||
|
|
||||||
CPP_CLASS computation;
|
CPP_CLASS computation;
|
||||||
@ -73,7 +73,7 @@ TEST(TEST_NAME, NoCrash) {
|
|||||||
void BM_NAME(int iters) {
|
void BM_NAME(int iters) {
|
||||||
testing::StopTiming();
|
testing::StopTiming();
|
||||||
|
|
||||||
Eigen::ThreadPool pool(port::NumSchedulableCPUs());
|
Eigen::ThreadPool pool(port::MaxParallelism());
|
||||||
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
|
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
|
||||||
|
|
||||||
CPP_CLASS computation;
|
CPP_CLASS computation;
|
||||||
|
@ -137,7 +137,7 @@ Backend::Backend(se::Platform* platform, Compiler* compiler,
|
|||||||
if (platform->id() == se::host::kHostPlatformId) {
|
if (platform->id() == se::host::kHostPlatformId) {
|
||||||
const int num_threads = intra_op_parallelism_threads > 0
|
const int num_threads = intra_op_parallelism_threads > 0
|
||||||
? intra_op_parallelism_threads
|
? intra_op_parallelism_threads
|
||||||
: tensorflow::port::NumSchedulableCPUs();
|
: tensorflow::port::MaxParallelism();
|
||||||
intra_op_thread_pool_.reset(new IntraOpThreadPool(num_threads));
|
intra_op_thread_pool_.reset(new IntraOpThreadPool(num_threads));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -77,7 +77,7 @@ class DefaultCostModel : public ParallelCostModel {
|
|||||||
// TODO(b/29630486) Develop system bandwidth model.
|
// TODO(b/29630486) Develop system bandwidth model.
|
||||||
max_parallelism = std::min<int64>(
|
max_parallelism = std::min<int64>(
|
||||||
max_parallelism_,
|
max_parallelism_,
|
||||||
std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())));
|
std::ceil(std::sqrt(tensorflow::port::MaxParallelism())));
|
||||||
// Use shape size instruction cost and L2 cache size min per-thread cost.
|
// Use shape size instruction cost and L2 cache size min per-thread cost.
|
||||||
instruction_cost = shape_size_(instruction->shape());
|
instruction_cost = shape_size_(instruction->shape());
|
||||||
min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
|
min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
|
||||||
|
@ -751,7 +751,7 @@ class ShapeUtil {
|
|||||||
// once with the proper empty indexes.
|
// once with the proper empty indexes.
|
||||||
int64 n = -1;
|
int64 n = -1;
|
||||||
std::vector<int64> indexes(base.begin(), base.end());
|
std::vector<int64> indexes(base.begin(), base.end());
|
||||||
const int kNumThreads = tensorflow::port::NumSchedulableCPUs();
|
const int kNumThreads = tensorflow::port::MaxParallelism();
|
||||||
absl::optional<tensorflow::thread::ThreadPool> pool;
|
absl::optional<tensorflow::thread::ThreadPool> pool;
|
||||||
if (parallel) {
|
if (parallel) {
|
||||||
pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
|
pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
|
||||||
|
@ -86,8 +86,7 @@ namespace {
|
|||||||
// Command-line opts to this tool. See main() for descriptions of these
|
// Command-line opts to this tool. See main() for descriptions of these
|
||||||
// fields.
|
// fields.
|
||||||
struct Options {
|
struct Options {
|
||||||
Options()
|
Options() : intra_op_thread_pool_size(tensorflow::port::MaxParallelism()) {}
|
||||||
: intra_op_thread_pool_size(tensorflow::port::NumSchedulableCPUs()) {}
|
|
||||||
|
|
||||||
bool NeedsRealData() const { return !use_fake_data && !compile_only; }
|
bool NeedsRealData() const { return !use_fake_data && !compile_only; }
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user