Use mkldnn sgemm in Eigen contractions.
PiperOrigin-RevId: 220873819
This commit is contained in:
parent
1802b5ef23
commit
2204f80a81
tensorflow
core/kernels
BUILDbatch_matmul_op_impl.hconv_grad_filter_ops.ccconv_grad_input_ops.ccconv_grad_ops_3d.cceigen_contraction_kernel.heigen_cuboid_convolution.heigen_mkldnn.heigen_mkldnn_contraction_kernel_test.cceigen_spatial_convolutions.hgemm_functors.hlrn_op.ccmatmul_op.hsparse_matmul_op.cc
python/kernel_tests
@ -94,13 +94,13 @@ config_setting(
|
||||
)
|
||||
|
||||
config_setting(
|
||||
# Add "--define tensorflow_eigen_mkldnn=1" to your build command to use mkldnn
|
||||
# sgemm in Eigen tensor contractions (matrix multiplications and convolutions).
|
||||
# The mkldnn kernels are generated at runtime and use avx/avx2/fma/avx512
|
||||
# based on cpu status registers (https://en.wikipedia.org/wiki/CPUID).
|
||||
name = "eigen_mkldnn",
|
||||
# Add "--define tensorflow_mkldnn_contraction_kernel=1" to your build command to use mkldnn
|
||||
# sgemm in Eigen tensor contractions (matrix multiplications and convolutions). The mkldnn
|
||||
# kernels are generated at runtime and use avx/avx2/fma/avx512 based on cpu status registers
|
||||
# (https://en.wikipedia.org/wiki/CPUID).
|
||||
name = "mkldnn_contraction_kernel",
|
||||
values = {
|
||||
"define": "tensorflow_eigen_mkldnn=1",
|
||||
"define": "tensorflow_mkldnn_contraction_kernel=1",
|
||||
},
|
||||
)
|
||||
|
||||
@ -554,6 +554,40 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# Depending on a build configuration this target provides custom kernel for Eigen
|
||||
# tensor contractions (small matrix multiplication kernel used to multiple together
|
||||
# blocks of the original tensors).
|
||||
#
|
||||
# 0) Default contraction kernel is Eigen::internal::gebp_kernel.
|
||||
#
|
||||
# 1) --define tensorflow_mkldnn_contraction_kernel=1
|
||||
# Use Mkldnn single threaded sgemm. The mkldnn kernels are generated at runtime and
|
||||
# use avx/avx2/fma/avx512 based on cpu status registers (https://en.wikipedia.org/wiki/CPUID).
|
||||
#
|
||||
# If you use `tensor.contract(other_tensor)` in your code, you must include additional header
|
||||
# to get the benefit of custom contraction kernel:
|
||||
#
|
||||
# #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
# #include "third_party/tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
# #endif
|
||||
cc_library(
|
||||
name = "eigen_contraction_kernel",
|
||||
hdrs = ["eigen_contraction_kernel.h"],
|
||||
defines = select({
|
||||
":mkldnn_contraction_kernel": [
|
||||
"TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL",
|
||||
"TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
"//third_party/eigen3",
|
||||
] + select({
|
||||
":mkldnn_contraction_kernel": ["//third_party/intel_mkl_dnn:mkldnn_single_threaded"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "eigen_helpers",
|
||||
hdrs = [
|
||||
@ -566,20 +600,11 @@ cc_library(
|
||||
"eigen_softmax.h",
|
||||
"eigen_spatial_convolutions.h",
|
||||
"eigen_volume_patch.h",
|
||||
] + select({
|
||||
":eigen_mkldnn": ["eigen_mkldnn.h"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
defines = select({
|
||||
":eigen_mkldnn": ["EIGEN_USE_MKLDNN"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
],
|
||||
deps = [
|
||||
":eigen_contraction_kernel",
|
||||
"//third_party/eigen3",
|
||||
] + select({
|
||||
":eigen_mkldnn": ["//third_party/intel_mkl_dnn:mkldnn_single_threaded"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -2434,15 +2459,15 @@ tf_cc_tests(
|
||||
# Conditional test target generation is not supported by the "tf_cc_tests" macro
|
||||
# (can't add 'select' to the srcs field, type 'select' is not iterable).
|
||||
tf_cc_test(
|
||||
name = "eigen_mkldnn_test",
|
||||
name = "eigen_mkldnn_contraction_kernel_test",
|
||||
size = "small",
|
||||
srcs = select({
|
||||
":eigen_mkldnn": ["eigen_mkldnn_test.cc"],
|
||||
":mkldnn_contraction_kernel": ["eigen_mkldnn_contraction_kernel_test.cc"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = ["eigen_mkldnn"],
|
||||
tags = ["mkldnn_contraction_kernel"],
|
||||
deps = [
|
||||
":eigen_helpers",
|
||||
":eigen_contraction_kernel",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
@ -3058,7 +3083,7 @@ tf_kernel_library(
|
||||
# <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
|
||||
hdrs = ["batch_matmul_op_impl.h"],
|
||||
prefix = "batch_matmul_op",
|
||||
deps = MATH_DEPS + if_mkl_ml([
|
||||
deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
]),
|
||||
)
|
||||
@ -3132,11 +3157,10 @@ tf_kernel_library(
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = MATH_DEPS + [
|
||||
":eigen_contraction_kernel",
|
||||
":gpu_util_hdrs",
|
||||
] + select({
|
||||
":xsmm": [
|
||||
"@libxsmm_archive//:xsmm_avx",
|
||||
],
|
||||
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
|
||||
"//conditions:default": [],
|
||||
}) + mkl_deps() + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
@ -3539,6 +3563,7 @@ tf_kernel_library(
|
||||
":bounds_check",
|
||||
":conv_2d",
|
||||
":conv_3d",
|
||||
":eigen_contraction_kernel",
|
||||
":image_resizer_state",
|
||||
":fill_functor",
|
||||
":ops_util",
|
||||
@ -3629,6 +3654,7 @@ cc_library(
|
||||
NN_DEPS = [
|
||||
":bounds_check",
|
||||
":conv_2d",
|
||||
":eigen_contraction_kernel",
|
||||
":fused_batch_norm_util_gpu",
|
||||
":ops_util",
|
||||
":pooling_ops",
|
||||
|
@ -34,6 +34,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -44,6 +44,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
@ -43,6 +43,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
@ -35,6 +35,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
using stream_executor::dnn::DimIndex;
|
||||
|
234
tensorflow/core/kernels/eigen_contraction_kernel.h
Normal file
234
tensorflow/core/kernels/eigen_contraction_kernel.h
Normal file
@ -0,0 +1,234 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_
|
||||
|
||||
// Depending on a build configuration this header provides custom kernel for
|
||||
// Eigen tensor contractions (small matrix multiplication kernel used to
|
||||
// multiple together blocks of the original tensors).
|
||||
//
|
||||
// 1) --define tensorflow_mkldnn_contraction_kernel=1
|
||||
// Use Mkldnn single threaded sgemm. The mkldnn kernels are generated at
|
||||
// runtime and use avx/avx2/fma/avx512 based on cpu status registers
|
||||
// (https://en.wikipedia.org/wiki/CPUID).
|
||||
//
|
||||
// If you use `tensor.contract(other_tensor)` in your code, you must include
|
||||
// this header to get the benefit of custom contraction kernel:
|
||||
//
|
||||
// #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
// #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
// #endif
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "third_party/intel_mkl_dnn/include/mkldnn.h"
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
// Enabled by build option: "--define tensorflow_mkldnn_contraction_kernel=1"
|
||||
#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
|
||||
|
||||
template <typename Scalar, typename IndexType, typename DataMapper,
|
||||
int StorageOrder>
|
||||
struct mkldnn_gemm_pack;
|
||||
|
||||
// mkl_gemm_pack for ColMajor storage order.
|
||||
template <typename Scalar, typename IndexType, typename DataMapper>
|
||||
struct mkldnn_gemm_pack<Scalar, IndexType, DataMapper,
|
||||
/*StorageOrder*/ ColMajor> {
|
||||
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||
|
||||
enum { PacketSize = internal::packet_traits<Scalar>::size };
|
||||
|
||||
EIGEN_DONT_INLINE
|
||||
void operator()(Scalar* block, const DataMapper& data_mapper, IndexType rows,
|
||||
IndexType cols) {
|
||||
const IndexType unrolled_rows =
|
||||
(rows / (4 * PacketSize)) * (4 * PacketSize);
|
||||
const IndexType vectorized_rows = (rows / PacketSize) * PacketSize;
|
||||
|
||||
for (IndexType col = 0; col < cols; ++col) {
|
||||
LinearMapper lm = data_mapper.getLinearMapper(0, col);
|
||||
|
||||
// Give compiler a strong possibility to unroll the loop.
|
||||
for (IndexType i = 0; i < unrolled_rows; i += 4 * PacketSize) {
|
||||
for (IndexType j = 0; j < 4; ++j) {
|
||||
const Packet p = lm.template loadPacket<Packet>(i + j * PacketSize);
|
||||
internal::pstoreu(block + j * PacketSize, p);
|
||||
}
|
||||
block += 4 * PacketSize;
|
||||
}
|
||||
|
||||
// Process remaining rows with packets.
|
||||
for (IndexType i = unrolled_rows; i < vectorized_rows; i += PacketSize) {
|
||||
const Packet p = lm.template loadPacket<Packet>(i);
|
||||
internal::pstoreu(block, p);
|
||||
block += PacketSize;
|
||||
}
|
||||
|
||||
// Finalize with coefficients.
|
||||
for (IndexType i = vectorized_rows; i < rows; ++i) {
|
||||
*block = lm(i);
|
||||
++block;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename IndexType, typename OutputMapper,
|
||||
bool ConjugateLhs = false, bool ConjugateRhs = false>
|
||||
struct mkldnn_gemm_kernel;
|
||||
|
||||
// mkldnn_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm.
|
||||
template <typename IndexType, typename OutputMapper, bool ConjugateLhs,
|
||||
bool ConjugateRhs>
|
||||
struct mkldnn_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper,
|
||||
ConjugateLhs, ConjugateRhs> {
|
||||
EIGEN_DONT_INLINE
|
||||
void operator()(const OutputMapper& output, const float* blockA,
|
||||
const float* blockB, const IndexType rows,
|
||||
const IndexType depth, const IndexType cols, float alpha) {
|
||||
static const int max_index = (std::numeric_limits<int>::max)();
|
||||
|
||||
eigen_assert(max_index >= rows);
|
||||
eigen_assert(max_index >= cols);
|
||||
eigen_assert(max_index >= depth);
|
||||
eigen_assert(max_index >= output.stride());
|
||||
|
||||
const int m = static_cast<int>(rows);
|
||||
const int n = static_cast<int>(cols);
|
||||
const int k = static_cast<int>(depth);
|
||||
|
||||
const char transposeA = ConjugateLhs ? 'Y' : 'N';
|
||||
const char transposeB = ConjugateRhs ? 'Y' : 'N';
|
||||
|
||||
const int ldA = ConjugateLhs ? k : m;
|
||||
const int ldB = ConjugateRhs ? n : k;
|
||||
const int ldC = static_cast<int>(output.stride());
|
||||
|
||||
const float beta = 1.0;
|
||||
|
||||
mkldnn_status_t st = mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k,
|
||||
&alpha, blockA, &ldA, blockB, &ldB, &beta,
|
||||
const_cast<float*>(output.data()), &ldC);
|
||||
eigen_assert(st == 0);
|
||||
}
|
||||
};
|
||||
|
||||
// For mkldnn_sgemm having the right dimensions (especially for small matrices)
|
||||
// is more important than fitting all the working set in L1/L2 caches.
|
||||
// TODO(ezhulenev): Do better heuristics.
|
||||
template <typename StorageIndex, int sharding_type>
|
||||
class TensorContractionBlocking<float, float, float, StorageIndex,
|
||||
sharding_type> {
|
||||
// For now mkldnn has only mkldnn_sgemm (gemm for floats).
|
||||
using Scalar = float;
|
||||
|
||||
// Adjust the block sizes to work well with mkldnn kernels.
|
||||
|
||||
// Multiply default choice of block size along M and N dimensions.
|
||||
// TODO(ezhulenev): Explore if this can work in general (kScaleM=2.0 worked
|
||||
// well in some of models).
|
||||
static const float kScaleM = 1.5;
|
||||
static const float kScaleN = 1.0;
|
||||
|
||||
// Mkldnn Avx/Avx2/Avx512 unroll factors are: 8/16/48.
|
||||
static const StorageIndex kUnrollM = 48;
|
||||
|
||||
// Mkldnn Avx/Avx2/Avx512 unroll factors are: 6/6/8.
|
||||
static const StorageIndex kUnrollN = 24;
|
||||
|
||||
public:
|
||||
TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n,
|
||||
StorageIndex num_threads = 1)
|
||||
: kc_(k), mc_(m), nc_(n) {
|
||||
// 1. Compute block sizes using default Eigen heuristics.
|
||||
if (sharding_type == ShardByCol) {
|
||||
computeProductBlockingSizes<Scalar, Scalar, 1>(kc_, mc_, nc_,
|
||||
num_threads);
|
||||
} else {
|
||||
computeProductBlockingSizes<Scalar, Scalar, 1>(kc_, nc_, mc_,
|
||||
num_threads);
|
||||
}
|
||||
|
||||
// 2. And refine them to work well with mkldnn sgemm.
|
||||
mc_ = (std::min)(
|
||||
m, Eigen::divup(static_cast<StorageIndex>(mc_ * kScaleM), kUnrollM) *
|
||||
kUnrollM);
|
||||
nc_ = (std::min)(
|
||||
n, Eigen::divup(static_cast<StorageIndex>(nc_ * kScaleN), kUnrollN) *
|
||||
kUnrollN);
|
||||
|
||||
// We split Kth dimensions in roughly equal slices.
|
||||
StorageIndex target_k_slices =
|
||||
(std::max)(StorageIndex(1), Eigen::divup(k, kc_));
|
||||
StorageIndex packet_size = 8;
|
||||
StorageIndex target_bk =
|
||||
Eigen::divup(k / target_k_slices, packet_size) * packet_size;
|
||||
kc_ = (std::min)(k, target_bk);
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; }
|
||||
EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; }
|
||||
EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; }
|
||||
|
||||
private:
|
||||
StorageIndex kc_;
|
||||
StorageIndex mc_;
|
||||
StorageIndex nc_;
|
||||
};
|
||||
|
||||
template <typename StorageIndex, typename OutputMapper, typename LhsMapper,
|
||||
typename RhsMapper>
|
||||
struct TensorContractionKernel<float, float, float, StorageIndex, OutputMapper,
|
||||
LhsMapper, RhsMapper> {
|
||||
// For now mkldnn has only mkldnn_sgemm (gemm for floats).
|
||||
using Scalar = float;
|
||||
using Traits = typename internal::gebp_traits<Scalar, Scalar>;
|
||||
|
||||
using LhsPacker = mkldnn_gemm_pack<Scalar, StorageIndex,
|
||||
typename LhsMapper::SubMapper, ColMajor>;
|
||||
using RhsPacker = mkldnn_gemm_pack<Scalar, StorageIndex,
|
||||
typename RhsMapper::SubMapper, ColMajor>;
|
||||
using GemmKernel = mkldnn_gemm_kernel<Scalar, StorageIndex, OutputMapper>;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void packLhs(
|
||||
Scalar* lhsBlock, const typename LhsMapper::SubMapper& data_mapper,
|
||||
const StorageIndex depth, const StorageIndex rows) {
|
||||
LhsPacker()(lhsBlock, data_mapper, rows, depth);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void packRhs(
|
||||
Scalar* rhsBlock, const typename RhsMapper::SubMapper& data_mapper,
|
||||
const StorageIndex depth, const StorageIndex cols) {
|
||||
RhsPacker()(rhsBlock, data_mapper, depth, cols);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void invoke(
|
||||
const OutputMapper& output_mapper, const Scalar* lhsBlock,
|
||||
const Scalar* rhsBlock, const StorageIndex rows, const StorageIndex depth,
|
||||
const StorageIndex cols, const Scalar alpha) {
|
||||
GemmKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_
|
@ -19,6 +19,10 @@ limitations under the License.
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/eigen_volume_patch.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
@ -51,11 +55,10 @@ namespace internal {
|
||||
// col - index of the extracted patch (in code: patchIndex)
|
||||
// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
|
||||
//
|
||||
template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
|
||||
DenseIndex Cols, typename ArgType, typename Device, typename Scalar_,
|
||||
typename Index, typename nocontract_t, typename contract_t, int Side,
|
||||
int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
|
||||
int Alignment>
|
||||
template <typename NewDimension, Index Planes, Index Rows, Index Cols,
|
||||
typename ArgType, typename Device, typename Scalar_, typename Index,
|
||||
typename nocontract_t, typename contract_t, int Side, int packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionInputMapper<
|
||||
Scalar_, Index, Side,
|
||||
TensorEvaluator<const TensorReshapingOp<NewDimension,
|
||||
@ -681,11 +684,10 @@ class TensorContractionInputMapper<
|
||||
const TensorEvaluator<ArgType, Device> m_impl;
|
||||
};
|
||||
|
||||
template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
|
||||
DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
|
||||
typename Index, typename nocontract_t, typename contract_t, int Side,
|
||||
int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
|
||||
int Alignment>
|
||||
template <typename NewDimension, Index Planes, Index Rows, Index Cols,
|
||||
typename ArgType, typename Device, typename Scalar, typename Index,
|
||||
typename nocontract_t, typename contract_t, int Side, int packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionSubMapper<
|
||||
Scalar, Index, Side,
|
||||
TensorEvaluator<const TensorReshapingOp<NewDimension,
|
||||
@ -993,11 +995,11 @@ class TensorContractionSubMapper<
|
||||
// *) nr - number of registers along the 'n' dimension.
|
||||
// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
|
||||
// Multiplication" paper.
|
||||
template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
|
||||
DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
|
||||
typename Index, typename nocontract_t, typename contract_t,
|
||||
int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
|
||||
int Alignment, int nr>
|
||||
template <typename NewDimension, Index Planes, Index Rows, Index Cols,
|
||||
typename ArgType, typename Device, typename Scalar, typename Index,
|
||||
typename nocontract_t, typename contract_t, int packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
|
||||
int nr>
|
||||
struct gemm_pack_rhs<
|
||||
Scalar, Index,
|
||||
TensorContractionSubMapper<
|
||||
@ -1172,11 +1174,10 @@ struct gemm_pack_rhs<
|
||||
|
||||
// Template specialization for packet_size = 2. We must special-case packet
|
||||
// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
|
||||
template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
|
||||
DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
|
||||
typename Index, typename nocontract_t, typename contract_t,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
|
||||
int nr>
|
||||
template <typename NewDimension, Index Planes, Index Rows, Index Cols,
|
||||
typename ArgType, typename Device, typename Scalar, typename Index,
|
||||
typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
|
||||
bool inner_dim_reordered, int Alignment, int nr>
|
||||
struct gemm_pack_rhs<
|
||||
Scalar, Index,
|
||||
TensorContractionSubMapper<
|
||||
@ -1353,11 +1354,10 @@ struct gemm_pack_rhs<
|
||||
};
|
||||
|
||||
// Special case for non-vectorized types such as float16 (packet_size = 1).
|
||||
template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
|
||||
DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
|
||||
typename Index, typename nocontract_t, typename contract_t,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
|
||||
int nr>
|
||||
template <typename NewDimension, Index Planes, Index Rows, Index Cols,
|
||||
typename ArgType, typename Device, typename Scalar, typename Index,
|
||||
typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
|
||||
bool inner_dim_reordered, int Alignment, int nr>
|
||||
struct gemm_pack_rhs<
|
||||
Scalar, Index,
|
||||
TensorContractionSubMapper<
|
||||
@ -1427,6 +1427,153 @@ struct gemm_pack_rhs<
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
|
||||
// Arrange a block of the right input matrix (in our case it's always a "virtual
|
||||
// matrix" constructed from extracted volume patches) in contiguous memory.
|
||||
//
|
||||
// Mkldnn doesn't require Lhs/Rhs blocks to be packed in any specific format, so
|
||||
// this is basically the same as taking a slice of the matrix. Knowing
|
||||
// properties of the original patch op we can do it more efficient than default
|
||||
// mkldnn_gemm_pack.
|
||||
template <typename NewDimension, Index Planes, Index Rows, Index Cols,
|
||||
typename ArgType, typename Device, typename Scalar,
|
||||
typename StorageIndex, typename nocontract_t, typename contract_t,
|
||||
int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
|
||||
int Alignment>
|
||||
struct mkldnn_gemm_pack<
|
||||
Scalar, StorageIndex,
|
||||
TensorContractionSubMapper<
|
||||
Scalar, StorageIndex, Rhs,
|
||||
TensorEvaluator<const TensorReshapingOp<
|
||||
NewDimension, const TensorVolumePatchOp<
|
||||
Planes, Rows, Cols, ArgType> >,
|
||||
Device>,
|
||||
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
|
||||
inner_dim_reordered, Alignment>,
|
||||
ColMajor> {
|
||||
typedef TensorContractionSubMapper<
|
||||
Scalar, StorageIndex, Rhs,
|
||||
TensorEvaluator<const TensorReshapingOp<
|
||||
NewDimension, const TensorVolumePatchOp<
|
||||
Planes, Rows, Cols, ArgType> >,
|
||||
Device>,
|
||||
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
|
||||
inner_dim_reordered, Alignment>
|
||||
SubMapper;
|
||||
|
||||
typedef SubMapper DataMapper;
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
|
||||
EIGEN_DONT_INLINE
|
||||
void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows,
|
||||
StorageIndex cols) {
|
||||
const bool standard_patches = !rhs.nonStandardPatches();
|
||||
|
||||
const StorageIndex vectorized_rows = (rows / packet_size) * packet_size;
|
||||
|
||||
if (standard_patches && rhs.patchDepth() % packet_size == 0) {
|
||||
// If patches are standard and patch depth is a multiple of the packet
|
||||
// size, than we can guarantee that single packet do not span across
|
||||
// multiple patch rows or columns, and we can read it directly from
|
||||
// TensorPatchOp argument.
|
||||
|
||||
// Give vectorized_rows the name used in all other gemm_pack_rhs above.
|
||||
const Index peeled_k = vectorized_rows;
|
||||
|
||||
const Index start_col = rhs.colOffset();
|
||||
const Index max_col = rhs.maxCol(peeled_k);
|
||||
|
||||
for (StorageIndex col = 0; col < cols; ++col) {
|
||||
SubMapper lm = rhs.getLinearMapper(0, col);
|
||||
|
||||
Index k = 0;
|
||||
for (Index c = start_col; c < max_col; ++c) {
|
||||
eigen_assert(k <= peeled_k);
|
||||
|
||||
const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
|
||||
const Index max_row = rhs.maxRow(peeled_k, c);
|
||||
const bool pad_col = lm.padCol(c);
|
||||
|
||||
for (Index r = start_row; r < max_row; ++r) {
|
||||
eigen_assert(k <= peeled_k);
|
||||
|
||||
const Index start_plane =
|
||||
((c == start_col) && (r == start_row)) ? rhs.planeOffset() : 0;
|
||||
const Index max_plane = rhs.maxPlane(peeled_k, c, r);
|
||||
const bool pad_row = pad_col || lm.padRow(r);
|
||||
|
||||
for (Index p = start_plane; p < max_plane; ++p) {
|
||||
eigen_assert(k <= peeled_k);
|
||||
|
||||
const Index start_depth =
|
||||
((c == start_col) && (r == start_row) && (p == start_plane))
|
||||
? rhs.depthOffset()
|
||||
: 0;
|
||||
const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
|
||||
eigen_assert((max_depth - start_depth) % packet_size == 0);
|
||||
|
||||
const bool pad = pad_col || pad_row || lm.padPlane(p);
|
||||
const Index base_idx = lm.baseIndex(p, r, c);
|
||||
|
||||
for (Index d = start_depth; d < max_depth; d += packet_size) {
|
||||
eigen_assert(k < peeled_k);
|
||||
const Packet packet = pad ? pset1<Packet>(Scalar(0))
|
||||
: rhs.packetNoPadding(d, base_idx);
|
||||
internal::pstoreu(block, packet);
|
||||
block += packet_size;
|
||||
k += packet_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The loop above should fill peeled_k elements.
|
||||
eigen_assert(peeled_k == k);
|
||||
|
||||
// Fill remaining elements using loadCoeffStandard.
|
||||
for (; k < rows; ++k) {
|
||||
*block = lm.loadCoeffStandard(k);
|
||||
++block;
|
||||
}
|
||||
}
|
||||
|
||||
} else if (standard_patches) {
|
||||
// Single packet can span across multiple patch rows or columns, so we
|
||||
// have to go through the slower path, that will fallback on building a
|
||||
// packet from coefficients.
|
||||
|
||||
for (StorageIndex col = 0; col < cols; ++col) {
|
||||
SubMapper lm = rhs.getLinearMapper(0, col);
|
||||
|
||||
for (StorageIndex i = 0; i < vectorized_rows; i += packet_size) {
|
||||
const Packet p = lm.loadPacketStandard(i);
|
||||
internal::pstoreu(block, p);
|
||||
block += packet_size;
|
||||
}
|
||||
|
||||
// Finalize with coefficients.
|
||||
for (StorageIndex i = vectorized_rows; i < rows; ++i) {
|
||||
*block = lm.loadCoeffStandard(i);
|
||||
++block;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
// With non-standard patches we don't do any vectorized loads.
|
||||
// TODO(ezhulenev): It doesn't look like that we should completely give up
|
||||
// on packets. Make this code path faster!
|
||||
for (StorageIndex col = 0; col < cols; ++col) {
|
||||
SubMapper lm = rhs.getLinearMapper(0, col);
|
||||
for (StorageIndex i = 0; i < rows; ++i) {
|
||||
*block = lm(i);
|
||||
++block;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
|
||||
|
||||
} // namespace internal
|
||||
|
||||
/** CuboidConvolution
|
||||
@ -1478,9 +1625,8 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
|
||||
const DSizes<typename internal::traits<Input>::Index, 2>,
|
||||
const Kernel> > > >::type
|
||||
CuboidConvolution(const Input& input, const Kernel& kernel,
|
||||
const DenseIndex stridePlanes = 1,
|
||||
const DenseIndex strideRows = 1,
|
||||
const DenseIndex strideCols = 1,
|
||||
const Index stridePlanes = 1, const Index strideRows = 1,
|
||||
const Index strideCols = 1,
|
||||
const PaddingType padding_type = PADDING_SAME) {
|
||||
typedef typename internal::traits<Input>::Index TensorIndex;
|
||||
TensorRef<Tensor<typename internal::traits<Input>::Scalar,
|
||||
|
@ -1,123 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_
|
||||
|
||||
// Support for Mkldnn sgemm kernel in Eigen/Tensor contractions:
|
||||
//
|
||||
// 1. Prepare packed Lhs/Rhs blocks from tensor expressions using
|
||||
// DataMapper (see TensorContractionInputMapper).
|
||||
// 2. Invoke gemm kernel with packed blocks (replacement for default
|
||||
// gebp_kernel).
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "third_party/intel_mkl_dnn/include/mkldnn.h"
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
template <typename Scalar, typename IndexType, typename DataMapper,
|
||||
int StorageOrder>
|
||||
struct mkldnn_gemm_pack;
|
||||
|
||||
// mkl_gemm_pack for ColMajor storage order.
|
||||
template <typename Scalar, typename IndexType, typename DataMapper>
|
||||
struct mkldnn_gemm_pack<Scalar, IndexType, DataMapper,
|
||||
/*StorageOrder*/ ColMajor> {
|
||||
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||
|
||||
enum { PacketSize = internal::packet_traits<Scalar>::size };
|
||||
|
||||
EIGEN_DONT_INLINE
|
||||
void operator()(Scalar *block, const DataMapper &data_mapper, IndexType rows,
|
||||
IndexType cols) {
|
||||
const IndexType unrolled_rows =
|
||||
(rows / (4 * PacketSize)) * (4 * PacketSize);
|
||||
const IndexType vectorized_rows = (rows / PacketSize) * PacketSize;
|
||||
|
||||
for (IndexType col = 0; col < cols; ++col) {
|
||||
LinearMapper lm = data_mapper.getLinearMapper(0, col);
|
||||
|
||||
// Give compiler a strong possibility to unroll the loop.
|
||||
for (IndexType i = 0; i < unrolled_rows; i += 4 * PacketSize) {
|
||||
for (IndexType j = 0; j < 4; ++j) {
|
||||
const Packet p = lm.loadPacket(i + j * PacketSize);
|
||||
internal::pstoreu(block + j * PacketSize, p);
|
||||
}
|
||||
block += 4 * PacketSize;
|
||||
}
|
||||
|
||||
// Process remaining rows with packets.
|
||||
for (IndexType i = unrolled_rows; i < vectorized_rows; i += PacketSize) {
|
||||
const Packet p = lm.loadPacket(i);
|
||||
internal::pstoreu(block, p);
|
||||
block += PacketSize;
|
||||
}
|
||||
|
||||
// Finalize with coefficients.
|
||||
for (IndexType i = vectorized_rows; i < rows; ++i) {
|
||||
*block = lm(i);
|
||||
++block;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename IndexType, typename OutputMapper,
|
||||
bool ConjugateLhs = false, bool ConjugateRhs = false>
|
||||
struct mkldnn_gemm_kernel;
|
||||
|
||||
// mkldnn_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm.
|
||||
template <typename IndexType, typename OutputMapper, bool ConjugateLhs,
|
||||
bool ConjugateRhs>
|
||||
struct mkldnn_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper,
|
||||
ConjugateLhs, ConjugateRhs> {
|
||||
EIGEN_DONT_INLINE
|
||||
void operator()(const OutputMapper &output, const float *blockA,
|
||||
const float *blockB, const IndexType rows,
|
||||
const IndexType depth, const IndexType cols, float alpha) {
|
||||
static const int max_index = (std::numeric_limits<int>::max)();
|
||||
|
||||
eigen_assert(max_index >= rows);
|
||||
eigen_assert(max_index >= cols);
|
||||
eigen_assert(max_index >= depth);
|
||||
eigen_assert(max_index >= output.stride());
|
||||
|
||||
const int m = static_cast<int>(rows);
|
||||
const int n = static_cast<int>(cols);
|
||||
const int k = static_cast<int>(depth);
|
||||
|
||||
const char transposeA = ConjugateLhs ? 'Y' : 'N';
|
||||
const char transposeB = ConjugateRhs ? 'Y' : 'N';
|
||||
|
||||
const int ldA = ConjugateLhs ? k : m;
|
||||
const int ldB = ConjugateRhs ? n : k;
|
||||
const int ldC = static_cast<int>(output.stride());
|
||||
|
||||
const float beta = 1.0;
|
||||
|
||||
mkldnn_status_t st = mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k,
|
||||
&alpha, blockA, &ldA, blockB, &ldB, &beta,
|
||||
const_cast<float *>(output.data()), &ldC);
|
||||
eigen_assert(st == 0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/eigen_mkldnn.h"
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace Eigen {
|
@ -18,6 +18,10 @@ limitations under the License.
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
@ -52,8 +56,8 @@ namespace internal {
|
||||
//
|
||||
// TODO(ezhulenev): Consolidate this part of the code with the image patch
|
||||
// extraction code since they are both very similar.
|
||||
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
|
||||
typename ArgType, typename Device, typename Scalar_, typename Index,
|
||||
template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
|
||||
typename Device, typename Scalar_, typename Index,
|
||||
typename nocontract_t, typename contract_t, int Side, int packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionInputMapper<
|
||||
@ -511,8 +515,8 @@ class TensorContractionInputMapper<
|
||||
const TensorEvaluator<ArgType, Device> m_impl;
|
||||
};
|
||||
|
||||
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
|
||||
typename ArgType, typename Device, typename Scalar, typename Index,
|
||||
template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
|
||||
typename Device, typename Scalar, typename Index,
|
||||
typename nocontract_t, typename contract_t, int Side, int packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionSubMapper<
|
||||
@ -770,8 +774,8 @@ class TensorContractionSubMapper<
|
||||
// *) nr - number of registers along the 'n' dimension.
|
||||
// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
|
||||
// Multiplication" paper.
|
||||
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
|
||||
typename ArgType, typename Device, typename Scalar, typename Index,
|
||||
template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
|
||||
typename Device, typename Scalar, typename Index,
|
||||
typename nocontract_t, typename contract_t, int packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
|
||||
int nr>
|
||||
@ -931,8 +935,8 @@ struct gemm_pack_rhs<
|
||||
|
||||
// Template specialization for packet_size = 2. We must special-case packet
|
||||
// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
|
||||
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
|
||||
typename ArgType, typename Device, typename Scalar, typename Index,
|
||||
template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
|
||||
typename Device, typename Scalar, typename Index,
|
||||
typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
|
||||
bool inner_dim_reordered, int Alignment, int nr>
|
||||
struct gemm_pack_rhs<
|
||||
@ -1097,8 +1101,8 @@ struct gemm_pack_rhs<
|
||||
};
|
||||
|
||||
// Special case for non-vectorized types such as float16.
|
||||
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
|
||||
typename ArgType, typename Device, typename Scalar, typename Index,
|
||||
template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
|
||||
typename Device, typename Scalar, typename Index,
|
||||
typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
|
||||
bool inner_dim_reordered, int Alignment, int nr>
|
||||
struct gemm_pack_rhs<
|
||||
@ -1170,6 +1174,141 @@ struct gemm_pack_rhs<
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
|
||||
// Arrange a block of the right input matrix (in our case it's always a
|
||||
// "virtual matrix" constructed from extracted image patches) in contiguous
|
||||
// memory.
|
||||
//
|
||||
// Mkldnn doesn't require Lhs/Rhs blocks to be packed in any specific format, so
|
||||
// this is basically the same as taking a slice of the matrix. Knowing
|
||||
// properties of the original patch op we can do it more efficient than default
|
||||
// mkldnn_gemm_pack.
|
||||
template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
|
||||
typename Device, typename Scalar, typename StorageIndex,
|
||||
typename nocontract_t, typename contract_t, int packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
struct mkldnn_gemm_pack<
|
||||
Scalar, StorageIndex,
|
||||
TensorContractionSubMapper<
|
||||
Scalar, StorageIndex, Rhs,
|
||||
TensorEvaluator<
|
||||
const TensorReshapingOp<
|
||||
NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
|
||||
Device>,
|
||||
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
|
||||
inner_dim_reordered, Alignment>,
|
||||
ColMajor> {
|
||||
typedef TensorContractionSubMapper<
|
||||
Scalar, StorageIndex, Rhs,
|
||||
TensorEvaluator<
|
||||
const TensorReshapingOp<
|
||||
NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
|
||||
Device>,
|
||||
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
|
||||
inner_dim_reordered, Alignment>
|
||||
SubMapper;
|
||||
|
||||
typedef SubMapper DataMapper;
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
|
||||
EIGEN_DONT_INLINE
|
||||
void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows,
|
||||
StorageIndex cols) {
|
||||
const bool standard_patches = !rhs.nonStandardPatches();
|
||||
|
||||
const StorageIndex vectorized_rows = (rows / packet_size) * packet_size;
|
||||
|
||||
if (standard_patches && rhs.patchDepth() % packet_size == 0) {
|
||||
// If patches are standard and patch depth is a multiple of the packet
|
||||
// size, than we can guarantee that single packet do not span across
|
||||
// multiple patch rows or columns, and we can read it directly from
|
||||
// TensorPatchOp argument.
|
||||
|
||||
// Give vectorized_rows the name used in all other gemm_pack_rhs above.
|
||||
const Index peeled_k = vectorized_rows;
|
||||
|
||||
const Index start_col = rhs.colOffset();
|
||||
const Index max_col = rhs.maxCol(peeled_k);
|
||||
|
||||
for (StorageIndex col = 0; col < cols; ++col) {
|
||||
SubMapper lm = rhs.getLinearMapper(0, col);
|
||||
|
||||
Index k = 0;
|
||||
for (Index c = start_col; c < max_col; ++c) {
|
||||
eigen_assert(k <= peeled_k);
|
||||
|
||||
const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
|
||||
const Index max_row = rhs.maxRow(peeled_k, c);
|
||||
const bool pad_col = lm.padCol(c);
|
||||
|
||||
for (Index r = start_row; r < max_row; ++r) {
|
||||
eigen_assert(k <= peeled_k);
|
||||
|
||||
const Index start_depth =
|
||||
((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
|
||||
const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
|
||||
eigen_assert((max_depth - start_depth) % packet_size == 0);
|
||||
|
||||
const bool pad = pad_col || lm.padRow(r);
|
||||
const Index base_idx = lm.baseIndex(r, c);
|
||||
|
||||
for (Index d = start_depth; d < max_depth; d += packet_size) {
|
||||
eigen_assert(k < peeled_k);
|
||||
const Packet p = pad ? pset1<Packet>(Scalar(0))
|
||||
: rhs.packetNoPadding(d, base_idx);
|
||||
internal::pstoreu(block, p);
|
||||
block += packet_size;
|
||||
k += packet_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The loop above should fill peeled_k elements.
|
||||
eigen_assert(peeled_k == k);
|
||||
|
||||
// Fill remaining elements using loadCoeffStandard.
|
||||
for (; k < rows; ++k) {
|
||||
*block = lm.loadCoeffStandard(k);
|
||||
++block;
|
||||
}
|
||||
}
|
||||
} else if (standard_patches) {
|
||||
// Single packet can span across multiple patch rows or columns, so we
|
||||
// have to go through the slower path, that will fallback on building a
|
||||
// packet from coefficients.
|
||||
|
||||
for (StorageIndex col = 0; col < cols; ++col) {
|
||||
SubMapper lm = rhs.getLinearMapper(0, col);
|
||||
|
||||
for (StorageIndex i = 0; i < vectorized_rows; i += packet_size) {
|
||||
const Packet p = lm.loadPacketStandard(i);
|
||||
internal::pstoreu(block, p);
|
||||
block += packet_size;
|
||||
}
|
||||
|
||||
// Finalize with coefficients.
|
||||
for (StorageIndex i = vectorized_rows; i < rows; ++i) {
|
||||
*block = lm.loadCoeffStandard(i);
|
||||
++block;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
// With non-standard patches we don't do any vectorized loads.
|
||||
// TODO(ezhulenev): It doesn't look like that we should completely give up
|
||||
// on packets. Make this code path faster!
|
||||
for (StorageIndex col = 0; col < cols; ++col) {
|
||||
SubMapper lm = rhs.getLinearMapper(0, col);
|
||||
for (StorageIndex i = 0; i < rows; ++i) {
|
||||
*block = lm(i);
|
||||
++block;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
/** SpatialConvolution
|
||||
@ -1226,11 +1365,10 @@ EIGEN_DEVICE_FUNC
|
||||
const DSizes<typename internal::traits<Input>::Index, 2>,
|
||||
const Kernel> > > >::type
|
||||
SpatialConvolution(const Input& input, const Kernel& kernel,
|
||||
const DenseIndex row_stride = 1,
|
||||
const DenseIndex col_stride = 1,
|
||||
const Index row_stride = 1, const Index col_stride = 1,
|
||||
const PaddingType padding_type = PADDING_SAME,
|
||||
const DenseIndex row_in_stride = 1,
|
||||
const DenseIndex col_in_stride = 1) {
|
||||
const Index row_in_stride = 1,
|
||||
const Index col_in_stride = 1) {
|
||||
typedef typename internal::traits<Input>::Index TensorIndex;
|
||||
TensorRef<Tensor<typename internal::traits<Input>::Scalar,
|
||||
internal::traits<Input>::NumDimensions,
|
||||
@ -1260,9 +1398,9 @@ EIGEN_DEVICE_FUNC
|
||||
const TensorIndex kernelCols =
|
||||
isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
|
||||
|
||||
const DenseIndex kernelRowsEff =
|
||||
const Index kernelRowsEff =
|
||||
kernelRows + (kernelRows - 1) * (row_in_stride - 1);
|
||||
const DenseIndex kernelColsEff =
|
||||
const Index kernelColsEff =
|
||||
kernelCols + (kernelCols - 1) * (col_in_stride - 1);
|
||||
|
||||
array<IndexPair<TensorIndex>, 1> contract_dims;
|
||||
|
@ -36,6 +36,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
// Apple provides an optimized BLAS library that is better than Eigen for their
|
||||
// devices, so use that if possible.
|
||||
#if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV)
|
||||
|
@ -26,6 +26,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
#endif
|
||||
|
@ -21,6 +21,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
|
@ -44,6 +44,10 @@ limitations under the License.
|
||||
#include "include/libxsmm_spmdm.h"
|
||||
#endif
|
||||
|
||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
|
@ -2888,7 +2888,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:nn_grad",
|
||||
"//tensorflow/python:nn_ops",
|
||||
],
|
||||
shard_count = 20,
|
||||
shard_count = 30,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
@ -638,6 +638,30 @@ class Conv3DTest(test.TestCase):
|
||||
padding="SAME",
|
||||
test_input=False)
|
||||
|
||||
# Test the fast path in gemm_pack_rhs/mkldnn_gemm_pack, when channel
|
||||
# dimension is a multiple of packet size.
|
||||
def testInputGradientValidPaddingStrideOneFastPath(self):
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_shape=(3, 5, 4),
|
||||
filter_shape=(2, 2, 2),
|
||||
in_depth=8,
|
||||
out_depth=2,
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
test_input=True)
|
||||
|
||||
def testFilterGradientValidPaddingStrideOneFastPath(self):
|
||||
self.ConstructAndTestGradient(
|
||||
batch=2,
|
||||
input_shape=(4, 6, 5),
|
||||
filter_shape=(2, 2, 2),
|
||||
in_depth=8,
|
||||
out_depth=2,
|
||||
stride=1,
|
||||
padding="VALID",
|
||||
test_input=False)
|
||||
|
||||
# Testing for backprops
|
||||
def _RunAndVerifyBackprop(self, input_sizes, filter_sizes, output_sizes,
|
||||
strides, dilations, padding, data_format, use_gpu,
|
||||
|
Loading…
Reference in New Issue
Block a user