diff --git a/.bazelrc b/.bazelrc index 066b0db10bc..396b84f70b3 100644 --- a/.bazelrc +++ b/.bazelrc @@ -174,6 +174,12 @@ build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_opensource_only --define=build_with_mkl_opensource=true build:mkl_opensource_only -c opt +# Config setting to build with oneDNN for Arm. +build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true +build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl_aarch64 --define=build_with_mkl_opensource=true +build:mkl_aarch64 -c opt + # This config refers to building with CUDA available. It does not necessarily # mean that we build CUDA op kernels. build:using_cuda --define=using_cuda=true diff --git a/configure.py b/configure.py index 5b9fd55b740..e381c8c20db 100644 --- a/configure.py +++ b/configure.py @@ -1485,6 +1485,7 @@ def main(): 'adding "--config=<>" to your build command. See .bazelrc for more ' 'details.') config_info_line('mkl', 'Build with MKL support.') + config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.') config_info_line('monolithic', 'Config for mostly static monolithic build.') config_info_line('ngraph', 'Build with Intel nGraph support.') config_info_line('numa', 'Build with NUMA support.') diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD index 66a2bf8ceb9..c1c2c450e34 100644 --- a/third_party/mkl/BUILD +++ b/third_party/mkl/BUILD @@ -21,6 +21,14 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "build_with_mkl_aarch64", + define_values = { + "build_with_mkl_aarch64": "true", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "enable_mkl", define_values = { diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl index 28bd262e61e..b3efa4d9ca7 100644 --- a/third_party/mkl/build_defs.bzl +++ b/third_party/mkl/build_defs.bzl @@ -91,6 +91,7 @@ def mkl_deps(): """ return select({ "@org_tensorflow//third_party/mkl:build_with_mkl": ["@mkl_dnn_v1//:mkl_dnn"], + "@org_tensorflow//third_party/mkl:build_with_mkl_aarch64": ["@mkl_dnn_v1//:mkl_dnn_aarch64"], "//conditions:default": [], }) diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD index 0e6acc2fadd..3ac44913f1e 100644 --- a/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -135,3 +135,36 @@ cc_library( ], visibility = ["//visibility:public"], ) + +cc_library( + name = "mkl_dnn_aarch64", + srcs = glob([ + "src/common/*.cpp", + "src/common/*.hpp", + "src/cpu/*.cpp", + "src/cpu/*.hpp", + "src/cpu/rnn/*.cpp", + "src/cpu/rnn/*.hpp", + "src/cpu/matmul/*.cpp", + "src/cpu/matmul/*.hpp", + "src/cpu/gemm/**/*", + ]) + [ + ":dnnl_config_h", + ":dnnl_version_h", + ], + hdrs = glob(["include/*"]), + copts = [ + "-fexceptions", + "-UUSE_MKL", + "-UUSE_CBLAS", + ], + includes = [ + "include", + "src", + "src/common", + "src/cpu", + "src/cpu/gemm", + ], + linkopts = ["-lgomp"], + visibility = ["//visibility:public"], +) diff --git a/third_party/ngraph/ngraph.BUILD b/third_party/ngraph/ngraph.BUILD index dfdd891b228..715148d38f6 100644 --- a/third_party/ngraph/ngraph.BUILD +++ b/third_party/ngraph/ngraph.BUILD @@ -118,6 +118,7 @@ cc_library( ":ngraph_headers", "@eigen_archive//:eigen", "@mkl_dnn_v1//:mkl_dnn", + "@mkl_dnn_v1//:mkl_dnn_aarch64", "@nlohmann_json_lib", "@tbb", ],