Tensorflow/Eigen/Mkldnn integration.

Add mkldnn_pack and mkldnn_gemm templates to be called later from TensorContraction kernel.

PiperOrigin-RevId: 219826801
This commit is contained in:
Eugene Zhulenev 2018-11-02 11:00:09 -07:00 committed by TensorFlower Gardener
parent ac472732bf
commit 8741e20227
3 changed files with 311 additions and 2 deletions

View File

@ -93,6 +93,17 @@ config_setting(
},
)
config_setting(
# Add "--define tensorflow_eigen_mkldnn=1" to your build command to use mkldnn
# sgemm in Eigen tensor contractions (matrix multiplications and convolutions).
# The mkldnn kernels are generated at runtime and use avx/avx2/fma/avx512
# based on cpu status registers (https://en.wikipedia.org/wiki/CPUID).
name = "eigen_mkldnn",
values = {
"define": "tensorflow_eigen_mkldnn=1",
},
)
# Public support libraries ----------------------------------------------------
cc_library(
@ -555,10 +566,20 @@ cc_library(
"eigen_softmax.h",
"eigen_spatial_convolutions.h",
"eigen_volume_patch.h",
],
] + select({
":eigen_mkldnn": ["eigen_mkldnn.h"],
"//conditions:default": [],
}),
defines = select({
":eigen_mkldnn": ["EIGEN_USE_MKLDNN"],
"//conditions:default": [],
}),
deps = [
"//third_party/eigen3",
],
] + select({
":eigen_mkldnn": ["//third_party/intel_mkl_dnn:mkldnn_single_threaded"],
"//conditions:default": [],
}),
)
cc_library(
@ -2402,6 +2423,23 @@ tf_cc_tests(
],
)
# Conditional test target generation is not supported by the "tf_cc_tests" macro
# (can't add 'select' to the srcs field, type 'select' is not iterable).
tf_cc_test(
name = "eigen_mkldnn_test",
size = "small",
srcs = select({
":eigen_mkldnn": ["eigen_mkldnn_test.cc"],
"//conditions:default": [],
}),
tags = ["eigen_mkldnn"],
deps = [
":eigen_helpers",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "eigen_benchmark",
testonly = 1,

View File

@ -0,0 +1,123 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_
#define TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_
// Support for Mkldnn sgemm kernel in Eigen/Tensor contractions:
//
// 1. Prepare packed Lhs/Rhs blocks from tensor expressions using
// DataMapper (see TensorContractionInputMapper).
// 2. Invoke gemm kernel with packed blocks (replacement for default
// gebp_kernel).
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "third_party/intel_mkl_dnn/include/mkldnn.h"
namespace Eigen {
namespace internal {
template <typename Scalar, typename IndexType, typename DataMapper,
int StorageOrder>
struct mkldnn_gemm_pack;
// mkl_gemm_pack for ColMajor storage order.
template <typename Scalar, typename IndexType, typename DataMapper>
struct mkldnn_gemm_pack<Scalar, IndexType, DataMapper,
/*StorageOrder*/ ColMajor> {
typedef typename internal::packet_traits<Scalar>::type Packet;
typedef typename DataMapper::LinearMapper LinearMapper;
enum { PacketSize = internal::packet_traits<Scalar>::size };
EIGEN_DONT_INLINE
void operator()(Scalar *block, const DataMapper &data_mapper, IndexType rows,
IndexType cols) {
const IndexType unrolled_rows =
(rows / (4 * PacketSize)) * (4 * PacketSize);
const IndexType vectorized_rows = (rows / PacketSize) * PacketSize;
for (IndexType col = 0; col < cols; ++col) {
LinearMapper lm = data_mapper.getLinearMapper(0, col);
// Give compiler a strong possibility to unroll the loop.
for (IndexType i = 0; i < unrolled_rows; i += 4 * PacketSize) {
for (IndexType j = 0; j < 4; ++j) {
const Packet p = lm.loadPacket(i + j * PacketSize);
internal::pstoreu(block + j * PacketSize, p);
}
block += 4 * PacketSize;
}
// Process remaining rows with packets.
for (IndexType i = unrolled_rows; i < vectorized_rows; i += PacketSize) {
const Packet p = lm.loadPacket(i);
internal::pstoreu(block, p);
block += PacketSize;
}
// Finalize with coefficients.
for (IndexType i = vectorized_rows; i < rows; ++i) {
*block = lm(i);
++block;
}
}
}
};
template <typename Scalar, typename IndexType, typename OutputMapper,
bool ConjugateLhs = false, bool ConjugateRhs = false>
struct mkldnn_gemm_kernel;
// mkldnn_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm.
template <typename IndexType, typename OutputMapper, bool ConjugateLhs,
bool ConjugateRhs>
struct mkldnn_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper,
ConjugateLhs, ConjugateRhs> {
EIGEN_DONT_INLINE
void operator()(const OutputMapper &output, const float *blockA,
const float *blockB, const IndexType rows,
const IndexType depth, const IndexType cols, float alpha) {
static const int max_index = (std::numeric_limits<int>::max)();
eigen_assert(max_index >= rows);
eigen_assert(max_index >= cols);
eigen_assert(max_index >= depth);
eigen_assert(max_index >= output.stride());
const int m = static_cast<int>(rows);
const int n = static_cast<int>(cols);
const int k = static_cast<int>(depth);
const char transposeA = ConjugateLhs ? 'Y' : 'N';
const char transposeB = ConjugateRhs ? 'Y' : 'N';
const int ldA = ConjugateLhs ? k : m;
const int ldB = ConjugateRhs ? n : k;
const int ldC = static_cast<int>(output.stride());
const float beta = 1.0;
mkldnn_status_t st = mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k,
&alpha, blockA, &ldA, blockB, &ldB, &beta,
const_cast<float *>(output.data()), &ldC);
eigen_assert(st == 0);
}
};
} // namespace internal
} // namespace Eigen
#endif // TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_

View File

@ -0,0 +1,148 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/eigen_mkldnn.h"
#include "tensorflow/core/platform/test.h"
namespace Eigen {
namespace internal {
namespace {
template <typename Index, int NumDims>
Eigen::array<Index, NumDims> RandomDims(int min_dim = 1, int max_dim = 20) {
Eigen::array<Index, NumDims> dims;
for (int i = 0; i < NumDims; ++i) {
dims[i] = internal::random<int>(min_dim, max_dim);
}
return dims;
}
} // namespace
using Scalar = float;
using Index = Eigen::Index;
TEST(EigenMkldnnTest, MkldnnPack) {
// Packing with mkldnn_gemm_pack is the same as taking a slice of 2
// dimensional Tensor.
// Mkldnn pack and gemm are used only in Tensor contractions, and it's
// guaranteed that Tensors will have ColMajor layout.
static const int Options = ColMajor;
using DataMapper = blas_data_mapper<Scalar, Index, ColMajor>;
using MkldnnGemmPack = mkldnn_gemm_pack<Scalar, Index, DataMapper, ColMajor>;
using Tensor2d = Tensor<Scalar, 2, Options, Index>;
Eigen::array<Index, 2> dims = RandomDims<Index, 2>(1, 500);
// Create a tensor initialized with random data.
Tensor2d src(dims);
src.setRandom();
// Pick a random slice of src tensor.
Eigen::array<Index, 2> slice_start = RandomDims<Index, 2>(0, 250);
Eigen::array<Index, 2> slice_size = RandomDims<Index, 2>(100, 500);
// Make sure that slice start + size do not overflow tensor dims.
for (int i = 0; i < 2; ++i) {
slice_start[i] = numext::mini(dims[i] - 1, slice_start[i]);
slice_size[i] = numext::mini(slice_size[i], dims[i] - slice_start[i]);
}
// Prepare tensors for packing and slicing results.
Tensor2d pack_dst(slice_size[0], slice_size[1]);
Tensor2d slice_dst(slice_size[0], slice_size[1]);
// Pack memory using mkldnn_gemm_pack.
DataMapper data_mapper(src.data(), dims[0]);
MkldnnGemmPack gemm_pack;
gemm_pack(pack_dst.data(),
data_mapper.getSubMapper(slice_start[0], slice_start[1]),
slice_size[0], slice_size[1]);
// Slice the source tensor.
slice_dst = src.slice(slice_start, slice_size);
// Verify that dst tensors are equal.
EXPECT_EQ(pack_dst.dimensions().TotalSize(),
slice_dst.dimensions().TotalSize());
for (size_t i = 0; i < pack_dst.dimensions().TotalSize(); ++i) {
Scalar packed = pack_dst.coeff(i);
Scalar sliced = slice_dst.coeff(i);
EXPECT_EQ(packed, sliced);
}
}
TEST(EigenMkldnnTest, MkldnnGemm) {
// Mkldnn pack and gemm are used only in Tensor contractions, and it's
// guaranteed that Tensors will have ColMajor layout.
static const int Options = ColMajor;
using Tensor2d = Tensor<Scalar, 2, Options, Index>;
int m = internal::random<int>(1, 100);
int n = internal::random<int>(1, 100);
int k = internal::random<int>(1, 100);
Tensor2d lhs(m, k);
lhs.setRandom();
Tensor2d rhs(k, n);
rhs.setRandom();
// Compute matmul with mkldnn gemm kernel.
using OutputMapper = blas_data_mapper<Scalar, Index, ColMajor>;
using MkldnnGemmKernel =
mkldnn_gemm_kernel<Scalar, Index, OutputMapper, ColMajor>;
Tensor2d mkldnn_result(m, n);
mkldnn_result.setZero();
OutputMapper output_mapper(mkldnn_result.data(), m);
MkldnnGemmKernel gemm_kernel;
gemm_kernel(output_mapper, lhs.data(), rhs.data(), m, k, n, /*alpha=*/1.0);
// Compute matmul with Eigen::Matrix.
using Matrix = Eigen::Matrix<Scalar, Dynamic, Dynamic, ColMajor>;
using MatrixMap = Map<Eigen::Matrix<Scalar, Dynamic, Dynamic, ColMajor>>;
MatrixMap lhs_mat(lhs.data(), m, k);
MatrixMap rhs_mat(rhs.data(), k, n);
Matrix matmul_result(m, n);
matmul_result.setZero();
matmul_result = lhs_mat * rhs_mat;
// Verify that results are equal.
for (Index i = 0; i < m * n; ++i) {
Scalar gemm = mkldnn_result(i);
Scalar matmul = matmul_result(i % m, i / m);
Scalar delta = std::abs(gemm - matmul);
// NOTE(rmlarsen): Compute proper forward error bound.
Scalar sum = Scalar(0.0);
for (int k1 = 0; k1 < k; ++k1) {
sum += std::abs(lhs_mat(i % m, k1) * rhs_mat(k1, i / m));
}
Scalar epsilon = std::numeric_limits<Scalar>::epsilon();
Scalar upper_bound = Scalar(1.01) * epsilon * k * sum;
EXPECT_LE(delta, upper_bound);
}
}
} // namespace internal
} // namespace Eigen