Replace MKL-DNN v0.x in vanilla TensorFlow and TFRT with DNNL/oneDNN v1.8.
PiperOrigin-RevId: 355889276 Change-Id: I0f741521c4f771915f06a46b5b850cb3e9330221
This commit is contained in:
parent
0446d42b70
commit
978aab50cd
@ -798,7 +798,7 @@ cc_library(
|
||||
"//tensorflow:ios": [],
|
||||
"//tensorflow:linux_ppc64le": [],
|
||||
"//tensorflow:linux_s390x": [],
|
||||
"//conditions:default": ["@mkl_dnn//:mkldnn_single_threaded"],
|
||||
"//conditions:default": ["@mkl_dnn_v1//:dnnl_single_threaded"],
|
||||
}),
|
||||
)
|
||||
|
||||
|
@ -311,29 +311,21 @@ struct Conv2DCustomBackpropInputMatMulFunctor<float> {
|
||||
void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
|
||||
const int filter_total_size, const int output_image_size,
|
||||
const int dims_out_depth, T* im2col_buf) {
|
||||
// Inputs are in RowMajor order, we "cheat" by swapping the LHS and RHS:
|
||||
// RowMajor: C = A * B
|
||||
// ColMajor: C^T = B^T * A^T
|
||||
// Inputs are in RowMajor order.
|
||||
// im2col = out_data * filter_data^T
|
||||
// [ois x fts] = [ois x dod] * [fts x dod]^T
|
||||
//
|
||||
// Dimension names:
|
||||
// out_image_size -> ois
|
||||
// filter_total_size -> fts
|
||||
// dims_out_depth -> dod
|
||||
//
|
||||
// RowMajor:
|
||||
// im2col = out_data * filter_data^T
|
||||
// [ois x fts] = [ois x dod] * [fts x dod]^T
|
||||
//
|
||||
// ColMajor:
|
||||
// im2col^T = filter_data * out_data^T
|
||||
// [fts x ois] = [fts x dod] * [dod x ois]*
|
||||
|
||||
const int m = filter_total_size;
|
||||
const int n = output_image_size;
|
||||
const int m = output_image_size;
|
||||
const int n = filter_total_size;
|
||||
const int k = dims_out_depth; // contraction dim
|
||||
|
||||
const char transposeA = 'T'; // sgemm(A) == filter_data
|
||||
const char transposeB = 'N'; // sgemm(B) == out_data
|
||||
const char transposeA = 'N'; // sgemm(A) == filter_data
|
||||
const char transposeB = 'T'; // sgemm(B) == out_data
|
||||
|
||||
const int ldA = dims_out_depth;
|
||||
const int ldB = dims_out_depth;
|
||||
@ -342,17 +334,17 @@ struct Conv2DCustomBackpropInputMatMulFunctor<float> {
|
||||
const float alpha = 1.0;
|
||||
const float beta = 0.0;
|
||||
|
||||
// mkldnn_sgemm code can't be instrumented with msan.
|
||||
// dnnl_sgemm code can't be instrumented with msan.
|
||||
ANNOTATE_MEMORY_IS_INITIALIZED(
|
||||
im2col_buf, filter_total_size * output_image_size * sizeof(T));
|
||||
|
||||
mkldnn_status_t st =
|
||||
mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k, &alpha, filter_data,
|
||||
&ldA, out_data, &ldB, &beta, im2col_buf, &ldC);
|
||||
dnnl_status_t st =
|
||||
dnnl_sgemm(transposeA, transposeB, m, n, k, alpha, out_data, ldA,
|
||||
filter_data, ldB, beta, im2col_buf, ldC);
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, st == 0,
|
||||
errors::Internal("Failed to call mkldnn_sgemm. Error code: ", st));
|
||||
errors::Internal("Failed to call dnnl_sgemm. Error code: ", st));
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
@ -40,7 +40,7 @@ limitations under the License.
|
||||
// clang-format on
|
||||
|
||||
#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
|
||||
#include "mkldnn.h"
|
||||
#include "dnnl.h"
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
@ -125,15 +125,15 @@ struct gemm_pack_colmajor_block<Scalar, IndexType, DataMapper,
|
||||
|
||||
template <typename Scalar, typename IndexType, typename OutputMapper,
|
||||
bool ConjugateLhs = false, bool ConjugateRhs = false>
|
||||
struct mkldnn_gemm_kernel;
|
||||
struct dnnl_gemm_kernel;
|
||||
|
||||
// mkldnn_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm.
|
||||
// dnnl_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm.
|
||||
template <typename IndexType, typename OutputMapper, bool ConjugateLhs,
|
||||
bool ConjugateRhs>
|
||||
struct mkldnn_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper,
|
||||
ConjugateLhs, ConjugateRhs> {
|
||||
static_assert(!ConjugateLhs, "MKL-DNN kernel doesn't support ConjugateLhs");
|
||||
static_assert(!ConjugateRhs, "MKL-DNN kernel doesn't support ConjugateRhs");
|
||||
struct dnnl_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper, ConjugateLhs,
|
||||
ConjugateRhs> {
|
||||
static_assert(!ConjugateLhs, "DNNL kernel doesn't support ConjugateLhs");
|
||||
static_assert(!ConjugateRhs, "DNNL kernel doesn't support ConjugateRhs");
|
||||
|
||||
static constexpr int kComputeStrideFromBlockDimensions = -1;
|
||||
|
||||
@ -163,9 +163,16 @@ struct mkldnn_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper,
|
||||
ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB;
|
||||
const int ldC = static_cast<int>(output.stride());
|
||||
|
||||
mkldnn_status_t st = mkldnn_sgemm(
|
||||
&transposeA, &transposeB, &m, &n, &k, &alpha, blockA, &ldA, blockB,
|
||||
&ldB, &beta, const_cast<ResScalar*>(output.data()), &ldC);
|
||||
// DNNL takes row-major matrices. Our packed column-major matrices can be
|
||||
// viewed as a transposed row-major matrix, i.e.,
|
||||
// C_colmajor = C_rowmajor^T = (A_rowmajor * B_rowmajor)^T
|
||||
// = B_rowmajor^T * A_rowmajor^T
|
||||
// = B_colmajor * A_colmajor
|
||||
// So we can just swap the input matrices A and B for DNNL.
|
||||
// TODO(penporn): Switch to row-major packing instead.
|
||||
dnnl_status_t st =
|
||||
dnnl_sgemm(transposeB, transposeA, n, m, k, alpha, blockB, ldB, blockA,
|
||||
ldA, beta, const_cast<ResScalar*>(output.data()), ldC);
|
||||
eigen_assert(st == 0);
|
||||
|
||||
#if DYNAMIC_ANNOTATIONS_ENABLED == 1 || defined(MEMORY_SANITIZER)
|
||||
@ -186,8 +193,8 @@ struct mkldnn_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper,
|
||||
template <typename IndexType, typename OutputMapper, bool ConjugateLhs = false,
|
||||
bool ConjugateRhs = false>
|
||||
struct mkldnn_gemm_s8u8s32_kernel {
|
||||
static_assert(!ConjugateLhs, "MKL-DNN kernel doesn't support ConjugateLhs");
|
||||
static_assert(!ConjugateRhs, "MKL-DNN kernel doesn't support ConjugateRhs");
|
||||
static_assert(!ConjugateLhs, "DNNL kernel doesn't support ConjugateLhs");
|
||||
static_assert(!ConjugateRhs, "DNNL kernel doesn't support ConjugateRhs");
|
||||
|
||||
static constexpr int kComputeStrideFromBlockDimensions = -1;
|
||||
|
||||
@ -229,14 +236,20 @@ struct mkldnn_gemm_s8u8s32_kernel {
|
||||
const auto* B = reinterpret_cast<const uint8_t*>(blockB);
|
||||
auto* C = reinterpret_cast<int32_t*>(const_cast<ResScalar*>(output.data()));
|
||||
|
||||
mkldnn_status_t st =
|
||||
mkldnn_gemm_s8u8s32(&transposeA, &transposeB, &offsetc, //
|
||||
&m, &n, &k, //
|
||||
&alpha, //
|
||||
A, &ldA, &ao, //
|
||||
B, &ldB, &bo, //
|
||||
&beta, //
|
||||
C, &ldC, &co);
|
||||
// DNNL takes row-major matrices. Our packed column-major matrices can be
|
||||
// viewed as a transposed row-major matrix, i.e., C_colmajor = C_rowmajor^T.
|
||||
// C_colmajor = C_rowmajor^T = (A_rowmajor * B_rowmajor)^T
|
||||
// = B_rowmajor^T * A_rowmajor^T
|
||||
// = B_colmajor * A_colmajor
|
||||
// So we can just swap the input matrices A and B for DNNL.
|
||||
// TODO(penporn): Switch to row-major packing instead.
|
||||
dnnl_status_t st = dnnl_gemm_u8s8s32(transposeB, transposeA, offsetc, //
|
||||
n, m, k, //
|
||||
alpha, //
|
||||
B, ldB, bo, //
|
||||
A, ldA, ao, //
|
||||
beta, //
|
||||
C, ldC, &co);
|
||||
eigen_assert(st == 0);
|
||||
|
||||
#if DYNAMIC_ANNOTATIONS_ENABLED == 1 || defined(MEMORY_SANITIZER)
|
||||
@ -294,7 +307,7 @@ class TensorContractionBlocking<float, float, float, StorageIndex,
|
||||
if (kc_ <= 0 || mc_ <= 0 || nc_ <= 0) return;
|
||||
|
||||
// If we are using default Eigen gebp kernel there is no need to adjust the
|
||||
// block sizes for MKL-DNN.
|
||||
// block sizes for DNNL.
|
||||
if (!UseCustomContractionKernels()) return;
|
||||
|
||||
// 2. And refine them to work well with mkldnn sgemm.
|
||||
@ -332,8 +345,8 @@ class TensorContractionBlocking<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
|
||||
// Default Eigen block heuristics for `QInt8xQUInt8 -> QInt32` are wrong.
|
||||
// Mostly because gebp_traits are not correctly defined. But we know that we
|
||||
// are going to use s8u8s32_gemm from MKL-DNN, so we use float heuristics, and
|
||||
// adjust them to work well with MKL-DNN.
|
||||
// are going to use s8u8s32_gemm from DNNL, so we use float heuristics, and
|
||||
// adjust them to work well with DNNL.
|
||||
using LhsScalar = Eigen::QInt8;
|
||||
using RhsScalar = Eigen::QUInt8;
|
||||
using ResScalar = Eigen::QInt32;
|
||||
@ -500,7 +513,7 @@ struct GemmKernelProvider {
|
||||
template <typename StorageIndex, typename OutputMapper>
|
||||
struct GemmKernelProvider<float, float, float, StorageIndex, OutputMapper> {
|
||||
enum { Defined = 1 };
|
||||
using GemmKernel = mkldnn_gemm_kernel<float, StorageIndex, OutputMapper>;
|
||||
using GemmKernel = dnnl_gemm_kernel<float, StorageIndex, OutputMapper>;
|
||||
};
|
||||
|
||||
template <typename StorageIndex, typename OutputMapper>
|
||||
|
@ -113,7 +113,7 @@ TEST(EigenMkldnnTest, MkldnnGemm) {
|
||||
// Compute matmul with mkldnn gemm kernel.
|
||||
using OutputMapper = blas_data_mapper<Scalar, Index, ColMajor>;
|
||||
using MkldnnGemmKernel =
|
||||
mkldnn_gemm_kernel<Scalar, Index, OutputMapper, ColMajor>;
|
||||
dnnl_gemm_kernel<Scalar, Index, OutputMapper, ColMajor>;
|
||||
|
||||
Tensor2d mkldnn_result(m, n);
|
||||
mkldnn_result.setRandom();
|
||||
|
@ -316,7 +316,7 @@ def compare_results(results_with_ds,
|
||||
default_tolerance = 1e-3
|
||||
relaxed_tolerance = 1e-3
|
||||
else:
|
||||
default_tolerance = 1e-5
|
||||
default_tolerance = 4e-5
|
||||
relaxed_tolerance = 1e-4
|
||||
|
||||
def _get_compare_result_tolerance(key):
|
||||
|
80
third_party/mkl_dnn/mkldnn_v1.BUILD
vendored
80
third_party/mkl_dnn/mkldnn_v1.BUILD
vendored
@ -1,5 +1,9 @@
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load(
|
||||
"@org_tensorflow//third_party/mkl:build_defs.bzl",
|
||||
"if_mkl",
|
||||
)
|
||||
load(
|
||||
"@org_tensorflow//tensorflow:tensorflow.bzl",
|
||||
"tf_openmp_copts",
|
||||
@ -30,34 +34,53 @@ _DNNL_RUNTIME_THREADPOOL = {
|
||||
"#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
|
||||
}
|
||||
|
||||
_DNNL_RUNTIME_SEQ = {
|
||||
"#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_SEQ",
|
||||
"#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_SEQ",
|
||||
"#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
|
||||
}
|
||||
|
||||
template_rule(
|
||||
name = "dnnl_config_h",
|
||||
src = "include/dnnl_config.h.in",
|
||||
out = "include/dnnl_config.h",
|
||||
substitutions = if_mkldnn_threadpool(
|
||||
_DNNL_RUNTIME_THREADPOOL,
|
||||
if_false = _DNNL_RUNTIME_OMP,
|
||||
),
|
||||
substitutions = select({
|
||||
"@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_threadpool": _DNNL_RUNTIME_THREADPOOL,
|
||||
"@org_tensorflow//third_party/mkl:build_with_mkl": _DNNL_RUNTIME_OMP,
|
||||
"//conditions:default": _DNNL_RUNTIME_SEQ,
|
||||
}),
|
||||
)
|
||||
|
||||
# Create the file mkldnn_version.h with MKL-DNN version numbers.
|
||||
# Currently, the version numbers are hard coded here. If MKL-DNN is upgraded then
|
||||
_DNNL_VERSION_1_6_4 = {
|
||||
"@DNNL_VERSION_MAJOR@": "1",
|
||||
"@DNNL_VERSION_MINOR@": "6",
|
||||
"@DNNL_VERSION_PATCH@": "4",
|
||||
"@DNNL_VERSION_HASH@": "N/A",
|
||||
}
|
||||
|
||||
_DNNL_VERSION_1_7 = {
|
||||
"@DNNL_VERSION_MAJOR@": "1",
|
||||
"@DNNL_VERSION_MINOR@": "7",
|
||||
"@DNNL_VERSION_PATCH@": "0",
|
||||
"@DNNL_VERSION_HASH@": "N/A",
|
||||
}
|
||||
|
||||
# Create the file dnnl_version.h with DNNL version numbers.
|
||||
# Currently, the version numbers are hard coded here. If DNNL is upgraded then
|
||||
# the version numbers have to be updated manually. The version numbers can be
|
||||
# obtained from the PROJECT_VERSION settings in CMakeLists.txt. The variable is
|
||||
# set to "version_major.version_minor.version_patch". The git hash version can
|
||||
# be set to NA.
|
||||
# TODO(agramesh1) Automatically get the version numbers from CMakeLists.txt.
|
||||
|
||||
# TODO(agramesh1): Automatically get the version numbers from CMakeLists.txt.
|
||||
template_rule(
|
||||
name = "dnnl_version_h",
|
||||
src = "include/dnnl_version.h.in",
|
||||
out = "include/dnnl_version.h",
|
||||
substitutions = {
|
||||
"@DNNL_VERSION_MAJOR@": "1",
|
||||
"@DNNL_VERSION_MINOR@": "6",
|
||||
"@DNNL_VERSION_PATCH@": "4",
|
||||
"@DNNL_VERSION_HASH@": "N/A",
|
||||
},
|
||||
substitutions = select({
|
||||
"@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_threadpool": _DNNL_VERSION_1_6_4,
|
||||
"@org_tensorflow//third_party/mkl:build_with_mkl": _DNNL_VERSION_1_6_4,
|
||||
"//conditions:default": _DNNL_VERSION_1_7,
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -101,29 +124,38 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mkldnn_single_threaded",
|
||||
name = "dnnl_single_threaded",
|
||||
srcs = glob([
|
||||
"src/common/*.cpp",
|
||||
"src/common/*.hpp",
|
||||
"src/cpu/*.cpp",
|
||||
"src/cpu/*.hpp",
|
||||
"src/cpu/**/*.c",
|
||||
"src/cpu/**/*.cpp",
|
||||
"src/cpu/**/*.hpp",
|
||||
"src/cpu/xbyak/*.h",
|
||||
]) + [":dnnl_config_h"],
|
||||
hdrs = glob(["include/*"]),
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-DMKLDNN_THR=MKLDNN_THR_SEQ", # Disables threading.
|
||||
]) + [
|
||||
":dnnl_config_h",
|
||||
":dnnl_version_h",
|
||||
],
|
||||
copts = ["-fexceptions"],
|
||||
includes = [
|
||||
"include",
|
||||
"src",
|
||||
"src/common",
|
||||
"src/cpu",
|
||||
"src/cpu/gemm",
|
||||
"src/cpu/gemm/bf16",
|
||||
"src/cpu/gemm/f32",
|
||||
"src/cpu/gemm/s8x8s32",
|
||||
"src/cpu/rnn",
|
||||
"src/cpu/x64/jit_utils",
|
||||
"src/cpu/xbyak",
|
||||
],
|
||||
textual_hdrs = glob([
|
||||
"include/*",
|
||||
"src/common/*.hpp",
|
||||
"src/cpu/*.hpp",
|
||||
"src/cpu/**/*.h",
|
||||
"src/cpu/**/*.hpp",
|
||||
"src/cpu/xbyak/*.h",
|
||||
]),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user