diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index 057e851aba6..15ae95f13cf 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -141,7 +141,7 @@ __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev, // const int gid = batch_id * cell_size * 4 + act_id; const int cid = batch_id * cell_size + act_id; - Eigen::internal::scalar_sigmoid_op sigmoid_op; + Eigen::internal::scalar_logistic_op sigmoid_op; Eigen::internal::scalar_tanh_op tanh_op; Eigen::scalar_clip_op clip_op; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index d8ebdeff5d2..870b2f24d14 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -84,13 +84,13 @@ namespace tensorflow { // corresponding stream have completed. The following two classes // serve this purpose in two different compilation environments. -class EigenCudaStreamDevice : public ::Eigen::StreamInterface { +class EigenGpuStreamDevice : public ::Eigen::StreamInterface { public: - EigenCudaStreamDevice() + EigenGpuStreamDevice() : scratch_(nullptr), semaphore_(nullptr), context_(nullptr) { Eigen::initializeDeviceProp(); } - ~EigenCudaStreamDevice() override {} + ~EigenGpuStreamDevice() override {} void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream, TfGpuId tf_gpu_id, ::tensorflow::Allocator* alloc, char* scratch) { @@ -101,7 +101,7 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface { context_ = context; scratch_ = scratch; semaphore_ = - reinterpret_cast(scratch + Eigen::kCudaScratchSize); + reinterpret_cast(scratch + Eigen::kGpuScratchSize); stream_ = cuda_stream; allocator_ = alloc; PlatformGpuId platform_gpu_id; @@ -185,7 +185,7 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface { mutable unsigned int* semaphore_; OpKernelContext* context_; - TF_DISALLOW_COPY_AND_ASSIGN(EigenCudaStreamDevice); + TF_DISALLOW_COPY_AND_ASSIGN(EigenGpuStreamDevice); }; // This factory helps to ensure that different GPU device objects that refer to @@ -292,7 +292,7 @@ Status BaseGPUDevice::InitScratchBuffers() { DCHECK(streams_[i]); if (scratch_.size() > i && scratch_[i]) continue; size_t scratch_buffer_size = - Eigen::kCudaScratchSize + sizeof(unsigned int); + Eigen::kGpuScratchSize + sizeof(unsigned int); void* scratch_buffer = gpu_allocator_->AllocateRaw( Allocator::kAllocatorAlignment, scratch_buffer_size); if (scratch_buffer == nullptr) { @@ -304,7 +304,7 @@ Status BaseGPUDevice::InitScratchBuffers() { se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size)); bool ok = executor_->SynchronousMemZero( - &mem, Eigen::kCudaScratchSize + sizeof(unsigned int)); + &mem, Eigen::kGpuScratchSize + sizeof(unsigned int)); if (!ok) { return errors::FailedPrecondition( "Failed to memcopy into scratch buffer for device ", @@ -692,7 +692,7 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice { const Eigen::GpuDevice& device() const override { return device_; } private: - EigenCudaStreamDevice stream_device_; + EigenGpuStreamDevice stream_device_; Eigen::GpuDevice device_; }; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 76e5c989fca..0e552092385 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -311,8 +311,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {"Square", EIGEN_COST(scalar_square_op)}, {"Tanh", EIGEN_COST(scalar_tanh_op)}, {"Relu", EIGEN_COST(scalar_max_op)}, - {"Sigmoid", EIGEN_COST(scalar_sigmoid_op)}, - {"QuantizedSigmoid", EIGEN_COST(scalar_sigmoid_op)}, + {"Sigmoid", EIGEN_COST(scalar_logistic_op)}, + {"QuantizedSigmoid", EIGEN_COST(scalar_logistic_op)}, {"Sign", EIGEN_COST(scalar_sign_op)}, {"Sin", EIGEN_COST(scalar_sin_op)}, {"Tan", EIGEN_COST(scalar_tan_op)}, diff --git a/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc b/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc index 87bdba14550..f9f10c1b42f 100644 --- a/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc +++ b/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc @@ -60,9 +60,9 @@ template struct CheckNumericsLaunch { void Run(const GPUDevice &d, const T *data, int size, int abnormal_detected[2]) { - const int32 block_size = d.maxCudaThreadsPerBlock(); + const int32 block_size = d.maxGpuThreadsPerBlock(); const int32 num_blocks = - (d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) / + (d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) / block_size; CheckNumericsKernel<<>>( diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 66ba827a901..3f7aa0dc399 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -656,7 +656,7 @@ template struct erfc : base> {}; template -struct sigmoid : base> {}; +struct sigmoid : base> {}; template struct sin : base> {}; diff --git a/tensorflow/core/kernels/deep_conv2d.cc b/tensorflow/core/kernels/deep_conv2d.cc index 1aa8c72d667..f9c8f16cb9a 100644 --- a/tensorflow/core/kernels/deep_conv2d.cc +++ b/tensorflow/core/kernels/deep_conv2d.cc @@ -500,8 +500,9 @@ class GemmFilterPacker { typedef Eigen::internal::const_blas_data_mapper LhsMapper; typedef Eigen::internal::gebp_traits Traits; - Eigen::internal::gemm_pack_lhs + Eigen::internal::gemm_pack_lhs< + T, int64, LhsMapper, Traits::mr, Traits::LhsProgress, + typename Traits::LhsPacket4Packing, Eigen::RowMajor> pack_lhs; GemmFilterPacker(const int64 rows, const int64 depth, const T* lhs_input, diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 76afd6f18c2..1398c876625 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -764,7 +764,7 @@ Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args, const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 || kKnownDepthMultiplier < 0 ? std::numeric_limits::max() - : device.getNumCudaMultiProcessors(); + : device.getNumGpuMultiProcessors(); kernel<<>>(args, input, filter, output, num_outputs); diff --git a/tensorflow/core/kernels/random_op_gpu.cu.cc b/tensorflow/core/kernels/random_op_gpu.cu.cc index 3393b39faf4..edb2b10e3d6 100644 --- a/tensorflow/core/kernels/random_op_gpu.cu.cc +++ b/tensorflow/core/kernels/random_op_gpu.cu.cc @@ -217,9 +217,9 @@ void FillPhiloxRandom::operator()( OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen, typename Distribution::ResultElementType* data, int64 size, Distribution dist) { - const int32 block_size = d.maxCudaThreadsPerBlock(); + const int32 block_size = d.maxGpuThreadsPerBlock(); const int32 num_blocks = - (d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) / + (d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) / block_size; FillPhiloxRandomKernelLaunch diff --git a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc index 732ed33ede1..2b035ab0e9c 100644 --- a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc +++ b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc @@ -131,7 +131,7 @@ class CudaLaunchConfigTest : public ::testing::Test { protected: const int bufsize = 1024; int* outbuf = nullptr; - Eigen::CudaStreamDevice stream; + Eigen::GpuStreamDevice stream; Eigen::GpuDevice d = Eigen::GpuDevice(&stream); virtual void SetUp() { diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/cuda_launch_config.h index d0d95736d3f..080d4067cec 100644 --- a/tensorflow/core/util/cuda_launch_config.h +++ b/tensorflow/core/util/cuda_launch_config.h @@ -128,12 +128,12 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, CudaLaunchConfig config; const int virtual_thread_count = work_element_count; const int physical_thread_count = std::min( - d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(), + d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(), virtual_thread_count); - const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock()); + const int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock()); const int block_count = std::min(DivUp(physical_thread_count, thread_per_block), - d.getNumCudaMultiProcessors()); + d.getNumGpuMultiProcessors()); config.virtual_thread_count = virtual_thread_count; config.thread_per_block = thread_per_block; @@ -184,7 +184,7 @@ inline CudaLaunchConfig GetCudaLaunchConfigFixedBlockSize( cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &block_count, func, fixed_block_size, dynamic_shared_memory_size); CHECK_EQ(err, cudaSuccess); - block_count = std::min(block_count * d.getNumCudaMultiProcessors(), + block_count = std::min(block_count * d.getNumGpuMultiProcessors(), DivUp(work_element_count, fixed_block_size)); config.virtual_thread_count = work_element_count; @@ -213,7 +213,7 @@ inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim, int block_rows = std::max(kThreadsPerBlock / block_cols, 1); const int physical_thread_count = - d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(); + d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(); const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1); diff --git a/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h index f71ddbf3220..6461a5e5426 100644 --- a/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h +++ b/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -12,25 +12,55 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ -#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ -#define EIGEN_USE_CUSTOM_THREAD_POOL -#define EIGEN_USE_THREADS +// This is essentially unsupported/CXX11/Eigen/Tensor.h +// TODO(petewarden) - move this to a common location in Eigen itself. // clang-format off -#include +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ + + +#include "Eigen/Core" + +#if defined(EIGEN_USE_SYCL) +#undef min +#undef max +#undef isnan +#undef isinf +#undef isfinite +#include +#include +#include +#include +#include +#endif +#include #include #include -#include + + + + + +#ifdef _WIN32 +typedef __int16 int16_t; +typedef unsigned __int16 uint16_t; +typedef __int32 int32_t; +typedef unsigned __int32 uint32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#include +#else +#include +#include +#endif + +#if __cplusplus > 199711 || EIGEN_COMP_MSVC >= 1900 #include -#include -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) -#include +#endif #ifdef _WIN32 #include @@ -40,58 +70,53 @@ limitations under the License. #include #endif +// #if defined(EIGEN_USE_LIBXSMM) +// #include "libxsmm.h" +// #endif -// Because some programs may link Eigen in through other frameworks with -// different flags, we can run into multiple definition issues if we don't have -// a private namespace for our versions. This is a nasty hack, but a similar -// approach is used elsewhere to handle the problem, so it should be stable. -#define Eigen EigenForTFLite +#ifdef EIGEN_USE_THREADS +#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool" +#endif -#include "Eigen/src/Core/util/StaticAssert.h" -#include "unsupported/Eigen/CXX11/Core" -#include "unsupported/Eigen/SpecialFunctions" #include "Eigen/src/Core/util/DisableStupidWarnings.h" -#include "Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/SpecialFunctions" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/util/CXX11Meta.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/util/MaxSizeVector.h" -// Beware: the order of the include matters to some compilers. For example -// TensorIndexList.h should be included before TensorDimensions.h in order to -// use index lists to encode tensor dimensions when compiling with llvm. -// We're defining this ourselves rather than using the Eigen Tensor header file -// so that we can alter the macro definition of TENSOR_CONTRACTION_DISPATCH to -// reduce binary size. + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/ThreadPoolInterface.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorNonBlockingThreadPool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceGpu.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" - -#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStats.h" - #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" - +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" - #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMappers.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h" + #undef TENSOR_CONTRACTION_DISPATCH #define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ if (this->m_lhs_inner_dim_contiguous && \ @@ -102,8 +127,9 @@ limitations under the License. eigen_assert(false && "Unsupported contraction formats"); \ } + #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" @@ -125,19 +151,18 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h" - +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" - #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" - -#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" - #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" #include "Eigen/src/Core/util/ReenableStupidWarnings.h" + + #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ diff --git a/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h index 5e83b7b846e..f5576fbff70 100644 --- a/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h +++ b/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h @@ -94,7 +94,7 @@ typedef unsigned __int64 uint64_t; #include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h" -#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceGpu.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" @@ -106,10 +106,11 @@ typedef unsigned __int64 uint64_t; #include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" -#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h" @@ -128,7 +129,7 @@ typedef unsigned __int64 uint64_t; #include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" -#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 6f7031b36d2..00be5a9db83 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -3151,12 +3151,12 @@ inline void LstmCell( // Combined memory state and final output calculation gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput"); output_state_map = - input_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op()) * new_input_sm.tanh() + - forget_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op()) * prev_state_map; output_activ_map = - output_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op()) * output_state_map.tanh(); } @@ -4367,7 +4367,7 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data, auto input_map = MapAsVector(input_data, input_shape); auto output_map = MapAsVector(output_data, output_shape); output_map.array() = - input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op()); + input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op()); } // Convenience version that allows, for example, generated-code calls to be diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 781c3f0a18a..6c93f977291 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1570,7 +1570,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): self.evaluate(variables.global_variables_initializer()) - self.assertAllEqual([[[[4.0]]]], self.evaluate(y)) + self.assertAllClose([[[[4.0]]]], self.evaluate(y)) # Remove reference cycles in model test_util.dismantle_polymorphic_function(model) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 7b3e17fbb98..18a0ba6b197 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -134,11 +134,12 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), - sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9", - strip_prefix = "eigen-eigen-fd6845384b86", + patch_file = clean_dep("//third_party:eigen_reshaped.patch"), + sha256 = "d66cec3b54b3dfaa4666c1d49481a7197f93fc078cd53c54e2b4a8893a529c9f", + strip_prefix = "eigen-eigen-b4890dc6bc34", urls = [ - "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz", - "https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz", + "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/b4890dc6bc34.tar.gz", + "https://bitbucket.org/eigen/eigen/get/b4890dc6bc34.tar.gz", ], ) diff --git a/third_party/eigen.BUILD b/third_party/eigen.BUILD index 759f8a9be92..194a2272d54 100644 --- a/third_party/eigen.BUILD +++ b/third_party/eigen.BUILD @@ -65,6 +65,7 @@ cc_library( # code. We use it, but we do not rely on it, as evidenced above. "EIGEN_MPL2_ONLY", "EIGEN_MAX_ALIGN_BYTES=64", + "EIGEN_HAS_TYPE_TRAITS=0", ], includes = ["."], visibility = ["//visibility:public"], diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h index 5ab36649187..ff359cedced 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h @@ -249,9 +249,7 @@ EIGEN_STRONG_INLINE QInt32& operator/=(QInt32& a, const QInt32 b) { a.value /= b.value; return a; } -EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) { - return -a.value; -} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) { return -a.value; } // Scaling QInt32 by double. We do the arithmetic in double because // float only has 23 bits of mantissa, so casting QInt32 to float might reduce diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h index e6f4080ae12..8477933e1ba 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h @@ -15,11 +15,9 @@ namespace internal { // Accumulate the product of 2 QInt8 inputs on 32 bits to prevent // overflows -template<> struct scalar_product_traits -{ - enum { - Defined = 1 - }; +template <> +struct scalar_product_traits { + enum { Defined = 1 }; typedef QInt32 ReturnType; }; @@ -33,11 +31,9 @@ struct scalar_product_traits { // Accumulate the product of QInt8 inputs with QUint8 inputs on 32 bits // to prevent overflows -template<> struct scalar_product_traits -{ - enum { - Defined = 1 - }; +template <> +struct scalar_product_traits { + enum { Defined = 1 }; typedef QInt32 ReturnType; }; @@ -47,14 +43,16 @@ template<> struct scalar_product_traits // signed 8bit integers #ifndef EIGEN_USE_OPTIMIZED_INT8_INT8_MAT_MAT_PRODUCT -template -class gebp_traits -{ -public: +template +class gebp_traits { + public: typedef QInt8 LhsScalar; typedef QInt8 RhsScalar; typedef QInt32 ResScalar; + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + enum { // register block size along the M and N directions // One for the current implementation @@ -68,22 +66,24 @@ public: }; // The signed 8bit Mat-Mat product itself. -template -struct gebp_kernel -{ +template +struct gebp_kernel { EIGEN_DONT_INLINE - void operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); + void operator()(const DataMapper& res, const QInt8* blockA, + const QInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); }; -template -EIGEN_DONT_INLINE -void gebp_kernel -::operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA, Index strideB, Index offsetA, Index offsetB) -{ +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); @@ -113,18 +113,19 @@ void gebp_kernel -class gebp_traits -{ -public: +template +class gebp_traits { + public: typedef QInt8 LhsScalar; typedef QUInt8 RhsScalar; typedef QInt32 ResScalar; + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + enum { // register block size along the M and N directions // One for the current implementation @@ -138,22 +139,24 @@ public: }; // Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs -template -struct gebp_kernel -{ +template +struct gebp_kernel { EIGEN_DONT_INLINE - void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); + void operator()(const DataMapper& res, const QInt8* blockA, + const QUInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); }; -template -EIGEN_DONT_INLINE -void gebp_kernel -::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA, Index strideB, Index offsetA, Index offsetB) -{ +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); @@ -183,18 +186,19 @@ void gebp_kernel -class gebp_traits -{ -public: +template +class gebp_traits { + public: typedef QUInt8 LhsScalar; typedef QInt8 RhsScalar; typedef QInt32 ResScalar; + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + enum { // register block size along the M and N directions // One for the current implementation @@ -207,24 +211,25 @@ public: }; }; - // Mat-Mat product of an unsigned 8bit lhs with a signed 8bit rhs -template -struct gebp_kernel -{ +template +struct gebp_kernel { EIGEN_DONT_INLINE - void operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); + void operator()(const DataMapper& res, const QUInt8* blockA, + const QInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); }; -template -EIGEN_DONT_INLINE -void gebp_kernel -::operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA, Index strideB, Index offsetA, Index offsetB) -{ +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); @@ -263,6 +268,9 @@ class gebp_traits { typedef QInt16 RhsScalar; typedef QInt32 ResScalar; + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + enum { // register block size along the M and N directions // One for the current implementation diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h index 66532fb6002..8547dca1b32 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h @@ -28,6 +28,9 @@ class gebp_traits { typedef QInt16 RhsScalar; typedef QInt32 ResScalar; + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + enum { // Define register blocking scheme. nr = 16, @@ -43,7 +46,7 @@ class gebp_traits { // Used by TensorContractionThreadPool, inputs must have dimensions that are // multiples of 32. template -class TensorContractionBlocking { +class TensorContractionBlocking { public: TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) : kc_(((k + 15) / 16) * 16), @@ -144,7 +147,7 @@ class gemm_blocking_space -struct gemm_pack_lhs { EIGEN_DONT_INLINE void operator()(QInt16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, @@ -154,12 +157,14 @@ struct gemm_pack_lhs EIGEN_DONT_INLINE void gemm_pack_lhs:: + QInt16, ColMajor, Conjugate, PanelMode>:: operator()(QInt16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { eigen_assert(stride == 0); eigen_assert(offset == 0); + typedef typename packet_traits::type Packet; + // Use alternate function for weird sizes if (rows % 16 != 0 || depth % 16 != 0) { assert(false && @@ -178,10 +183,10 @@ operator()(QInt16* blockA, const DataMapper& lhs, Index depth, Index rows, // Pack depth in sets of 4 for (Index k = 0; k < depth; k += 4) { // Load vectors - __m256i L_A = lhs.loadPacket(m, k); - __m256i L_B = lhs.loadPacket(m, k + 1); - __m256i L_C = lhs.loadPacket(m, k + 2); - __m256i L_D = lhs.loadPacket(m, k + 3); + __m256i L_A = lhs.template loadPacket(m, k); + __m256i L_B = lhs.template loadPacket(m, k + 1); + __m256i L_C = lhs.template loadPacket(m, k + 2); + __m256i L_D = lhs.template loadPacket(m, k + 3); // Rearrange the inputs as required by the kernel __m256i L_AB0_AB7 = _mm256_unpacklo_epi16(L_A, L_B); @@ -236,13 +241,15 @@ struct gemm_pack_rhs -EIGEN_DONT_INLINE void -gemm_pack_rhs:: +EIGEN_DONT_INLINE void gemm_pack_rhs:: operator()(QInt16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { eigen_assert(stride == 0); eigen_assert(offset == 0); + typedef typename packet_traits::type Packet; + // Use alternate function for weird sizes if (cols % 16 != 0 || depth % 16 != 0) { assert(false && @@ -277,28 +284,28 @@ operator()(QInt16* blockB, const DataMapper& rhs, Index depth, Index cols, for (Index n = 0; n < cols; n += 16) { // Pack depth in sets of 16 for (Index k = 0; k < depth; k += 16) { - __m256i R_A = rhs.loadPacket(k, n); - __m256i R_B = rhs.loadPacket(k, n + 1); - __m256i R_C = rhs.loadPacket(k, n + 2); - __m256i R_D = rhs.loadPacket(k, n + 3); + __m256i R_A = rhs.template loadPacket(k, n); + __m256i R_B = rhs.template loadPacket(k, n + 1); + __m256i R_C = rhs.template loadPacket(k, n + 2); + __m256i R_D = rhs.template loadPacket(k, n + 3); PACK_STEP; - R_A = rhs.loadPacket(k, n + 4); - R_B = rhs.loadPacket(k, n + 5); - R_C = rhs.loadPacket(k, n + 6); - R_D = rhs.loadPacket(k, n + 7); + R_A = rhs.template loadPacket(k, n + 4); + R_B = rhs.template loadPacket(k, n + 5); + R_C = rhs.template loadPacket(k, n + 6); + R_D = rhs.template loadPacket(k, n + 7); PACK_STEP; - R_A = rhs.loadPacket(k, n + 8); - R_B = rhs.loadPacket(k, n + 9); - R_C = rhs.loadPacket(k, n + 10); - R_D = rhs.loadPacket(k, n + 11); + R_A = rhs.template loadPacket(k, n + 8); + R_B = rhs.template loadPacket(k, n + 9); + R_C = rhs.template loadPacket(k, n + 10); + R_D = rhs.template loadPacket(k, n + 11); PACK_STEP; - R_A = rhs.loadPacket(k, n + 12); - R_B = rhs.loadPacket(k, n + 13); - R_C = rhs.loadPacket(k, n + 14); - R_D = rhs.loadPacket(k, n + 15); + R_A = rhs.template loadPacket(k, n + 12); + R_B = rhs.template loadPacket(k, n + 13); + R_C = rhs.template loadPacket(k, n + 14); + R_D = rhs.template loadPacket(k, n + 15); PACK_STEP; blockB_256 += 12; @@ -476,9 +483,13 @@ operator()(const DataMapper& res, const QInt16* blockA, const QInt16* blockB, for (Index j = n; j < n + 16; j++) { LinearMapper r0 = res.getLinearMapper(m, j); LinearMapper r1 = res.getLinearMapper(m + 8, j); - - r0.storePacket(0, _mm256_add_epi32(blockO_256[i++], r0.loadPacket(0))); - r1.storePacket(0, _mm256_add_epi32(blockO_256[i++], r1.loadPacket(0))); + typedef typename packet_traits::type Packet; + r0.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r0.template loadPacket(0))); + r1.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r1.template loadPacket(0))); } // Zero the result block so it can be reused @@ -496,14 +507,16 @@ operator()(const DataMapper& res, const QInt16* blockA, const QInt16* blockB, #ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT // Define quantized traits -template -class gebp_traits -{ -public: +template +class gebp_traits { + public: typedef QInt8 LhsScalar; typedef QUInt8 RhsScalar; typedef QInt32 ResScalar; + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + enum { // Define register blocking scheme. nr = 32, @@ -518,22 +531,28 @@ public: // Specialized blocking for quantized implementations. // Used by TensorContractionThreadPool, inputs must have dimensions that are // multiples of 32. -template -class TensorContractionBlocking, TensorContractionInputMapper, Index, ShardingType> { +template +class TensorContractionBlocking< + ResScalar, + TensorContractionInputMapper< + QInt8, Index, Lhs, LeftTensor, left_nocontract_t, left_contract_t, 32, + left_inner_dim_contiguous, left_inner_dim_reordered, LeftAlignment>, + TensorContractionInputMapper, + Index, ShardingType> { public: - - typedef QInt8 LhsScalar; + typedef QInt8 LhsScalar; typedef QUInt8 RhsScalar; - TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) : - kc_(k), mc_(m), nc_(n) - { + TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) + : kc_(k), mc_(m), nc_(n) { eigen_assert(m % 32 == 0); eigen_assert(k % 32 == 0); if (!k || !m || !n) { @@ -543,8 +562,7 @@ class TensorContractionBlocking class gemm_blocking_space @@ -633,42 +650,60 @@ class gemm_blocking_space +template struct gemm_pack_lhs_any; -template -struct gemm_pack_lhs_any { - EIGEN_DONT_INLINE void operator() - (QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0); +template +struct gemm_pack_lhs_any { + EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs, + Index depth, Index rows, Index stride = 0, + Index offset = 0); }; -template +template struct gemm_pack_rhs_any; -template -struct gemm_pack_rhs_any { - EIGEN_DONT_INLINE void operator() - (QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0); +template +struct gemm_pack_rhs_any { + EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0); }; -template +template struct gebp_kernel_any; -template -struct gebp_kernel_any -{ +template +struct gebp_kernel_any { typedef typename DataMapper::LinearMapper LinearMapper; EIGEN_DONT_INLINE - void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); + void operator()(const DataMapper& res, const QInt8* blockA, + const QUInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); }; // Alternate implementations for any input sizes -template -EIGEN_DONT_INLINE void gemm_pack_lhs_any:: -operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { +template +EIGEN_DONT_INLINE void gemm_pack_lhs_any:: +operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, + Index stride, Index offset) { eigen_assert(stride == 0); eigen_assert(offset == 0); + typedef typename packet_traits::type Packet; + // Get vector pointer __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA); @@ -690,15 +725,15 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index // Pack depth in sets of 8 for (Index k = 0; k < depth_8; k += 8) { // Load vectors - __m256i L_A = lhs.loadPacket(m, k); - __m256i L_B = lhs.loadPacket(m, k + 1); + __m256i L_A = lhs.template loadPacket(m, k); + __m256i L_B = lhs.template loadPacket(m, k + 1); // Interleave 8-bit elements __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); - __m256i L_C = lhs.loadPacket(m, k + 2); - __m256i L_D = lhs.loadPacket(m, k + 3); + __m256i L_C = lhs.template loadPacket(m, k + 2); + __m256i L_D = lhs.template loadPacket(m, k + 3); __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); @@ -719,12 +754,12 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index _mm256_store_si256(blockA_256++, L_AD16); __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); _mm256_store_si256(blockA_256++, L_AD24); - __m256i L_E = lhs.loadPacket(m, k + 4); - __m256i L_F = lhs.loadPacket(m, k + 5); + __m256i L_E = lhs.template loadPacket(m, k + 4); + __m256i L_F = lhs.template loadPacket(m, k + 5); __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); - __m256i L_G = lhs.loadPacket(m, k + 6); - __m256i L_H = lhs.loadPacket(m, k + 7); + __m256i L_G = lhs.template loadPacket(m, k + 6); + __m256i L_H = lhs.template loadPacket(m, k + 7); __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); @@ -745,76 +780,76 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index if (depth_8 < depth) { __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H; switch (depth - depth_8) { - case 1: - L_A = lhs.loadPacket(m, depth_8); - L_B = _mm256_setzero_si256(); - L_C = _mm256_setzero_si256(); - L_D = _mm256_setzero_si256(); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - break; - case 2: - L_A = lhs.loadPacket(m, depth_8); - L_B = lhs.loadPacket(m, depth_8 + 1); - L_C = _mm256_setzero_si256(); - L_D = _mm256_setzero_si256(); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - break; - case 3: - L_A = lhs.loadPacket(m, depth_8); - L_B = lhs.loadPacket(m, depth_8 + 1); - L_C = lhs.loadPacket(m, depth_8 + 2); - L_D = _mm256_setzero_si256(); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - break; - case 4: - L_A = lhs.loadPacket(m, depth_8); - L_B = lhs.loadPacket(m, depth_8 + 1); - L_C = lhs.loadPacket(m, depth_8 + 2); - L_D = lhs.loadPacket(m, depth_8 + 3); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - break; - case 5: - L_A = lhs.loadPacket(m, depth_8); - L_B = lhs.loadPacket(m, depth_8 + 1); - L_C = lhs.loadPacket(m, depth_8 + 2); - L_D = lhs.loadPacket(m, depth_8 + 3); - L_E = lhs.loadPacket(m, depth_8 + 4); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - break; - case 6: - L_A = lhs.loadPacket(m, depth_8); - L_B = lhs.loadPacket(m, depth_8 + 1); - L_C = lhs.loadPacket(m, depth_8 + 2); - L_D = lhs.loadPacket(m, depth_8 + 3); - L_E = lhs.loadPacket(m, depth_8 + 4); - L_F = lhs.loadPacket(m, depth_8 + 5); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - break; - case 7: - L_A = lhs.loadPacket(m, depth_8); - L_B = lhs.loadPacket(m, depth_8 + 1); - L_C = lhs.loadPacket(m, depth_8 + 2); - L_D = lhs.loadPacket(m, depth_8 + 3); - L_E = lhs.loadPacket(m, depth_8 + 4); - L_F = lhs.loadPacket(m, depth_8 + 5); - L_G = lhs.loadPacket(m, depth_8 + 6); - L_H = _mm256_setzero_si256(); - break; + case 1: + L_A = lhs.template loadPacket(m, depth_8); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 2: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 3: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = lhs.template loadPacket(m, depth_8 + 2); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 4: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = lhs.template loadPacket(m, depth_8 + 2); + L_D = lhs.template loadPacket(m, depth_8 + 3); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 5: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = lhs.template loadPacket(m, depth_8 + 2); + L_D = lhs.template loadPacket(m, depth_8 + 3); + L_E = lhs.template loadPacket(m, depth_8 + 4); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 6: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = lhs.template loadPacket(m, depth_8 + 2); + L_D = lhs.template loadPacket(m, depth_8 + 3); + L_E = lhs.template loadPacket(m, depth_8 + 4); + L_F = lhs.template loadPacket(m, depth_8 + 5); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 7: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = lhs.template loadPacket(m, depth_8 + 2); + L_D = lhs.template loadPacket(m, depth_8 + 3); + L_E = lhs.template loadPacket(m, depth_8 + 4); + L_F = lhs.template loadPacket(m, depth_8 + 5); + L_G = lhs.template loadPacket(m, depth_8 + 6); + L_H = _mm256_setzero_si256(); + break; } // Interleave 8-bit elements @@ -875,21 +910,21 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index __m256i L_G = _mm256_setzero_si256(); __m256i L_H = _mm256_setzero_si256(); for (Index m = 0; m < rows - rows_32; m++) { - QInt8* ptr = (QInt8*) &L_A; + QInt8* ptr = (QInt8*)&L_A; ptr[m] = lhs(rows_32 + m, k); - ptr = (QInt8*) &L_B; + ptr = (QInt8*)&L_B; ptr[m] = lhs(rows_32 + m, k + 1); - ptr = (QInt8*) &L_C; + ptr = (QInt8*)&L_C; ptr[m] = lhs(rows_32 + m, k + 2); - ptr = (QInt8*) &L_D; + ptr = (QInt8*)&L_D; ptr[m] = lhs(rows_32 + m, k + 3); - ptr = (QInt8*) &L_E; + ptr = (QInt8*)&L_E; ptr[m] = lhs(rows_32 + m, k + 4); - ptr = (QInt8*) &L_F; + ptr = (QInt8*)&L_F; ptr[m] = lhs(rows_32 + m, k + 5); - ptr = (QInt8*) &L_G; + ptr = (QInt8*)&L_G; ptr[m] = lhs(rows_32 + m, k + 6); - ptr = (QInt8*) &L_H; + ptr = (QInt8*)&L_H; ptr[m] = lhs(rows_32 + m, k + 7); } @@ -939,146 +974,146 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H; QInt8* ptr; switch (depth - depth_8) { - case 1: - L_A = _mm256_setzero_si256(); - L_B = _mm256_setzero_si256(); - L_C = _mm256_setzero_si256(); - L_D = _mm256_setzero_si256(); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - for (Index m = 0; m < rows - rows_32; m++) { - QInt8* ptr = (QInt8*) &L_A; - ptr[m] = lhs(rows_32 + m, depth_8); - } - break; - case 2: - L_A = _mm256_setzero_si256(); - L_B = _mm256_setzero_si256(); - L_C = _mm256_setzero_si256(); - L_D = _mm256_setzero_si256(); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - for (Index m = 0; m < rows - rows_32; m++) { - ptr = (QInt8*) &L_A; - ptr[m] = lhs(rows_32 + m, depth_8); - ptr = (QInt8*) &L_B; - ptr[m] = lhs(rows_32 + m, depth_8 + 1); - } - break; - case 3: - L_A = _mm256_setzero_si256(); - L_B = _mm256_setzero_si256(); - L_C = _mm256_setzero_si256(); - L_D = _mm256_setzero_si256(); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - for (Index m = 0; m < rows - rows_32; m++) { - ptr = (QInt8*) &L_A; - ptr[m] = lhs(rows_32 + m, depth_8); - ptr = (QInt8*) &L_B; - ptr[m] = lhs(rows_32 + m, depth_8 + 1); - ptr = (QInt8*) &L_C; - ptr[m] = lhs(rows_32 + m, depth_8 + 2); - } - break; - case 4: - L_A = _mm256_setzero_si256(); - L_B = _mm256_setzero_si256(); - L_C = _mm256_setzero_si256(); - L_D = _mm256_setzero_si256(); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - for (Index m = 0; m < rows - rows_32; m++) { - ptr = (QInt8*) &L_A; - ptr[m] = lhs(rows_32 + m, depth_8); - ptr = (QInt8*) &L_B; - ptr[m] = lhs(rows_32 + m, depth_8 + 1); - ptr = (QInt8*) &L_C; - ptr[m] = lhs(rows_32 + m, depth_8 + 2); - ptr = (QInt8*) &L_D; - ptr[m] = lhs(rows_32 + m, depth_8 + 3); - } - break; - case 5: - L_A = _mm256_setzero_si256(); - L_B = _mm256_setzero_si256(); - L_C = _mm256_setzero_si256(); - L_D = _mm256_setzero_si256(); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - for (Index m = 0; m < rows - rows_32; m++) { - ptr = (QInt8*) &L_A; - ptr[m] = lhs(rows_32 + m, depth_8); - ptr = (QInt8*) &L_B; - ptr[m] = lhs(rows_32 + m, depth_8 + 1); - ptr = (QInt8*) &L_C; - ptr[m] = lhs(rows_32 + m, depth_8 + 2); - ptr = (QInt8*) &L_D; - ptr[m] = lhs(rows_32 + m, depth_8 + 3); - ptr = (QInt8*) &L_E; - ptr[m] = lhs(rows_32 + m, depth_8 + 4); - } - break; - case 6: - L_A = _mm256_setzero_si256(); - L_B = _mm256_setzero_si256(); - L_C = _mm256_setzero_si256(); - L_D = _mm256_setzero_si256(); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - for (Index m = 0; m < rows - rows_32; m++) { - ptr = (QInt8*) &L_A; - ptr[m] = lhs(rows_32 + m, depth_8); - ptr = (QInt8*) &L_B; - ptr[m] = lhs(rows_32 + m, depth_8 + 1); - ptr = (QInt8*) &L_C; - ptr[m] = lhs(rows_32 + m, depth_8 + 2); - ptr = (QInt8*) &L_D; - ptr[m] = lhs(rows_32 + m, depth_8 + 3); - ptr = (QInt8*) &L_E; - ptr[m] = lhs(rows_32 + m, depth_8 + 4); - ptr = (QInt8*) &L_F; - ptr[m] = lhs(rows_32 + m, depth_8 + 5); - } - break; - case 7: - L_A = _mm256_setzero_si256(); - L_B = _mm256_setzero_si256(); - L_C = _mm256_setzero_si256(); - L_D = _mm256_setzero_si256(); - L_E = _mm256_setzero_si256(); - L_F = _mm256_setzero_si256(); - L_G = _mm256_setzero_si256(); - L_H = _mm256_setzero_si256(); - for (Index m = 0; m < rows - rows_32; m++) { - ptr = (QInt8*) &L_A; - ptr[m] = lhs(rows_32 + m, depth_8); - ptr = (QInt8*) &L_B; - ptr[m] = lhs(rows_32 + m, depth_8 + 1); - ptr = (QInt8*) &L_C; - ptr[m] = lhs(rows_32 + m, depth_8 + 2); - ptr = (QInt8*) &L_D; - ptr[m] = lhs(rows_32 + m, depth_8 + 3); - ptr = (QInt8*) &L_E; - ptr[m] = lhs(rows_32 + m, depth_8 + 4); - ptr = (QInt8*) &L_F; - ptr[m] = lhs(rows_32 + m, depth_8 + 5); - ptr = (QInt8*) &L_G; - ptr[m] = lhs(rows_32 + m, depth_8 + 6); - } - break; + case 1: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + QInt8* ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + } + break; + case 2: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + } + break; + case 3: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + } + break; + case 4: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*)&L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + } + break; + case 5: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*)&L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + ptr = (QInt8*)&L_E; + ptr[m] = lhs(rows_32 + m, depth_8 + 4); + } + break; + case 6: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*)&L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + ptr = (QInt8*)&L_E; + ptr[m] = lhs(rows_32 + m, depth_8 + 4); + ptr = (QInt8*)&L_F; + ptr[m] = lhs(rows_32 + m, depth_8 + 5); + } + break; + case 7: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*)&L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + ptr = (QInt8*)&L_E; + ptr[m] = lhs(rows_32 + m, depth_8 + 4); + ptr = (QInt8*)&L_F; + ptr[m] = lhs(rows_32 + m, depth_8 + 5); + ptr = (QInt8*)&L_G; + ptr[m] = lhs(rows_32 + m, depth_8 + 6); + } + break; } // Interleave 8-bit elements @@ -1124,12 +1159,17 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index } } -template -EIGEN_DONT_INLINE void gemm_pack_rhs_any:: -operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { +template +EIGEN_DONT_INLINE void gemm_pack_rhs_any:: +operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, + Index stride, Index offset) { eigen_assert(stride == 0); eigen_assert(offset == 0); + typedef typename packet_traits::type Packet; + // Get vector pointer __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB); @@ -1158,52 +1198,52 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index for (Index n = 0; n < cols_32; n += 32) { // Pack depth in sets of 32 for (Index k = 0; k < depth_32; k += 32) { - __m256i R_A = rhs.loadPacket(k, n); - __m256i R_B = rhs.loadPacket(k, n + 1); - __m256i R_C = rhs.loadPacket(k, n + 2); - __m256i R_D = rhs.loadPacket(k, n + 3); + __m256i R_A = rhs.template loadPacket(k, n); + __m256i R_B = rhs.template loadPacket(k, n + 1); + __m256i R_C = rhs.template loadPacket(k, n + 2); + __m256i R_D = rhs.template loadPacket(k, n + 3); PACK_STEP; - R_A = rhs.loadPacket(k, n + 4); - R_B = rhs.loadPacket(k, n + 5); - R_C = rhs.loadPacket(k, n + 6); - R_D = rhs.loadPacket(k, n + 7); + R_A = rhs.template loadPacket(k, n + 4); + R_B = rhs.template loadPacket(k, n + 5); + R_C = rhs.template loadPacket(k, n + 6); + R_D = rhs.template loadPacket(k, n + 7); PACK_STEP; - R_A = rhs.loadPacket(k, n + 8); - R_B = rhs.loadPacket(k, n + 9); - R_C = rhs.loadPacket(k, n + 10); - R_D = rhs.loadPacket(k, n + 11); + R_A = rhs.template loadPacket(k, n + 8); + R_B = rhs.template loadPacket(k, n + 9); + R_C = rhs.template loadPacket(k, n + 10); + R_D = rhs.template loadPacket(k, n + 11); PACK_STEP; - R_A = rhs.loadPacket(k, n + 12); - R_B = rhs.loadPacket(k, n + 13); - R_C = rhs.loadPacket(k, n + 14); - R_D = rhs.loadPacket(k, n + 15); + R_A = rhs.template loadPacket(k, n + 12); + R_B = rhs.template loadPacket(k, n + 13); + R_C = rhs.template loadPacket(k, n + 14); + R_D = rhs.template loadPacket(k, n + 15); PACK_STEP; - R_A = rhs.loadPacket(k, n + 16); - R_B = rhs.loadPacket(k, n + 17); - R_C = rhs.loadPacket(k, n + 18); - R_D = rhs.loadPacket(k, n + 19); + R_A = rhs.template loadPacket(k, n + 16); + R_B = rhs.template loadPacket(k, n + 17); + R_C = rhs.template loadPacket(k, n + 18); + R_D = rhs.template loadPacket(k, n + 19); PACK_STEP; - R_A = rhs.loadPacket(k, n + 20); - R_B = rhs.loadPacket(k, n + 21); - R_C = rhs.loadPacket(k, n + 22); - R_D = rhs.loadPacket(k, n + 23); + R_A = rhs.template loadPacket(k, n + 20); + R_B = rhs.template loadPacket(k, n + 21); + R_C = rhs.template loadPacket(k, n + 22); + R_D = rhs.template loadPacket(k, n + 23); PACK_STEP; - R_A = rhs.loadPacket(k, n + 24); - R_B = rhs.loadPacket(k, n + 25); - R_C = rhs.loadPacket(k, n + 26); - R_D = rhs.loadPacket(k, n + 27); + R_A = rhs.template loadPacket(k, n + 24); + R_B = rhs.template loadPacket(k, n + 25); + R_C = rhs.template loadPacket(k, n + 26); + R_D = rhs.template loadPacket(k, n + 27); PACK_STEP; - R_A = rhs.loadPacket(k, n + 28); - R_B = rhs.loadPacket(k, n + 29); - R_C = rhs.loadPacket(k, n + 30); - R_D = rhs.loadPacket(k, n + 31); + R_A = rhs.template loadPacket(k, n + 28); + R_B = rhs.template loadPacket(k, n + 29); + R_C = rhs.template loadPacket(k, n + 30); + R_D = rhs.template loadPacket(k, n + 31); PACK_STEP; blockB_256 += 24; @@ -1216,13 +1256,13 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index __m256i R_C = _mm256_setzero_si256(); __m256i R_D = _mm256_setzero_si256(); for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; + ptr = (QUInt8*)&R_A; ptr[k - depth_32] = rhs(k, n); - ptr = (QUInt8*) &R_B; + ptr = (QUInt8*)&R_B; ptr[k - depth_32] = rhs(k, n + 1); - ptr = (QUInt8*) &R_C; + ptr = (QUInt8*)&R_C; ptr[k - depth_32] = rhs(k, n + 2); - ptr = (QUInt8*) &R_D; + ptr = (QUInt8*)&R_D; ptr[k - depth_32] = rhs(k, n + 3); } PACK_STEP; @@ -1232,13 +1272,13 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index R_C = _mm256_setzero_si256(); R_D = _mm256_setzero_si256(); for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; + ptr = (QUInt8*)&R_A; ptr[k - depth_32] = rhs(k, n + 4); - ptr = (QUInt8*) &R_B; + ptr = (QUInt8*)&R_B; ptr[k - depth_32] = rhs(k, n + 5); - ptr = (QUInt8*) &R_C; + ptr = (QUInt8*)&R_C; ptr[k - depth_32] = rhs(k, n + 6); - ptr = (QUInt8*) &R_D; + ptr = (QUInt8*)&R_D; ptr[k - depth_32] = rhs(k, n + 7); } PACK_STEP; @@ -1248,13 +1288,13 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index R_C = _mm256_setzero_si256(); R_D = _mm256_setzero_si256(); for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; + ptr = (QUInt8*)&R_A; ptr[k - depth_32] = rhs(k, n + 8); - ptr = (QUInt8*) &R_B; + ptr = (QUInt8*)&R_B; ptr[k - depth_32] = rhs(k, n + 9); - ptr = (QUInt8*) &R_C; + ptr = (QUInt8*)&R_C; ptr[k - depth_32] = rhs(k, n + 10); - ptr = (QUInt8*) &R_D; + ptr = (QUInt8*)&R_D; ptr[k - depth_32] = rhs(k, n + 11); } PACK_STEP; @@ -1264,13 +1304,13 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index R_C = _mm256_setzero_si256(); R_D = _mm256_setzero_si256(); for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; + ptr = (QUInt8*)&R_A; ptr[k - depth_32] = rhs(k, n + 12); - ptr = (QUInt8*) &R_B; + ptr = (QUInt8*)&R_B; ptr[k - depth_32] = rhs(k, n + 13); - ptr = (QUInt8*) &R_C; + ptr = (QUInt8*)&R_C; ptr[k - depth_32] = rhs(k, n + 14); - ptr = (QUInt8*) &R_D; + ptr = (QUInt8*)&R_D; ptr[k - depth_32] = rhs(k, n + 15); } PACK_STEP; @@ -1280,13 +1320,13 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index R_C = _mm256_setzero_si256(); R_D = _mm256_setzero_si256(); for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; + ptr = (QUInt8*)&R_A; ptr[k - depth_32] = rhs(k, n + 16); - ptr = (QUInt8*) &R_B; + ptr = (QUInt8*)&R_B; ptr[k - depth_32] = rhs(k, n + 17); - ptr = (QUInt8*) &R_C; + ptr = (QUInt8*)&R_C; ptr[k - depth_32] = rhs(k, n + 18); - ptr = (QUInt8*) &R_D; + ptr = (QUInt8*)&R_D; ptr[k - depth_32] = rhs(k, n + 19); } PACK_STEP; @@ -1296,13 +1336,13 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index R_C = _mm256_setzero_si256(); R_D = _mm256_setzero_si256(); for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; + ptr = (QUInt8*)&R_A; ptr[k - depth_32] = rhs(k, n + 20); - ptr = (QUInt8*) &R_B; + ptr = (QUInt8*)&R_B; ptr[k - depth_32] = rhs(k, n + 21); - ptr = (QUInt8*) &R_C; + ptr = (QUInt8*)&R_C; ptr[k - depth_32] = rhs(k, n + 22); - ptr = (QUInt8*) &R_D; + ptr = (QUInt8*)&R_D; ptr[k - depth_32] = rhs(k, n + 23); } PACK_STEP; @@ -1312,13 +1352,13 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index R_C = _mm256_setzero_si256(); R_D = _mm256_setzero_si256(); for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; + ptr = (QUInt8*)&R_A; ptr[k - depth_32] = rhs(k, n + 24); - ptr = (QUInt8*) &R_B; + ptr = (QUInt8*)&R_B; ptr[k - depth_32] = rhs(k, n + 25); - ptr = (QUInt8*) &R_C; + ptr = (QUInt8*)&R_C; ptr[k - depth_32] = rhs(k, n + 26); - ptr = (QUInt8*) &R_D; + ptr = (QUInt8*)&R_D; ptr[k - depth_32] = rhs(k, n + 27); } PACK_STEP; @@ -1328,13 +1368,13 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index R_C = _mm256_setzero_si256(); R_D = _mm256_setzero_si256(); for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; + ptr = (QUInt8*)&R_A; ptr[k - depth_32] = rhs(k, n + 28); - ptr = (QUInt8*) &R_B; + ptr = (QUInt8*)&R_B; ptr[k - depth_32] = rhs(k, n + 29); - ptr = (QUInt8*) &R_C; + ptr = (QUInt8*)&R_C; ptr[k - depth_32] = rhs(k, n + 30); - ptr = (QUInt8*) &R_D; + ptr = (QUInt8*)&R_D; ptr[k - depth_32] = rhs(k, n + 31); } PACK_STEP; @@ -1350,34 +1390,34 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index Index n; for (n = cols_32; n < cols; n += 4) { switch (cols - n) { - case 1: - R_A = rhs.loadPacket(k, n); - R_B = _mm256_setzero_si256(); - R_C = _mm256_setzero_si256(); - R_D = _mm256_setzero_si256(); - PACK_STEP; - break; - case 2: - R_A = rhs.loadPacket(k, n); - R_B = rhs.loadPacket(k, n + 1); - R_C = _mm256_setzero_si256(); - R_D = _mm256_setzero_si256(); - PACK_STEP; - break; - case 3: - R_A = rhs.loadPacket(k, n); - R_B = rhs.loadPacket(k, n + 1); - R_C = rhs.loadPacket(k, n + 2); - R_D = _mm256_setzero_si256(); - PACK_STEP; - break; - default: - R_A = rhs.loadPacket(k, n); - R_B = rhs.loadPacket(k, n + 1); - R_C = rhs.loadPacket(k, n + 2); - R_D = rhs.loadPacket(k, n + 3); - PACK_STEP; - break; + case 1: + R_A = rhs.template loadPacket(k, n); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + PACK_STEP; + break; + case 2: + R_A = rhs.template loadPacket(k, n); + R_B = rhs.template loadPacket(k, n + 1); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + PACK_STEP; + break; + case 3: + R_A = rhs.template loadPacket(k, n); + R_B = rhs.template loadPacket(k, n + 1); + R_C = rhs.template loadPacket(k, n + 2); + R_D = _mm256_setzero_si256(); + PACK_STEP; + break; + default: + R_A = rhs.template loadPacket(k, n); + R_B = rhs.template loadPacket(k, n + 1); + R_C = rhs.template loadPacket(k, n + 2); + R_D = rhs.template loadPacket(k, n + 3); + PACK_STEP; + break; } } @@ -1394,46 +1434,46 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index __m256i R_C = _mm256_setzero_si256(); __m256i R_D = _mm256_setzero_si256(); switch (cols - n) { - case 1: - for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; - ptr[k - depth_32] = rhs(k, n); - } - PACK_STEP; - break; - case 2: - for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; - ptr[k - depth_32] = rhs(k, n); - ptr = (QUInt8*) &R_B; - ptr[k - depth_32] = rhs(k, n + 1); - } - PACK_STEP; - break; - case 3: - for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; - ptr[k - depth_32] = rhs(k, n); - ptr = (QUInt8*) &R_B; - ptr[k - depth_32] = rhs(k, n + 1); - ptr = (QUInt8*) &R_C; - ptr[k - depth_32] = rhs(k, n + 2); - } - PACK_STEP; - break; - default: - for (Index k = depth_32; k < depth; k++) { - ptr = (QUInt8*) &R_A; - ptr[k - depth_32] = rhs(k, n); - ptr = (QUInt8*) &R_B; - ptr[k - depth_32] = rhs(k, n + 1); - ptr = (QUInt8*) &R_C; - ptr[k - depth_32] = rhs(k, n + 2); - ptr = (QUInt8*) &R_D; - ptr[k - depth_32] = rhs(k, n + 3); - } - PACK_STEP; - break; + case 1: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n); + } + PACK_STEP; + break; + case 2: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 1); + } + PACK_STEP; + break; + case 3: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 1); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 2); + } + PACK_STEP; + break; + default: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 1); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 2); + ptr = (QUInt8*)&R_D; + ptr[k - depth_32] = rhs(k, n + 3); + } + PACK_STEP; + break; } } } @@ -1441,13 +1481,13 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index #undef PACK_STEP } -template -EIGEN_DONT_INLINE -void gebp_kernel_any -::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA, Index strideB, Index offsetA, Index offsetB) -{ +template +EIGEN_DONT_INLINE void gebp_kernel_any:: +operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); eigen_assert(alpha.value == 1); @@ -1678,17 +1718,21 @@ void gebp_kernel_any::type Packet; + r0.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r0.template loadPacket(0))); + r1.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r1.template loadPacket(0))); + r2.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r2.template loadPacket(0))); + r3.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r3.template loadPacket(0))); } - } - else { + } else { for (Index j = n; j < cols; j++) { for (Index i = m; i < rows; i++) { res(i, j) = blockO[(j - n) * 32 + (i - m)]; @@ -1745,7 +1789,7 @@ void gebp_kernel_any -struct gemm_pack_lhs { EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, @@ -1755,15 +1799,18 @@ struct gemm_pack_lhs EIGEN_DONT_INLINE void gemm_pack_lhs:: + QInt8, ColMajor, Conjugate, PanelMode>:: operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { eigen_assert(stride == 0); eigen_assert(offset == 0); + typedef typename packet_traits::type Packet; + // Use alternate function for weird sizes if (rows % 32 != 0 || depth % 32 != 0) { - gemm_pack_lhs_any lhs_pack; + gemm_pack_lhs_any lhs_pack; return lhs_pack(blockA, lhs, depth, rows, stride, offset); } @@ -1775,15 +1822,15 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, // Pack depth in sets of 8 for (Index k = 0; k < depth; k += 8) { // Load vectors - __m256i L_A = lhs.loadPacket(m, k); - __m256i L_B = lhs.loadPacket(m, k + 1); + __m256i L_A = lhs.template loadPacket(m, k); + __m256i L_B = lhs.template loadPacket(m, k + 1); // Interleave 8-bit elements __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); - __m256i L_C = lhs.loadPacket(m, k + 2); - __m256i L_D = lhs.loadPacket(m, k + 3); + __m256i L_C = lhs.template loadPacket(m, k + 2); + __m256i L_D = lhs.template loadPacket(m, k + 3); __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); @@ -1804,12 +1851,12 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, _mm256_store_si256(blockA_256++, L_AD16); __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); _mm256_store_si256(blockA_256++, L_AD24); - __m256i L_E = lhs.loadPacket(m, k + 4); - __m256i L_F = lhs.loadPacket(m, k + 5); + __m256i L_E = lhs.template loadPacket(m, k + 4); + __m256i L_F = lhs.template loadPacket(m, k + 5); __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); - __m256i L_G = lhs.loadPacket(m, k + 6); - __m256i L_H = lhs.loadPacket(m, k + 7); + __m256i L_G = lhs.template loadPacket(m, k + 6); + __m256i L_H = lhs.template loadPacket(m, k + 7); __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); @@ -1868,9 +1915,12 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, eigen_assert(stride == 0); eigen_assert(offset == 0); + typedef typename packet_traits::type Packet; + // Use alternate function for weird sizes if (cols % 32 != 0 || depth % 32 != 0) { - gemm_pack_rhs_any rhs_pack; + gemm_pack_rhs_any rhs_pack; return rhs_pack(blockB, rhs, depth, cols, stride, offset); } @@ -1898,52 +1948,52 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, for (Index n = 0; n < cols; n += 32) { // Pack depth in sets of 32 for (Index k = 0; k < depth; k += 32) { - __m256i R_A = rhs.loadPacket(k, n); - __m256i R_B = rhs.loadPacket(k, n + 1); - __m256i R_C = rhs.loadPacket(k, n + 2); - __m256i R_D = rhs.loadPacket(k, n + 3); + __m256i R_A = rhs.template loadPacket(k, n); + __m256i R_B = rhs.template loadPacket(k, n + 1); + __m256i R_C = rhs.template loadPacket(k, n + 2); + __m256i R_D = rhs.template loadPacket(k, n + 3); PACK_STEP; - R_A = rhs.loadPacket(k, n + 4); - R_B = rhs.loadPacket(k, n + 5); - R_C = rhs.loadPacket(k, n + 6); - R_D = rhs.loadPacket(k, n + 7); + R_A = rhs.template loadPacket(k, n + 4); + R_B = rhs.template loadPacket(k, n + 5); + R_C = rhs.template loadPacket(k, n + 6); + R_D = rhs.template loadPacket(k, n + 7); PACK_STEP; - R_A = rhs.loadPacket(k, n + 8); - R_B = rhs.loadPacket(k, n + 9); - R_C = rhs.loadPacket(k, n + 10); - R_D = rhs.loadPacket(k, n + 11); + R_A = rhs.template loadPacket(k, n + 8); + R_B = rhs.template loadPacket(k, n + 9); + R_C = rhs.template loadPacket(k, n + 10); + R_D = rhs.template loadPacket(k, n + 11); PACK_STEP; - R_A = rhs.loadPacket(k, n + 12); - R_B = rhs.loadPacket(k, n + 13); - R_C = rhs.loadPacket(k, n + 14); - R_D = rhs.loadPacket(k, n + 15); + R_A = rhs.template loadPacket(k, n + 12); + R_B = rhs.template loadPacket(k, n + 13); + R_C = rhs.template loadPacket(k, n + 14); + R_D = rhs.template loadPacket(k, n + 15); PACK_STEP; - R_A = rhs.loadPacket(k, n + 16); - R_B = rhs.loadPacket(k, n + 17); - R_C = rhs.loadPacket(k, n + 18); - R_D = rhs.loadPacket(k, n + 19); + R_A = rhs.template loadPacket(k, n + 16); + R_B = rhs.template loadPacket(k, n + 17); + R_C = rhs.template loadPacket(k, n + 18); + R_D = rhs.template loadPacket(k, n + 19); PACK_STEP; - R_A = rhs.loadPacket(k, n + 20); - R_B = rhs.loadPacket(k, n + 21); - R_C = rhs.loadPacket(k, n + 22); - R_D = rhs.loadPacket(k, n + 23); + R_A = rhs.template loadPacket(k, n + 20); + R_B = rhs.template loadPacket(k, n + 21); + R_C = rhs.template loadPacket(k, n + 22); + R_D = rhs.template loadPacket(k, n + 23); PACK_STEP; - R_A = rhs.loadPacket(k, n + 24); - R_B = rhs.loadPacket(k, n + 25); - R_C = rhs.loadPacket(k, n + 26); - R_D = rhs.loadPacket(k, n + 27); + R_A = rhs.template loadPacket(k, n + 24); + R_B = rhs.template loadPacket(k, n + 25); + R_C = rhs.template loadPacket(k, n + 26); + R_D = rhs.template loadPacket(k, n + 27); PACK_STEP; - R_A = rhs.loadPacket(k, n + 28); - R_B = rhs.loadPacket(k, n + 29); - R_C = rhs.loadPacket(k, n + 30); - R_D = rhs.loadPacket(k, n + 31); + R_A = rhs.template loadPacket(k, n + 28); + R_B = rhs.template loadPacket(k, n + 29); + R_C = rhs.template loadPacket(k, n + 30); + R_D = rhs.template loadPacket(k, n + 31); PACK_STEP; blockB_256 += 24; @@ -1953,24 +2003,26 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, } // Perform the actual multiplication on packed inputs -template -struct gebp_kernel -{ +template +struct gebp_kernel { typedef typename DataMapper::LinearMapper LinearMapper; EIGEN_DONT_INLINE - void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); + void operator()(const DataMapper& res, const QInt8* blockA, + const QUInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); }; -template -EIGEN_DONT_INLINE -void gebp_kernel -::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA, Index strideB, Index offsetA, Index offsetB) -{ +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); eigen_assert(alpha.value == 1); @@ -1986,8 +2038,10 @@ void gebp_kernel gebp; - return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + gebp_kernel_any gebp; + return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, + offsetA, offsetB); } // Create result block @@ -2205,14 +2259,19 @@ void gebp_kernel::type Packet; + r0.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r0.template loadPacket(0))); + r1.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r1.template loadPacket(0))); + r2.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r2.template loadPacket(0))); + r3.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r3.template loadPacket(0))); } // Zero the result block so it can be reused diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h index 9cd31570231..9e0efae6c9b 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h @@ -14,15 +14,14 @@ namespace Eigen { namespace internal { - -// AVX2 optimized implementation of the case where the lhs is encoded using signed 8bit +// AVX2 optimized implementation of the case where the lhs is encoded using +// signed 8bit // integers and the rhs using unsigned 8bit integers. #ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT -template -class gebp_traits -{ -public: +template +class gebp_traits { + public: typedef QInt8 LhsScalar; typedef QUInt8 RhsScalar; typedef QInt32 ResScalar; @@ -40,22 +39,24 @@ public: }; // Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs -template -struct gebp_kernel -{ +template +struct gebp_kernel { EIGEN_DONT_INLINE - void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); + void operator()(const DataMapper& res, const QInt8* blockA, + const QUInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); }; -template -EIGEN_DONT_INLINE -void gebp_kernel -::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, - Index rows, Index depth, Index cols, QInt32 alpha, - Index strideA, Index strideB, Index offsetA, Index offsetB) -{ +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); @@ -85,7 +86,6 @@ void gebp_kernel -struct general_matrix_vector_product -{ -EIGEN_DONT_INLINE static void run( - Index rows, Index cols, - const LhsMapper& lhs, - const RhsMapper& rhs, - QInt32* res, Index resIncr, - QInt8 alpha); +template +struct general_matrix_vector_product { + EIGEN_DONT_INLINE static void run(Index rows, Index cols, + const LhsMapper& lhs, const RhsMapper& rhs, + QInt32* res, Index resIncr, QInt8 alpha); }; -template -EIGEN_DONT_INLINE void general_matrix_vector_product::run( - Index rows, Index cols, - const LhsMapper& lhs, - const RhsMapper& rhs, - QInt32* res, Index resIncr, - QInt8 alpha) -{ +template +EIGEN_DONT_INLINE void general_matrix_vector_product< + Index, QInt8, LhsMapper, ColMajor, ConjugateLhs, QInt8, RhsMapper, + ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs, + const RhsMapper& rhs, QInt32* res, + Index resIncr, QInt8 alpha) { eigen_assert(alpha.value == 1); eigen_assert(resIncr == 1); eigen_assert(rows > 0); @@ -78,26 +76,25 @@ EIGEN_DONT_INLINE void general_matrix_vector_product< } // Mat-Vec product -// The lhs is encoded using 8bit signed integers, the rhs using 8bit unsigned integers -template -struct general_matrix_vector_product -{ -EIGEN_DONT_INLINE static void run( - Index rows, Index cols, - const LhsMapper& lhs, - const RhsMapper& rhs, - QInt32* res, Index resIncr, - QUInt8 alpha); +// The lhs is encoded using 8bit signed integers, the rhs using 8bit unsigned +// integers +template +struct general_matrix_vector_product { + EIGEN_DONT_INLINE static void run(Index rows, Index cols, + const LhsMapper& lhs, const RhsMapper& rhs, + QInt32* res, Index resIncr, QUInt8 alpha); }; -template -EIGEN_DONT_INLINE void general_matrix_vector_product::run( - Index rows, Index cols, - const LhsMapper& lhs, - const RhsMapper& rhs, - QInt32* res, Index resIncr, - QUInt8 alpha) -{ +template +EIGEN_DONT_INLINE void general_matrix_vector_product< + Index, QInt8, LhsMapper, ColMajor, ConjugateLhs, QUInt8, RhsMapper, + ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs, + const RhsMapper& rhs, QInt32* res, + Index resIncr, QUInt8 alpha) { eigen_assert(alpha.value == 1); eigen_assert(resIncr == 1); eigen_assert(rows > 0); @@ -110,28 +107,26 @@ EIGEN_DONT_INLINE void general_matrix_vector_product -struct general_matrix_vector_product -{ -EIGEN_DONT_INLINE static void run( - Index rows, Index cols, - const LhsMapper& lhs, - const RhsMapper& rhs, - QInt32* res, Index resIncr, - QInt8 alpha); +// The lhs is encoded using bit unsigned integers, the rhs using 8bit signed +// integers +template +struct general_matrix_vector_product { + EIGEN_DONT_INLINE static void run(Index rows, Index cols, + const LhsMapper& lhs, const RhsMapper& rhs, + QInt32* res, Index resIncr, QInt8 alpha); }; -template -EIGEN_DONT_INLINE void general_matrix_vector_product::run( - Index rows, Index cols, - const LhsMapper& lhs, - const RhsMapper& rhs, - QInt32* res, Index resIncr, - QInt8 alpha) -{ +template +EIGEN_DONT_INLINE void general_matrix_vector_product< + Index, QUInt8, LhsMapper, ColMajor, ConjugateLhs, QInt8, RhsMapper, + ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs, + const RhsMapper& rhs, QInt32* res, + Index resIncr, QInt8 alpha) { eigen_assert(alpha.value == 1); eigen_assert(resIncr == 1); eigen_assert(rows > 0); diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h index 3abd4ee49c2..223ea4d58bf 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h @@ -8,24 +8,20 @@ #endif -inline int _mm256_extract_epi16_N0(const __m256i X) -{ - return _mm_extract_epi16(_mm256_extractf128_si256(X, 0 >> 3), 0 % 8); +inline int _mm256_extract_epi16_N0(const __m256i X) { + return _mm_extract_epi16(_mm256_extractf128_si256(X, 0 >> 3), 0 % 8); } -inline int _mm256_extract_epi16_N1(const __m256i X) -{ - return _mm_extract_epi16(_mm256_extractf128_si256(X, 1 >> 3), 1 % 8); +inline int _mm256_extract_epi16_N1(const __m256i X) { + return _mm_extract_epi16(_mm256_extractf128_si256(X, 1 >> 3), 1 % 8); } -inline int _mm256_extract_epi8_N0(const __m256i X) -{ - return _mm_extract_epi8(_mm256_extractf128_si256((X), 0 >> 4), 0 % 16); +inline int _mm256_extract_epi8_N0(const __m256i X) { + return _mm_extract_epi8(_mm256_extractf128_si256((X), 0 >> 4), 0 % 16); } -inline int _mm256_extract_epi8_N1(const __m256i X) -{ - return _mm_extract_epi8(_mm256_extractf128_si256((X), 1 >> 4), 1 % 16); +inline int _mm256_extract_epi8_N1(const __m256i X) { + return _mm_extract_epi8(_mm256_extractf128_si256((X), 1 >> 4), 1 % 16); } namespace Eigen { @@ -34,56 +30,56 @@ namespace internal { typedef struct Packet32q8i { __m256i val; operator __m256i() const { return val; } - Packet32q8i(); + Packet32q8i() : val(_mm256_setzero_si256()){}; Packet32q8i(__m256i val) : val(val) {} } Packet32q8i; typedef struct Packet16q16i { __m256i val; operator __m256i() const { return val; } - Packet16q16i(); + Packet16q16i() : val(_mm256_setzero_si256()){}; Packet16q16i(__m256i val) : val(val) {} } Packet16q16i; typedef struct Packet32q8u { __m256i val; operator __m256i() const { return val; } - Packet32q8u(); + Packet32q8u() : val(_mm256_setzero_si256()){}; Packet32q8u(__m256i val) : val(val) {} } Packet32q8u; typedef struct Packet16q8i { __m128i val; operator __m128i() const { return val; } - Packet16q8i(); + Packet16q8i() : val(_mm_setzero_si128()) {} Packet16q8i(__m128i val) : val(val) {} } Packet16q8i; typedef struct Packet16q8u { __m128i val; operator __m128i() const { return val; } - Packet16q8u(); + Packet16q8u() : val(_mm_setzero_si128()) {} Packet16q8u(__m128i val) : val(val) {} } Packet16q8u; typedef struct Packet8q16i { __m128i val; operator __m128i() const { return val; } - Packet8q16i(); + Packet8q16i() : val(_mm_setzero_si128()) {} Packet8q16i(__m128i val) : val(val) {} } Packet8q16i; typedef struct Packet8q32i { __m256i val; operator __m256i() const { return val; } - Packet8q32i(); + Packet8q32i() : val(_mm256_setzero_si256()){}; Packet8q32i(__m256i val) : val(val) {} } Packet8q32i; typedef struct Packet4q32i { __m128i val; operator __m128i() const { return val; } - Packet4q32i(); + Packet4q32i() : val(_mm_setzero_si128()) {} Packet4q32i(__m128i val) : val(val) {} } Packet4q32i; @@ -182,25 +178,25 @@ template <> struct unpacket_traits { typedef QInt8 type; typedef Packet16q8i half; - enum { size = 32, alignment=Aligned32 }; + enum { size = 32, alignment = Aligned32 }; }; template <> struct unpacket_traits { typedef QInt16 type; typedef Packet8q16i half; - enum { size = 16, alignment=Aligned32 }; + enum { size = 16, alignment = Aligned32 }; }; template <> struct unpacket_traits { typedef QUInt8 type; typedef Packet16q8u half; - enum { size = 32, alignment=Aligned32 }; + enum { size = 32, alignment = Aligned32 }; }; template <> struct unpacket_traits { typedef QInt32 type; typedef Packet4q32i half; - enum { size = 8, alignment=Aligned32 }; + enum { size = 8, alignment = Aligned32 }; }; // Unaligned load @@ -455,40 +451,47 @@ EIGEN_STRONG_INLINE QUInt8 predux_max(const Packet32q8u& a) { template <> EIGEN_STRONG_INLINE QInt8 predux_min(const Packet32q8i& a) { __m256i tmp = _mm256_min_epi8(a, _mm256_permute2f128_si256(a, a, 1)); - tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = + _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, 1)); - tmp = _mm256_min_epi8(tmp, _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_min_epi8(tmp, + _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); return std::min(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp)); } template <> EIGEN_STRONG_INLINE QInt8 predux_max(const Packet32q8i& a) { __m256i tmp = _mm256_max_epi8(a, _mm256_permute2f128_si256(a, a, 1)); - tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = + _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, 1)); - tmp = _mm256_max_epi8(tmp, _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_max_epi8(tmp, + _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); return std::max(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp)); } // Vectorized scaling of Packet32q8i by float. -template<> +template <> struct scalar_product_op : binary_op_base { typedef typename ScalarBinaryOpTraits::ReturnType result_type; #ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op) #else - scalar_product_op() { - EIGEN_SCALAR_BINARY_OP_PLUGIN - } + scalar_product_op() { EIGEN_SCALAR_BINARY_OP_PLUGIN } #endif - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const QInt32& a, const double& b) const { return a * b; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type + operator()(const QInt32& a, const double& b) const { + return a * b; + } - EIGEN_STRONG_INLINE const Packet8q32i packetOp(const Packet8q32i& a, const double& b) const { + EIGEN_STRONG_INLINE const Packet8q32i packetOp(const Packet8q32i& a, + const double& b) const { __m256d scale = _mm256_set1_pd(b); __m256d a_lo = _mm256_cvtepi32_pd(_mm256_castsi256_si128(a)); __m128i result_lo = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_lo)); __m256d a_hi = _mm256_cvtepi32_pd(_mm256_extracti128_si256(a, 1)); __m128i result_hi = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_hi)); - return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); + return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, + 1); } }; diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h index 2092ce1d4c9..84750c1945a 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h @@ -127,25 +127,25 @@ template <> struct unpacket_traits { typedef QInt8 type; typedef Packet32q8i half; - enum { size = 64, alignment=Aligned64 }; + enum { size = 64, alignment = Aligned64 }; }; template <> struct unpacket_traits { typedef QInt16 type; typedef Packet16q16i half; - enum { size = 32, alignment=Aligned64 }; + enum { size = 32, alignment = Aligned64 }; }; template <> struct unpacket_traits { typedef QUInt8 type; typedef Packet32q8u half; - enum { size = 64, alignment=Aligned64 }; + enum { size = 64, alignment = Aligned64 }; }; template <> struct unpacket_traits { typedef QInt32 type; typedef Packet8q32i half; - enum { size = 16, alignment=Aligned64 }; + enum { size = 16, alignment = Aligned64 }; }; // Unaligned load @@ -244,7 +244,7 @@ EIGEN_STRONG_INLINE QInt32 pfirst(const Packet16q32i& a) { template <> EIGEN_STRONG_INLINE QUInt8 pfirst(const Packet64q8u& a) { return static_cast( - _mm_extract_epi8(_mm512_extracti32x4_epi32(a.val, 0), 0)); + _mm_extract_epi8(_mm512_extracti32x4_epi32(a.val, 0), 0)); } template <> EIGEN_STRONG_INLINE QInt8 pfirst(const Packet64q8i& a) { @@ -410,9 +410,7 @@ EIGEN_STRONG_INLINE QInt32 predux_min(const Packet16q32i& a) { _mm_min_epi32(_mm_min_epi32(lane0, lane1), _mm_min_epi32(lane2, lane3)); res = _mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); return pfirst( - _mm_min_epi32( - res, - _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + _mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); } template <> EIGEN_STRONG_INLINE QInt32 predux_max(const Packet16q32i& a) { @@ -424,9 +422,7 @@ EIGEN_STRONG_INLINE QInt32 predux_max(const Packet16q32i& a) { _mm_max_epi32(_mm_max_epi32(lane0, lane1), _mm_max_epi32(lane2, lane3)); res = _mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); return pfirst( - _mm_max_epi32( - res, - _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + _mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); } template <> EIGEN_STRONG_INLINE QInt16 predux_min(const Packet32q16i& a) { @@ -437,13 +433,10 @@ EIGEN_STRONG_INLINE QInt16 predux_min(const Packet32q16i& a) { Packet4i res = _mm_min_epi16(_mm_min_epi16(lane0, lane1), _mm_min_epi16(lane2, lane3)); res = _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); - std::uint32_t w = - pfirst( - _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); - return std::min({ - static_cast(w >> 16), - static_cast(w) - }); + std::uint32_t w = pfirst( + _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::min( + {static_cast(w >> 16), static_cast(w)}); } template <> EIGEN_STRONG_INLINE QInt16 predux_max(const Packet32q16i& a) { @@ -454,13 +447,10 @@ EIGEN_STRONG_INLINE QInt16 predux_max(const Packet32q16i& a) { Packet4i res = _mm_max_epi16(_mm_max_epi16(lane0, lane1), _mm_max_epi16(lane2, lane3)); res = _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); - std::uint32_t w = - pfirst( - _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); - return std::max({ - static_cast(w >> 16), - static_cast(w) - }); + std::uint32_t w = pfirst( + _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::max( + {static_cast(w >> 16), static_cast(w)}); } template <> EIGEN_STRONG_INLINE QUInt8 predux_min(const Packet64q8u& a) { @@ -471,15 +461,11 @@ EIGEN_STRONG_INLINE QUInt8 predux_min(const Packet64q8u& a) { Packet4i res = _mm_min_epu8(_mm_min_epu8(lane0, lane1), _mm_min_epu8(lane2, lane3)); res = _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); - std::uint32_t w = - pfirst( - _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); - return std::min({ - static_cast(w >> 24), - static_cast(w >> 16), - static_cast(w >> 8), - static_cast(w) - }); + std::uint32_t w = pfirst( + _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::min( + {static_cast(w >> 24), static_cast(w >> 16), + static_cast(w >> 8), static_cast(w)}); } template <> EIGEN_STRONG_INLINE QUInt8 predux_max(const Packet64q8u& a) { @@ -490,15 +476,11 @@ EIGEN_STRONG_INLINE QUInt8 predux_max(const Packet64q8u& a) { Packet4i res = _mm_max_epu8(_mm_max_epu8(lane0, lane1), _mm_max_epu8(lane2, lane3)); res = _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); - std::uint32_t w = - pfirst( - _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); - return std::max({ - static_cast(w >> 24), - static_cast(w >> 16), - static_cast(w >> 8), - static_cast(w) - }); + std::uint32_t w = pfirst( + _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::max( + {static_cast(w >> 24), static_cast(w >> 16), + static_cast(w >> 8), static_cast(w)}); } template <> EIGEN_STRONG_INLINE QInt8 predux_min(const Packet64q8i& a) { @@ -509,15 +491,11 @@ EIGEN_STRONG_INLINE QInt8 predux_min(const Packet64q8i& a) { Packet4i res = _mm_min_epi8(_mm_min_epi8(lane0, lane1), _mm_min_epi8(lane2, lane3)); res = _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); - std::uint32_t w = - pfirst( - _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); - return std::min({ - static_cast(w >> 24), - static_cast(w >> 16), - static_cast(w >> 8), - static_cast(w) - }); + std::uint32_t w = pfirst( + _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::min( + {static_cast(w >> 24), static_cast(w >> 16), + static_cast(w >> 8), static_cast(w)}); } template <> EIGEN_STRONG_INLINE QInt8 predux_max(const Packet64q8i& a) { @@ -528,15 +506,11 @@ EIGEN_STRONG_INLINE QInt8 predux_max(const Packet64q8i& a) { Packet4i res = _mm_max_epi8(_mm_max_epi8(lane0, lane1), _mm_max_epi8(lane2, lane3)); res = _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); - std::uint32_t w = - pfirst( - _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); - return std::min({ - static_cast(w >> 24), - static_cast(w >> 16), - static_cast(w >> 8), - static_cast(w) - }); + std::uint32_t w = pfirst( + _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); + return std::min( + {static_cast(w >> 24), static_cast(w >> 16), + static_cast(w >> 8), static_cast(w)}); } } // end namespace internal diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h index a09eac67070..d3b02402971 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h +++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h @@ -33,28 +33,23 @@ struct type_casting_traits { }; template <> -EIGEN_STRONG_INLINE Packet32q16i -pcast(const Packet16f& a, const Packet16f& b) { +EIGEN_STRONG_INLINE Packet32q16i pcast(const Packet16f& a, + const Packet16f& b) { Packet16i a_int = _mm512_cvtps_epi32(a); Packet16i b_int = _mm512_cvtps_epi32(b); #ifdef EIGEN_VECTORIZE_AVX512BW return _mm512_packs_epi32(a_int, b_int); #else - Packet8i ab_int16_low = - _mm256_permute4x64_epi64( - _mm256_packs_epi32( - _mm512_castsi512_si256(a_int), - _mm512_castsi512_si256(b_int)), - _MM_SHUFFLE(0, 2, 1, 3)); - Packet8i ab_int16_high = - _mm256_permute4x64_epi64( - _mm256_packs_epi32( - _mm512_extracti32x8_epi32(a_int, 1), - _mm512_extracti32x8_epi32(b_int, 1)), - _MM_SHUFFLE(0, 2, 1, 3)); - return _mm512_inserti32x8( - _mm512_castsi256_si512(ab_int16_low), - ab_int16_high, 1); + Packet8i ab_int16_low = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_castsi512_si256(a_int), + _mm512_castsi512_si256(b_int)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i ab_int16_high = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_extracti32x8_epi32(a_int, 1), + _mm512_extracti32x8_epi32(b_int, 1)), + _MM_SHUFFLE(0, 2, 1, 3)); + return _mm512_inserti32x8(_mm512_castsi256_si512(ab_int16_low), ab_int16_high, + 1); #endif } @@ -64,55 +59,41 @@ struct type_casting_traits { }; template <> -EIGEN_STRONG_INLINE Packet64q8i -pcast(const Packet16f& a, - const Packet16f& b, - const Packet16f& c, - const Packet16f& d) { +EIGEN_STRONG_INLINE Packet64q8i pcast(const Packet16f& a, + const Packet16f& b, + const Packet16f& c, + const Packet16f& d) { Packet16i a_int = _mm512_cvtps_epi32(a); Packet16i b_int = _mm512_cvtps_epi32(b); Packet16i c_int = _mm512_cvtps_epi32(c); Packet16i d_int = _mm512_cvtps_epi32(d); #ifdef EIGEN_VECTORIZE_AVX512BW - return _mm512_packs_epi16( - _mm512_packs_epi32(a_int, b_int), - _mm512_packs_epi32(c_int, d_int)); + return _mm512_packs_epi16(_mm512_packs_epi32(a_int, b_int), + _mm512_packs_epi32(c_int, d_int)); #else - Packet8i ab_int16_low = - _mm256_permute4x64_epi64( - _mm256_packs_epi32( - _mm512_castsi512_si256(a_int), - _mm512_castsi512_si256(b_int)), - _MM_SHUFFLE(0, 2, 1, 3)); - Packet8i cd_int16_low = - _mm256_permute4x64_epi64( - _mm256_packs_epi32( - _mm512_castsi512_si256(c_int), - _mm512_castsi512_si256(d_int)), - _MM_SHUFFLE(0, 2, 1, 3)); - Packet8i ab_int16_high = - _mm256_permute4x64_epi64( - _mm256_packs_epi32( - _mm512_extracti32x8_epi32(a_int, 1), - _mm512_extracti32x8_epi32(b_int, 1)), - _MM_SHUFFLE(0, 2, 1, 3)); - Packet8i cd_int16_high = - _mm256_permute4x64_epi64( - _mm256_packs_epi32( - _mm512_extracti32x8_epi32(c_int, 1), - _mm512_extracti32x8_epi32(d_int, 1)), - _MM_SHUFFLE(0, 2, 1, 3)); - Packet8i abcd_int8_low = - _mm256_permute4x64_epi64( - _mm256_packs_epi16(ab_int16_low, cd_int16_low), - _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i ab_int16_low = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_castsi512_si256(a_int), + _mm512_castsi512_si256(b_int)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i cd_int16_low = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_castsi512_si256(c_int), + _mm512_castsi512_si256(d_int)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i ab_int16_high = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_extracti32x8_epi32(a_int, 1), + _mm512_extracti32x8_epi32(b_int, 1)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i cd_int16_high = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_extracti32x8_epi32(c_int, 1), + _mm512_extracti32x8_epi32(d_int, 1)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i abcd_int8_low = _mm256_permute4x64_epi64( + _mm256_packs_epi16(ab_int16_low, cd_int16_low), _MM_SHUFFLE(0, 2, 1, 3)); Packet8i abcd_int8_high = - _mm256_permute4x64_epi64( - _mm256_packs_epi16(ab_int16_high, cd_int16_high), - _MM_SHUFFLE(0, 2, 1, 3)); - return _mm512_inserti32x8( - _mm512_castsi256_si512(abcd_int8_low), - abcd_int8_high, 1); + _mm256_permute4x64_epi64(_mm256_packs_epi16(ab_int16_high, cd_int16_high), + _MM_SHUFFLE(0, 2, 1, 3)); + return _mm512_inserti32x8(_mm512_castsi256_si512(abcd_int8_low), + abcd_int8_high, 1); #endif } @@ -128,10 +109,8 @@ struct type_casting_traits { template <> EIGEN_STRONG_INLINE Packet64q8i -pcast(const Packet16q32i& a, - const Packet16q32i& b, - const Packet16q32i& c, - const Packet16q32i& d) { +pcast(const Packet16q32i& a, const Packet16q32i& b, + const Packet16q32i& c, const Packet16q32i& d) { __m128i a_part = _mm512_cvtsepi32_epi8(a); __m128i b_part = _mm512_cvtsepi32_epi8(b); __m128i c_part = _mm512_cvtsepi32_epi8(c); @@ -145,9 +124,8 @@ pcast(const Packet16q32i& a, } template <> -EIGEN_STRONG_INLINE Packet32q16i -pcast(const Packet16q32i& a, - const Packet16q32i& b) { +EIGEN_STRONG_INLINE Packet32q16i pcast( + const Packet16q32i& a, const Packet16q32i& b) { __m256i a_part = _mm512_cvtsepi32_epi16(a); __m256i b_part = _mm512_cvtsepi32_epi16(b); __m512i converted = diff --git a/third_party/eigen_reshaped.patch b/third_party/eigen_reshaped.patch new file mode 100644 index 00000000000..7acfdcf9fef --- /dev/null +++ b/third_party/eigen_reshaped.patch @@ -0,0 +1,48 @@ +--- a/Eigen/src/Core/util/ReshapedHelper.h (date 1541195478000) ++++ b/Eigen/src/Core/util/ReshapedHelper.h (date 1541195478000) +@@ -39,6 +39,11 @@ + return total/other; + } + ++template ++struct get_compiletime_reshape_order { ++ enum { value = Order == AutoOrder ? Flags & RowMajorBit : Order }; ++}; ++ + } + + } // end namespace Eigen +--- a/Eigen/src/plugins/ReshapedMethods.h (date 1541195254000) ++++ b/Eigen/src/plugins/ReshapedMethods.h (date 1541195254000) +@@ -105,13 +105,13 @@ + inline Reshaped::value, + internal::get_compiletime_reshape_size::value, +- (Order==AutoOrder?Flags&RowMajorBit:Order)> ++ internal::get_compiletime_reshape_order::value> + reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST + { + return Reshaped::value, + internal::get_compiletime_reshape_size::value, +- (Order==AutoOrder?Flags&RowMajorBit:Order)> ++ internal::get_compiletime_reshape_order::value> + (derived(), + internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()), + internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size())); +@@ -128,11 +128,13 @@ + + template + EIGEN_DEVICE_FUNC +-inline Reshaped ++inline Reshaped::value> + reshaped() EIGEN_RESHAPED_METHOD_CONST + { + EIGEN_STATIC_ASSERT(Order==RowMajor || Order==ColMajor || Order==AutoOrder, INVALID_TEMPLATE_PARAMETER); +- return Reshaped ++ return Reshaped::value> + (derived(), size(), 1); + } + \ No newline at end of file