Tensorflow/Eigen/Mkldnn integration.
Add mkldnn_pack and mkldnn_gemm templates to be called later from TensorContraction kernel. PiperOrigin-RevId: 219826801
This commit is contained in:
parent
ac472732bf
commit
8741e20227
@ -93,6 +93,17 @@ config_setting(
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
# Add "--define tensorflow_eigen_mkldnn=1" to your build command to use mkldnn
|
||||
# sgemm in Eigen tensor contractions (matrix multiplications and convolutions).
|
||||
# The mkldnn kernels are generated at runtime and use avx/avx2/fma/avx512
|
||||
# based on cpu status registers (https://en.wikipedia.org/wiki/CPUID).
|
||||
name = "eigen_mkldnn",
|
||||
values = {
|
||||
"define": "tensorflow_eigen_mkldnn=1",
|
||||
},
|
||||
)
|
||||
|
||||
# Public support libraries ----------------------------------------------------
|
||||
|
||||
cc_library(
|
||||
@ -555,10 +566,20 @@ cc_library(
|
||||
"eigen_softmax.h",
|
||||
"eigen_spatial_convolutions.h",
|
||||
"eigen_volume_patch.h",
|
||||
],
|
||||
] + select({
|
||||
":eigen_mkldnn": ["eigen_mkldnn.h"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
defines = select({
|
||||
":eigen_mkldnn": ["EIGEN_USE_MKLDNN"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
] + select({
|
||||
":eigen_mkldnn": ["//third_party/intel_mkl_dnn:mkldnn_single_threaded"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -2402,6 +2423,23 @@ tf_cc_tests(
|
||||
],
|
||||
)
|
||||
|
||||
# Conditional test target generation is not supported by the "tf_cc_tests" macro
|
||||
# (can't add 'select' to the srcs field, type 'select' is not iterable).
|
||||
tf_cc_test(
|
||||
name = "eigen_mkldnn_test",
|
||||
size = "small",
|
||||
srcs = select({
|
||||
":eigen_mkldnn": ["eigen_mkldnn_test.cc"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = ["eigen_mkldnn"],
|
||||
deps = [
|
||||
":eigen_helpers",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "eigen_benchmark",
|
||||
testonly = 1,
|
||||
|
123
tensorflow/core/kernels/eigen_mkldnn.h
Normal file
123
tensorflow/core/kernels/eigen_mkldnn.h
Normal file
@ -0,0 +1,123 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_
|
||||
|
||||
// Support for Mkldnn sgemm kernel in Eigen/Tensor contractions:
|
||||
//
|
||||
// 1. Prepare packed Lhs/Rhs blocks from tensor expressions using
|
||||
// DataMapper (see TensorContractionInputMapper).
|
||||
// 2. Invoke gemm kernel with packed blocks (replacement for default
|
||||
// gebp_kernel).
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "third_party/intel_mkl_dnn/include/mkldnn.h"
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
template <typename Scalar, typename IndexType, typename DataMapper,
|
||||
int StorageOrder>
|
||||
struct mkldnn_gemm_pack;
|
||||
|
||||
// mkl_gemm_pack for ColMajor storage order.
|
||||
template <typename Scalar, typename IndexType, typename DataMapper>
|
||||
struct mkldnn_gemm_pack<Scalar, IndexType, DataMapper,
|
||||
/*StorageOrder*/ ColMajor> {
|
||||
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||
|
||||
enum { PacketSize = internal::packet_traits<Scalar>::size };
|
||||
|
||||
EIGEN_DONT_INLINE
|
||||
void operator()(Scalar *block, const DataMapper &data_mapper, IndexType rows,
|
||||
IndexType cols) {
|
||||
const IndexType unrolled_rows =
|
||||
(rows / (4 * PacketSize)) * (4 * PacketSize);
|
||||
const IndexType vectorized_rows = (rows / PacketSize) * PacketSize;
|
||||
|
||||
for (IndexType col = 0; col < cols; ++col) {
|
||||
LinearMapper lm = data_mapper.getLinearMapper(0, col);
|
||||
|
||||
// Give compiler a strong possibility to unroll the loop.
|
||||
for (IndexType i = 0; i < unrolled_rows; i += 4 * PacketSize) {
|
||||
for (IndexType j = 0; j < 4; ++j) {
|
||||
const Packet p = lm.loadPacket(i + j * PacketSize);
|
||||
internal::pstoreu(block + j * PacketSize, p);
|
||||
}
|
||||
block += 4 * PacketSize;
|
||||
}
|
||||
|
||||
// Process remaining rows with packets.
|
||||
for (IndexType i = unrolled_rows; i < vectorized_rows; i += PacketSize) {
|
||||
const Packet p = lm.loadPacket(i);
|
||||
internal::pstoreu(block, p);
|
||||
block += PacketSize;
|
||||
}
|
||||
|
||||
// Finalize with coefficients.
|
||||
for (IndexType i = vectorized_rows; i < rows; ++i) {
|
||||
*block = lm(i);
|
||||
++block;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename IndexType, typename OutputMapper,
|
||||
bool ConjugateLhs = false, bool ConjugateRhs = false>
|
||||
struct mkldnn_gemm_kernel;
|
||||
|
||||
// mkldnn_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm.
|
||||
template <typename IndexType, typename OutputMapper, bool ConjugateLhs,
|
||||
bool ConjugateRhs>
|
||||
struct mkldnn_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper,
|
||||
ConjugateLhs, ConjugateRhs> {
|
||||
EIGEN_DONT_INLINE
|
||||
void operator()(const OutputMapper &output, const float *blockA,
|
||||
const float *blockB, const IndexType rows,
|
||||
const IndexType depth, const IndexType cols, float alpha) {
|
||||
static const int max_index = (std::numeric_limits<int>::max)();
|
||||
|
||||
eigen_assert(max_index >= rows);
|
||||
eigen_assert(max_index >= cols);
|
||||
eigen_assert(max_index >= depth);
|
||||
eigen_assert(max_index >= output.stride());
|
||||
|
||||
const int m = static_cast<int>(rows);
|
||||
const int n = static_cast<int>(cols);
|
||||
const int k = static_cast<int>(depth);
|
||||
|
||||
const char transposeA = ConjugateLhs ? 'Y' : 'N';
|
||||
const char transposeB = ConjugateRhs ? 'Y' : 'N';
|
||||
|
||||
const int ldA = ConjugateLhs ? k : m;
|
||||
const int ldB = ConjugateRhs ? n : k;
|
||||
const int ldC = static_cast<int>(output.stride());
|
||||
|
||||
const float beta = 1.0;
|
||||
|
||||
mkldnn_status_t st = mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k,
|
||||
&alpha, blockA, &ldA, blockB, &ldB, &beta,
|
||||
const_cast<float *>(output.data()), &ldC);
|
||||
eigen_assert(st == 0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_
|
148
tensorflow/core/kernels/eigen_mkldnn_test.cc
Normal file
148
tensorflow/core/kernels/eigen_mkldnn_test.cc
Normal file
@ -0,0 +1,148 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/eigen_mkldnn.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
namespace {
|
||||
template <typename Index, int NumDims>
|
||||
Eigen::array<Index, NumDims> RandomDims(int min_dim = 1, int max_dim = 20) {
|
||||
Eigen::array<Index, NumDims> dims;
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
dims[i] = internal::random<int>(min_dim, max_dim);
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using Scalar = float;
|
||||
using Index = Eigen::Index;
|
||||
|
||||
TEST(EigenMkldnnTest, MkldnnPack) {
|
||||
// Packing with mkldnn_gemm_pack is the same as taking a slice of 2
|
||||
// dimensional Tensor.
|
||||
|
||||
// Mkldnn pack and gemm are used only in Tensor contractions, and it's
|
||||
// guaranteed that Tensors will have ColMajor layout.
|
||||
static const int Options = ColMajor;
|
||||
|
||||
using DataMapper = blas_data_mapper<Scalar, Index, ColMajor>;
|
||||
using MkldnnGemmPack = mkldnn_gemm_pack<Scalar, Index, DataMapper, ColMajor>;
|
||||
using Tensor2d = Tensor<Scalar, 2, Options, Index>;
|
||||
|
||||
Eigen::array<Index, 2> dims = RandomDims<Index, 2>(1, 500);
|
||||
|
||||
// Create a tensor initialized with random data.
|
||||
Tensor2d src(dims);
|
||||
src.setRandom();
|
||||
|
||||
// Pick a random slice of src tensor.
|
||||
Eigen::array<Index, 2> slice_start = RandomDims<Index, 2>(0, 250);
|
||||
Eigen::array<Index, 2> slice_size = RandomDims<Index, 2>(100, 500);
|
||||
|
||||
// Make sure that slice start + size do not overflow tensor dims.
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
slice_start[i] = numext::mini(dims[i] - 1, slice_start[i]);
|
||||
slice_size[i] = numext::mini(slice_size[i], dims[i] - slice_start[i]);
|
||||
}
|
||||
|
||||
// Prepare tensors for packing and slicing results.
|
||||
Tensor2d pack_dst(slice_size[0], slice_size[1]);
|
||||
Tensor2d slice_dst(slice_size[0], slice_size[1]);
|
||||
|
||||
// Pack memory using mkldnn_gemm_pack.
|
||||
DataMapper data_mapper(src.data(), dims[0]);
|
||||
MkldnnGemmPack gemm_pack;
|
||||
gemm_pack(pack_dst.data(),
|
||||
data_mapper.getSubMapper(slice_start[0], slice_start[1]),
|
||||
slice_size[0], slice_size[1]);
|
||||
|
||||
// Slice the source tensor.
|
||||
slice_dst = src.slice(slice_start, slice_size);
|
||||
|
||||
// Verify that dst tensors are equal.
|
||||
EXPECT_EQ(pack_dst.dimensions().TotalSize(),
|
||||
slice_dst.dimensions().TotalSize());
|
||||
for (size_t i = 0; i < pack_dst.dimensions().TotalSize(); ++i) {
|
||||
Scalar packed = pack_dst.coeff(i);
|
||||
Scalar sliced = slice_dst.coeff(i);
|
||||
EXPECT_EQ(packed, sliced);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(EigenMkldnnTest, MkldnnGemm) {
|
||||
// Mkldnn pack and gemm are used only in Tensor contractions, and it's
|
||||
// guaranteed that Tensors will have ColMajor layout.
|
||||
static const int Options = ColMajor;
|
||||
|
||||
using Tensor2d = Tensor<Scalar, 2, Options, Index>;
|
||||
|
||||
int m = internal::random<int>(1, 100);
|
||||
int n = internal::random<int>(1, 100);
|
||||
int k = internal::random<int>(1, 100);
|
||||
|
||||
Tensor2d lhs(m, k);
|
||||
lhs.setRandom();
|
||||
|
||||
Tensor2d rhs(k, n);
|
||||
rhs.setRandom();
|
||||
|
||||
// Compute matmul with mkldnn gemm kernel.
|
||||
using OutputMapper = blas_data_mapper<Scalar, Index, ColMajor>;
|
||||
using MkldnnGemmKernel =
|
||||
mkldnn_gemm_kernel<Scalar, Index, OutputMapper, ColMajor>;
|
||||
|
||||
Tensor2d mkldnn_result(m, n);
|
||||
mkldnn_result.setZero();
|
||||
OutputMapper output_mapper(mkldnn_result.data(), m);
|
||||
|
||||
MkldnnGemmKernel gemm_kernel;
|
||||
gemm_kernel(output_mapper, lhs.data(), rhs.data(), m, k, n, /*alpha=*/1.0);
|
||||
|
||||
// Compute matmul with Eigen::Matrix.
|
||||
using Matrix = Eigen::Matrix<Scalar, Dynamic, Dynamic, ColMajor>;
|
||||
using MatrixMap = Map<Eigen::Matrix<Scalar, Dynamic, Dynamic, ColMajor>>;
|
||||
|
||||
MatrixMap lhs_mat(lhs.data(), m, k);
|
||||
MatrixMap rhs_mat(rhs.data(), k, n);
|
||||
|
||||
Matrix matmul_result(m, n);
|
||||
matmul_result.setZero();
|
||||
matmul_result = lhs_mat * rhs_mat;
|
||||
|
||||
// Verify that results are equal.
|
||||
for (Index i = 0; i < m * n; ++i) {
|
||||
Scalar gemm = mkldnn_result(i);
|
||||
Scalar matmul = matmul_result(i % m, i / m);
|
||||
|
||||
Scalar delta = std::abs(gemm - matmul);
|
||||
|
||||
// NOTE(rmlarsen): Compute proper forward error bound.
|
||||
Scalar sum = Scalar(0.0);
|
||||
for (int k1 = 0; k1 < k; ++k1) {
|
||||
sum += std::abs(lhs_mat(i % m, k1) * rhs_mat(k1, i / m));
|
||||
}
|
||||
Scalar epsilon = std::numeric_limits<Scalar>::epsilon();
|
||||
Scalar upper_bound = Scalar(1.01) * epsilon * k * sum;
|
||||
|
||||
EXPECT_LE(delta, upper_bound);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
Loading…
x
Reference in New Issue
Block a user