Upgrade the version of Eigen to commit b4890dc6bc34.

PiperOrigin-RevId: 220359861
This commit is contained in:
A. Unique TensorFlower 2018-11-06 15:04:55 -08:00 committed by TensorFlower Gardener
parent bfb4bda0ff
commit cf02d61a83
25 changed files with 988 additions and 896 deletions

View File

@ -141,7 +141,7 @@ __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev,
// //
const int gid = batch_id * cell_size * 4 + act_id; const int gid = batch_id * cell_size * 4 + act_id;
const int cid = batch_id * cell_size + act_id; const int cid = batch_id * cell_size + act_id;
Eigen::internal::scalar_sigmoid_op<T> sigmoid_op; Eigen::internal::scalar_logistic_op<T> sigmoid_op;
Eigen::internal::scalar_tanh_op<T> tanh_op; Eigen::internal::scalar_tanh_op<T> tanh_op;
Eigen::scalar_clip_op<T> clip_op; Eigen::scalar_clip_op<T> clip_op;

View File

@ -84,13 +84,13 @@ namespace tensorflow {
// corresponding stream have completed. The following two classes // corresponding stream have completed. The following two classes
// serve this purpose in two different compilation environments. // serve this purpose in two different compilation environments.
class EigenCudaStreamDevice : public ::Eigen::StreamInterface { class EigenGpuStreamDevice : public ::Eigen::StreamInterface {
public: public:
EigenCudaStreamDevice() EigenGpuStreamDevice()
: scratch_(nullptr), semaphore_(nullptr), context_(nullptr) { : scratch_(nullptr), semaphore_(nullptr), context_(nullptr) {
Eigen::initializeDeviceProp(); Eigen::initializeDeviceProp();
} }
~EigenCudaStreamDevice() override {} ~EigenGpuStreamDevice() override {}
void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream, void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream,
TfGpuId tf_gpu_id, ::tensorflow::Allocator* alloc, TfGpuId tf_gpu_id, ::tensorflow::Allocator* alloc,
char* scratch) { char* scratch) {
@ -101,7 +101,7 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
context_ = context; context_ = context;
scratch_ = scratch; scratch_ = scratch;
semaphore_ = semaphore_ =
reinterpret_cast<unsigned int*>(scratch + Eigen::kCudaScratchSize); reinterpret_cast<unsigned int*>(scratch + Eigen::kGpuScratchSize);
stream_ = cuda_stream; stream_ = cuda_stream;
allocator_ = alloc; allocator_ = alloc;
PlatformGpuId platform_gpu_id; PlatformGpuId platform_gpu_id;
@ -185,7 +185,7 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
mutable unsigned int* semaphore_; mutable unsigned int* semaphore_;
OpKernelContext* context_; OpKernelContext* context_;
TF_DISALLOW_COPY_AND_ASSIGN(EigenCudaStreamDevice); TF_DISALLOW_COPY_AND_ASSIGN(EigenGpuStreamDevice);
}; };
// This factory helps to ensure that different GPU device objects that refer to // This factory helps to ensure that different GPU device objects that refer to
@ -292,7 +292,7 @@ Status BaseGPUDevice::InitScratchBuffers() {
DCHECK(streams_[i]); DCHECK(streams_[i]);
if (scratch_.size() > i && scratch_[i]) continue; if (scratch_.size() > i && scratch_[i]) continue;
size_t scratch_buffer_size = size_t scratch_buffer_size =
Eigen::kCudaScratchSize + sizeof(unsigned int); Eigen::kGpuScratchSize + sizeof(unsigned int);
void* scratch_buffer = gpu_allocator_->AllocateRaw( void* scratch_buffer = gpu_allocator_->AllocateRaw(
Allocator::kAllocatorAlignment, scratch_buffer_size); Allocator::kAllocatorAlignment, scratch_buffer_size);
if (scratch_buffer == nullptr) { if (scratch_buffer == nullptr) {
@ -304,7 +304,7 @@ Status BaseGPUDevice::InitScratchBuffers() {
se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size)); se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
bool ok = executor_->SynchronousMemZero( bool ok = executor_->SynchronousMemZero(
&mem, Eigen::kCudaScratchSize + sizeof(unsigned int)); &mem, Eigen::kGpuScratchSize + sizeof(unsigned int));
if (!ok) { if (!ok) {
return errors::FailedPrecondition( return errors::FailedPrecondition(
"Failed to memcopy into scratch buffer for device ", "Failed to memcopy into scratch buffer for device ",
@ -692,7 +692,7 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
const Eigen::GpuDevice& device() const override { return device_; } const Eigen::GpuDevice& device() const override { return device_; }
private: private:
EigenCudaStreamDevice stream_device_; EigenGpuStreamDevice stream_device_;
Eigen::GpuDevice device_; Eigen::GpuDevice device_;
}; };

View File

@ -311,8 +311,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{"Square", EIGEN_COST(scalar_square_op<float>)}, {"Square", EIGEN_COST(scalar_square_op<float>)},
{"Tanh", EIGEN_COST(scalar_tanh_op<float>)}, {"Tanh", EIGEN_COST(scalar_tanh_op<float>)},
{"Relu", EIGEN_COST(scalar_max_op<float>)}, {"Relu", EIGEN_COST(scalar_max_op<float>)},
{"Sigmoid", EIGEN_COST(scalar_sigmoid_op<float>)}, {"Sigmoid", EIGEN_COST(scalar_logistic_op<float>)},
{"QuantizedSigmoid", EIGEN_COST(scalar_sigmoid_op<float>)}, {"QuantizedSigmoid", EIGEN_COST(scalar_logistic_op<float>)},
{"Sign", EIGEN_COST(scalar_sign_op<float>)}, {"Sign", EIGEN_COST(scalar_sign_op<float>)},
{"Sin", EIGEN_COST(scalar_sin_op<float>)}, {"Sin", EIGEN_COST(scalar_sin_op<float>)},
{"Tan", EIGEN_COST(scalar_tan_op<float>)}, {"Tan", EIGEN_COST(scalar_tan_op<float>)},

View File

@ -60,9 +60,9 @@ template <typename T>
struct CheckNumericsLaunch { struct CheckNumericsLaunch {
void Run(const GPUDevice &d, const T *data, int size, void Run(const GPUDevice &d, const T *data, int size,
int abnormal_detected[2]) { int abnormal_detected[2]) {
const int32 block_size = d.maxCudaThreadsPerBlock(); const int32 block_size = d.maxGpuThreadsPerBlock();
const int32 num_blocks = const int32 num_blocks =
(d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) / (d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
block_size; block_size;
CheckNumericsKernel<T><<<num_blocks, block_size, 0, d.stream()>>>( CheckNumericsKernel<T><<<num_blocks, block_size, 0, d.stream()>>>(

View File

@ -656,7 +656,7 @@ template <typename T>
struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {}; struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {};
template <typename T> template <typename T>
struct sigmoid : base<T, Eigen::internal::scalar_sigmoid_op<T>> {}; struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {};
template <typename T> template <typename T>
struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {}; struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {};

View File

@ -500,8 +500,9 @@ class GemmFilterPacker {
typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::RowMajor> typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::RowMajor>
LhsMapper; LhsMapper;
typedef Eigen::internal::gebp_traits<T, T> Traits; typedef Eigen::internal::gebp_traits<T, T> Traits;
Eigen::internal::gemm_pack_lhs<T, int64, LhsMapper, Traits::mr, Eigen::internal::gemm_pack_lhs<
Traits::LhsProgress, Eigen::RowMajor> T, int64, LhsMapper, Traits::mr, Traits::LhsProgress,
typename Traits::LhsPacket4Packing, Eigen::RowMajor>
pack_lhs; pack_lhs;
GemmFilterPacker(const int64 rows, const int64 depth, const T* lhs_input, GemmFilterPacker(const int64 rows, const int64 depth, const T* lhs_input,

View File

@ -764,7 +764,7 @@ Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args,
const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 || const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 ||
kKnownDepthMultiplier < 0 kKnownDepthMultiplier < 0
? std::numeric_limits<int>::max() ? std::numeric_limits<int>::max()
: device.getNumCudaMultiProcessors(); : device.getNumGpuMultiProcessors();
kernel<<<std::min(max_block_count, config.block_count), kernel<<<std::min(max_block_count, config.block_count),
config.thread_per_block, 0, device.stream()>>>(args, input, filter, config.thread_per_block, 0, device.stream()>>>(args, input, filter,
output, num_outputs); output, num_outputs);

View File

@ -217,9 +217,9 @@ void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen, OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
typename Distribution::ResultElementType* data, int64 size, typename Distribution::ResultElementType* data, int64 size,
Distribution dist) { Distribution dist) {
const int32 block_size = d.maxCudaThreadsPerBlock(); const int32 block_size = d.maxGpuThreadsPerBlock();
const int32 num_blocks = const int32 num_blocks =
(d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) / (d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
block_size; block_size;
FillPhiloxRandomKernelLaunch<Distribution> FillPhiloxRandomKernelLaunch<Distribution>

View File

@ -131,7 +131,7 @@ class CudaLaunchConfigTest : public ::testing::Test {
protected: protected:
const int bufsize = 1024; const int bufsize = 1024;
int* outbuf = nullptr; int* outbuf = nullptr;
Eigen::CudaStreamDevice stream; Eigen::GpuStreamDevice stream;
Eigen::GpuDevice d = Eigen::GpuDevice(&stream); Eigen::GpuDevice d = Eigen::GpuDevice(&stream);
virtual void SetUp() { virtual void SetUp() {

View File

@ -128,12 +128,12 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
CudaLaunchConfig config; CudaLaunchConfig config;
const int virtual_thread_count = work_element_count; const int virtual_thread_count = work_element_count;
const int physical_thread_count = std::min( const int physical_thread_count = std::min(
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(), d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(),
virtual_thread_count); virtual_thread_count);
const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock()); const int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
const int block_count = const int block_count =
std::min(DivUp(physical_thread_count, thread_per_block), std::min(DivUp(physical_thread_count, thread_per_block),
d.getNumCudaMultiProcessors()); d.getNumGpuMultiProcessors());
config.virtual_thread_count = virtual_thread_count; config.virtual_thread_count = virtual_thread_count;
config.thread_per_block = thread_per_block; config.thread_per_block = thread_per_block;
@ -184,7 +184,7 @@ inline CudaLaunchConfig GetCudaLaunchConfigFixedBlockSize(
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&block_count, func, fixed_block_size, dynamic_shared_memory_size); &block_count, func, fixed_block_size, dynamic_shared_memory_size);
CHECK_EQ(err, cudaSuccess); CHECK_EQ(err, cudaSuccess);
block_count = std::min(block_count * d.getNumCudaMultiProcessors(), block_count = std::min(block_count * d.getNumGpuMultiProcessors(),
DivUp(work_element_count, fixed_block_size)); DivUp(work_element_count, fixed_block_size));
config.virtual_thread_count = work_element_count; config.virtual_thread_count = work_element_count;
@ -213,7 +213,7 @@ inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
int block_rows = std::max(kThreadsPerBlock / block_cols, 1); int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
const int physical_thread_count = const int physical_thread_count =
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(); d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor();
const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1); const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);

View File

@ -12,25 +12,55 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
#define EIGEN_USE_CUSTOM_THREAD_POOL // This is essentially unsupported/CXX11/Eigen/Tensor.h
#define EIGEN_USE_THREADS // TODO(petewarden) - move this to a common location in Eigen itself.
// clang-format off // clang-format off
#include <stdint.h>
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
#include "Eigen/Core"
#if defined(EIGEN_USE_SYCL)
#undef min
#undef max
#undef isnan
#undef isinf
#undef isfinite
#include <CL/sycl.hpp>
#include <iostream>
#include <map>
#include <memory>
#include <utility>
#endif
#include <cmath>
#include <cstddef> #include <cstddef>
#include <cstring> #include <cstring>
#include <cmath>
#ifdef _WIN32
typedef __int16 int16_t;
typedef unsigned __int16 uint16_t;
typedef __int32 int32_t;
typedef unsigned __int32 uint32_t;
typedef __int64 int64_t;
typedef unsigned __int64 uint64_t;
#include <windows.h>
#else
#include <stdint.h>
#include <unistd.h>
#endif
#if __cplusplus > 199711 || EIGEN_COMP_MSVC >= 1900
#include <random> #include <random>
#include <atomic> #endif
#include <condition_variable> // NOLINT(build/c++11)
#include <mutex> // NOLINT(build/c++11)
#include <thread> // NOLINT(build/c++11)
#include <functional>
#ifdef _WIN32 #ifdef _WIN32
#include <windows.h> #include <windows.h>
@ -40,58 +70,53 @@ limitations under the License.
#include <time.h> #include <time.h>
#endif #endif
// #if defined(EIGEN_USE_LIBXSMM)
// #include "libxsmm.h"
// #endif
// Because some programs may link Eigen in through other frameworks with #ifdef EIGEN_USE_THREADS
// different flags, we can run into multiple definition issues if we don't have #include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
// a private namespace for our versions. This is a nasty hack, but a similar #endif
// approach is used elsewhere to handle the problem, so it should be stable.
#define Eigen EigenForTFLite
#include "Eigen/src/Core/util/StaticAssert.h"
#include "unsupported/Eigen/CXX11/Core"
#include "unsupported/Eigen/SpecialFunctions"
#include "Eigen/src/Core/util/DisableStupidWarnings.h" #include "Eigen/src/Core/util/DisableStupidWarnings.h"
#include "Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/SpecialFunctions"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/util/CXX11Meta.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/util/MaxSizeVector.h"
// Beware: the order of the include matters to some compilers. For example
// TensorIndexList.h should be included before TensorDimensions.h in order to #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h"
// use index lists to encode tensor dimensions when compiling with llvm.
// We're defining this ourselves rather than using the Eigen Tensor header file
// so that we can alter the macro definition of TENSOR_CONTRACTION_DISPATCH to
// reduce binary size.
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/ThreadPoolInterface.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorNonBlockingThreadPool.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceGpu.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStats.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMappers.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h"
#undef TENSOR_CONTRACTION_DISPATCH #undef TENSOR_CONTRACTION_DISPATCH
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ #define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
if (this->m_lhs_inner_dim_contiguous && \ if (this->m_lhs_inner_dim_contiguous && \
@ -102,8 +127,9 @@ limitations under the License.
eigen_assert(false && "Unsupported contraction formats"); \ eigen_assert(false && "Unsupported contraction formats"); \
} }
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h"
@ -125,19 +151,18 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
#include "Eigen/src/Core/util/ReenableStupidWarnings.h" #include "Eigen/src/Core/util/ReenableStupidWarnings.h"
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_

View File

@ -94,7 +94,7 @@ typedef unsigned __int64 uint64_t;
#include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceGpu.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h"
@ -106,10 +106,11 @@ typedef unsigned __int64 uint64_t;
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h"
@ -128,7 +129,7 @@ typedef unsigned __int64 uint64_t;
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h"

View File

@ -3151,12 +3151,12 @@ inline void LstmCell(
// Combined memory state and final output calculation // Combined memory state and final output calculation
gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput"); gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
output_state_map = output_state_map =
input_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) * input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
new_input_sm.tanh() + new_input_sm.tanh() +
forget_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) * forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
prev_state_map; prev_state_map;
output_activ_map = output_activ_map =
output_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) * output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
output_state_map.tanh(); output_state_map.tanh();
} }
@ -4367,7 +4367,7 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
auto input_map = MapAsVector(input_data, input_shape); auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape); auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = output_map.array() =
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()); input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
} }
// Convenience version that allows, for example, generated-code calls to be // Convenience version that allows, for example, generated-code calls to be

View File

@ -1570,7 +1570,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly(): if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[[[4.0]]]], self.evaluate(y)) self.assertAllClose([[[[4.0]]]], self.evaluate(y))
# Remove reference cycles in model # Remove reference cycles in model
test_util.dismantle_polymorphic_function(model) test_util.dismantle_polymorphic_function(model)

View File

@ -134,11 +134,12 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive( tf_http_archive(
name = "eigen_archive", name = "eigen_archive",
build_file = clean_dep("//third_party:eigen.BUILD"), build_file = clean_dep("//third_party:eigen.BUILD"),
sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9", patch_file = clean_dep("//third_party:eigen_reshaped.patch"),
strip_prefix = "eigen-eigen-fd6845384b86", sha256 = "d66cec3b54b3dfaa4666c1d49481a7197f93fc078cd53c54e2b4a8893a529c9f",
strip_prefix = "eigen-eigen-b4890dc6bc34",
urls = [ urls = [
"https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz", "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/b4890dc6bc34.tar.gz",
"https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz", "https://bitbucket.org/eigen/eigen/get/b4890dc6bc34.tar.gz",
], ],
) )

View File

@ -65,6 +65,7 @@ cc_library(
# code. We use it, but we do not rely on it, as evidenced above. # code. We use it, but we do not rely on it, as evidenced above.
"EIGEN_MPL2_ONLY", "EIGEN_MPL2_ONLY",
"EIGEN_MAX_ALIGN_BYTES=64", "EIGEN_MAX_ALIGN_BYTES=64",
"EIGEN_HAS_TYPE_TRAITS=0",
], ],
includes = ["."], includes = ["."],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],

View File

@ -249,9 +249,7 @@ EIGEN_STRONG_INLINE QInt32& operator/=(QInt32& a, const QInt32 b) {
a.value /= b.value; a.value /= b.value;
return a; return a;
} }
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) { EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) { return -a.value; }
return -a.value;
}
// Scaling QInt32 by double. We do the arithmetic in double because // Scaling QInt32 by double. We do the arithmetic in double because
// float only has 23 bits of mantissa, so casting QInt32 to float might reduce // float only has 23 bits of mantissa, so casting QInt32 to float might reduce

