diff --git a/.bazelrc b/.bazelrc index 7e0f820b4c2..f2aa3ac447b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl --define=build_with_mkl_dnn_v1_only=true build:mkl -c opt +# config to build OneDNN backend with a user specified threadpool. +build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true +build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl_threadpool --define=build_with_mkldnn_threadpool=true +build:mkl_threadpool -c opt # This config refers to building with CUDA available. It does not necessarily # mean that we build CUDA op kernels. build:using_cuda --define=using_cuda=true diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index c029de9a4e8..9e89094f4e7 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -48,6 +48,7 @@ load( "//third_party/mkl_dnn:build_defs.bzl", "if_mkl_open_source_only", "if_mkl_v1_open_source_only", + "if_mkldnn_threadpool", ) load( "//third_party/ngraph:build_defs.bzl", @@ -327,6 +328,11 @@ def tf_copts( if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_mkl_v1_open_source_only(["-DENABLE_MKLDNN_V1"]) + + if_mkldnn_threadpool([ + "-DENABLE_MKLDNN_THREADPOOL", + "-DENABLE_MKLDNN_V1", + "-DINTEL_MKL_DNN_ONLY", + ]) + if_enable_mkl(["-DENABLE_MKL"]) + if_ngraph(["-DINTEL_NGRAPH=1"]) + if_android_arm(["-mfpu=neon"]) + @@ -348,7 +354,9 @@ def tf_copts( ) def tf_openmp_copts(): - return if_mkl_lnx_x64(["-fopenmp"]) + # TODO(intel-mkl): Remove -fopenmp for threadpool after removing all + # omp pragmas in tensorflow/core. + return if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fopenmp"]) def tfe_xla_copts(): return select({ diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl index 4b8fb83eb09..bd0686523bc 100644 --- a/third_party/mkl/build_defs.bzl +++ b/third_party/mkl/build_defs.bzl @@ -107,6 +107,7 @@ def mkl_deps(): return select({ "@org_tensorflow//third_party/mkl_dnn:build_with_mkl_dnn_only": ["@mkl_dnn"], "@org_tensorflow//third_party/mkl_dnn:build_with_mkl_dnn_v1_only": ["@mkl_dnn_v1//:mkl_dnn"], + "@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_threadpool": ["@mkl_dnn_v1//:mkl_dnn"], "@org_tensorflow//third_party/mkl:build_with_mkl_ml_only": ["@org_tensorflow//third_party/mkl:intel_binary_blob"], "@org_tensorflow//third_party/mkl:build_with_mkl": [ "@org_tensorflow//third_party/mkl:intel_binary_blob", diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD index 774e5b0e2c0..fe558322916 100644 --- a/third_party/mkl_dnn/BUILD +++ b/third_party/mkl_dnn/BUILD @@ -27,6 +27,15 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "build_with_mkldnn_threadpool", + define_values = { + "build_with_mkl": "true", + "build_with_mkldnn_threadpool": "true", + }, + visibility = ["//visibility:public"], +) + bzl_library( name = "build_defs_bzl", srcs = ["build_defs.bzl"], diff --git a/third_party/mkl_dnn/build_defs.bzl b/third_party/mkl_dnn/build_defs.bzl index af05333c947..bd3b4b94f29 100644 --- a/third_party/mkl_dnn/build_defs.bzl +++ b/third_party/mkl_dnn/build_defs.bzl @@ -29,3 +29,19 @@ def if_mkl_v1_open_source_only(if_true, if_false = []): "@org_tensorflow//third_party/mkl_dnn:build_with_mkl_dnn_v1_only": if_true, "//conditions:default": if_false, }) + +def if_mkldnn_threadpool(if_true, if_false = []): + """Returns `if_true` if MKL-DNN v1.x is used. + + Shorthand for select()'ing on whether we're building with + MKL-DNN v1.x open source library only with user specified threadpool, without depending on MKL binary form. + + Returns a select statement which evaluates to if_true if we're building + with MKL-DNN v1.x open source library only with user specified threadpool. Otherwise, the + select statement evaluates to if_false. + + """ + return select({ + "@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_threadpool": if_true, + "//conditions:default": if_false, + }) diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD index 243ec00a60f..313f81c8108 100644 --- a/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -4,6 +4,7 @@ load( "@org_tensorflow//third_party/mkl_dnn:build_defs.bzl", "if_mkl_open_source_only", "if_mkl_v1_open_source_only", + "if_mkldnn_threadpool", ) load( "@org_tensorflow//third_party:common.bzl", @@ -18,15 +19,26 @@ config_setting( }, ) +_DNNL_RUNTIME_OMP = { + "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP", + "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP", + "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", +} + +_DNNL_RUNTIME_THREADPOOL = { + "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_THREADPOOL", + "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_THREADPOOL", + "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", +} + template_rule( name = "dnnl_config_h", src = "include/dnnl_config.h.in", out = "include/dnnl_config.h", - substitutions = { - "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP", - "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP", - "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", - }, + substitutions = if_mkldnn_threadpool( + _DNNL_RUNTIME_THREADPOOL, + if_false = _DNNL_RUNTIME_OMP, + ), ) # Create the file mkldnn_version.h with MKL-DNN version numbers. @@ -59,9 +71,10 @@ cc_library( "src/cpu/**/*.cpp", "src/cpu/**/*.hpp", "src/cpu/xbyak/*.h", - ]) + if_mkl_v1_open_source_only([ + ]) + [ ":dnnl_config_h", - ]) + [":dnnl_version_h"], + ":dnnl_version_h", + ], hdrs = glob(["include/*"]), copts = [ "-fexceptions",