diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc index 2a6d6f5a7ae..649d83eebec 100644 --- a/tensorflow/core/common_runtime/local_device.cc +++ b/tensorflow/core/common_runtime/local_device.cc @@ -28,8 +28,26 @@ limitations under the License. #include "tensorflow/core/platform/numa.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { +namespace { + +bool OverrideGlobalThreadPoolFromEnvironment() { + static const bool override_global_threadpool = [] { + bool flag; + auto status = ReadBoolFromEnvVar("TF_OVERRIDE_GLOBAL_THREADPOOL", + /*default_val=*/false, &flag); + if (!status.ok()) { + LOG(ERROR) << "OverrideGlobalThreadPool: " << status.error_message(); + return false; + } + return flag; + }(); + return override_global_threadpool; +} + +} // namespace /* static */ bool LocalDevice::use_global_threadpool_ = true; @@ -107,6 +125,11 @@ LocalDevice::LocalDevice(const SessionOptions& options, // could speed up performance and are available on the current CPU. port::InfoAboutUnusedCPUFeatures(); LocalDevice::EigenThreadPoolInfo* tp_info; + + if (OverrideGlobalThreadPoolFromEnvironment()) { + set_use_global_threadpool(false); + } + if (use_global_threadpool_) { mutex_lock l(global_tp_mu_); if (options.config.experimental().use_numa_affinity()) { diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py index d6773d7b813..428505a7375 100644 --- a/tensorflow/python/platform/benchmark.py +++ b/tensorflow/python/platform/benchmark.py @@ -46,6 +46,10 @@ GLOBAL_BENCHMARK_REGISTRY = set() # See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv. TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX" +# Environment variable that lets the TensorFlow runtime allocate a new +# threadpool for each benchmark. +OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL" + def _global_report_benchmark( name, iters=None, cpu_time=None, wall_time=None, @@ -201,6 +205,12 @@ def benchmark_config(): class TensorFlowBenchmark(Benchmark): """Abstract class that provides helpers for TensorFlow benchmarks.""" + def __init__(self): + # Allow TensorFlow runtime to allocate a new threadpool with different + # number of threads for each new benchmark. + os.environ[OVERRIDE_GLOBAL_THREADPOOL] = "1" + super(TensorFlowBenchmark, self).__init__() + @classmethod def is_abstract(cls): # mro: (_BenchmarkRegistrar, Benchmark, TensorFlowBenchmark) means diff --git a/tensorflow/tools/api/golden/v1/tensorflow.test.-benchmark.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.test.-benchmark.pbtxt index 6fc489c8604..48f53a85454 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.test.-benchmark.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.test.-benchmark.pbtxt @@ -5,6 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { name: "evaluate" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt index 6fc489c8604..48f53a85454 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt @@ -5,6 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { name: "evaluate"