View File

@ -15,11 +15,9 @@ namespace internal {
// Accumulate the product of 2 QInt8 inputs on 32 bits to prevent // Accumulate the product of 2 QInt8 inputs on 32 bits to prevent
// overflows // overflows
template<> struct scalar_product_traits<QInt8, QInt8> template <>
{ struct scalar_product_traits<QInt8, QInt8> {
enum { enum { Defined = 1 };
Defined = 1
};
typedef QInt32 ReturnType; typedef QInt32 ReturnType;
}; };
@ -33,11 +31,9 @@ struct scalar_product_traits<QInt16, QInt16> {
// Accumulate the product of QInt8 inputs with QUint8 inputs on 32 bits // Accumulate the product of QInt8 inputs with QUint8 inputs on 32 bits
// to prevent overflows // to prevent overflows
template<> struct scalar_product_traits<QInt8, QUInt8> template <>
{ struct scalar_product_traits<QInt8, QUInt8> {
enum { enum { Defined = 1 };
Defined = 1
};
typedef QInt32 ReturnType; typedef QInt32 ReturnType;
}; };
@ -48,13 +44,15 @@ template<> struct scalar_product_traits<QInt8, QUInt8>
#ifndef EIGEN_USE_OPTIMIZED_INT8_INT8_MAT_MAT_PRODUCT #ifndef EIGEN_USE_OPTIMIZED_INT8_INT8_MAT_MAT_PRODUCT
template <bool _ConjLhs, bool _ConjRhs> template <bool _ConjLhs, bool _ConjRhs>
class gebp_traits<QInt8, QInt8, _ConjLhs, _ConjRhs> class gebp_traits<QInt8, QInt8, _ConjLhs, _ConjRhs> {
{
public: public:
typedef QInt8 LhsScalar; typedef QInt8 LhsScalar;
typedef QInt8 RhsScalar; typedef QInt8 RhsScalar;
typedef QInt32 ResScalar; typedef QInt32 ResScalar;
typedef typename packet_traits<LhsScalar>::type LhsPacket;
typedef LhsPacket LhsPacket4Packing;
enum { enum {
// register block size along the M and N directions // register block size along the M and N directions
// One for the current implementation // One for the current implementation
@ -68,22 +66,24 @@ public:
}; };
// The signed 8bit Mat-Mat product itself. // The signed 8bit Mat-Mat product itself.
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
struct gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> bool ConjugateLhs, bool ConjugateRhs>
{ struct gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs,
ConjugateRhs> {
EIGEN_DONT_INLINE EIGEN_DONT_INLINE
void operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB, void operator()(const DataMapper& res, const QInt8* blockA,
Index rows, Index depth, Index cols, QInt32 alpha, const QInt8* blockB, Index rows, Index depth, Index cols,
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); QInt32 alpha, Index strideA = -1, Index strideB = -1,
Index offsetA = 0, Index offsetB = 0);
}; };
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
EIGEN_DONT_INLINE bool ConjugateLhs, bool ConjugateRhs>
void gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> EIGEN_DONT_INLINE void gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr,
::operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB, ConjugateLhs, ConjugateRhs>::
Index rows, Index depth, Index cols, QInt32 alpha, operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB,
Index strideA, Index strideB, Index offsetA, Index offsetB) Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
{ Index strideB, Index offsetA, Index offsetB) {
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
@ -113,18 +113,19 @@ void gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, Conjugat
} }
#endif #endif
// This definition tackle the case where the lhs is encoded using signed 8bit // This definition tackle the case where the lhs is encoded using signed 8bit
// integers and the rhs using unsigned 8bit integers. // integers and the rhs using unsigned 8bit integers.
#ifndef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT #ifndef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
template <bool _ConjLhs, bool _ConjRhs> template <bool _ConjLhs, bool _ConjRhs>
class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> {
{
public: public:
typedef QInt8 LhsScalar; typedef QInt8 LhsScalar;
typedef QUInt8 RhsScalar; typedef QUInt8 RhsScalar;
typedef QInt32 ResScalar; typedef QInt32 ResScalar;
typedef typename packet_traits<LhsScalar>::type LhsPacket;
typedef LhsPacket LhsPacket4Packing;
enum { enum {
// register block size along the M and N directions // register block size along the M and N directions
// One for the current implementation // One for the current implementation
@ -138,22 +139,24 @@ public:
}; };
// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs // Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> bool ConjugateLhs, bool ConjugateRhs>
{ struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
ConjugateRhs> {
EIGEN_DONT_INLINE EIGEN_DONT_INLINE
void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, void operator()(const DataMapper& res, const QInt8* blockA,
Index rows, Index depth, Index cols, QInt32 alpha, const QUInt8* blockB, Index rows, Index depth, Index cols,
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); QInt32 alpha, Index strideA = -1, Index strideB = -1,
Index offsetA = 0, Index offsetB = 0);
}; };
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
EIGEN_DONT_INLINE bool ConjugateLhs, bool ConjugateRhs>
void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> EIGEN_DONT_INLINE void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr,
::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, ConjugateLhs, ConjugateRhs>::
Index rows, Index depth, Index cols, QInt32 alpha, operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
Index strideA, Index strideB, Index offsetA, Index offsetB) Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
{ Index strideB, Index offsetA, Index offsetB) {
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
@ -183,18 +186,19 @@ void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, Conjuga
} }
#endif #endif
// This definition tackle the case where the khs is encoded using unsigned 8bit // This definition tackle the case where the khs is encoded using unsigned 8bit
// integers and the rhs using signed 8bit integers. // integers and the rhs using signed 8bit integers.
#ifndef EIGEN_USE_OPTIMIZED_UINT8_INT8_MAT_MAT_PRODUCT #ifndef EIGEN_USE_OPTIMIZED_UINT8_INT8_MAT_MAT_PRODUCT
template <bool _ConjLhs, bool _ConjRhs> template <bool _ConjLhs, bool _ConjRhs>
class gebp_traits<QUInt8, QInt8, _ConjLhs, _ConjRhs> class gebp_traits<QUInt8, QInt8, _ConjLhs, _ConjRhs> {
{
public: public:
typedef QUInt8 LhsScalar; typedef QUInt8 LhsScalar;
typedef QInt8 RhsScalar; typedef QInt8 RhsScalar;
typedef QInt32 ResScalar; typedef QInt32 ResScalar;
typedef typename packet_traits<LhsScalar>::type LhsPacket;
typedef LhsPacket LhsPacket4Packing;
enum { enum {
// register block size along the M and N directions // register block size along the M and N directions
// One for the current implementation // One for the current implementation
@ -207,24 +211,25 @@ public:
}; };
}; };
// Mat-Mat product of an unsigned 8bit lhs with a signed 8bit rhs // Mat-Mat product of an unsigned 8bit lhs with a signed 8bit rhs
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
struct gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> bool ConjugateLhs, bool ConjugateRhs>
{ struct gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs,
ConjugateRhs> {
EIGEN_DONT_INLINE EIGEN_DONT_INLINE
void operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB, void operator()(const DataMapper& res, const QUInt8* blockA,
Index rows, Index depth, Index cols, QInt32 alpha, const QInt8* blockB, Index rows, Index depth, Index cols,
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); QInt32 alpha, Index strideA = -1, Index strideB = -1,
Index offsetA = 0, Index offsetB = 0);
}; };
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
EIGEN_DONT_INLINE bool ConjugateLhs, bool ConjugateRhs>
void gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> EIGEN_DONT_INLINE void gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr,
::operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB, ConjugateLhs, ConjugateRhs>::
Index rows, Index depth, Index cols, QInt32 alpha, operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB,
Index strideA, Index strideB, Index offsetA, Index offsetB) Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
{ Index strideB, Index offsetA, Index offsetB) {
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
@ -263,6 +268,9 @@ class gebp_traits<QInt16, QInt16, _ConjLhs, _ConjRhs> {
typedef QInt16 RhsScalar; typedef QInt16 RhsScalar;
typedef QInt32 ResScalar; typedef QInt32 ResScalar;
typedef typename packet_traits<LhsScalar>::type LhsPacket;
typedef LhsPacket LhsPacket4Packing;
enum { enum {
// register block size along the M and N directions // register block size along the M and N directions
// One for the current implementation // One for the current implementation

View File

@ -28,6 +28,9 @@ class gebp_traits<QInt16, QInt16, _ConjLhs, _ConjRhs> {
typedef QInt16 RhsScalar; typedef QInt16 RhsScalar;
typedef QInt32 ResScalar; typedef QInt32 ResScalar;
typedef typename packet_traits<LhsScalar>::type LhsPacket;
typedef LhsPacket LhsPacket4Packing;
enum { enum {
// Define register blocking scheme. // Define register blocking scheme.
nr = 16, nr = 16,
@ -43,7 +46,7 @@ class gebp_traits<QInt16, QInt16, _ConjLhs, _ConjRhs> {
// Used by TensorContractionThreadPool, inputs must have dimensions that are // Used by TensorContractionThreadPool, inputs must have dimensions that are
// multiples of 32. // multiples of 32.
template <typename Index, int ShardingType> template <typename Index, int ShardingType>
class TensorContractionBlocking<QInt16, QInt16, Index, ShardingType> { class TensorContractionBlocking<QInt16, QInt16, QInt16, Index, ShardingType> {
public: public:
TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1)
: kc_(((k + 15) / 16) * 16), : kc_(((k + 15) / 16) * 16),
@ -144,7 +147,7 @@ class gemm_blocking_space<ColMajor, QInt16, QInt16, MaxRows, MaxCols, MaxDepth,
template <typename Index, typename DataMapper, int Pack1, int Pack2, template <typename Index, typename DataMapper, int Pack1, int Pack2,
bool Conjugate, bool PanelMode> bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2, ColMajor, struct gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2, QInt16, ColMajor,
Conjugate, PanelMode> { Conjugate, PanelMode> {
EIGEN_DONT_INLINE void operator()(QInt16* blockA, const DataMapper& lhs, EIGEN_DONT_INLINE void operator()(QInt16* blockA, const DataMapper& lhs,
Index depth, Index rows, Index stride = 0, Index depth, Index rows, Index stride = 0,
@ -154,12 +157,14 @@ struct gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2, ColMajor,
template <typename Index, typename DataMapper, int Pack1, int Pack2, template <typename Index, typename DataMapper, int Pack1, int Pack2,
bool Conjugate, bool PanelMode> bool Conjugate, bool PanelMode>
EIGEN_DONT_INLINE void gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2, EIGEN_DONT_INLINE void gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2,
ColMajor, Conjugate, PanelMode>:: QInt16, ColMajor, Conjugate, PanelMode>::
operator()(QInt16* blockA, const DataMapper& lhs, Index depth, Index rows, operator()(QInt16* blockA, const DataMapper& lhs, Index depth, Index rows,
Index stride, Index offset) { Index stride, Index offset) {
eigen_assert(stride == 0); eigen_assert(stride == 0);
eigen_assert(offset == 0); eigen_assert(offset == 0);
typedef typename packet_traits<QInt16>::type Packet;
// Use alternate function for weird sizes // Use alternate function for weird sizes
if (rows % 16 != 0 || depth % 16 != 0) { if (rows % 16 != 0 || depth % 16 != 0) {
assert(false && assert(false &&
@ -178,10 +183,10 @@ operator()(QInt16* blockA, const DataMapper& lhs, Index depth, Index rows,
// Pack depth in sets of 4 // Pack depth in sets of 4
for (Index k = 0; k < depth; k += 4) { for (Index k = 0; k < depth; k += 4) {
// Load vectors // Load vectors
__m256i L_A = lhs.loadPacket(m, k); __m256i L_A = lhs.template loadPacket<Packet>(m, k);
__m256i L_B = lhs.loadPacket(m, k + 1); __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1);
__m256i L_C = lhs.loadPacket(m, k + 2); __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2);
__m256i L_D = lhs.loadPacket(m, k + 3); __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3);
// Rearrange the inputs as required by the kernel // Rearrange the inputs as required by the kernel
__m256i L_AB0_AB7 = _mm256_unpacklo_epi16(L_A, L_B); __m256i L_AB0_AB7 = _mm256_unpacklo_epi16(L_A, L_B);
@ -236,13 +241,15 @@ struct gemm_pack_rhs<QInt16, Index, DataMapper, nr, ColMajor, Conjugate,
template <typename Index, typename DataMapper, int nr, bool Conjugate, template <typename Index, typename DataMapper, int nr, bool Conjugate,
bool PanelMode> bool PanelMode>
EIGEN_DONT_INLINE void EIGEN_DONT_INLINE void gemm_pack_rhs<QInt16, Index, DataMapper, nr, ColMajor,
gemm_pack_rhs<QInt16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>:: Conjugate, PanelMode>::
operator()(QInt16* blockB, const DataMapper& rhs, Index depth, Index cols, operator()(QInt16* blockB, const DataMapper& rhs, Index depth, Index cols,
Index stride, Index offset) { Index stride, Index offset) {
eigen_assert(stride == 0); eigen_assert(stride == 0);
eigen_assert(offset == 0); eigen_assert(offset == 0);
typedef typename packet_traits<QInt16>::type Packet;
// Use alternate function for weird sizes // Use alternate function for weird sizes
if (cols % 16 != 0 || depth % 16 != 0) { if (cols % 16 != 0 || depth % 16 != 0) {
assert(false && assert(false &&
@ -277,28 +284,28 @@ operator()(QInt16* blockB, const DataMapper& rhs, Index depth, Index cols,
for (Index n = 0; n < cols; n += 16) { for (Index n = 0; n < cols; n += 16) {
// Pack depth in sets of 16 // Pack depth in sets of 16
for (Index k = 0; k < depth; k += 16) { for (Index k = 0; k < depth; k += 16) {
__m256i R_A = rhs.loadPacket(k, n); __m256i R_A = rhs.template loadPacket<Packet>(k, n);
__m256i R_B = rhs.loadPacket(k, n + 1); __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1);
__m256i R_C = rhs.loadPacket(k, n + 2); __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2);
__m256i R_D = rhs.loadPacket(k, n + 3); __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 4); R_A = rhs.template loadPacket<Packet>(k, n + 4);
R_B = rhs.loadPacket(k, n + 5); R_B = rhs.template loadPacket<Packet>(k, n + 5);
R_C = rhs.loadPacket(k, n + 6); R_C = rhs.template loadPacket<Packet>(k, n + 6);
R_D = rhs.loadPacket(k, n + 7); R_D = rhs.template loadPacket<Packet>(k, n + 7);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 8); R_A = rhs.template loadPacket<Packet>(k, n + 8);
R_B = rhs.loadPacket(k, n + 9); R_B = rhs.template loadPacket<Packet>(k, n + 9);
R_C = rhs.loadPacket(k, n + 10); R_C = rhs.template loadPacket<Packet>(k, n + 10);
R_D = rhs.loadPacket(k, n + 11); R_D = rhs.template loadPacket<Packet>(k, n + 11);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 12); R_A = rhs.template loadPacket<Packet>(k, n + 12);
R_B = rhs.loadPacket(k, n + 13); R_B = rhs.template loadPacket<Packet>(k, n + 13);
R_C = rhs.loadPacket(k, n + 14); R_C = rhs.template loadPacket<Packet>(k, n + 14);
R_D = rhs.loadPacket(k, n + 15); R_D = rhs.template loadPacket<Packet>(k, n + 15);
PACK_STEP; PACK_STEP;
blockB_256 += 12; blockB_256 += 12;
@ -476,9 +483,13 @@ operator()(const DataMapper& res, const QInt16* blockA, const QInt16* blockB,
for (Index j = n; j < n + 16; j++) { for (Index j = n; j < n + 16; j++) {
LinearMapper r0 = res.getLinearMapper(m, j); LinearMapper r0 = res.getLinearMapper(m, j);
LinearMapper r1 = res.getLinearMapper(m + 8, j); LinearMapper r1 = res.getLinearMapper(m + 8, j);
typedef typename packet_traits<QInt32>::type Packet;
r0.storePacket(0, _mm256_add_epi32(blockO_256[i++], r0.loadPacket(0))); r0.template storePacket<Packet>(
r1.storePacket(0, _mm256_add_epi32(blockO_256[i++], r1.loadPacket(0))); 0, _mm256_add_epi32(blockO_256[i++],
r0.template loadPacket<Packet>(0)));
r1.template storePacket<Packet>(
0, _mm256_add_epi32(blockO_256[i++],
r1.template loadPacket<Packet>(0)));
} }
// Zero the result block so it can be reused // Zero the result block so it can be reused
@ -497,13 +508,15 @@ operator()(const DataMapper& res, const QInt16* blockA, const QInt16* blockB,
// Define quantized traits // Define quantized traits
template <bool _ConjLhs, bool _ConjRhs> template <bool _ConjLhs, bool _ConjRhs>
class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> {
{
public: public:
typedef QInt8 LhsScalar; typedef QInt8 LhsScalar;
typedef QUInt8 RhsScalar; typedef QUInt8 RhsScalar;
typedef QInt32 ResScalar; typedef QInt32 ResScalar;
typedef typename packet_traits<LhsScalar>::type LhsPacket;
typedef LhsPacket LhsPacket4Packing;
enum { enum {
// Define register blocking scheme. // Define register blocking scheme.
nr = 32, nr = 32,
@ -518,22 +531,28 @@ public:
// Specialized blocking for quantized implementations. // Specialized blocking for quantized implementations.
// Used by TensorContractionThreadPool, inputs must have dimensions that are // Used by TensorContractionThreadPool, inputs must have dimensions that are
// multiples of 32. // multiples of 32.
template<typename Index, template <typename ResScalar, typename Index, typename LeftTensor,
typename LeftTensor,
typename left_nocontract_t, typename left_contract_t, typename left_nocontract_t, typename left_contract_t,
bool left_inner_dim_contiguous, bool left_inner_dim_reordered, int LeftAlignment, bool left_inner_dim_contiguous, bool left_inner_dim_reordered,
typename RightTensor, int LeftAlignment, typename RightTensor, typename right_nocontract_t,
typename right_nocontract_t, typename right_contract_t, typename right_contract_t, bool right_inner_dim_contiguous,
bool right_inner_dim_contiguous, bool right_inner_dim_reordered, int RightAlignment, int ShardingType> bool right_inner_dim_reordered, int RightAlignment, int ShardingType>
class TensorContractionBlocking<TensorContractionInputMapper<QInt8, Index, Lhs, LeftTensor, left_nocontract_t, left_contract_t, 32, left_inner_dim_contiguous, left_inner_dim_reordered, LeftAlignment>, TensorContractionInputMapper<QUInt8, Index, Rhs, RightTensor, right_nocontract_t, right_contract_t, 32, right_inner_dim_contiguous, right_inner_dim_reordered, RightAlignment>, Index, ShardingType> { class TensorContractionBlocking<
ResScalar,
TensorContractionInputMapper<
QInt8, Index, Lhs, LeftTensor, left_nocontract_t, left_contract_t, 32,
left_inner_dim_contiguous, left_inner_dim_reordered, LeftAlignment>,
TensorContractionInputMapper<QUInt8, Index, Rhs, RightTensor,
right_nocontract_t, right_contract_t, 32,
right_inner_dim_contiguous,
right_inner_dim_reordered, RightAlignment>,
Index, ShardingType> {
public: public:
typedef QInt8 LhsScalar; typedef QInt8 LhsScalar;
typedef QUInt8 RhsScalar; typedef QUInt8 RhsScalar;
TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) : TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1)
kc_(k), mc_(m), nc_(n) : kc_(k), mc_(m), nc_(n) {
{
eigen_assert(m % 32 == 0); eigen_assert(m % 32 == 0);
eigen_assert(k % 32 == 0); eigen_assert(k % 32 == 0);
if (!k || !m || !n) { if (!k || !m || !n) {
@ -543,8 +562,7 @@ class TensorContractionBlocking<TensorContractionInputMapper<QInt8, Index, Lhs,
if (ShardingType == ShardByCol) { if (ShardingType == ShardByCol) {
eigen_assert(n % 32 == 0); eigen_assert(n % 32 == 0);
nc_ = (((n / num_threads) + 31) / 32) * 32; nc_ = (((n / num_threads) + 31) / 32) * 32;
} } else {
else {
eigen_assert(n % 32 == 0 || n == 1); eigen_assert(n % 32 == 0 || n == 1);
// Special case to avoid breaking the unimplemented matrix-vector case // Special case to avoid breaking the unimplemented matrix-vector case
if (n == 1) { if (n == 1) {
@ -599,7 +617,6 @@ class gemm_blocking_space<ColMajor, QInt8, QInt8, MaxRows, MaxCols, MaxDepth,
} }
}; };
template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor> template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
class gemm_blocking_space<ColMajor, QInt8, QUInt8, MaxRows, MaxCols, MaxDepth, class gemm_blocking_space<ColMajor, QInt8, QUInt8, MaxRows, MaxCols, MaxDepth,
KcFactor, false> KcFactor, false>
@ -633,42 +650,60 @@ class gemm_blocking_space<ColMajor, QInt8, QUInt8, MaxRows, MaxCols, MaxDepth,
}; };
// Alternate templates for any input sizes // Alternate templates for any input sizes
template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false> template <typename Scalar, typename Index, typename DataMapper, int Pack1,
int Pack2, int StorageOrder, bool Conjugate = false,
bool PanelMode = false>
struct gemm_pack_lhs_any; struct gemm_pack_lhs_any;
template <typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode> template <typename Index, typename DataMapper, int Pack1, int Pack2,
struct gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode> { bool Conjugate, bool PanelMode>
EIGEN_DONT_INLINE void operator() struct gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor,
(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0); Conjugate, PanelMode> {
EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs,
Index depth, Index rows, Index stride = 0,
Index offset = 0);
}; };
template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false> template <typename Scalar, typename Index, typename DataMapper, int nr,
int StorageOrder, bool Conjugate = false, bool PanelMode = false>
struct gemm_pack_rhs_any; struct gemm_pack_rhs_any;
template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> template <typename Index, typename DataMapper, int nr, bool Conjugate,
struct gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> { bool PanelMode>
EIGEN_DONT_INLINE void operator() struct gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0); PanelMode> {
EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs,
Index depth, Index cols, Index stride = 0,
Index offset = 0);
}; };
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false> template <typename LhsScalar, typename RhsScalar, typename Index,
typename DataMapper, int mr, int nr, bool ConjugateLhs = false,
bool ConjugateRhs = false>
struct gebp_kernel_any; struct gebp_kernel_any;
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
struct gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> bool ConjugateLhs, bool ConjugateRhs>
{ struct gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
ConjugateRhs> {
typedef typename DataMapper::LinearMapper LinearMapper; typedef typename DataMapper::LinearMapper LinearMapper;
EIGEN_DONT_INLINE EIGEN_DONT_INLINE
void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, void operator()(const DataMapper& res, const QInt8* blockA,
Index rows, Index depth, Index cols, QInt32 alpha, const QUInt8* blockB, Index rows, Index depth, Index cols,
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); QInt32 alpha, Index strideA = -1, Index strideB = -1,
Index offsetA = 0, Index offsetB = 0);
}; };
// Alternate implementations for any input sizes // Alternate implementations for any input sizes
template <typename Index, typename DataMapper, int Pack1, int Pack2, bool Conjugate, bool PanelMode> template <typename Index, typename DataMapper, int Pack1, int Pack2,
EIGEN_DONT_INLINE void gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode>:: bool Conjugate, bool PanelMode>
operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { EIGEN_DONT_INLINE void gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2,
ColMajor, Conjugate, PanelMode>::
operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows,
Index stride, Index offset) {
eigen_assert(stride == 0); eigen_assert(stride == 0);
eigen_assert(offset == 0); eigen_assert(offset == 0);
typedef typename packet_traits<QInt8>::type Packet;
// Get vector pointer // Get vector pointer
__m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA); __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA);
@ -690,15 +725,15 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index
// Pack depth in sets of 8 // Pack depth in sets of 8
for (Index k = 0; k < depth_8; k += 8) { for (Index k = 0; k < depth_8; k += 8) {
// Load vectors // Load vectors
__m256i L_A = lhs.loadPacket(m, k); __m256i L_A = lhs.template loadPacket<Packet>(m, k);
__m256i L_B = lhs.loadPacket(m, k + 1); __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1);
// Interleave 8-bit elements // Interleave 8-bit elements
__m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
__m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
__m256i L_C = lhs.loadPacket(m, k + 2); __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2);
__m256i L_D = lhs.loadPacket(m, k + 3); __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3);
__m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
__m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
@ -719,12 +754,12 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index
_mm256_store_si256(blockA_256++, L_AD16); _mm256_store_si256(blockA_256++, L_AD16);
__m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
_mm256_store_si256(blockA_256++, L_AD24); _mm256_store_si256(blockA_256++, L_AD24);
__m256i L_E = lhs.loadPacket(m, k + 4); __m256i L_E = lhs.template loadPacket<Packet>(m, k + 4);
__m256i L_F = lhs.loadPacket(m, k + 5); __m256i L_F = lhs.template loadPacket<Packet>(m, k + 5);
__m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
__m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
__m256i L_G = lhs.loadPacket(m, k + 6); __m256i L_G = lhs.template loadPacket<Packet>(m, k + 6);
__m256i L_H = lhs.loadPacket(m, k + 7); __m256i L_H = lhs.template loadPacket<Packet>(m, k + 7);
__m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
__m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
__m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
@ -746,7 +781,7 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index
__m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H; __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H;
switch (depth - depth_8) { switch (depth - depth_8) {
case 1: case 1:
L_A = lhs.loadPacket(m, depth_8); L_A = lhs.template loadPacket<Packet>(m, depth_8);
L_B = _mm256_setzero_si256(); L_B = _mm256_setzero_si256();
L_C = _mm256_setzero_si256(); L_C = _mm256_setzero_si256();
L_D = _mm256_setzero_si256(); L_D = _mm256_setzero_si256();
@ -756,8 +791,8 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index
L_H = _mm256_setzero_si256(); L_H = _mm256_setzero_si256();
break; break;
case 2: case 2:
L_A = lhs.loadPacket(m, depth_8); L_A = lhs.template loadPacket<Packet>(m, depth_8);
L_B = lhs.loadPacket(m, depth_8 + 1); L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
L_C = _mm256_setzero_si256(); L_C = _mm256_setzero_si256();
L_D = _mm256_setzero_si256(); L_D = _mm256_setzero_si256();
L_E = _mm256_setzero_si256(); L_E = _mm256_setzero_si256();
@ -766,9 +801,9 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index
L_H = _mm256_setzero_si256(); L_H = _mm256_setzero_si256();
break; break;
case 3: case 3:
L_A = lhs.loadPacket(m, depth_8); L_A = lhs.template loadPacket<Packet>(m, depth_8);
L_B = lhs.loadPacket(m, depth_8 + 1); L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
L_C = lhs.loadPacket(m, depth_8 + 2); L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
L_D = _mm256_setzero_si256(); L_D = _mm256_setzero_si256();
L_E = _mm256_setzero_si256(); L_E = _mm256_setzero_si256();
L_F = _mm256_setzero_si256(); L_F = _mm256_setzero_si256();
@ -776,43 +811,43 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index
L_H = _mm256_setzero_si256(); L_H = _mm256_setzero_si256();
break; break;
case 4: case 4:
L_A = lhs.loadPacket(m, depth_8); L_A = lhs.template loadPacket<Packet>(m, depth_8);
L_B = lhs.loadPacket(m, depth_8 + 1); L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
L_C = lhs.loadPacket(m, depth_8 + 2); L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
L_D = lhs.loadPacket(m, depth_8 + 3); L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
L_E = _mm256_setzero_si256(); L_E = _mm256_setzero_si256();
L_F = _mm256_setzero_si256(); L_F = _mm256_setzero_si256();
L_G = _mm256_setzero_si256(); L_G = _mm256_setzero_si256();
L_H = _mm256_setzero_si256(); L_H = _mm256_setzero_si256();
break; break;
case 5: case 5:
L_A = lhs.loadPacket(m, depth_8); L_A = lhs.template loadPacket<Packet>(m, depth_8);
L_B = lhs.loadPacket(m, depth_8 + 1); L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
L_C = lhs.loadPacket(m, depth_8 + 2); L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
L_D = lhs.loadPacket(m, depth_8 + 3); L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
L_E = lhs.loadPacket(m, depth_8 + 4); L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4);
L_F = _mm256_setzero_si256(); L_F = _mm256_setzero_si256();
L_G = _mm256_setzero_si256(); L_G = _mm256_setzero_si256();
L_H = _mm256_setzero_si256(); L_H = _mm256_setzero_si256();
break; break;
case 6: case 6:
L_A = lhs.loadPacket(m, depth_8); L_A = lhs.template loadPacket<Packet>(m, depth_8);
L_B = lhs.loadPacket(m, depth_8 + 1); L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
L_C = lhs.loadPacket(m, depth_8 + 2); L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
L_D = lhs.loadPacket(m, depth_8 + 3); L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
L_E = lhs.loadPacket(m, depth_8 + 4); L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4);
L_F = lhs.loadPacket(m, depth_8 + 5); L_F = lhs.template loadPacket<Packet>(m, depth_8 + 5);
L_G = _mm256_setzero_si256(); L_G = _mm256_setzero_si256();
L_H = _mm256_setzero_si256(); L_H = _mm256_setzero_si256();
break; break;
case 7: case 7:
L_A = lhs.loadPacket(m, depth_8); L_A = lhs.template loadPacket<Packet>(m, depth_8);
L_B = lhs.loadPacket(m, depth_8 + 1); L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
L_C = lhs.loadPacket(m, depth_8 + 2); L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
L_D = lhs.loadPacket(m, depth_8 + 3); L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
L_E = lhs.loadPacket(m, depth_8 + 4); L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4);
L_F = lhs.loadPacket(m, depth_8 + 5); L_F = lhs.template loadPacket<Packet>(m, depth_8 + 5);
L_G = lhs.loadPacket(m, depth_8 + 6); L_G = lhs.template loadPacket<Packet>(m, depth_8 + 6);
L_H = _mm256_setzero_si256(); L_H = _mm256_setzero_si256();
break; break;
} }
@ -1124,12 +1159,17 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, Index
} }
} }
template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> template <typename Index, typename DataMapper, int nr, bool Conjugate,
EIGEN_DONT_INLINE void gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>:: bool PanelMode>
operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { EIGEN_DONT_INLINE void gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr,
ColMajor, Conjugate, PanelMode>::
operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols,
Index stride, Index offset) {
eigen_assert(stride == 0); eigen_assert(stride == 0);
eigen_assert(offset == 0); eigen_assert(offset == 0);
typedef typename packet_traits<QUInt8>::type Packet;
// Get vector pointer // Get vector pointer
__m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB); __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB);
@ -1158,52 +1198,52 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index
for (Index n = 0; n < cols_32; n += 32) { for (Index n = 0; n < cols_32; n += 32) {
// Pack depth in sets of 32 // Pack depth in sets of 32
for (Index k = 0; k < depth_32; k += 32) { for (Index k = 0; k < depth_32; k += 32) {
__m256i R_A = rhs.loadPacket(k, n); __m256i R_A = rhs.template loadPacket<Packet>(k, n);
__m256i R_B = rhs.loadPacket(k, n + 1); __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1);
__m256i R_C = rhs.loadPacket(k, n + 2); __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2);
__m256i R_D = rhs.loadPacket(k, n + 3); __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 4); R_A = rhs.template loadPacket<Packet>(k, n + 4);
R_B = rhs.loadPacket(k, n + 5); R_B = rhs.template loadPacket<Packet>(k, n + 5);
R_C = rhs.loadPacket(k, n + 6); R_C = rhs.template loadPacket<Packet>(k, n + 6);
R_D = rhs.loadPacket(k, n + 7); R_D = rhs.template loadPacket<Packet>(k, n + 7);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 8); R_A = rhs.template loadPacket<Packet>(k, n + 8);
R_B = rhs.loadPacket(k, n + 9); R_B = rhs.template loadPacket<Packet>(k, n + 9);
R_C = rhs.loadPacket(k, n + 10); R_C = rhs.template loadPacket<Packet>(k, n + 10);
R_D = rhs.loadPacket(k, n + 11); R_D = rhs.template loadPacket<Packet>(k, n + 11);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 12); R_A = rhs.template loadPacket<Packet>(k, n + 12);
R_B = rhs.loadPacket(k, n + 13); R_B = rhs.template loadPacket<Packet>(k, n + 13);
R_C = rhs.loadPacket(k, n + 14); R_C = rhs.template loadPacket<Packet>(k, n + 14);
R_D = rhs.loadPacket(k, n + 15); R_D = rhs.template loadPacket<Packet>(k, n + 15);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 16); R_A = rhs.template loadPacket<Packet>(k, n + 16);
R_B = rhs.loadPacket(k, n + 17); R_B = rhs.template loadPacket<Packet>(k, n + 17);
R_C = rhs.loadPacket(k, n + 18); R_C = rhs.template loadPacket<Packet>(k, n + 18);
R_D = rhs.loadPacket(k, n + 19); R_D = rhs.template loadPacket<Packet>(k, n + 19);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 20); R_A = rhs.template loadPacket<Packet>(k, n + 20);
R_B = rhs.loadPacket(k, n + 21); R_B = rhs.template loadPacket<Packet>(k, n + 21);
R_C = rhs.loadPacket(k, n + 22); R_C = rhs.template loadPacket<Packet>(k, n + 22);
R_D = rhs.loadPacket(k, n + 23); R_D = rhs.template loadPacket<Packet>(k, n + 23);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 24); R_A = rhs.template loadPacket<Packet>(k, n + 24);
R_B = rhs.loadPacket(k, n + 25); R_B = rhs.template loadPacket<Packet>(k, n + 25);
R_C = rhs.loadPacket(k, n + 26); R_C = rhs.template loadPacket<Packet>(k, n + 26);
R_D = rhs.loadPacket(k, n + 27); R_D = rhs.template loadPacket<Packet>(k, n + 27);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 28); R_A = rhs.template loadPacket<Packet>(k, n + 28);
R_B = rhs.loadPacket(k, n + 29); R_B = rhs.template loadPacket<Packet>(k, n + 29);
R_C = rhs.loadPacket(k, n + 30); R_C = rhs.template loadPacket<Packet>(k, n + 30);
R_D = rhs.loadPacket(k, n + 31); R_D = rhs.template loadPacket<Packet>(k, n + 31);
PACK_STEP; PACK_STEP;
blockB_256 += 24; blockB_256 += 24;
@ -1351,31 +1391,31 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index
for (n = cols_32; n < cols; n += 4) { for (n = cols_32; n < cols; n += 4) {
switch (cols - n) { switch (cols - n) {
case 1: case 1:
R_A = rhs.loadPacket(k, n); R_A = rhs.template loadPacket<Packet>(k, n);
R_B = _mm256_setzero_si256(); R_B = _mm256_setzero_si256();
R_C = _mm256_setzero_si256(); R_C = _mm256_setzero_si256();
R_D = _mm256_setzero_si256(); R_D = _mm256_setzero_si256();
PACK_STEP; PACK_STEP;
break; break;
case 2: case 2:
R_A = rhs.loadPacket(k, n); R_A = rhs.template loadPacket<Packet>(k, n);
R_B = rhs.loadPacket(k, n + 1); R_B = rhs.template loadPacket<Packet>(k, n + 1);
R_C = _mm256_setzero_si256(); R_C = _mm256_setzero_si256();
R_D = _mm256_setzero_si256(); R_D = _mm256_setzero_si256();
PACK_STEP; PACK_STEP;
break; break;
case 3: case 3:
R_A = rhs.loadPacket(k, n); R_A = rhs.template loadPacket<Packet>(k, n);
R_B = rhs.loadPacket(k, n + 1); R_B = rhs.template loadPacket<Packet>(k, n + 1);
R_C = rhs.loadPacket(k, n + 2); R_C = rhs.template loadPacket<Packet>(k, n + 2);
R_D = _mm256_setzero_si256(); R_D = _mm256_setzero_si256();
PACK_STEP; PACK_STEP;
break; break;
default: default:
R_A = rhs.loadPacket(k, n); R_A = rhs.template loadPacket<Packet>(k, n);
R_B = rhs.loadPacket(k, n + 1); R_B = rhs.template loadPacket<Packet>(k, n + 1);
R_C = rhs.loadPacket(k, n + 2); R_C = rhs.template loadPacket<Packet>(k, n + 2);
R_D = rhs.loadPacket(k, n + 3); R_D = rhs.template loadPacket<Packet>(k, n + 3);
PACK_STEP; PACK_STEP;
break; break;
} }
@ -1441,13 +1481,13 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, Index
#undef PACK_STEP #undef PACK_STEP
} }
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
EIGEN_DONT_INLINE bool ConjugateLhs, bool ConjugateRhs>
void gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> EIGEN_DONT_INLINE void gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr,
::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, ConjugateLhs, ConjugateRhs>::
Index rows, Index depth, Index cols, QInt32 alpha, operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
Index strideA, Index strideB, Index offsetA, Index offsetB) Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
{ Index strideB, Index offsetA, Index offsetB) {
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
eigen_assert(alpha.value == 1); eigen_assert(alpha.value == 1);
@ -1678,17 +1718,21 @@ void gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, Con
LinearMapper r1 = res.getLinearMapper(m + 8, j); LinearMapper r1 = res.getLinearMapper(m + 8, j);
LinearMapper r2 = res.getLinearMapper(m + 16, j); LinearMapper r2 = res.getLinearMapper(m + 16, j);
LinearMapper r3 = res.getLinearMapper(m + 24, j); LinearMapper r3 = res.getLinearMapper(m + 24, j);
r0.storePacket( typedef typename packet_traits<QInt32>::type Packet;
0, _mm256_add_epi32(blockO_256[i++], r0.loadPacket(0))); r0.template storePacket<Packet>(
r1.storePacket( 0, _mm256_add_epi32(blockO_256[i++],
0, _mm256_add_epi32(blockO_256[i++], r1.loadPacket(0))); r0.template loadPacket<Packet>(0)));
r2.storePacket( r1.template storePacket<Packet>(
0, _mm256_add_epi32(blockO_256[i++], r2.loadPacket(0))); 0, _mm256_add_epi32(blockO_256[i++],
r3.storePacket( r1.template loadPacket<Packet>(0)));
0, _mm256_add_epi32(blockO_256[i++], r3.loadPacket(0))); r2.template storePacket<Packet>(
0, _mm256_add_epi32(blockO_256[i++],
r2.template loadPacket<Packet>(0)));
r3.template storePacket<Packet>(
0, _mm256_add_epi32(blockO_256[i++],
r3.template loadPacket<Packet>(0)));
} }
} } else {
else {
for (Index j = n; j < cols; j++) { for (Index j = n; j < cols; j++) {
for (Index i = m; i < rows; i++) { for (Index i = m; i < rows; i++) {
res(i, j) = blockO[(j - n) * 32 + (i - m)]; res(i, j) = blockO[(j - n) * 32 + (i - m)];
@ -1745,7 +1789,7 @@ void gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, Con
// madd both perform an adjacent addition in the kernel. // madd both perform an adjacent addition in the kernel.
template <typename Index, typename DataMapper, int Pack1, int Pack2, template <typename Index, typename DataMapper, int Pack1, int Pack2,
bool Conjugate, bool PanelMode> bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, struct gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, QInt8, ColMajor,
Conjugate, PanelMode> { Conjugate, PanelMode> {
EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs, EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs,
Index depth, Index rows, Index stride = 0, Index depth, Index rows, Index stride = 0,
@ -1755,15 +1799,18 @@ struct gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor,
template <typename Index, typename DataMapper, int Pack1, int Pack2, template <typename Index, typename DataMapper, int Pack1, int Pack2,
bool Conjugate, bool PanelMode> bool Conjugate, bool PanelMode>
EIGEN_DONT_INLINE void gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, EIGEN_DONT_INLINE void gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2,
ColMajor, Conjugate, PanelMode>:: QInt8, ColMajor, Conjugate, PanelMode>::
operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows,
Index stride, Index offset) { Index stride, Index offset) {
eigen_assert(stride == 0); eigen_assert(stride == 0);
eigen_assert(offset == 0); eigen_assert(offset == 0);
typedef typename packet_traits<QInt8>::type Packet;
// Use alternate function for weird sizes // Use alternate function for weird sizes
if (rows % 32 != 0 || depth % 32 != 0) { if (rows % 32 != 0 || depth % 32 != 0) {
gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate, PanelMode> lhs_pack; gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor,
Conjugate, PanelMode> lhs_pack;
return lhs_pack(blockA, lhs, depth, rows, stride, offset); return lhs_pack(blockA, lhs, depth, rows, stride, offset);
} }
@ -1775,15 +1822,15 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows,
// Pack depth in sets of 8 // Pack depth in sets of 8
for (Index k = 0; k < depth; k += 8) { for (Index k = 0; k < depth; k += 8) {
// Load vectors // Load vectors
__m256i L_A = lhs.loadPacket(m, k); __m256i L_A = lhs.template loadPacket<Packet>(m, k);
__m256i L_B = lhs.loadPacket(m, k + 1); __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1);
// Interleave 8-bit elements // Interleave 8-bit elements
__m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
__m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
__m256i L_C = lhs.loadPacket(m, k + 2); __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2);
__m256i L_D = lhs.loadPacket(m, k + 3); __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3);
__m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
__m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
@ -1804,12 +1851,12 @@ operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows,
_mm256_store_si256(blockA_256++, L_AD16); _mm256_store_si256(blockA_256++, L_AD16);
__m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
_mm256_store_si256(blockA_256++, L_AD24); _mm256_store_si256(blockA_256++, L_AD24);
__m256i L_E = lhs.loadPacket(m, k + 4); __m256i L_E = lhs.template loadPacket<Packet>(m, k + 4);
__m256i L_F = lhs.loadPacket(m, k + 5); __m256i L_F = lhs.template loadPacket<Packet>(m, k + 5);
__m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
__m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
__m256i L_G = lhs.loadPacket(m, k + 6); __m256i L_G = lhs.template loadPacket<Packet>(m, k + 6);
__m256i L_H = lhs.loadPacket(m, k + 7); __m256i L_H = lhs.template loadPacket<Packet>(m, k + 7);
__m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
__m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
__m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
@ -1868,9 +1915,12 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols,
eigen_assert(stride == 0); eigen_assert(stride == 0);
eigen_assert(offset == 0); eigen_assert(offset == 0);
typedef typename packet_traits<QUInt8>::type Packet;
// Use alternate function for weird sizes // Use alternate function for weird sizes
if (cols % 32 != 0 || depth % 32 != 0) { if (cols % 32 != 0 || depth % 32 != 0) {
gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> rhs_pack; gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
PanelMode> rhs_pack;
return rhs_pack(blockB, rhs, depth, cols, stride, offset); return rhs_pack(blockB, rhs, depth, cols, stride, offset);
} }
@ -1898,52 +1948,52 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols,
for (Index n = 0; n < cols; n += 32) { for (Index n = 0; n < cols; n += 32) {
// Pack depth in sets of 32 // Pack depth in sets of 32
for (Index k = 0; k < depth; k += 32) { for (Index k = 0; k < depth; k += 32) {
__m256i R_A = rhs.loadPacket(k, n); __m256i R_A = rhs.template loadPacket<Packet>(k, n);
__m256i R_B = rhs.loadPacket(k, n + 1); __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1);
__m256i R_C = rhs.loadPacket(k, n + 2); __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2);
__m256i R_D = rhs.loadPacket(k, n + 3); __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 4); R_A = rhs.template loadPacket<Packet>(k, n + 4);
R_B = rhs.loadPacket(k, n + 5); R_B = rhs.template loadPacket<Packet>(k, n + 5);
R_C = rhs.loadPacket(k, n + 6); R_C = rhs.template loadPacket<Packet>(k, n + 6);
R_D = rhs.loadPacket(k, n + 7); R_D = rhs.template loadPacket<Packet>(k, n + 7);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 8); R_A = rhs.template loadPacket<Packet>(k, n + 8);
R_B = rhs.loadPacket(k, n + 9); R_B = rhs.template loadPacket<Packet>(k, n + 9);
R_C = rhs.loadPacket(k, n + 10); R_C = rhs.template loadPacket<Packet>(k, n + 10);
R_D = rhs.loadPacket(k, n + 11); R_D = rhs.template loadPacket<Packet>(k, n + 11);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 12); R_A = rhs.template loadPacket<Packet>(k, n + 12);
R_B = rhs.loadPacket(k, n + 13); R_B = rhs.template loadPacket<Packet>(k, n + 13);
R_C = rhs.loadPacket(k, n + 14); R_C = rhs.template loadPacket<Packet>(k, n + 14);
R_D = rhs.loadPacket(k, n + 15); R_D = rhs.template loadPacket<Packet>(k, n + 15);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 16); R_A = rhs.template loadPacket<Packet>(k, n + 16);
R_B = rhs.loadPacket(k, n + 17); R_B = rhs.template loadPacket<Packet>(k, n + 17);
R_C = rhs.loadPacket(k, n + 18); R_C = rhs.template loadPacket<Packet>(k, n + 18);
R_D = rhs.loadPacket(k, n + 19); R_D = rhs.template loadPacket<Packet>(k, n + 19);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 20); R_A = rhs.template loadPacket<Packet>(k, n + 20);
R_B = rhs.loadPacket(k, n + 21); R_B = rhs.template loadPacket<Packet>(k, n + 21);
R_C = rhs.loadPacket(k, n + 22); R_C = rhs.template loadPacket<Packet>(k, n + 22);
R_D = rhs.loadPacket(k, n + 23); R_D = rhs.template loadPacket<Packet>(k, n + 23);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 24); R_A = rhs.template loadPacket<Packet>(k, n + 24);
R_B = rhs.loadPacket(k, n + 25); R_B = rhs.template loadPacket<Packet>(k, n + 25);
R_C = rhs.loadPacket(k, n + 26); R_C = rhs.template loadPacket<Packet>(k, n + 26);
R_D = rhs.loadPacket(k, n + 27); R_D = rhs.template loadPacket<Packet>(k, n + 27);
PACK_STEP; PACK_STEP;
R_A = rhs.loadPacket(k, n + 28); R_A = rhs.template loadPacket<Packet>(k, n + 28);
R_B = rhs.loadPacket(k, n + 29); R_B = rhs.template loadPacket<Packet>(k, n + 29);
R_C = rhs.loadPacket(k, n + 30); R_C = rhs.template loadPacket<Packet>(k, n + 30);
R_D = rhs.loadPacket(k, n + 31); R_D = rhs.template loadPacket<Packet>(k, n + 31);
PACK_STEP; PACK_STEP;
blockB_256 += 24; blockB_256 += 24;
@ -1953,24 +2003,26 @@ operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols,
} }
// Perform the actual multiplication on packed inputs // Perform the actual multiplication on packed inputs
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> bool ConjugateLhs, bool ConjugateRhs>
{ struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
ConjugateRhs> {
typedef typename DataMapper::LinearMapper LinearMapper; typedef typename DataMapper::LinearMapper LinearMapper;
EIGEN_DONT_INLINE EIGEN_DONT_INLINE
void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, void operator()(const DataMapper& res, const QInt8* blockA,
Index rows, Index depth, Index cols, QInt32 alpha, const QUInt8* blockB, Index rows, Index depth, Index cols,
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); QInt32 alpha, Index strideA = -1, Index strideB = -1,
Index offsetA = 0, Index offsetB = 0);
}; };
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
EIGEN_DONT_INLINE bool ConjugateLhs, bool ConjugateRhs>
void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> EIGEN_DONT_INLINE void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr,
::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, ConjugateLhs, ConjugateRhs>::
Index rows, Index depth, Index cols, QInt32 alpha, operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
Index strideA, Index strideB, Index offsetA, Index offsetB) Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
{ Index strideB, Index offsetA, Index offsetB) {
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
eigen_assert(alpha.value == 1); eigen_assert(alpha.value == 1);
@ -1986,8 +2038,10 @@ void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, Conjuga
// Use alternate function for weird sizes // Use alternate function for weird sizes
if (rows % 32 != 0 || cols % 32 != 0 || depth % 32 != 0) { if (rows % 32 != 0 || cols % 32 != 0 || depth % 32 != 0) {
gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> gebp; gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); ConjugateRhs> gebp;
return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB,
offsetA, offsetB);
} }
// Create result block // Create result block
@ -2205,14 +2259,19 @@ void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, Conjuga
LinearMapper r1 = res.getLinearMapper(m + 8, j); LinearMapper r1 = res.getLinearMapper(m + 8, j);
LinearMapper r2 = res.getLinearMapper(m + 16, j); LinearMapper r2 = res.getLinearMapper(m + 16, j);
LinearMapper r3 = res.getLinearMapper(m + 24, j); LinearMapper r3 = res.getLinearMapper(m + 24, j);
r0.storePacket( typedef typename packet_traits<QInt32>::type Packet;
0, _mm256_add_epi32(blockO_256[i++], r0.loadPacket(0))); r0.template storePacket<Packet>(
r1.storePacket( 0, _mm256_add_epi32(blockO_256[i++],
0, _mm256_add_epi32(blockO_256[i++], r1.loadPacket(0))); r0.template loadPacket<Packet>(0)));
r2.storePacket( r1.template storePacket<Packet>(
0, _mm256_add_epi32(blockO_256[i++], r2.loadPacket(0))); 0, _mm256_add_epi32(blockO_256[i++],
r3.storePacket( r1.template loadPacket<Packet>(0)));
0, _mm256_add_epi32(blockO_256[i++], r3.loadPacket(0))); r2.template storePacket<Packet>(
0, _mm256_add_epi32(blockO_256[i++],
r2.template loadPacket<Packet>(0)));
r3.template storePacket<Packet>(
0, _mm256_add_epi32(blockO_256[i++],
r3.template loadPacket<Packet>(0)));
} }
// Zero the result block so it can be reused // Zero the result block so it can be reused

View File

@ -14,14 +14,13 @@
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
// AVX2 optimized implementation of the case where the lhs is encoded using
// AVX2 optimized implementation of the case where the lhs is encoded using signed 8bit // signed 8bit
// integers and the rhs using unsigned 8bit integers. // integers and the rhs using unsigned 8bit integers.
#ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT #ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
template <bool _ConjLhs, bool _ConjRhs> template <bool _ConjLhs, bool _ConjRhs>
class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> {
{
public: public:
typedef QInt8 LhsScalar; typedef QInt8 LhsScalar;
typedef QUInt8 RhsScalar; typedef QUInt8 RhsScalar;
@ -40,22 +39,24 @@ public:
}; };
// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs // Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> bool ConjugateLhs, bool ConjugateRhs>
{ struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
ConjugateRhs> {
EIGEN_DONT_INLINE EIGEN_DONT_INLINE
void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, void operator()(const DataMapper& res, const QInt8* blockA,
Index rows, Index depth, Index cols, QInt32 alpha, const QUInt8* blockB, Index rows, Index depth, Index cols,
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); QInt32 alpha, Index strideA = -1, Index strideB = -1,
Index offsetA = 0, Index offsetB = 0);
}; };
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template <typename Index, typename DataMapper, int mr, int nr,
EIGEN_DONT_INLINE bool ConjugateLhs, bool ConjugateRhs>
void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> EIGEN_DONT_INLINE void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr,
::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, ConjugateLhs, ConjugateRhs>::
Index rows, Index depth, Index cols, QInt32 alpha, operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
Index strideA, Index strideB, Index offsetA, Index offsetB) Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
{ Index strideB, Index offsetA, Index offsetB) {
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
@ -85,7 +86,6 @@ void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, Conjuga
} }
#endif #endif
} // namespace internal } // namespace internal
} // namespace Eigen } // namespace Eigen

View File

@ -15,25 +15,23 @@ namespace internal {
// Mat-Vec product // Mat-Vec product
// Both lhs and rhs are encoded as 8bit signed integers // Both lhs and rhs are encoded as 8bit signed integers
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> template <typename Index, typename LhsMapper, bool ConjugateLhs,
struct general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version> typename RhsMapper, bool ConjugateRhs, int Version>
{ struct general_matrix_vector_product<Index, QInt8, LhsMapper, ColMajor,
EIGEN_DONT_INLINE static void run( ConjugateLhs, QInt8, RhsMapper,
Index rows, Index cols, ConjugateRhs, Version> {
const LhsMapper& lhs, EIGEN_DONT_INLINE static void run(Index rows, Index cols,
const RhsMapper& rhs, const LhsMapper& lhs, const RhsMapper& rhs,
QInt32* res, Index resIncr, QInt32* res, Index resIncr, QInt8 alpha);
QInt8 alpha);
}; };
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> template <typename Index, typename LhsMapper, bool ConjugateLhs,
EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>::run( typename RhsMapper, bool ConjugateRhs, int Version>
Index rows, Index cols, EIGEN_DONT_INLINE void general_matrix_vector_product<
const LhsMapper& lhs, Index, QInt8, LhsMapper, ColMajor, ConjugateLhs, QInt8, RhsMapper,
const RhsMapper& rhs, ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs,
QInt32* res, Index resIncr, const RhsMapper& rhs, QInt32* res,
QInt8 alpha) Index resIncr, QInt8 alpha) {
{
eigen_assert(alpha.value == 1); eigen_assert(alpha.value == 1);
eigen_assert(resIncr == 1); eigen_assert(resIncr == 1);
eigen_assert(rows > 0); eigen_assert(rows > 0);
@ -78,26 +76,25 @@ EIGEN_DONT_INLINE void general_matrix_vector_product<
} }
// Mat-Vec product // Mat-Vec product
// The lhs is encoded using 8bit signed integers, the rhs using 8bit unsigned integers // The lhs is encoded using 8bit signed integers, the rhs using 8bit unsigned
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> // integers
struct general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QUInt8,RhsMapper,ConjugateRhs,Version> template <typename Index, typename LhsMapper, bool ConjugateLhs,
{ typename RhsMapper, bool ConjugateRhs, int Version>
EIGEN_DONT_INLINE static void run( struct general_matrix_vector_product<Index, QInt8, LhsMapper, ColMajor,
Index rows, Index cols, ConjugateLhs, QUInt8, RhsMapper,
const LhsMapper& lhs, ConjugateRhs, Version> {
const RhsMapper& rhs, EIGEN_DONT_INLINE static void run(Index rows, Index cols,
QInt32* res, Index resIncr, const LhsMapper& lhs, const RhsMapper& rhs,
QUInt8 alpha); QInt32* res, Index resIncr, QUInt8 alpha);
}; };
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> template <typename Index, typename LhsMapper, bool ConjugateLhs,
EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QUInt8,RhsMapper,ConjugateRhs,Version>::run( typename RhsMapper, bool ConjugateRhs, int Version>
Index rows, Index cols, EIGEN_DONT_INLINE void general_matrix_vector_product<
const LhsMapper& lhs, Index, QInt8, LhsMapper, ColMajor, ConjugateLhs, QUInt8, RhsMapper,
const RhsMapper& rhs, ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs,
QInt32* res, Index resIncr, const RhsMapper& rhs, QInt32* res,
QUInt8 alpha) Index resIncr, QUInt8 alpha) {
{
eigen_assert(alpha.value == 1); eigen_assert(alpha.value == 1);
eigen_assert(resIncr == 1); eigen_assert(resIncr == 1);
eigen_assert(rows > 0); eigen_assert(rows > 0);
@ -110,28 +107,26 @@ EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QInt8,LhsMapper,ColMa
} }
} }
// Mat-Vec product // Mat-Vec product
// The lhs is encoded using bit unsigned integers, the rhs using 8bit signed integers // The lhs is encoded using bit unsigned integers, the rhs using 8bit signed
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> // integers
struct general_matrix_vector_product<Index,QUInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version> template <typename Index, typename LhsMapper, bool ConjugateLhs,
{ typename RhsMapper, bool ConjugateRhs, int Version>
EIGEN_DONT_INLINE static void run( struct general_matrix_vector_product<Index, QUInt8, LhsMapper, ColMajor,
Index rows, Index cols, ConjugateLhs, QInt8, RhsMapper,
const LhsMapper& lhs, ConjugateRhs, Version> {
const RhsMapper& rhs, EIGEN_DONT_INLINE static void run(Index rows, Index cols,
QInt32* res, Index resIncr, const LhsMapper& lhs, const RhsMapper& rhs,
QInt8 alpha); QInt32* res, Index resIncr, QInt8 alpha);
}; };
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> template <typename Index, typename LhsMapper, bool ConjugateLhs,
EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QUInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>::run( typename RhsMapper, bool ConjugateRhs, int Version>
Index rows, Index cols, EIGEN_DONT_INLINE void general_matrix_vector_product<
const LhsMapper& lhs, Index, QUInt8, LhsMapper, ColMajor, ConjugateLhs, QInt8, RhsMapper,
const RhsMapper& rhs, ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs,
QInt32* res, Index resIncr, const RhsMapper& rhs, QInt32* res,
QInt8 alpha) Index resIncr, QInt8 alpha) {
{
eigen_assert(alpha.value == 1); eigen_assert(alpha.value == 1);
eigen_assert(resIncr == 1); eigen_assert(resIncr == 1);
eigen_assert(rows > 0); eigen_assert(rows > 0);

View File

@ -8,23 +8,19 @@
#endif #endif
inline int _mm256_extract_epi16_N0(const __m256i X) inline int _mm256_extract_epi16_N0(const __m256i X) {
{
return _mm_extract_epi16(_mm256_extractf128_si256(X, 0 >> 3), 0 % 8); return _mm_extract_epi16(_mm256_extractf128_si256(X, 0 >> 3), 0 % 8);
} }
inline int _mm256_extract_epi16_N1(const __m256i X) inline int _mm256_extract_epi16_N1(const __m256i X) {
{
return _mm_extract_epi16(_mm256_extractf128_si256(X, 1 >> 3), 1 % 8); return _mm_extract_epi16(_mm256_extractf128_si256(X, 1 >> 3), 1 % 8);
} }
inline int _mm256_extract_epi8_N0(const __m256i X) inline int _mm256_extract_epi8_N0(const __m256i X) {
{
return _mm_extract_epi8(_mm256_extractf128_si256((X), 0 >> 4), 0 % 16); return _mm_extract_epi8(_mm256_extractf128_si256((X), 0 >> 4), 0 % 16);
} }
inline int _mm256_extract_epi8_N1(const __m256i X) inline int _mm256_extract_epi8_N1(const __m256i X) {
{
return _mm_extract_epi8(_mm256_extractf128_si256((X), 1 >> 4), 1 % 16); return _mm_extract_epi8(_mm256_extractf128_si256((X), 1 >> 4), 1 % 16);
} }
@ -34,56 +30,56 @@ namespace internal {
typedef struct Packet32q8i { typedef struct Packet32q8i {
__m256i val; __m256i val;
operator __m256i() const { return val; } operator __m256i() const { return val; }
Packet32q8i(); Packet32q8i() : val(_mm256_setzero_si256()){};
Packet32q8i(__m256i val) : val(val) {} Packet32q8i(__m256i val) : val(val) {}
} Packet32q8i; } Packet32q8i;
typedef struct Packet16q16i { typedef struct Packet16q16i {
__m256i val; __m256i val;
operator __m256i() const { return val; } operator __m256i() const { return val; }
Packet16q16i(); Packet16q16i() : val(_mm256_setzero_si256()){};
Packet16q16i(__m256i val) : val(val) {} Packet16q16i(__m256i val) : val(val) {}
} Packet16q16i; } Packet16q16i;
typedef struct Packet32q8u { typedef struct Packet32q8u {
__m256i val; __m256i val;
operator __m256i() const { return val; } operator __m256i() const { return val; }
Packet32q8u(); Packet32q8u() : val(_mm256_setzero_si256()){};
Packet32q8u(__m256i val) : val(val) {} Packet32q8u(__m256i val) : val(val) {}
} Packet32q8u; } Packet32q8u;
typedef struct Packet16q8i { typedef struct Packet16q8i {
__m128i val; __m128i val;
operator __m128i() const { return val; } operator __m128i() const { return val; }
Packet16q8i(); Packet16q8i() : val(_mm_setzero_si128()) {}
Packet16q8i(__m128i val) : val(val) {} Packet16q8i(__m128i val) : val(val) {}
} Packet16q8i; } Packet16q8i;
typedef struct Packet16q8u { typedef struct Packet16q8u {
__m128i val; __m128i val;
operator __m128i() const { return val; } operator __m128i() const { return val; }
Packet16q8u(); Packet16q8u() : val(_mm_setzero_si128()) {}
Packet16q8u(__m128i val) : val(val) {} Packet16q8u(__m128i val) : val(val) {}
} Packet16q8u; } Packet16q8u;
typedef struct Packet8q16i { typedef struct Packet8q16i {
__m128i val; __m128i val;
operator __m128i() const { return val; } operator __m128i() const { return val; }
Packet8q16i(); Packet8q16i() : val(_mm_setzero_si128()) {}
Packet8q16i(__m128i val) : val(val) {} Packet8q16i(__m128i val) : val(val) {}
} Packet8q16i; } Packet8q16i;
typedef struct Packet8q32i { typedef struct Packet8q32i {
__m256i val; __m256i val;
operator __m256i() const { return val; } operator __m256i() const { return val; }
Packet8q32i(); Packet8q32i() : val(_mm256_setzero_si256()){};
Packet8q32i(__m256i val) : val(val) {} Packet8q32i(__m256i val) : val(val) {}
} Packet8q32i; } Packet8q32i;
typedef struct Packet4q32i { typedef struct Packet4q32i {
__m128i val; __m128i val;
operator __m128i() const { return val; } operator __m128i() const { return val; }
Packet4q32i(); Packet4q32i() : val(_mm_setzero_si128()) {}
Packet4q32i(__m128i val) : val(val) {} Packet4q32i(__m128i val) : val(val) {}
} Packet4q32i; } Packet4q32i;
@ -455,17 +451,21 @@ EIGEN_STRONG_INLINE QUInt8 predux_max<Packet32q8u>(const Packet32q8u& a) {
template <> template <>
EIGEN_STRONG_INLINE QInt8 predux_min<Packet32q8i>(const Packet32q8i& a) { EIGEN_STRONG_INLINE QInt8 predux_min<Packet32q8i>(const Packet32q8i& a) {
__m256i tmp = _mm256_min_epi8(a, _mm256_permute2f128_si256(a, a, 1)); __m256i tmp = _mm256_min_epi8(a, _mm256_permute2f128_si256(a, a, 1));
tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); tmp =
_mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, 1)); tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, 1));
tmp = _mm256_min_epi8(tmp, _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); tmp = _mm256_min_epi8(tmp,
_mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
return std::min(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp)); return std::min(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp));
} }
template <> template <>
EIGEN_STRONG_INLINE QInt8 predux_max<Packet32q8i>(const Packet32q8i& a) { EIGEN_STRONG_INLINE QInt8 predux_max<Packet32q8i>(const Packet32q8i& a) {
__m256i tmp = _mm256_max_epi8(a, _mm256_permute2f128_si256(a, a, 1)); __m256i tmp = _mm256_max_epi8(a, _mm256_permute2f128_si256(a, a, 1));
tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); tmp =
_mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, 1)); tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, 1));
tmp = _mm256_max_epi8(tmp, _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); tmp = _mm256_max_epi8(tmp,
_mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
return std::max(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp)); return std::max(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp));
} }
@ -476,19 +476,22 @@ struct scalar_product_op<QInt32, double> : binary_op_base<QInt32, double> {
#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN #ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op)
#else #else
scalar_product_op() { scalar_product_op() { EIGEN_SCALAR_BINARY_OP_PLUGIN }
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
#endif #endif
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const QInt32& a, const double& b) const { return a * b; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type
operator()(const QInt32& a, const double& b) const {
return a * b;
}
EIGEN_STRONG_INLINE const Packet8q32i packetOp(const Packet8q32i& a, const double& b) const { EIGEN_STRONG_INLINE const Packet8q32i packetOp(const Packet8q32i& a,
const double& b) const {
__m256d scale = _mm256_set1_pd(b); __m256d scale = _mm256_set1_pd(b);
__m256d a_lo = _mm256_cvtepi32_pd(_mm256_castsi256_si128(a)); __m256d a_lo = _mm256_cvtepi32_pd(_mm256_castsi256_si128(a));
__m128i result_lo = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_lo)); __m128i result_lo = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_lo));
__m256d a_hi = _mm256_cvtepi32_pd(_mm256_extracti128_si256(a, 1)); __m256d a_hi = _mm256_cvtepi32_pd(_mm256_extracti128_si256(a, 1));
__m128i result_hi = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_hi)); __m128i result_hi = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_hi));
return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi,
1);
} }
}; };

