diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index b855fe5436b..2fbc7d75ec6 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -94,13 +94,13 @@ config_setting( ) config_setting( - # Add "--define tensorflow_eigen_mkldnn=1" to your build command to use mkldnn - # sgemm in Eigen tensor contractions (matrix multiplications and convolutions). - # The mkldnn kernels are generated at runtime and use avx/avx2/fma/avx512 - # based on cpu status registers (https://en.wikipedia.org/wiki/CPUID). - name = "eigen_mkldnn", + # Add "--define tensorflow_mkldnn_contraction_kernel=1" to your build command to use mkldnn + # sgemm in Eigen tensor contractions (matrix multiplications and convolutions). The mkldnn + # kernels are generated at runtime and use avx/avx2/fma/avx512 based on cpu status registers + # (https://en.wikipedia.org/wiki/CPUID). + name = "mkldnn_contraction_kernel", values = { - "define": "tensorflow_eigen_mkldnn=1", + "define": "tensorflow_mkldnn_contraction_kernel=1", }, ) @@ -554,6 +554,40 @@ cc_library( ], ) +# Depending on a build configuration this target provides custom kernel for Eigen +# tensor contractions (small matrix multiplication kernel used to multiple together +# blocks of the original tensors). +# +# 0) Default contraction kernel is Eigen::internal::gebp_kernel. +# +# 1) --define tensorflow_mkldnn_contraction_kernel=1 +# Use Mkldnn single threaded sgemm. The mkldnn kernels are generated at runtime and +# use avx/avx2/fma/avx512 based on cpu status registers (https://en.wikipedia.org/wiki/CPUID). +# +# If you use `tensor.contract(other_tensor)` in your code, you must include additional header +# to get the benefit of custom contraction kernel: +# +# #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +# #include "third_party/tensorflow/core/kernels/eigen_contraction_kernel.h" +# #endif +cc_library( + name = "eigen_contraction_kernel", + hdrs = ["eigen_contraction_kernel.h"], + defines = select({ + ":mkldnn_contraction_kernel": [ + "TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL", + "TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL", + ], + "//conditions:default": [], + }), + deps = [ + "//third_party/eigen3", + ] + select({ + ":mkldnn_contraction_kernel": ["//third_party/intel_mkl_dnn:mkldnn_single_threaded"], + "//conditions:default": [], + }), +) + cc_library( name = "eigen_helpers", hdrs = [ @@ -566,20 +600,11 @@ cc_library( "eigen_softmax.h", "eigen_spatial_convolutions.h", "eigen_volume_patch.h", - ] + select({ - ":eigen_mkldnn": ["eigen_mkldnn.h"], - "//conditions:default": [], - }), - defines = select({ - ":eigen_mkldnn": ["EIGEN_USE_MKLDNN"], - "//conditions:default": [], - }), + ], deps = [ + ":eigen_contraction_kernel", "//third_party/eigen3", - ] + select({ - ":eigen_mkldnn": ["//third_party/intel_mkl_dnn:mkldnn_single_threaded"], - "//conditions:default": [], - }), + ], ) cc_library( @@ -2434,15 +2459,15 @@ tf_cc_tests( # Conditional test target generation is not supported by the "tf_cc_tests" macro # (can't add 'select' to the srcs field, type 'select' is not iterable). tf_cc_test( - name = "eigen_mkldnn_test", + name = "eigen_mkldnn_contraction_kernel_test", size = "small", srcs = select({ - ":eigen_mkldnn": ["eigen_mkldnn_test.cc"], + ":mkldnn_contraction_kernel": ["eigen_mkldnn_contraction_kernel_test.cc"], "//conditions:default": [], }), - tags = ["eigen_mkldnn"], + tags = ["mkldnn_contraction_kernel"], deps = [ - ":eigen_helpers", + ":eigen_contraction_kernel", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -3058,7 +3083,7 @@ tf_kernel_library( # *impl.h are excluded by default from the CPU build, add explicitly. hdrs = ["batch_matmul_op_impl.h"], prefix = "batch_matmul_op", - deps = MATH_DEPS + if_mkl_ml([ + deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([ "//third_party/mkl:intel_binary_blob", ]), ) @@ -3132,11 +3157,10 @@ tf_kernel_library( "//conditions:default": [], }), deps = MATH_DEPS + [ + ":eigen_contraction_kernel", ":gpu_util_hdrs", ] + select({ - ":xsmm": [ - "@libxsmm_archive//:xsmm_avx", - ], + ":xsmm": ["@libxsmm_archive//:xsmm_avx"], "//conditions:default": [], }) + mkl_deps() + if_cuda([ "//tensorflow/core/platform/default/build_config:cublas_plugin", @@ -3539,6 +3563,7 @@ tf_kernel_library( ":bounds_check", ":conv_2d", ":conv_3d", + ":eigen_contraction_kernel", ":image_resizer_state", ":fill_functor", ":ops_util", @@ -3629,6 +3654,7 @@ cc_library( NN_DEPS = [ ":bounds_check", ":conv_2d", + ":eigen_contraction_kernel", ":fused_batch_norm_util_gpu", ":ops_util", ":pooling_ops", diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index 766713a338c..43539ac908f 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -34,6 +34,10 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/work_sharder.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + #if GOOGLE_CUDA #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 9e86a16b66d..bc30da40991 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -44,6 +44,10 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + #if GOOGLE_CUDA #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 43bb5ea56c9..e06af15f2fc 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -43,6 +43,10 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + #if GOOGLE_CUDA #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index bab91f5e861..f62c60d255d 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -35,6 +35,10 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + #if GOOGLE_CUDA #include "tensorflow/core/platform/stream_executor.h" using stream_executor::dnn::DimIndex; diff --git a/tensorflow/core/kernels/eigen_contraction_kernel.h b/tensorflow/core/kernels/eigen_contraction_kernel.h new file mode 100644 index 00000000000..92d29e39958 --- /dev/null +++ b/tensorflow/core/kernels/eigen_contraction_kernel.h @@ -0,0 +1,234 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_ + +// Depending on a build configuration this header provides custom kernel for +// Eigen tensor contractions (small matrix multiplication kernel used to +// multiple together blocks of the original tensors). +// +// 1) --define tensorflow_mkldnn_contraction_kernel=1 +// Use Mkldnn single threaded sgemm. The mkldnn kernels are generated at +// runtime and use avx/avx2/fma/avx512 based on cpu status registers +// (https://en.wikipedia.org/wiki/CPUID). +// +// If you use `tensor.contract(other_tensor)` in your code, you must include +// this header to get the benefit of custom contraction kernel: +// +// #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +// #include "tensorflow/core/kernels/eigen_contraction_kernel.h" +// #endif + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "third_party/intel_mkl_dnn/include/mkldnn.h" + +namespace Eigen { +namespace internal { + +// Enabled by build option: "--define tensorflow_mkldnn_contraction_kernel=1" +#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) + +template +struct mkldnn_gemm_pack; + +// mkl_gemm_pack for ColMajor storage order. +template +struct mkldnn_gemm_pack { + typedef typename internal::packet_traits::type Packet; + typedef typename DataMapper::LinearMapper LinearMapper; + + enum { PacketSize = internal::packet_traits::size }; + + EIGEN_DONT_INLINE + void operator()(Scalar* block, const DataMapper& data_mapper, IndexType rows, + IndexType cols) { + const IndexType unrolled_rows = + (rows / (4 * PacketSize)) * (4 * PacketSize); + const IndexType vectorized_rows = (rows / PacketSize) * PacketSize; + + for (IndexType col = 0; col < cols; ++col) { + LinearMapper lm = data_mapper.getLinearMapper(0, col); + + // Give compiler a strong possibility to unroll the loop. + for (IndexType i = 0; i < unrolled_rows; i += 4 * PacketSize) { + for (IndexType j = 0; j < 4; ++j) { + const Packet p = lm.template loadPacket(i + j * PacketSize); + internal::pstoreu(block + j * PacketSize, p); + } + block += 4 * PacketSize; + } + + // Process remaining rows with packets. + for (IndexType i = unrolled_rows; i < vectorized_rows; i += PacketSize) { + const Packet p = lm.template loadPacket(i); + internal::pstoreu(block, p); + block += PacketSize; + } + + // Finalize with coefficients. + for (IndexType i = vectorized_rows; i < rows; ++i) { + *block = lm(i); + ++block; + } + } + } +}; + +template +struct mkldnn_gemm_kernel; + +// mkldnn_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm. +template +struct mkldnn_gemm_kernel { + EIGEN_DONT_INLINE + void operator()(const OutputMapper& output, const float* blockA, + const float* blockB, const IndexType rows, + const IndexType depth, const IndexType cols, float alpha) { + static const int max_index = (std::numeric_limits::max)(); + + eigen_assert(max_index >= rows); + eigen_assert(max_index >= cols); + eigen_assert(max_index >= depth); + eigen_assert(max_index >= output.stride()); + + const int m = static_cast(rows); + const int n = static_cast(cols); + const int k = static_cast(depth); + + const char transposeA = ConjugateLhs ? 'Y' : 'N'; + const char transposeB = ConjugateRhs ? 'Y' : 'N'; + + const int ldA = ConjugateLhs ? k : m; + const int ldB = ConjugateRhs ? n : k; + const int ldC = static_cast(output.stride()); + + const float beta = 1.0; + + mkldnn_status_t st = mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k, + &alpha, blockA, &ldA, blockB, &ldB, &beta, + const_cast(output.data()), &ldC); + eigen_assert(st == 0); + } +}; + +// For mkldnn_sgemm having the right dimensions (especially for small matrices) +// is more important than fitting all the working set in L1/L2 caches. +// TODO(ezhulenev): Do better heuristics. +template +class TensorContractionBlocking { + // For now mkldnn has only mkldnn_sgemm (gemm for floats). + using Scalar = float; + + // Adjust the block sizes to work well with mkldnn kernels. + + // Multiply default choice of block size along M and N dimensions. + // TODO(ezhulenev): Explore if this can work in general (kScaleM=2.0 worked + // well in some of models). + static const float kScaleM = 1.5; + static const float kScaleN = 1.0; + + // Mkldnn Avx/Avx2/Avx512 unroll factors are: 8/16/48. + static const StorageIndex kUnrollM = 48; + + // Mkldnn Avx/Avx2/Avx512 unroll factors are: 6/6/8. + static const StorageIndex kUnrollN = 24; + + public: + TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n, + StorageIndex num_threads = 1) + : kc_(k), mc_(m), nc_(n) { + // 1. Compute block sizes using default Eigen heuristics. + if (sharding_type == ShardByCol) { + computeProductBlockingSizes(kc_, mc_, nc_, + num_threads); + } else { + computeProductBlockingSizes(kc_, nc_, mc_, + num_threads); + } + + // 2. And refine them to work well with mkldnn sgemm. + mc_ = (std::min)( + m, Eigen::divup(static_cast(mc_ * kScaleM), kUnrollM) * + kUnrollM); + nc_ = (std::min)( + n, Eigen::divup(static_cast(nc_ * kScaleN), kUnrollN) * + kUnrollN); + + // We split Kth dimensions in roughly equal slices. + StorageIndex target_k_slices = + (std::max)(StorageIndex(1), Eigen::divup(k, kc_)); + StorageIndex packet_size = 8; + StorageIndex target_bk = + Eigen::divup(k / target_k_slices, packet_size) * packet_size; + kc_ = (std::min)(k, target_bk); + } + + EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; } + EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; } + EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; } + + private: + StorageIndex kc_; + StorageIndex mc_; + StorageIndex nc_; +}; + +template +struct TensorContractionKernel { + // For now mkldnn has only mkldnn_sgemm (gemm for floats). + using Scalar = float; + using Traits = typename internal::gebp_traits; + + using LhsPacker = mkldnn_gemm_pack; + using RhsPacker = mkldnn_gemm_pack; + using GemmKernel = mkldnn_gemm_kernel; + + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void packLhs( + Scalar* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, + const StorageIndex depth, const StorageIndex rows) { + LhsPacker()(lhsBlock, data_mapper, rows, depth); + } + + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void packRhs( + Scalar* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, + const StorageIndex depth, const StorageIndex cols) { + RhsPacker()(rhsBlock, data_mapper, depth, cols); + } + + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void invoke( + const OutputMapper& output_mapper, const Scalar* lhsBlock, + const Scalar* rhsBlock, const StorageIndex rows, const StorageIndex depth, + const StorageIndex cols, const Scalar alpha) { + GemmKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha); + } +}; + +#endif // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) + +} // namespace internal +} // namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_ diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h index a98850cf4b3..72ac3c0f073 100644 --- a/tensorflow/core/kernels/eigen_cuboid_convolution.h +++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h @@ -19,6 +19,10 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/eigen_volume_patch.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + namespace Eigen { namespace internal { @@ -51,11 +55,10 @@ namespace internal { // col - index of the extracted patch (in code: patchIndex) // patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) // -template +template class TensorContractionInputMapper< Scalar_, Index, Side, TensorEvaluator m_impl; }; -template +template class TensorContractionSubMapper< Scalar, Index, Side, TensorEvaluator +template struct gemm_pack_rhs< Scalar, Index, TensorContractionSubMapper< @@ -1172,11 +1174,10 @@ struct gemm_pack_rhs< // Template specialization for packet_size = 2. We must special-case packet // blocks with nr > packet_size, e.g. PacketBlock. -template +template struct gemm_pack_rhs< Scalar, Index, TensorContractionSubMapper< @@ -1353,11 +1354,10 @@ struct gemm_pack_rhs< }; // Special case for non-vectorized types such as float16 (packet_size = 1). -template +template struct gemm_pack_rhs< Scalar, Index, TensorContractionSubMapper< @@ -1427,6 +1427,153 @@ struct gemm_pack_rhs< } }; +#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) +// Arrange a block of the right input matrix (in our case it's always a "virtual +// matrix" constructed from extracted volume patches) in contiguous memory. +// +// Mkldnn doesn't require Lhs/Rhs blocks to be packed in any specific format, so +// this is basically the same as taking a slice of the matrix. Knowing +// properties of the original patch op we can do it more efficient than default +// mkldnn_gemm_pack. +template +struct mkldnn_gemm_pack< + Scalar, StorageIndex, + TensorContractionSubMapper< + Scalar, StorageIndex, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + ColMajor> { + typedef TensorContractionSubMapper< + Scalar, StorageIndex, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + + typedef SubMapper DataMapper; + typedef typename packet_traits::type Packet; + + EIGEN_DONT_INLINE + void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows, + StorageIndex cols) { + const bool standard_patches = !rhs.nonStandardPatches(); + + const StorageIndex vectorized_rows = (rows / packet_size) * packet_size; + + if (standard_patches && rhs.patchDepth() % packet_size == 0) { + // If patches are standard and patch depth is a multiple of the packet + // size, than we can guarantee that single packet do not span across + // multiple patch rows or columns, and we can read it directly from + // TensorPatchOp argument. + + // Give vectorized_rows the name used in all other gemm_pack_rhs above. + const Index peeled_k = vectorized_rows; + + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (StorageIndex col = 0; col < cols; ++col) { + SubMapper lm = rhs.getLinearMapper(0, col); + + Index k = 0; + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); + const bool pad_col = lm.padCol(c); + + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + + const Index start_plane = + ((c == start_col) && (r == start_row)) ? rhs.planeOffset() : 0; + const Index max_plane = rhs.maxPlane(peeled_k, c, r); + const bool pad_row = pad_col || lm.padRow(r); + + for (Index p = start_plane; p < max_plane; ++p) { + eigen_assert(k <= peeled_k); + + const Index start_depth = + ((c == start_col) && (r == start_row) && (p == start_plane)) + ? rhs.depthOffset() + : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + const bool pad = pad_col || pad_row || lm.padPlane(p); + const Index base_idx = lm.baseIndex(p, r, c); + + for (Index d = start_depth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + const Packet packet = pad ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, base_idx); + internal::pstoreu(block, packet); + block += packet_size; + k += packet_size; + } + } + } + } + + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + + // Fill remaining elements using loadCoeffStandard. + for (; k < rows; ++k) { + *block = lm.loadCoeffStandard(k); + ++block; + } + } + + } else if (standard_patches) { + // Single packet can span across multiple patch rows or columns, so we + // have to go through the slower path, that will fallback on building a + // packet from coefficients. + + for (StorageIndex col = 0; col < cols; ++col) { + SubMapper lm = rhs.getLinearMapper(0, col); + + for (StorageIndex i = 0; i < vectorized_rows; i += packet_size) { + const Packet p = lm.loadPacketStandard(i); + internal::pstoreu(block, p); + block += packet_size; + } + + // Finalize with coefficients. + for (StorageIndex i = vectorized_rows; i < rows; ++i) { + *block = lm.loadCoeffStandard(i); + ++block; + } + } + + } else { + // With non-standard patches we don't do any vectorized loads. + // TODO(ezhulenev): It doesn't look like that we should completely give up + // on packets. Make this code path faster! + for (StorageIndex col = 0; col < cols; ++col) { + SubMapper lm = rhs.getLinearMapper(0, col); + for (StorageIndex i = 0; i < rows; ++i) { + *block = lm(i); + ++block; + } + } + } + } +}; +#endif // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) + } // namespace internal /** CuboidConvolution @@ -1478,9 +1625,8 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional< const DSizes::Index, 2>, const Kernel> > > >::type CuboidConvolution(const Input& input, const Kernel& kernel, - const DenseIndex stridePlanes = 1, - const DenseIndex strideRows = 1, - const DenseIndex strideCols = 1, + const Index stridePlanes = 1, const Index strideRows = 1, + const Index strideCols = 1, const PaddingType padding_type = PADDING_SAME) { typedef typename internal::traits::Index TensorIndex; TensorRef::Scalar, diff --git a/tensorflow/core/kernels/eigen_mkldnn.h b/tensorflow/core/kernels/eigen_mkldnn.h deleted file mode 100644 index 5235431f5f3..00000000000 --- a/tensorflow/core/kernels/eigen_mkldnn.h +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_ -#define TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_ - -// Support for Mkldnn sgemm kernel in Eigen/Tensor contractions: -// -// 1. Prepare packed Lhs/Rhs blocks from tensor expressions using -// DataMapper (see TensorContractionInputMapper). -// 2. Invoke gemm kernel with packed blocks (replacement for default -// gebp_kernel). - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/intel_mkl_dnn/include/mkldnn.h" - -namespace Eigen { -namespace internal { - -template -struct mkldnn_gemm_pack; - -// mkl_gemm_pack for ColMajor storage order. -template -struct mkldnn_gemm_pack { - typedef typename internal::packet_traits::type Packet; - typedef typename DataMapper::LinearMapper LinearMapper; - - enum { PacketSize = internal::packet_traits::size }; - - EIGEN_DONT_INLINE - void operator()(Scalar *block, const DataMapper &data_mapper, IndexType rows, - IndexType cols) { - const IndexType unrolled_rows = - (rows / (4 * PacketSize)) * (4 * PacketSize); - const IndexType vectorized_rows = (rows / PacketSize) * PacketSize; - - for (IndexType col = 0; col < cols; ++col) { - LinearMapper lm = data_mapper.getLinearMapper(0, col); - - // Give compiler a strong possibility to unroll the loop. - for (IndexType i = 0; i < unrolled_rows; i += 4 * PacketSize) { - for (IndexType j = 0; j < 4; ++j) { - const Packet p = lm.loadPacket(i + j * PacketSize); - internal::pstoreu(block + j * PacketSize, p); - } - block += 4 * PacketSize; - } - - // Process remaining rows with packets. - for (IndexType i = unrolled_rows; i < vectorized_rows; i += PacketSize) { - const Packet p = lm.loadPacket(i); - internal::pstoreu(block, p); - block += PacketSize; - } - - // Finalize with coefficients. - for (IndexType i = vectorized_rows; i < rows; ++i) { - *block = lm(i); - ++block; - } - } - } -}; - -template -struct mkldnn_gemm_kernel; - -// mkldnn_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm. -template -struct mkldnn_gemm_kernel { - EIGEN_DONT_INLINE - void operator()(const OutputMapper &output, const float *blockA, - const float *blockB, const IndexType rows, - const IndexType depth, const IndexType cols, float alpha) { - static const int max_index = (std::numeric_limits::max)(); - - eigen_assert(max_index >= rows); - eigen_assert(max_index >= cols); - eigen_assert(max_index >= depth); - eigen_assert(max_index >= output.stride()); - - const int m = static_cast(rows); - const int n = static_cast(cols); - const int k = static_cast(depth); - - const char transposeA = ConjugateLhs ? 'Y' : 'N'; - const char transposeB = ConjugateRhs ? 'Y' : 'N'; - - const int ldA = ConjugateLhs ? k : m; - const int ldB = ConjugateRhs ? n : k; - const int ldC = static_cast(output.stride()); - - const float beta = 1.0; - - mkldnn_status_t st = mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k, - &alpha, blockA, &ldA, blockB, &ldB, &beta, - const_cast(output.data()), &ldC); - eigen_assert(st == 0); - } -}; - -} // namespace internal -} // namespace Eigen - -#endif // TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_ diff --git a/tensorflow/core/kernels/eigen_mkldnn_test.cc b/tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc similarity index 98% rename from tensorflow/core/kernels/eigen_mkldnn_test.cc rename to tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc index 051ab28f792..da4a61d1bda 100644 --- a/tensorflow/core/kernels/eigen_mkldnn_test.cc +++ b/tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/eigen_mkldnn.h" +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" #include "tensorflow/core/platform/test.h" namespace Eigen { diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h index a08c7064d58..c9f07516e84 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h @@ -18,6 +18,10 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + namespace Eigen { namespace internal { @@ -52,8 +56,8 @@ namespace internal { // // TODO(ezhulenev): Consolidate this part of the code with the image patch // extraction code since they are both very similar. -template class TensorContractionInputMapper< @@ -511,8 +515,8 @@ class TensorContractionInputMapper< const TensorEvaluator m_impl; }; -template class TensorContractionSubMapper< @@ -770,8 +774,8 @@ class TensorContractionSubMapper< // *) nr - number of registers along the 'n' dimension. // See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix // Multiplication" paper. -template @@ -931,8 +935,8 @@ struct gemm_pack_rhs< // Template specialization for packet_size = 2. We must special-case packet // blocks with nr > packet_size, e.g. PacketBlock. -template struct gemm_pack_rhs< @@ -1097,8 +1101,8 @@ struct gemm_pack_rhs< }; // Special case for non-vectorized types such as float16. -template struct gemm_pack_rhs< @@ -1170,6 +1174,141 @@ struct gemm_pack_rhs< } }; +#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) +// Arrange a block of the right input matrix (in our case it's always a +// "virtual matrix" constructed from extracted image patches) in contiguous +// memory. +// +// Mkldnn doesn't require Lhs/Rhs blocks to be packed in any specific format, so +// this is basically the same as taking a slice of the matrix. Knowing +// properties of the original patch op we can do it more efficient than default +// mkldnn_gemm_pack. +template +struct mkldnn_gemm_pack< + Scalar, StorageIndex, + TensorContractionSubMapper< + Scalar, StorageIndex, Rhs, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + ColMajor> { + typedef TensorContractionSubMapper< + Scalar, StorageIndex, Rhs, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + + typedef SubMapper DataMapper; + typedef typename packet_traits::type Packet; + + EIGEN_DONT_INLINE + void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows, + StorageIndex cols) { + const bool standard_patches = !rhs.nonStandardPatches(); + + const StorageIndex vectorized_rows = (rows / packet_size) * packet_size; + + if (standard_patches && rhs.patchDepth() % packet_size == 0) { + // If patches are standard and patch depth is a multiple of the packet + // size, than we can guarantee that single packet do not span across + // multiple patch rows or columns, and we can read it directly from + // TensorPatchOp argument. + + // Give vectorized_rows the name used in all other gemm_pack_rhs above. + const Index peeled_k = vectorized_rows; + + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (StorageIndex col = 0; col < cols; ++col) { + SubMapper lm = rhs.getLinearMapper(0, col); + + Index k = 0; + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); + const bool pad_col = lm.padCol(c); + + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + + const Index start_depth = + ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + const bool pad = pad_col || lm.padRow(r); + const Index base_idx = lm.baseIndex(r, c); + + for (Index d = start_depth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + const Packet p = pad ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, base_idx); + internal::pstoreu(block, p); + block += packet_size; + k += packet_size; + } + } + } + + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + + // Fill remaining elements using loadCoeffStandard. + for (; k < rows; ++k) { + *block = lm.loadCoeffStandard(k); + ++block; + } + } + } else if (standard_patches) { + // Single packet can span across multiple patch rows or columns, so we + // have to go through the slower path, that will fallback on building a + // packet from coefficients. + + for (StorageIndex col = 0; col < cols; ++col) { + SubMapper lm = rhs.getLinearMapper(0, col); + + for (StorageIndex i = 0; i < vectorized_rows; i += packet_size) { + const Packet p = lm.loadPacketStandard(i); + internal::pstoreu(block, p); + block += packet_size; + } + + // Finalize with coefficients. + for (StorageIndex i = vectorized_rows; i < rows; ++i) { + *block = lm.loadCoeffStandard(i); + ++block; + } + } + + } else { + // With non-standard patches we don't do any vectorized loads. + // TODO(ezhulenev): It doesn't look like that we should completely give up + // on packets. Make this code path faster! + for (StorageIndex col = 0; col < cols; ++col) { + SubMapper lm = rhs.getLinearMapper(0, col); + for (StorageIndex i = 0; i < rows; ++i) { + *block = lm(i); + ++block; + } + } + } + } +}; +#endif // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) + } // end namespace internal /** SpatialConvolution @@ -1226,11 +1365,10 @@ EIGEN_DEVICE_FUNC const DSizes::Index, 2>, const Kernel> > > >::type SpatialConvolution(const Input& input, const Kernel& kernel, - const DenseIndex row_stride = 1, - const DenseIndex col_stride = 1, + const Index row_stride = 1, const Index col_stride = 1, const PaddingType padding_type = PADDING_SAME, - const DenseIndex row_in_stride = 1, - const DenseIndex col_in_stride = 1) { + const Index row_in_stride = 1, + const Index col_in_stride = 1) { typedef typename internal::traits::Index TensorIndex; TensorRef::Scalar, internal::traits::NumDimensions, @@ -1260,9 +1398,9 @@ EIGEN_DEVICE_FUNC const TensorIndex kernelCols = isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; - const DenseIndex kernelRowsEff = + const Index kernelRowsEff = kernelRows + (kernelRows - 1) * (row_in_stride - 1); - const DenseIndex kernelColsEff = + const Index kernelColsEff = kernelCols + (kernelCols - 1) * (col_in_stride - 1); array, 1> contract_dims; diff --git a/tensorflow/core/kernels/gemm_functors.h b/tensorflow/core/kernels/gemm_functors.h index 1c808440851..97e077c0960 100644 --- a/tensorflow/core/kernels/gemm_functors.h +++ b/tensorflow/core/kernels/gemm_functors.h @@ -36,6 +36,10 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + // Apple provides an optimized BLAS library that is better than Eigen for their // devices, so use that if possible. #if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV) diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc index b4252eb0444..f405ca3c58c 100644 --- a/tensorflow/core/kernels/lrn_op.cc +++ b/tensorflow/core/kernels/lrn_op.cc @@ -26,6 +26,10 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/util/work_sharder.h" #endif diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h index 4b74a64025a..48769f3fe5d 100644 --- a/tensorflow/core/kernels/matmul_op.h +++ b/tensorflow/core/kernels/matmul_op.h @@ -21,6 +21,10 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/lib/hash/hash.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + namespace tensorflow { namespace functor { diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index 866c5dcd521..2ea7a1ed3b9 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -44,6 +44,10 @@ limitations under the License. #include "include/libxsmm_spmdm.h" #endif +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + namespace tensorflow { namespace { diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index c527fad59f7..6b990d3a926 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2888,7 +2888,7 @@ cuda_py_test( "//tensorflow/python:nn_grad", "//tensorflow/python:nn_ops", ], - shard_count = 20, + shard_count = 30, ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py index 57b09dc167f..c4a9cdcf8e0 100644 --- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py @@ -638,6 +638,30 @@ class Conv3DTest(test.TestCase): padding="SAME", test_input=False) + # Test the fast path in gemm_pack_rhs/mkldnn_gemm_pack, when channel + # dimension is a multiple of packet size. + def testInputGradientValidPaddingStrideOneFastPath(self): + self.ConstructAndTestGradient( + batch=2, + input_shape=(3, 5, 4), + filter_shape=(2, 2, 2), + in_depth=8, + out_depth=2, + stride=1, + padding="VALID", + test_input=True) + + def testFilterGradientValidPaddingStrideOneFastPath(self): + self.ConstructAndTestGradient( + batch=2, + input_shape=(4, 6, 5), + filter_shape=(2, 2, 2), + in_depth=8, + out_depth=2, + stride=1, + padding="VALID", + test_input=False) + # Testing for backprops def _RunAndVerifyBackprop(self, input_sizes, filter_sizes, output_sizes, strides, dilations, padding, data_format, use_gpu,