Merge pull request #43795 from Intel-tensorflow:vishakh1/openmp_changes

PiperOrigin-RevId: 337562557
Change-Id: I63932912e0b5e9766e10fd91da0d9b36a8451bda
This commit is contained in:
TensorFlower Gardener 2020-10-16 13:22:42 -07:00
commit a45e38dcb1
3 changed files with 5 additions and 18 deletions

View File

@ -1549,6 +1549,7 @@ cc_library(
":local_device", ":local_device",
":scoped_allocator", ":scoped_allocator",
":session_options", ":session_options",
"@com_google_absl//absl/base",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:lib", "//tensorflow/core:lib",

View File

@ -16,7 +16,6 @@ limitations under the License.
#ifdef INTEL_MKL #ifdef INTEL_MKL
#include "tensorflow/core/common_runtime/threadpool_device.h" #include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -37,15 +36,6 @@ TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
EXPECT_EQ(omp_get_max_threads(), (port::NumSchedulableCPUs() + ht - 1) / ht); EXPECT_EQ(omp_get_max_threads(), (port::NumSchedulableCPUs() + ht - 1) / ht);
} }
TEST(MKLThreadPoolDeviceTest, TestOmpPreSets) {
SessionOptions options;
setenv("OMP_NUM_THREADS", "314", 1);
ThreadPoolDevice* tp = new ThreadPoolDevice(
options, "/device:CPU:0", Bytes(256), DeviceLocality(), cpu_allocator());
EXPECT_EQ(omp_get_max_threads(), 314);
}
#endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL) #endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
} // namespace tensorflow } // namespace tensorflow

View File

@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/common_runtime/threadpool_device.h" #include "tensorflow/core/common_runtime/threadpool_device.h"
#include "absl/base/call_once.h"
#include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/common_runtime/scoped_allocator.h" #include "tensorflow/core/common_runtime/scoped_allocator.h"
#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h" #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
@ -55,18 +55,14 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
if (DisableMKL()) return; if (DisableMKL()) return;
#ifdef _OPENMP #ifdef _OPENMP
const char* user_omp_threads = getenv("OMP_NUM_THREADS"); const char* user_omp_threads = getenv("OMP_NUM_THREADS");
static absl::once_flag omp_setting_flag;
if (user_omp_threads == nullptr) { if (user_omp_threads == nullptr) {
// OMP_NUM_THREADS controls MKL's intra-op parallelization // OMP_NUM_THREADS controls MKL's intra-op parallelization
// Default to available physical cores // Default to available physical cores
const int mkl_intra_op = port::NumSchedulableCPUs(); const int mkl_intra_op = port::NumSchedulableCPUs();
const int ht = port::NumHyperthreadsPerCore(); const int ht = port::NumHyperthreadsPerCore();
omp_set_num_threads((mkl_intra_op + ht - 1) / ht); absl::call_once(omp_setting_flag, omp_set_num_threads,
} else { (mkl_intra_op + ht - 1) / ht);
uint64 user_val = 0;
if (strings::safe_strtou64(user_omp_threads, &user_val)) {
// Superflous but triggers OpenMP loading
omp_set_num_threads(user_val);
}
} }
#endif // _OPENMP #endif // _OPENMP
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL) #endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)