View File

@ -410,9 +410,7 @@ EIGEN_STRONG_INLINE QInt32 predux_min<Packet16q32i>(const Packet16q32i& a) {
_mm_min_epi32(_mm_min_epi32(lane0, lane1), _mm_min_epi32(lane2, lane3)); _mm_min_epi32(_mm_min_epi32(lane0, lane1), _mm_min_epi32(lane2, lane3));
res = _mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); res = _mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst( return pfirst(
_mm_min_epi32( _mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
res,
_mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
} }
template <> template <>
EIGEN_STRONG_INLINE QInt32 predux_max<Packet16q32i>(const Packet16q32i& a) { EIGEN_STRONG_INLINE QInt32 predux_max<Packet16q32i>(const Packet16q32i& a) {
@ -424,9 +422,7 @@ EIGEN_STRONG_INLINE QInt32 predux_max<Packet16q32i>(const Packet16q32i& a) {
_mm_max_epi32(_mm_max_epi32(lane0, lane1), _mm_max_epi32(lane2, lane3)); _mm_max_epi32(_mm_max_epi32(lane0, lane1), _mm_max_epi32(lane2, lane3));
res = _mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); res = _mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst( return pfirst(
_mm_max_epi32( _mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
res,
_mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
} }
template <> template <>
EIGEN_STRONG_INLINE QInt16 predux_min<Packet32q16i>(const Packet32q16i& a) { EIGEN_STRONG_INLINE QInt16 predux_min<Packet32q16i>(const Packet32q16i& a) {
@ -437,13 +433,10 @@ EIGEN_STRONG_INLINE QInt16 predux_min<Packet32q16i>(const Packet32q16i& a) {
Packet4i res = Packet4i res =
_mm_min_epi16(_mm_min_epi16(lane0, lane1), _mm_min_epi16(lane2, lane3)); _mm_min_epi16(_mm_min_epi16(lane0, lane1), _mm_min_epi16(lane2, lane3));
res = _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); res = _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
std::uint32_t w = std::uint32_t w = pfirst(
pfirst(
_mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
return std::min({ return std::min(
static_cast<std::int16_t>(w >> 16), {static_cast<std::int16_t>(w >> 16), static_cast<std::int16_t>(w)});
static_cast<std::int16_t>(w)
});
} }
template <> template <>
EIGEN_STRONG_INLINE QInt16 predux_max<Packet32q16i>(const Packet32q16i& a) { EIGEN_STRONG_INLINE QInt16 predux_max<Packet32q16i>(const Packet32q16i& a) {
@ -454,13 +447,10 @@ EIGEN_STRONG_INLINE QInt16 predux_max<Packet32q16i>(const Packet32q16i& a) {
Packet4i res = Packet4i res =
_mm_max_epi16(_mm_max_epi16(lane0, lane1), _mm_max_epi16(lane2, lane3)); _mm_max_epi16(_mm_max_epi16(lane0, lane1), _mm_max_epi16(lane2, lane3));
res = _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); res = _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
std::uint32_t w = std::uint32_t w = pfirst(
pfirst(
_mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
return std::max({ return std::max(
static_cast<std::int16_t>(w >> 16), {static_cast<std::int16_t>(w >> 16), static_cast<std::int16_t>(w)});
static_cast<std::int16_t>(w)
});
} }
template <> template <>
EIGEN_STRONG_INLINE QUInt8 predux_min<Packet64q8u>(const Packet64q8u& a) { EIGEN_STRONG_INLINE QUInt8 predux_min<Packet64q8u>(const Packet64q8u& a) {
@ -471,15 +461,11 @@ EIGEN_STRONG_INLINE QUInt8 predux_min<Packet64q8u>(const Packet64q8u& a) {
Packet4i res = Packet4i res =
_mm_min_epu8(_mm_min_epu8(lane0, lane1), _mm_min_epu8(lane2, lane3)); _mm_min_epu8(_mm_min_epu8(lane0, lane1), _mm_min_epu8(lane2, lane3));
res = _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); res = _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
std::uint32_t w = std::uint32_t w = pfirst(
pfirst(
_mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
return std::min({ return std::min(
static_cast<std::uint8_t>(w >> 24), {static_cast<std::uint8_t>(w >> 24), static_cast<std::uint8_t>(w >> 16),
static_cast<std::uint8_t>(w >> 16), static_cast<std::uint8_t>(w >> 8), static_cast<std::uint8_t>(w)});
static_cast<std::uint8_t>(w >> 8),
static_cast<std::uint8_t>(w)
});
} }
template <> template <>
EIGEN_STRONG_INLINE QUInt8 predux_max<Packet64q8u>(const Packet64q8u& a) { EIGEN_STRONG_INLINE QUInt8 predux_max<Packet64q8u>(const Packet64q8u& a) {
@ -490,15 +476,11 @@ EIGEN_STRONG_INLINE QUInt8 predux_max<Packet64q8u>(const Packet64q8u& a) {
Packet4i res = Packet4i res =
_mm_max_epu8(_mm_max_epu8(lane0, lane1), _mm_max_epu8(lane2, lane3)); _mm_max_epu8(_mm_max_epu8(lane0, lane1), _mm_max_epu8(lane2, lane3));
res = _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); res = _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
std::uint32_t w = std::uint32_t w = pfirst(
pfirst(
_mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
return std::max({ return std::max(
static_cast<std::uint8_t>(w >> 24), {static_cast<std::uint8_t>(w >> 24), static_cast<std::uint8_t>(w >> 16),
static_cast<std::uint8_t>(w >> 16), static_cast<std::uint8_t>(w >> 8), static_cast<std::uint8_t>(w)});
static_cast<std::uint8_t>(w >> 8),
static_cast<std::uint8_t>(w)
});
} }
template <> template <>
EIGEN_STRONG_INLINE QInt8 predux_min<Packet64q8i>(const Packet64q8i& a) { EIGEN_STRONG_INLINE QInt8 predux_min<Packet64q8i>(const Packet64q8i& a) {
@ -509,15 +491,11 @@ EIGEN_STRONG_INLINE QInt8 predux_min<Packet64q8i>(const Packet64q8i& a) {
Packet4i res = Packet4i res =
_mm_min_epi8(_mm_min_epi8(lane0, lane1), _mm_min_epi8(lane2, lane3)); _mm_min_epi8(_mm_min_epi8(lane0, lane1), _mm_min_epi8(lane2, lane3));
res = _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); res = _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
std::uint32_t w = std::uint32_t w = pfirst(
pfirst(
_mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
return std::min({ return std::min(
static_cast<std::int8_t>(w >> 24), {static_cast<std::int8_t>(w >> 24), static_cast<std::int8_t>(w >> 16),
static_cast<std::int8_t>(w >> 16), static_cast<std::int8_t>(w >> 8), static_cast<std::int8_t>(w)});
static_cast<std::int8_t>(w >> 8),
static_cast<std::int8_t>(w)
});
} }
template <> template <>
EIGEN_STRONG_INLINE QInt8 predux_max<Packet64q8i>(const Packet64q8i& a) { EIGEN_STRONG_INLINE QInt8 predux_max<Packet64q8i>(const Packet64q8i& a) {
@ -528,15 +506,11 @@ EIGEN_STRONG_INLINE QInt8 predux_max<Packet64q8i>(const Packet64q8i& a) {
Packet4i res = Packet4i res =
_mm_max_epi8(_mm_max_epi8(lane0, lane1), _mm_max_epi8(lane2, lane3)); _mm_max_epi8(_mm_max_epi8(lane0, lane1), _mm_max_epi8(lane2, lane3));
res = _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); res = _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
std::uint32_t w = std::uint32_t w = pfirst(
pfirst(
_mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1)))); _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
return std::min({ return std::min(
static_cast<std::int8_t>(w >> 24), {static_cast<std::int8_t>(w >> 24), static_cast<std::int8_t>(w >> 16),
static_cast<std::int8_t>(w >> 16), static_cast<std::int8_t>(w >> 8), static_cast<std::int8_t>(w)});
static_cast<std::int8_t>(w >> 8),
static_cast<std::int8_t>(w)
});
} }
} // end namespace internal } // end namespace internal

View File

@ -33,28 +33,23 @@ struct type_casting_traits<float, QInt16> {
}; };
template <> template <>
EIGEN_STRONG_INLINE Packet32q16i EIGEN_STRONG_INLINE Packet32q16i pcast<Packet16f>(const Packet16f& a,
pcast<Packet16f>(const Packet16f& a, const Packet16f& b) { const Packet16f& b) {
Packet16i a_int = _mm512_cvtps_epi32(a); Packet16i a_int = _mm512_cvtps_epi32(a);
Packet16i b_int = _mm512_cvtps_epi32(b); Packet16i b_int = _mm512_cvtps_epi32(b);
#ifdef EIGEN_VECTORIZE_AVX512BW #ifdef EIGEN_VECTORIZE_AVX512BW
return _mm512_packs_epi32(a_int, b_int); return _mm512_packs_epi32(a_int, b_int);
#else #else
Packet8i ab_int16_low = Packet8i ab_int16_low = _mm256_permute4x64_epi64(
_mm256_permute4x64_epi64( _mm256_packs_epi32(_mm512_castsi512_si256(a_int),
_mm256_packs_epi32(
_mm512_castsi512_si256(a_int),
_mm512_castsi512_si256(b_int)), _mm512_castsi512_si256(b_int)),
_MM_SHUFFLE(0, 2, 1, 3)); _MM_SHUFFLE(0, 2, 1, 3));
Packet8i ab_int16_high = Packet8i ab_int16_high = _mm256_permute4x64_epi64(
_mm256_permute4x64_epi64( _mm256_packs_epi32(_mm512_extracti32x8_epi32(a_int, 1),
_mm256_packs_epi32(
_mm512_extracti32x8_epi32(a_int, 1),
_mm512_extracti32x8_epi32(b_int, 1)), _mm512_extracti32x8_epi32(b_int, 1)),
_MM_SHUFFLE(0, 2, 1, 3)); _MM_SHUFFLE(0, 2, 1, 3));
return _mm512_inserti32x8( return _mm512_inserti32x8(_mm512_castsi256_si512(ab_int16_low), ab_int16_high,
_mm512_castsi256_si512(ab_int16_low), 1);
ab_int16_high, 1);
#endif #endif
} }
@ -64,8 +59,7 @@ struct type_casting_traits<float, QInt8> {
}; };
template <> template <>
EIGEN_STRONG_INLINE Packet64q8i EIGEN_STRONG_INLINE Packet64q8i pcast<Packet16f>(const Packet16f& a,
pcast<Packet16f>(const Packet16f& a,
const Packet16f& b, const Packet16f& b,
const Packet16f& c, const Packet16f& c,
const Packet16f& d) { const Packet16f& d) {
@ -74,44 +68,31 @@ pcast<Packet16f>(const Packet16f& a,
Packet16i c_int = _mm512_cvtps_epi32(c); Packet16i c_int = _mm512_cvtps_epi32(c);
Packet16i d_int = _mm512_cvtps_epi32(d); Packet16i d_int = _mm512_cvtps_epi32(d);
#ifdef EIGEN_VECTORIZE_AVX512BW #ifdef EIGEN_VECTORIZE_AVX512BW
return _mm512_packs_epi16( return _mm512_packs_epi16(_mm512_packs_epi32(a_int, b_int),
_mm512_packs_epi32(a_int, b_int),
_mm512_packs_epi32(c_int, d_int)); _mm512_packs_epi32(c_int, d_int));
#else #else
Packet8i ab_int16_low = Packet8i ab_int16_low = _mm256_permute4x64_epi64(
_mm256_permute4x64_epi64( _mm256_packs_epi32(_mm512_castsi512_si256(a_int),
_mm256_packs_epi32(
_mm512_castsi512_si256(a_int),
_mm512_castsi512_si256(b_int)), _mm512_castsi512_si256(b_int)),
_MM_SHUFFLE(0, 2, 1, 3)); _MM_SHUFFLE(0, 2, 1, 3));
Packet8i cd_int16_low = Packet8i cd_int16_low = _mm256_permute4x64_epi64(
_mm256_permute4x64_epi64( _mm256_packs_epi32(_mm512_castsi512_si256(c_int),
_mm256_packs_epi32(
_mm512_castsi512_si256(c_int),
_mm512_castsi512_si256(d_int)), _mm512_castsi512_si256(d_int)),
_MM_SHUFFLE(0, 2, 1, 3)); _MM_SHUFFLE(0, 2, 1, 3));
Packet8i ab_int16_high = Packet8i ab_int16_high = _mm256_permute4x64_epi64(
_mm256_permute4x64_epi64( _mm256_packs_epi32(_mm512_extracti32x8_epi32(a_int, 1),
_mm256_packs_epi32(
_mm512_extracti32x8_epi32(a_int, 1),
_mm512_extracti32x8_epi32(b_int, 1)), _mm512_extracti32x8_epi32(b_int, 1)),
_MM_SHUFFLE(0, 2, 1, 3)); _MM_SHUFFLE(0, 2, 1, 3));
Packet8i cd_int16_high = Packet8i cd_int16_high = _mm256_permute4x64_epi64(
_mm256_permute4x64_epi64( _mm256_packs_epi32(_mm512_extracti32x8_epi32(c_int, 1),
_mm256_packs_epi32(
_mm512_extracti32x8_epi32(c_int, 1),
_mm512_extracti32x8_epi32(d_int, 1)), _mm512_extracti32x8_epi32(d_int, 1)),
_MM_SHUFFLE(0, 2, 1, 3)); _MM_SHUFFLE(0, 2, 1, 3));
Packet8i abcd_int8_low = Packet8i abcd_int8_low = _mm256_permute4x64_epi64(
_mm256_permute4x64_epi64( _mm256_packs_epi16(ab_int16_low, cd_int16_low), _MM_SHUFFLE(0, 2, 1, 3));
_mm256_packs_epi16(ab_int16_low, cd_int16_low),
_MM_SHUFFLE(0, 2, 1, 3));
Packet8i abcd_int8_high = Packet8i abcd_int8_high =
_mm256_permute4x64_epi64( _mm256_permute4x64_epi64(_mm256_packs_epi16(ab_int16_high, cd_int16_high),
_mm256_packs_epi16(ab_int16_high, cd_int16_high),
_MM_SHUFFLE(0, 2, 1, 3)); _MM_SHUFFLE(0, 2, 1, 3));
return _mm512_inserti32x8( return _mm512_inserti32x8(_mm512_castsi256_si512(abcd_int8_low),
_mm512_castsi256_si512(abcd_int8_low),
abcd_int8_high, 1); abcd_int8_high, 1);
#endif #endif
} }
@ -128,10 +109,8 @@ struct type_casting_traits<QInt32, QInt16> {
template <> template <>
EIGEN_STRONG_INLINE Packet64q8i EIGEN_STRONG_INLINE Packet64q8i
pcast<Packet16q32i, Packet64q8i>(const Packet16q32i& a, pcast<Packet16q32i, Packet64q8i>(const Packet16q32i& a, const Packet16q32i& b,
const Packet16q32i& b, const Packet16q32i& c, const Packet16q32i& d) {
const Packet16q32i& c,
const Packet16q32i& d) {
__m128i a_part = _mm512_cvtsepi32_epi8(a); __m128i a_part = _mm512_cvtsepi32_epi8(a);
__m128i b_part = _mm512_cvtsepi32_epi8(b); __m128i b_part = _mm512_cvtsepi32_epi8(b);
__m128i c_part = _mm512_cvtsepi32_epi8(c); __m128i c_part = _mm512_cvtsepi32_epi8(c);
@ -145,9 +124,8 @@ pcast<Packet16q32i, Packet64q8i>(const Packet16q32i& a,
} }
template <> template <>
EIGEN_STRONG_INLINE Packet32q16i EIGEN_STRONG_INLINE Packet32q16i pcast<Packet16q32i, Packet32q16i>(
pcast<Packet16q32i, Packet32q16i>(const Packet16q32i& a, const Packet16q32i& a, const Packet16q32i& b) {
const Packet16q32i& b) {
__m256i a_part = _mm512_cvtsepi32_epi16(a); __m256i a_part = _mm512_cvtsepi32_epi16(a);
__m256i b_part = _mm512_cvtsepi32_epi16(b); __m256i b_part = _mm512_cvtsepi32_epi16(b);
__m512i converted = __m512i converted =

48
third_party/eigen_reshaped.patch vendored Normal file
View File

@ -0,0 +1,48 @@
--- a/Eigen/src/Core/util/ReshapedHelper.h (date 1541195478000)
+++ b/Eigen/src/Core/util/ReshapedHelper.h (date 1541195478000)
@@ -39,6 +39,11 @@
return total/other;
}
+template<int Flags, int Order>
+struct get_compiletime_reshape_order {
+ enum { value = Order == AutoOrder ? Flags & RowMajorBit : Order };
+};
+
}
} // end namespace Eigen
--- a/Eigen/src/plugins/ReshapedMethods.h (date 1541195254000)
+++ b/Eigen/src/plugins/ReshapedMethods.h (date 1541195254000)
@@ -105,13 +105,13 @@
inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
- (Order==AutoOrder?Flags&RowMajorBit:Order)>
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST
{
return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
- (Order==AutoOrder?Flags&RowMajorBit:Order)>
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
(derived(),
internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
@@ -128,11 +128,13 @@
template<int Order>
EIGEN_DEVICE_FUNC
-inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, SizeAtCompileTime, 1, (Order==AutoOrder?Flags&RowMajorBit:Order)>
+inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, SizeAtCompileTime, 1,
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
reshaped() EIGEN_RESHAPED_METHOD_CONST
{
EIGEN_STATIC_ASSERT(Order==RowMajor || Order==ColMajor || Order==AutoOrder, INVALID_TEMPLATE_PARAMETER);
- return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, SizeAtCompileTime, 1, (Order==AutoOrder?Flags&RowMajorBit:Order)>
+ return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, SizeAtCompileTime, 1,
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
(derived(), size(), 1);
}