Addressing comments; Making DNNL 1.2 default
This commit is contained in:
parent
b96c0010fa
commit
accb5cda49
1
.bazelrc
1
.bazelrc
@ -136,6 +136,7 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
|
|||||||
# environment variable "TF_MKL_ROOT" every time before build.
|
# environment variable "TF_MKL_ROOT" every time before build.
|
||||||
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
|
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
|
||||||
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||||
|
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||||
build:mkl -c opt
|
build:mkl -c opt
|
||||||
|
|
||||||
# This config option is used to enable MKL-DNN open source library only,
|
# This config option is used to enable MKL-DNN open source library only,
|
||||||
|
@ -26,9 +26,6 @@ limitations under the License.
|
|||||||
#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
|
#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
|
||||||
#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
|
#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
|
||||||
|
|
||||||
#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
|
|
||||||
#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
|
|
||||||
|
|
||||||
using mkldnn::batch_normalization_backward;
|
using mkldnn::batch_normalization_backward;
|
||||||
using mkldnn::batch_normalization_forward;
|
using mkldnn::batch_normalization_forward;
|
||||||
using mkldnn::prop_kind;
|
using mkldnn::prop_kind;
|
||||||
|
@ -86,10 +86,11 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
const int k = src_tf_shape.dim_size(dim_pair[0]);
|
const int k = src_tf_shape.dim_size(dim_pair[0]);
|
||||||
const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]);
|
const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]);
|
||||||
|
|
||||||
OP_REQUIRES(ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
|
ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
|
||||||
src_tf_shape.DebugString(), ", In[1]: ",
|
errors::InvalidArgument(
|
||||||
weight_tf_shape.DebugString()));
|
"Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(),
|
||||||
|
", In[1]: ", weight_tf_shape.DebugString()));
|
||||||
OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel,
|
OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Must provide as many biases as the channel size: ",
|
"Must provide as many biases as the channel size: ",
|
||||||
@ -200,9 +201,9 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
// Execute fused matmul op.
|
// Execute fused matmul op.
|
||||||
matmul_prim->Execute(src_data, weight_data, bias_data, dst_data);
|
matmul_prim->Execute(src_data, weight_data, bias_data, dst_data);
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
|
string error_msg = "Status: " + std::to_string(e.status) +
|
||||||
string(e.message) + ", in file " + string(__FILE__) +
|
", message: " + string(e.message) + ", in file " +
|
||||||
":" + std::to_string(__LINE__);
|
string(__FILE__) + ":" + std::to_string(__LINE__);
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, errors::Aborted("Operation received an exception:", error_msg));
|
ctx, errors::Aborted("Operation received an exception:", error_msg));
|
||||||
}
|
}
|
||||||
|
@ -175,11 +175,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
|||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "mkl_dnn_v1",
|
name = "mkl_dnn_v1",
|
||||||
build_file = clean_dep("//third_party/mkl_dnn:mkldnn_v1.BUILD"),
|
build_file = clean_dep("//third_party/mkl_dnn:mkldnn_v1.BUILD"),
|
||||||
sha256 = "27fd9da9720c452852f1226581e7914efcf74e1ff898468fdcbe1813528831ba",
|
sha256 = "30979a09753e8e35d942446c3778c9f0eba543acf2fb0282af8b9c89355d0ddf",
|
||||||
strip_prefix = "mkl-dnn-1.0",
|
strip_prefix = "mkl-dnn-1.2",
|
||||||
urls = [
|
urls = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.0.tar.gz",
|
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.2.tar.gz",
|
||||||
"https://github.com/intel/mkl-dnn/archive/v1.0.tar.gz",
|
"https://github.com/intel/mkl-dnn/archive/v1.2.tar.gz",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
31
third_party/mkl_dnn/mkldnn_v1.BUILD
vendored
31
third_party/mkl_dnn/mkldnn_v1.BUILD
vendored
@ -19,12 +19,13 @@ config_setting(
|
|||||||
)
|
)
|
||||||
|
|
||||||
template_rule(
|
template_rule(
|
||||||
name = "mkldnn_config_h",
|
name = "dnnl_config_h",
|
||||||
src = "include/mkldnn_config.h.in",
|
src = "include/dnnl_config.h.in",
|
||||||
out = "include/mkldnn_config.h",
|
out = "include/dnnl_config.h",
|
||||||
substitutions = {
|
substitutions = {
|
||||||
"#cmakedefine MKLDNN_CPU_RUNTIME MKLDNN_RUNTIME_${MKLDNN_CPU_RUNTIME_CURRENT}": "#define MKLDNN_CPU_RUNTIME MKLDNN_RUNTIME_OMP",
|
"#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP",
|
||||||
"#cmakedefine MKLDNN_GPU_RUNTIME MKLDNN_RUNTIME_${MKLDNN_GPU_RUNTIME}": "#define MKLDNN_GPU_RUNTIME MKLDNN_RUNTIME_NONE",
|
"#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP",
|
||||||
|
"#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -37,14 +38,14 @@ template_rule(
|
|||||||
# TODO(agramesh1) Automatically get the version numbers from CMakeLists.txt.
|
# TODO(agramesh1) Automatically get the version numbers from CMakeLists.txt.
|
||||||
|
|
||||||
template_rule(
|
template_rule(
|
||||||
name = "mkldnn_version_h",
|
name = "dnnl_version_h",
|
||||||
src = "include/mkldnn_version.h.in",
|
src = "include/dnnl_version.h.in",
|
||||||
out = "include/mkldnn_version.h",
|
out = "include/dnnl_version.h",
|
||||||
substitutions = {
|
substitutions = {
|
||||||
"@MKLDNN_VERSION_MAJOR@": "1",
|
"@DNNL_VERSION_MAJOR@": "1",
|
||||||
"@MKLDNN_VERSION_MINOR@": "0",
|
"@DNNL_VERSION_MINOR@": "2",
|
||||||
"@MKLDNN_VERSION_PATCH@": "0",
|
"@DNNL_VERSION_PATCH@": "0",
|
||||||
"@MKLDNN_VERSION_HASH@": "N/A",
|
"@DNNL_VERSION_HASH@": "N/A",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -59,8 +60,8 @@ cc_library(
|
|||||||
"src/cpu/**/*.hpp",
|
"src/cpu/**/*.hpp",
|
||||||
"src/cpu/xbyak/*.h",
|
"src/cpu/xbyak/*.h",
|
||||||
]) + if_mkl_v1_open_source_only([
|
]) + if_mkl_v1_open_source_only([
|
||||||
":mkldnn_config_h",
|
":dnnl_config_h",
|
||||||
]) + [":mkldnn_version_h"],
|
]) + [":dnnl_version_h"],
|
||||||
hdrs = glob(["include/*"]),
|
hdrs = glob(["include/*"]),
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
@ -117,7 +118,7 @@ cc_library(
|
|||||||
"src/cpu/**/*.cpp",
|
"src/cpu/**/*.cpp",
|
||||||
"src/cpu/**/*.hpp",
|
"src/cpu/**/*.hpp",
|
||||||
"src/cpu/xbyak/*.h",
|
"src/cpu/xbyak/*.h",
|
||||||
]) + [":mkldnn_config_h"],
|
]) + [":dnnl_config_h"],
|
||||||
hdrs = glob(["include/*"]),
|
hdrs = glob(["include/*"]),
|
||||||
copts = [
|
copts = [
|
||||||
"-fexceptions",
|
"-fexceptions",
|
||||||
|
Loading…
Reference in New Issue
Block a user