Merge pull request #43795 from Intel-tensorflow:vishakh1/openmp_changes
PiperOrigin-RevId: 337562557 Change-Id: I63932912e0b5e9766e10fd91da0d9b36a8451bda
This commit is contained in:
commit
a45e38dcb1
tensorflow/core/common_runtime
@ -1549,6 +1549,7 @@ cc_library(
|
|||||||
":local_device",
|
":local_device",
|
||||||
":scoped_allocator",
|
":scoped_allocator",
|
||||||
":session_options",
|
":session_options",
|
||||||
|
"@com_google_absl//absl/base",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/cpu_info.h"
|
#include "tensorflow/core/platform/cpu_info.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -37,15 +36,6 @@ TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
|
|||||||
EXPECT_EQ(omp_get_max_threads(), (port::NumSchedulableCPUs() + ht - 1) / ht);
|
EXPECT_EQ(omp_get_max_threads(), (port::NumSchedulableCPUs() + ht - 1) / ht);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MKLThreadPoolDeviceTest, TestOmpPreSets) {
|
|
||||||
SessionOptions options;
|
|
||||||
setenv("OMP_NUM_THREADS", "314", 1);
|
|
||||||
|
|
||||||
ThreadPoolDevice* tp = new ThreadPoolDevice(
|
|
||||||
options, "/device:CPU:0", Bytes(256), DeviceLocality(), cpu_allocator());
|
|
||||||
|
|
||||||
EXPECT_EQ(omp_get_max_threads(), 314);
|
|
||||||
}
|
|
||||||
#endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
|
#endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
||||||
|
|
||||||
|
#include "absl/base/call_once.h"
|
||||||
#include "tensorflow/core/common_runtime/local_device.h"
|
#include "tensorflow/core/common_runtime/local_device.h"
|
||||||
#include "tensorflow/core/common_runtime/scoped_allocator.h"
|
#include "tensorflow/core/common_runtime/scoped_allocator.h"
|
||||||
#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
|
#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
|
||||||
@ -55,18 +55,14 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
|
|||||||
if (DisableMKL()) return;
|
if (DisableMKL()) return;
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
const char* user_omp_threads = getenv("OMP_NUM_THREADS");
|
const char* user_omp_threads = getenv("OMP_NUM_THREADS");
|
||||||
|
static absl::once_flag omp_setting_flag;
|
||||||
if (user_omp_threads == nullptr) {
|
if (user_omp_threads == nullptr) {
|
||||||
// OMP_NUM_THREADS controls MKL's intra-op parallelization
|
// OMP_NUM_THREADS controls MKL's intra-op parallelization
|
||||||
// Default to available physical cores
|
// Default to available physical cores
|
||||||
const int mkl_intra_op = port::NumSchedulableCPUs();
|
const int mkl_intra_op = port::NumSchedulableCPUs();
|
||||||
const int ht = port::NumHyperthreadsPerCore();
|
const int ht = port::NumHyperthreadsPerCore();
|
||||||
omp_set_num_threads((mkl_intra_op + ht - 1) / ht);
|
absl::call_once(omp_setting_flag, omp_set_num_threads,
|
||||||
} else {
|
(mkl_intra_op + ht - 1) / ht);
|
||||||
uint64 user_val = 0;
|
|
||||||
if (strings::safe_strtou64(user_omp_threads, &user_val)) {
|
|
||||||
// Superflous but triggers OpenMP loading
|
|
||||||
omp_set_num_threads(user_val);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#endif // _OPENMP
|
#endif // _OPENMP
|
||||||
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
||||||
|
Loading…
Reference in New Issue
Block a user