Upgrade the version of Eigen to commit b4890dc6bc34.
PiperOrigin-RevId: 220359861
This commit is contained in:
parent
bfb4bda0ff
commit
cf02d61a83
@ -141,7 +141,7 @@ __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev,
|
|||||||
//
|
//
|
||||||
const int gid = batch_id * cell_size * 4 + act_id;
|
const int gid = batch_id * cell_size * 4 + act_id;
|
||||||
const int cid = batch_id * cell_size + act_id;
|
const int cid = batch_id * cell_size + act_id;
|
||||||
Eigen::internal::scalar_sigmoid_op<T> sigmoid_op;
|
Eigen::internal::scalar_logistic_op<T> sigmoid_op;
|
||||||
Eigen::internal::scalar_tanh_op<T> tanh_op;
|
Eigen::internal::scalar_tanh_op<T> tanh_op;
|
||||||
Eigen::scalar_clip_op<T> clip_op;
|
Eigen::scalar_clip_op<T> clip_op;
|
||||||
|
|
||||||
|
@ -84,13 +84,13 @@ namespace tensorflow {
|
|||||||
// corresponding stream have completed. The following two classes
|
// corresponding stream have completed. The following two classes
|
||||||
// serve this purpose in two different compilation environments.
|
// serve this purpose in two different compilation environments.
|
||||||
|
|
||||||
class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
|
class EigenGpuStreamDevice : public ::Eigen::StreamInterface {
|
||||||
public:
|
public:
|
||||||
EigenCudaStreamDevice()
|
EigenGpuStreamDevice()
|
||||||
: scratch_(nullptr), semaphore_(nullptr), context_(nullptr) {
|
: scratch_(nullptr), semaphore_(nullptr), context_(nullptr) {
|
||||||
Eigen::initializeDeviceProp();
|
Eigen::initializeDeviceProp();
|
||||||
}
|
}
|
||||||
~EigenCudaStreamDevice() override {}
|
~EigenGpuStreamDevice() override {}
|
||||||
void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream,
|
void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream,
|
||||||
TfGpuId tf_gpu_id, ::tensorflow::Allocator* alloc,
|
TfGpuId tf_gpu_id, ::tensorflow::Allocator* alloc,
|
||||||
char* scratch) {
|
char* scratch) {
|
||||||
@ -101,7 +101,7 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
|
|||||||
context_ = context;
|
context_ = context;
|
||||||
scratch_ = scratch;
|
scratch_ = scratch;
|
||||||
semaphore_ =
|
semaphore_ =
|
||||||
reinterpret_cast<unsigned int*>(scratch + Eigen::kCudaScratchSize);
|
reinterpret_cast<unsigned int*>(scratch + Eigen::kGpuScratchSize);
|
||||||
stream_ = cuda_stream;
|
stream_ = cuda_stream;
|
||||||
allocator_ = alloc;
|
allocator_ = alloc;
|
||||||
PlatformGpuId platform_gpu_id;
|
PlatformGpuId platform_gpu_id;
|
||||||
@ -185,7 +185,7 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
|
|||||||
mutable unsigned int* semaphore_;
|
mutable unsigned int* semaphore_;
|
||||||
OpKernelContext* context_;
|
OpKernelContext* context_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(EigenCudaStreamDevice);
|
TF_DISALLOW_COPY_AND_ASSIGN(EigenGpuStreamDevice);
|
||||||
};
|
};
|
||||||
|
|
||||||
// This factory helps to ensure that different GPU device objects that refer to
|
// This factory helps to ensure that different GPU device objects that refer to
|
||||||
@ -292,7 +292,7 @@ Status BaseGPUDevice::InitScratchBuffers() {
|
|||||||
DCHECK(streams_[i]);
|
DCHECK(streams_[i]);
|
||||||
if (scratch_.size() > i && scratch_[i]) continue;
|
if (scratch_.size() > i && scratch_[i]) continue;
|
||||||
size_t scratch_buffer_size =
|
size_t scratch_buffer_size =
|
||||||
Eigen::kCudaScratchSize + sizeof(unsigned int);
|
Eigen::kGpuScratchSize + sizeof(unsigned int);
|
||||||
void* scratch_buffer = gpu_allocator_->AllocateRaw(
|
void* scratch_buffer = gpu_allocator_->AllocateRaw(
|
||||||
Allocator::kAllocatorAlignment, scratch_buffer_size);
|
Allocator::kAllocatorAlignment, scratch_buffer_size);
|
||||||
if (scratch_buffer == nullptr) {
|
if (scratch_buffer == nullptr) {
|
||||||
@ -304,7 +304,7 @@ Status BaseGPUDevice::InitScratchBuffers() {
|
|||||||
se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
|
se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
|
||||||
|
|
||||||
bool ok = executor_->SynchronousMemZero(
|
bool ok = executor_->SynchronousMemZero(
|
||||||
&mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
|
&mem, Eigen::kGpuScratchSize + sizeof(unsigned int));
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
"Failed to memcopy into scratch buffer for device ",
|
"Failed to memcopy into scratch buffer for device ",
|
||||||
@ -692,7 +692,7 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
|
|||||||
const Eigen::GpuDevice& device() const override { return device_; }
|
const Eigen::GpuDevice& device() const override { return device_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EigenCudaStreamDevice stream_device_;
|
EigenGpuStreamDevice stream_device_;
|
||||||
Eigen::GpuDevice device_;
|
Eigen::GpuDevice device_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -311,8 +311,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
|||||||
{"Square", EIGEN_COST(scalar_square_op<float>)},
|
{"Square", EIGEN_COST(scalar_square_op<float>)},
|
||||||
{"Tanh", EIGEN_COST(scalar_tanh_op<float>)},
|
{"Tanh", EIGEN_COST(scalar_tanh_op<float>)},
|
||||||
{"Relu", EIGEN_COST(scalar_max_op<float>)},
|
{"Relu", EIGEN_COST(scalar_max_op<float>)},
|
||||||
{"Sigmoid", EIGEN_COST(scalar_sigmoid_op<float>)},
|
{"Sigmoid", EIGEN_COST(scalar_logistic_op<float>)},
|
||||||
{"QuantizedSigmoid", EIGEN_COST(scalar_sigmoid_op<float>)},
|
{"QuantizedSigmoid", EIGEN_COST(scalar_logistic_op<float>)},
|
||||||
{"Sign", EIGEN_COST(scalar_sign_op<float>)},
|
{"Sign", EIGEN_COST(scalar_sign_op<float>)},
|
||||||
{"Sin", EIGEN_COST(scalar_sin_op<float>)},
|
{"Sin", EIGEN_COST(scalar_sin_op<float>)},
|
||||||
{"Tan", EIGEN_COST(scalar_tan_op<float>)},
|
{"Tan", EIGEN_COST(scalar_tan_op<float>)},
|
||||||
|
@ -60,9 +60,9 @@ template <typename T>
|
|||||||
struct CheckNumericsLaunch {
|
struct CheckNumericsLaunch {
|
||||||
void Run(const GPUDevice &d, const T *data, int size,
|
void Run(const GPUDevice &d, const T *data, int size,
|
||||||
int abnormal_detected[2]) {
|
int abnormal_detected[2]) {
|
||||||
const int32 block_size = d.maxCudaThreadsPerBlock();
|
const int32 block_size = d.maxGpuThreadsPerBlock();
|
||||||
const int32 num_blocks =
|
const int32 num_blocks =
|
||||||
(d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) /
|
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
|
||||||
block_size;
|
block_size;
|
||||||
|
|
||||||
CheckNumericsKernel<T><<<num_blocks, block_size, 0, d.stream()>>>(
|
CheckNumericsKernel<T><<<num_blocks, block_size, 0, d.stream()>>>(
|
||||||
|
@ -656,7 +656,7 @@ template <typename T>
|
|||||||
struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {};
|
struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct sigmoid : base<T, Eigen::internal::scalar_sigmoid_op<T>> {};
|
struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {};
|
struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {};
|
||||||
|
@ -500,8 +500,9 @@ class GemmFilterPacker {
|
|||||||
typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::RowMajor>
|
typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::RowMajor>
|
||||||
LhsMapper;
|
LhsMapper;
|
||||||
typedef Eigen::internal::gebp_traits<T, T> Traits;
|
typedef Eigen::internal::gebp_traits<T, T> Traits;
|
||||||
Eigen::internal::gemm_pack_lhs<T, int64, LhsMapper, Traits::mr,
|
Eigen::internal::gemm_pack_lhs<
|
||||||
Traits::LhsProgress, Eigen::RowMajor>
|
T, int64, LhsMapper, Traits::mr, Traits::LhsProgress,
|
||||||
|
typename Traits::LhsPacket4Packing, Eigen::RowMajor>
|
||||||
pack_lhs;
|
pack_lhs;
|
||||||
|
|
||||||
GemmFilterPacker(const int64 rows, const int64 depth, const T* lhs_input,
|
GemmFilterPacker(const int64 rows, const int64 depth, const T* lhs_input,
|
||||||
|
@ -764,7 +764,7 @@ Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args,
|
|||||||
const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 ||
|
const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 ||
|
||||||
kKnownDepthMultiplier < 0
|
kKnownDepthMultiplier < 0
|
||||||
? std::numeric_limits<int>::max()
|
? std::numeric_limits<int>::max()
|
||||||
: device.getNumCudaMultiProcessors();
|
: device.getNumGpuMultiProcessors();
|
||||||
kernel<<<std::min(max_block_count, config.block_count),
|
kernel<<<std::min(max_block_count, config.block_count),
|
||||||
config.thread_per_block, 0, device.stream()>>>(args, input, filter,
|
config.thread_per_block, 0, device.stream()>>>(args, input, filter,
|
||||||
output, num_outputs);
|
output, num_outputs);
|
||||||
|
@ -217,9 +217,9 @@ void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
|||||||
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
|
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
|
||||||
typename Distribution::ResultElementType* data, int64 size,
|
typename Distribution::ResultElementType* data, int64 size,
|
||||||
Distribution dist) {
|
Distribution dist) {
|
||||||
const int32 block_size = d.maxCudaThreadsPerBlock();
|
const int32 block_size = d.maxGpuThreadsPerBlock();
|
||||||
const int32 num_blocks =
|
const int32 num_blocks =
|
||||||
(d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) /
|
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
|
||||||
block_size;
|
block_size;
|
||||||
|
|
||||||
FillPhiloxRandomKernelLaunch<Distribution>
|
FillPhiloxRandomKernelLaunch<Distribution>
|
||||||
|
@ -131,7 +131,7 @@ class CudaLaunchConfigTest : public ::testing::Test {
|
|||||||
protected:
|
protected:
|
||||||
const int bufsize = 1024;
|
const int bufsize = 1024;
|
||||||
int* outbuf = nullptr;
|
int* outbuf = nullptr;
|
||||||
Eigen::CudaStreamDevice stream;
|
Eigen::GpuStreamDevice stream;
|
||||||
Eigen::GpuDevice d = Eigen::GpuDevice(&stream);
|
Eigen::GpuDevice d = Eigen::GpuDevice(&stream);
|
||||||
|
|
||||||
virtual void SetUp() {
|
virtual void SetUp() {
|
||||||
|
@ -128,12 +128,12 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
|||||||
CudaLaunchConfig config;
|
CudaLaunchConfig config;
|
||||||
const int virtual_thread_count = work_element_count;
|
const int virtual_thread_count = work_element_count;
|
||||||
const int physical_thread_count = std::min(
|
const int physical_thread_count = std::min(
|
||||||
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
|
d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(),
|
||||||
virtual_thread_count);
|
virtual_thread_count);
|
||||||
const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
|
const int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
|
||||||
const int block_count =
|
const int block_count =
|
||||||
std::min(DivUp(physical_thread_count, thread_per_block),
|
std::min(DivUp(physical_thread_count, thread_per_block),
|
||||||
d.getNumCudaMultiProcessors());
|
d.getNumGpuMultiProcessors());
|
||||||
|
|
||||||
config.virtual_thread_count = virtual_thread_count;
|
config.virtual_thread_count = virtual_thread_count;
|
||||||
config.thread_per_block = thread_per_block;
|
config.thread_per_block = thread_per_block;
|
||||||
@ -184,7 +184,7 @@ inline CudaLaunchConfig GetCudaLaunchConfigFixedBlockSize(
|
|||||||
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
&block_count, func, fixed_block_size, dynamic_shared_memory_size);
|
&block_count, func, fixed_block_size, dynamic_shared_memory_size);
|
||||||
CHECK_EQ(err, cudaSuccess);
|
CHECK_EQ(err, cudaSuccess);
|
||||||
block_count = std::min(block_count * d.getNumCudaMultiProcessors(),
|
block_count = std::min(block_count * d.getNumGpuMultiProcessors(),
|
||||||
DivUp(work_element_count, fixed_block_size));
|
DivUp(work_element_count, fixed_block_size));
|
||||||
|
|
||||||
config.virtual_thread_count = work_element_count;
|
config.virtual_thread_count = work_element_count;
|
||||||
@ -213,7 +213,7 @@ inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
|
|||||||
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
|
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
|
||||||
|
|
||||||
const int physical_thread_count =
|
const int physical_thread_count =
|
||||||
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
|
d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor();
|
||||||
|
|
||||||
const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
|
const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
|
||||||
|
|
||||||
|
@ -12,25 +12,55 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
|
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
|
|
||||||
|
|
||||||
#define EIGEN_USE_CUSTOM_THREAD_POOL
|
// This is essentially unsupported/CXX11/Eigen/Tensor.h
|
||||||
#define EIGEN_USE_THREADS
|
// TODO(petewarden) - move this to a common location in Eigen itself.
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
|
||||||
|
|
||||||
|
|
||||||
|
#include "Eigen/Core"
|
||||||
|
|
||||||
|
#if defined(EIGEN_USE_SYCL)
|
||||||
|
#undef min
|
||||||
|
#undef max
|
||||||
|
#undef isnan
|
||||||
|
#undef isinf
|
||||||
|
#undef isfinite
|
||||||
|
#include <CL/sycl.hpp>
|
||||||
|
#include <iostream>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#endif
|
||||||
|
#include <cmath>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
typedef __int16 int16_t;
|
||||||
|
typedef unsigned __int16 uint16_t;
|
||||||
|
typedef __int32 int32_t;
|
||||||
|
typedef unsigned __int32 uint32_t;
|
||||||
|
typedef __int64 int64_t;
|
||||||
|
typedef unsigned __int64 uint64_t;
|
||||||
|
#include <windows.h>
|
||||||
|
#else
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __cplusplus > 199711 || EIGEN_COMP_MSVC >= 1900
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <atomic>
|
#endif
|
||||||
#include <condition_variable> // NOLINT(build/c++11)
|
|
||||||
#include <mutex> // NOLINT(build/c++11)
|
|
||||||
#include <thread> // NOLINT(build/c++11)
|
|
||||||
#include <functional>
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
@ -40,58 +70,53 @@ limitations under the License.
|
|||||||
#include <time.h>
|
#include <time.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// #if defined(EIGEN_USE_LIBXSMM)
|
||||||
|
// #include "libxsmm.h"
|
||||||
|
// #endif
|
||||||
|
|
||||||
// Because some programs may link Eigen in through other frameworks with
|
#ifdef EIGEN_USE_THREADS
|
||||||
// different flags, we can run into multiple definition issues if we don't have
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
|
||||||
// a private namespace for our versions. This is a nasty hack, but a similar
|
#endif
|
||||||
// approach is used elsewhere to handle the problem, so it should be stable.
|
|
||||||
#define Eigen EigenForTFLite
|
|
||||||
|
|
||||||
#include "Eigen/src/Core/util/StaticAssert.h"
|
|
||||||
#include "unsupported/Eigen/CXX11/Core"
|
|
||||||
#include "unsupported/Eigen/SpecialFunctions"
|
|
||||||
|
|
||||||
#include "Eigen/src/Core/util/DisableStupidWarnings.h"
|
#include "Eigen/src/Core/util/DisableStupidWarnings.h"
|
||||||
|
|
||||||
#include "Eigen/Core"
|
#include "third_party/eigen3/unsupported/Eigen/SpecialFunctions"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/util/CXX11Meta.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/util/MaxSizeVector.h"
|
||||||
|
|
||||||
// Beware: the order of the include matters to some compilers. For example
|
|
||||||
// TensorIndexList.h should be included before TensorDimensions.h in order to
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h"
|
||||||
// use index lists to encode tensor dimensions when compiling with llvm.
|
|
||||||
// We're defining this ourselves rather than using the Eigen Tensor header file
|
|
||||||
// so that we can alter the macro definition of TENSOR_CONTRACTION_DISPATCH to
|
|
||||||
// reduce binary size.
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h"
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/ThreadPoolInterface.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorNonBlockingThreadPool.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceGpu.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h"
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h"
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStats.h"
|
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h"
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMappers.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h"
|
||||||
|
|
||||||
#undef TENSOR_CONTRACTION_DISPATCH
|
#undef TENSOR_CONTRACTION_DISPATCH
|
||||||
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
|
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
|
||||||
if (this->m_lhs_inner_dim_contiguous && \
|
if (this->m_lhs_inner_dim_contiguous && \
|
||||||
@ -102,8 +127,9 @@ limitations under the License.
|
|||||||
eigen_assert(false && "Unsupported contraction formats"); \
|
eigen_assert(false && "Unsupported contraction formats"); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h"
|
||||||
@ -125,19 +151,18 @@ limitations under the License.
|
|||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h"
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h"
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h"
|
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
|
||||||
|
|
||||||
#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
|
#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
|
||||||
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
|
||||||
|
@ -94,7 +94,7 @@ typedef unsigned __int64 uint64_t;
|
|||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceGpu.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h"
|
||||||
@ -106,10 +106,11 @@ typedef unsigned __int64 uint64_t;
|
|||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h"
|
||||||
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h"
|
||||||
@ -128,7 +129,7 @@ typedef unsigned __int64 uint64_t;
|
|||||||
|
|
||||||
|
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h"
|
||||||
|
@ -3151,12 +3151,12 @@ inline void LstmCell(
|
|||||||
// Combined memory state and final output calculation
|
// Combined memory state and final output calculation
|
||||||
gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
|
gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
|
||||||
output_state_map =
|
output_state_map =
|
||||||
input_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
|
input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
|
||||||
new_input_sm.tanh() +
|
new_input_sm.tanh() +
|
||||||
forget_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
|
forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
|
||||||
prev_state_map;
|
prev_state_map;
|
||||||
output_activ_map =
|
output_activ_map =
|
||||||
output_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
|
output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
|
||||||
output_state_map.tanh();
|
output_state_map.tanh();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4367,7 +4367,7 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
|
|||||||
auto input_map = MapAsVector(input_data, input_shape);
|
auto input_map = MapAsVector(input_data, input_shape);
|
||||||
auto output_map = MapAsVector(output_data, output_shape);
|
auto output_map = MapAsVector(output_data, output_shape);
|
||||||
output_map.array() =
|
output_map.array() =
|
||||||
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
|
input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convenience version that allows, for example, generated-code calls to be
|
// Convenience version that allows, for example, generated-code calls to be
|
||||||
|
@ -1570,7 +1570,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
self.assertAllEqual([[[[4.0]]]], self.evaluate(y))
|
self.assertAllClose([[[[4.0]]]], self.evaluate(y))
|
||||||
|
|
||||||
# Remove reference cycles in model
|
# Remove reference cycles in model
|
||||||
test_util.dismantle_polymorphic_function(model)
|
test_util.dismantle_polymorphic_function(model)
|
||||||
|
@ -134,11 +134,12 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
|||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "eigen_archive",
|
name = "eigen_archive",
|
||||||
build_file = clean_dep("//third_party:eigen.BUILD"),
|
build_file = clean_dep("//third_party:eigen.BUILD"),
|
||||||
sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9",
|
patch_file = clean_dep("//third_party:eigen_reshaped.patch"),
|
||||||
strip_prefix = "eigen-eigen-fd6845384b86",
|
sha256 = "d66cec3b54b3dfaa4666c1d49481a7197f93fc078cd53c54e2b4a8893a529c9f",
|
||||||
|
strip_prefix = "eigen-eigen-b4890dc6bc34",
|
||||||
urls = [
|
urls = [
|
||||||
"https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
|
"https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/b4890dc6bc34.tar.gz",
|
||||||
"https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
|
"https://bitbucket.org/eigen/eigen/get/b4890dc6bc34.tar.gz",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
1
third_party/eigen.BUILD
vendored
1
third_party/eigen.BUILD
vendored
@ -65,6 +65,7 @@ cc_library(
|
|||||||
# code. We use it, but we do not rely on it, as evidenced above.
|
# code. We use it, but we do not rely on it, as evidenced above.
|
||||||
"EIGEN_MPL2_ONLY",
|
"EIGEN_MPL2_ONLY",
|
||||||
"EIGEN_MAX_ALIGN_BYTES=64",
|
"EIGEN_MAX_ALIGN_BYTES=64",
|
||||||
|
"EIGEN_HAS_TYPE_TRAITS=0",
|
||||||
],
|
],
|
||||||
includes = ["."],
|
includes = ["."],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
@ -249,9 +249,7 @@ EIGEN_STRONG_INLINE QInt32& operator/=(QInt32& a, const QInt32 b) {
|
|||||||
a.value /= b.value;
|
a.value /= b.value;
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) {
|
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) { return -a.value; }
|
||||||
return -a.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scaling QInt32 by double. We do the arithmetic in double because
|
// Scaling QInt32 by double. We do the arithmetic in double because
|
||||||
// float only has 23 bits of mantissa, so casting QInt32 to float might reduce
|
// float only has 23 bits of mantissa, so casting QInt32 to float might reduce
|
||||||
|
@ -15,11 +15,9 @@ namespace internal {
|
|||||||
|
|
||||||
// Accumulate the product of 2 QInt8 inputs on 32 bits to prevent
|
// Accumulate the product of 2 QInt8 inputs on 32 bits to prevent
|
||||||
// overflows
|
// overflows
|
||||||
template<> struct scalar_product_traits<QInt8, QInt8>
|
template <>
|
||||||
{
|
struct scalar_product_traits<QInt8, QInt8> {
|
||||||
enum {
|
enum { Defined = 1 };
|
||||||
Defined = 1
|
|
||||||
};
|
|
||||||
typedef QInt32 ReturnType;
|
typedef QInt32 ReturnType;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -33,11 +31,9 @@ struct scalar_product_traits<QInt16, QInt16> {
|
|||||||
|
|
||||||
// Accumulate the product of QInt8 inputs with QUint8 inputs on 32 bits
|
// Accumulate the product of QInt8 inputs with QUint8 inputs on 32 bits
|
||||||
// to prevent overflows
|
// to prevent overflows
|
||||||
template<> struct scalar_product_traits<QInt8, QUInt8>
|
template <>
|
||||||
{
|
struct scalar_product_traits<QInt8, QUInt8> {
|
||||||
enum {
|
enum { Defined = 1 };
|
||||||
Defined = 1
|
|
||||||
};
|
|
||||||
typedef QInt32 ReturnType;
|
typedef QInt32 ReturnType;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -47,14 +43,16 @@ template<> struct scalar_product_traits<QInt8, QUInt8>
|
|||||||
// signed 8bit integers
|
// signed 8bit integers
|
||||||
#ifndef EIGEN_USE_OPTIMIZED_INT8_INT8_MAT_MAT_PRODUCT
|
#ifndef EIGEN_USE_OPTIMIZED_INT8_INT8_MAT_MAT_PRODUCT
|
||||||
|
|
||||||
template<bool _ConjLhs, bool _ConjRhs>
|
template <bool _ConjLhs, bool _ConjRhs>
|
||||||
class gebp_traits<QInt8, QInt8, _ConjLhs, _ConjRhs>
|
class gebp_traits<QInt8, QInt8, _ConjLhs, _ConjRhs> {
|
||||||
{
|
public:
|
||||||
public:
|
|
||||||
typedef QInt8 LhsScalar;
|
typedef QInt8 LhsScalar;
|
||||||
typedef QInt8 RhsScalar;
|
typedef QInt8 RhsScalar;
|
||||||
typedef QInt32 ResScalar;
|
typedef QInt32 ResScalar;
|
||||||
|
|
||||||
|
typedef typename packet_traits<LhsScalar>::type LhsPacket;
|
||||||
|
typedef LhsPacket LhsPacket4Packing;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
// register block size along the M and N directions
|
// register block size along the M and N directions
|
||||||
// One for the current implementation
|
// One for the current implementation
|
||||||
@ -68,22 +66,24 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
// The signed 8bit Mat-Mat product itself.
|
// The signed 8bit Mat-Mat product itself.
|
||||||
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
template <typename Index, typename DataMapper, int mr, int nr,
|
||||||
struct gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
|
bool ConjugateLhs, bool ConjugateRhs>
|
||||||
{
|
struct gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs,
|
||||||
|
ConjugateRhs> {
|
||||||
EIGEN_DONT_INLINE
|
EIGEN_DONT_INLINE
|
||||||
void operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB,
|
void operator()(const DataMapper& res, const QInt8* blockA,
|
||||||
Index rows, Index depth, Index cols, QInt32 alpha,
|
const QInt8* blockB, Index rows, Index depth, Index cols,
|
||||||
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
|
QInt32 alpha, Index strideA = -1, Index strideB = -1,
|
||||||
|
Index offsetA = 0, Index offsetB = 0);
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
template <typename Index, typename DataMapper, int mr, int nr,
|
||||||
EIGEN_DONT_INLINE
|
bool ConjugateLhs, bool ConjugateRhs>
|
||||||
void gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
|
EIGEN_DONT_INLINE void gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr,
|
||||||
::operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB,
|
ConjugateLhs, ConjugateRhs>::
|
||||||
Index rows, Index depth, Index cols, QInt32 alpha,
|
operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB,
|
||||||
Index strideA, Index strideB, Index offsetA, Index offsetB)
|
Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
|
||||||
{
|
Index strideB, Index offsetA, Index offsetB) {
|
||||||
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
@ -113,18 +113,19 @@ void gebp_kernel<QInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, Conjugat
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
// This definition tackle the case where the lhs is encoded using signed 8bit
|
// This definition tackle the case where the lhs is encoded using signed 8bit
|
||||||
// integers and the rhs using unsigned 8bit integers.
|
// integers and the rhs using unsigned 8bit integers.
|
||||||
#ifndef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
|
#ifndef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
|
||||||
template<bool _ConjLhs, bool _ConjRhs>
|
template <bool _ConjLhs, bool _ConjRhs>
|
||||||
class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs>
|
class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> {
|
||||||
{
|
public:
|
||||||
public:
|
|
||||||
typedef QInt8 LhsScalar;
|
typedef QInt8 LhsScalar;
|
||||||
typedef QUInt8 RhsScalar;
|
typedef QUInt8 RhsScalar;
|
||||||
typedef QInt32 ResScalar;
|
typedef QInt32 ResScalar;
|
||||||
|
|
||||||
|
typedef typename packet_traits<LhsScalar>::type LhsPacket;
|
||||||
|
typedef LhsPacket LhsPacket4Packing;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
// register block size along the M and N directions
|
// register block size along the M and N directions
|
||||||
// One for the current implementation
|
// One for the current implementation
|
||||||
@ -138,22 +139,24 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs
|
// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs
|
||||||
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
template <typename Index, typename DataMapper, int mr, int nr,
|
||||||
struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
|
bool ConjugateLhs, bool ConjugateRhs>
|
||||||
{
|
struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
|
||||||
|
ConjugateRhs> {
|
||||||
EIGEN_DONT_INLINE
|
EIGEN_DONT_INLINE
|
||||||
void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
|
void operator()(const DataMapper& res, const QInt8* blockA,
|
||||||
Index rows, Index depth, Index cols, QInt32 alpha,
|
const QUInt8* blockB, Index rows, Index depth, Index cols,
|
||||||
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
|
QInt32 alpha, Index strideA = -1, Index strideB = -1,
|
||||||
|
Index offsetA = 0, Index offsetB = 0);
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
template <typename Index, typename DataMapper, int mr, int nr,
|
||||||
EIGEN_DONT_INLINE
|
bool ConjugateLhs, bool ConjugateRhs>
|
||||||
void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
|
EIGEN_DONT_INLINE void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr,
|
||||||
::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
|
ConjugateLhs, ConjugateRhs>::
|
||||||
Index rows, Index depth, Index cols, QInt32 alpha,
|
operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
|
||||||
Index strideA, Index strideB, Index offsetA, Index offsetB)
|
Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
|
||||||
{
|
Index strideB, Index offsetA, Index offsetB) {
|
||||||
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
@ -183,18 +186,19 @@ void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, Conjuga
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
// This definition tackle the case where the khs is encoded using unsigned 8bit
|
// This definition tackle the case where the khs is encoded using unsigned 8bit
|
||||||
// integers and the rhs using signed 8bit integers.
|
// integers and the rhs using signed 8bit integers.
|
||||||
#ifndef EIGEN_USE_OPTIMIZED_UINT8_INT8_MAT_MAT_PRODUCT
|
#ifndef EIGEN_USE_OPTIMIZED_UINT8_INT8_MAT_MAT_PRODUCT
|
||||||
template<bool _ConjLhs, bool _ConjRhs>
|
template <bool _ConjLhs, bool _ConjRhs>
|
||||||
class gebp_traits<QUInt8, QInt8, _ConjLhs, _ConjRhs>
|
class gebp_traits<QUInt8, QInt8, _ConjLhs, _ConjRhs> {
|
||||||
{
|
public:
|
||||||
public:
|
|
||||||
typedef QUInt8 LhsScalar;
|
typedef QUInt8 LhsScalar;
|
||||||
typedef QInt8 RhsScalar;
|
typedef QInt8 RhsScalar;
|
||||||
typedef QInt32 ResScalar;
|
typedef QInt32 ResScalar;
|
||||||
|
|
||||||
|
typedef typename packet_traits<LhsScalar>::type LhsPacket;
|
||||||
|
typedef LhsPacket LhsPacket4Packing;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
// register block size along the M and N directions
|
// register block size along the M and N directions
|
||||||
// One for the current implementation
|
// One for the current implementation
|
||||||
@ -207,24 +211,25 @@ public:
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Mat-Mat product of an unsigned 8bit lhs with a signed 8bit rhs
|
// Mat-Mat product of an unsigned 8bit lhs with a signed 8bit rhs
|
||||||
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
template <typename Index, typename DataMapper, int mr, int nr,
|
||||||
struct gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
|
bool ConjugateLhs, bool ConjugateRhs>
|
||||||
{
|
struct gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs,
|
||||||
|
ConjugateRhs> {
|
||||||
EIGEN_DONT_INLINE
|
EIGEN_DONT_INLINE
|
||||||
void operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB,
|
void operator()(const DataMapper& res, const QUInt8* blockA,
|
||||||
Index rows, Index depth, Index cols, QInt32 alpha,
|
const QInt8* blockB, Index rows, Index depth, Index cols,
|
||||||
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
|
QInt32 alpha, Index strideA = -1, Index strideB = -1,
|
||||||
|
Index offsetA = 0, Index offsetB = 0);
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
template <typename Index, typename DataMapper, int mr, int nr,
|
||||||
EIGEN_DONT_INLINE
|
bool ConjugateLhs, bool ConjugateRhs>
|
||||||
void gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
|
EIGEN_DONT_INLINE void gebp_kernel<QUInt8, QInt8, Index, DataMapper, mr, nr,
|
||||||
::operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB,
|
ConjugateLhs, ConjugateRhs>::
|
||||||
Index rows, Index depth, Index cols, QInt32 alpha,
|
operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB,
|
||||||
Index strideA, Index strideB, Index offsetA, Index offsetB)
|
Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
|
||||||
{
|
Index strideB, Index offsetA, Index offsetB) {
|
||||||
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
@ -263,6 +268,9 @@ class gebp_traits<QInt16, QInt16, _ConjLhs, _ConjRhs> {
|
|||||||
typedef QInt16 RhsScalar;
|
typedef QInt16 RhsScalar;
|
||||||
typedef QInt32 ResScalar;
|
typedef QInt32 ResScalar;
|
||||||
|
|
||||||
|
typedef typename packet_traits<LhsScalar>::type LhsPacket;
|
||||||
|
typedef LhsPacket LhsPacket4Packing;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
// register block size along the M and N directions
|
// register block size along the M and N directions
|
||||||
// One for the current implementation
|
// One for the current implementation
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -14,15 +14,14 @@
|
|||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
|
// AVX2 optimized implementation of the case where the lhs is encoded using
|
||||||
// AVX2 optimized implementation of the case where the lhs is encoded using signed 8bit
|
// signed 8bit
|
||||||
// integers and the rhs using unsigned 8bit integers.
|
// integers and the rhs using unsigned 8bit integers.
|
||||||
#ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
|
#ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
|
||||||
|
|
||||||
template<bool _ConjLhs, bool _ConjRhs>
|
template <bool _ConjLhs, bool _ConjRhs>
|
||||||
class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs>
|
class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> {
|
||||||
{
|
public:
|
||||||
public:
|
|
||||||
typedef QInt8 LhsScalar;
|
typedef QInt8 LhsScalar;
|
||||||
typedef QUInt8 RhsScalar;
|
typedef QUInt8 RhsScalar;
|
||||||
typedef QInt32 ResScalar;
|
typedef QInt32 ResScalar;
|
||||||
@ -40,22 +39,24 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs
|
// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs
|
||||||
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
template <typename Index, typename DataMapper, int mr, int nr,
|
||||||
struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
|
bool ConjugateLhs, bool ConjugateRhs>
|
||||||
{
|
struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
|
||||||
|
ConjugateRhs> {
|
||||||
EIGEN_DONT_INLINE
|
EIGEN_DONT_INLINE
|
||||||
void operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
|
void operator()(const DataMapper& res, const QInt8* blockA,
|
||||||
Index rows, Index depth, Index cols, QInt32 alpha,
|
const QUInt8* blockB, Index rows, Index depth, Index cols,
|
||||||
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
|
QInt32 alpha, Index strideA = -1, Index strideB = -1,
|
||||||
|
Index offsetA = 0, Index offsetB = 0);
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
template <typename Index, typename DataMapper, int mr, int nr,
|
||||||
EIGEN_DONT_INLINE
|
bool ConjugateLhs, bool ConjugateRhs>
|
||||||
void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
|
EIGEN_DONT_INLINE void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr,
|
||||||
::operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
|
ConjugateLhs, ConjugateRhs>::
|
||||||
Index rows, Index depth, Index cols, QInt32 alpha,
|
operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
|
||||||
Index strideA, Index strideB, Index offsetA, Index offsetB)
|
Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
|
||||||
{
|
Index strideB, Index offsetA, Index offsetB) {
|
||||||
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
@ -85,7 +86,6 @@ void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, Conjuga
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace Eigen
|
} // namespace Eigen
|
||||||
|
|
||||||
|
@ -15,25 +15,23 @@ namespace internal {
|
|||||||
|
|
||||||
// Mat-Vec product
|
// Mat-Vec product
|
||||||
// Both lhs and rhs are encoded as 8bit signed integers
|
// Both lhs and rhs are encoded as 8bit signed integers
|
||||||
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
|
template <typename Index, typename LhsMapper, bool ConjugateLhs,
|
||||||
struct general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>
|
typename RhsMapper, bool ConjugateRhs, int Version>
|
||||||
{
|
struct general_matrix_vector_product<Index, QInt8, LhsMapper, ColMajor,
|
||||||
EIGEN_DONT_INLINE static void run(
|
ConjugateLhs, QInt8, RhsMapper,
|
||||||
Index rows, Index cols,
|
ConjugateRhs, Version> {
|
||||||
const LhsMapper& lhs,
|
EIGEN_DONT_INLINE static void run(Index rows, Index cols,
|
||||||
const RhsMapper& rhs,
|
const LhsMapper& lhs, const RhsMapper& rhs,
|
||||||
QInt32* res, Index resIncr,
|
QInt32* res, Index resIncr, QInt8 alpha);
|
||||||
QInt8 alpha);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
|
template <typename Index, typename LhsMapper, bool ConjugateLhs,
|
||||||
EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>::run(
|
typename RhsMapper, bool ConjugateRhs, int Version>
|
||||||
Index rows, Index cols,
|
EIGEN_DONT_INLINE void general_matrix_vector_product<
|
||||||
const LhsMapper& lhs,
|
Index, QInt8, LhsMapper, ColMajor, ConjugateLhs, QInt8, RhsMapper,
|
||||||
const RhsMapper& rhs,
|
ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs,
|
||||||
QInt32* res, Index resIncr,
|
const RhsMapper& rhs, QInt32* res,
|
||||||
QInt8 alpha)
|
Index resIncr, QInt8 alpha) {
|
||||||
{
|
|
||||||
eigen_assert(alpha.value == 1);
|
eigen_assert(alpha.value == 1);
|
||||||
eigen_assert(resIncr == 1);
|
eigen_assert(resIncr == 1);
|
||||||
eigen_assert(rows > 0);
|
eigen_assert(rows > 0);
|
||||||
@ -78,26 +76,25 @@ EIGEN_DONT_INLINE void general_matrix_vector_product<
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Mat-Vec product
|
// Mat-Vec product
|
||||||
// The lhs is encoded using 8bit signed integers, the rhs using 8bit unsigned integers
|
// The lhs is encoded using 8bit signed integers, the rhs using 8bit unsigned
|
||||||
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
|
// integers
|
||||||
struct general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QUInt8,RhsMapper,ConjugateRhs,Version>
|
template <typename Index, typename LhsMapper, bool ConjugateLhs,
|
||||||
{
|
typename RhsMapper, bool ConjugateRhs, int Version>
|
||||||
EIGEN_DONT_INLINE static void run(
|
struct general_matrix_vector_product<Index, QInt8, LhsMapper, ColMajor,
|
||||||
Index rows, Index cols,
|
ConjugateLhs, QUInt8, RhsMapper,
|
||||||
const LhsMapper& lhs,
|
ConjugateRhs, Version> {
|
||||||
const RhsMapper& rhs,
|
EIGEN_DONT_INLINE static void run(Index rows, Index cols,
|
||||||
QInt32* res, Index resIncr,
|
const LhsMapper& lhs, const RhsMapper& rhs,
|
||||||
QUInt8 alpha);
|
QInt32* res, Index resIncr, QUInt8 alpha);
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
|
template <typename Index, typename LhsMapper, bool ConjugateLhs,
|
||||||
EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QInt8,LhsMapper,ColMajor,ConjugateLhs,QUInt8,RhsMapper,ConjugateRhs,Version>::run(
|
typename RhsMapper, bool ConjugateRhs, int Version>
|
||||||
Index rows, Index cols,
|
EIGEN_DONT_INLINE void general_matrix_vector_product<
|
||||||
const LhsMapper& lhs,
|
Index, QInt8, LhsMapper, ColMajor, ConjugateLhs, QUInt8, RhsMapper,
|
||||||
const RhsMapper& rhs,
|
ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs,
|
||||||
QInt32* res, Index resIncr,
|
const RhsMapper& rhs, QInt32* res,
|
||||||
QUInt8 alpha)
|
Index resIncr, QUInt8 alpha) {
|
||||||
{
|
|
||||||
eigen_assert(alpha.value == 1);
|
eigen_assert(alpha.value == 1);
|
||||||
eigen_assert(resIncr == 1);
|
eigen_assert(resIncr == 1);
|
||||||
eigen_assert(rows > 0);
|
eigen_assert(rows > 0);
|
||||||
@ -110,28 +107,26 @@ EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QInt8,LhsMapper,ColMa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Mat-Vec product
|
// Mat-Vec product
|
||||||
// The lhs is encoded using bit unsigned integers, the rhs using 8bit signed integers
|
// The lhs is encoded using bit unsigned integers, the rhs using 8bit signed
|
||||||
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
|
// integers
|
||||||
struct general_matrix_vector_product<Index,QUInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>
|
template <typename Index, typename LhsMapper, bool ConjugateLhs,
|
||||||
{
|
typename RhsMapper, bool ConjugateRhs, int Version>
|
||||||
EIGEN_DONT_INLINE static void run(
|
struct general_matrix_vector_product<Index, QUInt8, LhsMapper, ColMajor,
|
||||||
Index rows, Index cols,
|
ConjugateLhs, QInt8, RhsMapper,
|
||||||
const LhsMapper& lhs,
|
ConjugateRhs, Version> {
|
||||||
const RhsMapper& rhs,
|
EIGEN_DONT_INLINE static void run(Index rows, Index cols,
|
||||||
QInt32* res, Index resIncr,
|
const LhsMapper& lhs, const RhsMapper& rhs,
|
||||||
QInt8 alpha);
|
QInt32* res, Index resIncr, QInt8 alpha);
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version>
|
template <typename Index, typename LhsMapper, bool ConjugateLhs,
|
||||||
EIGEN_DONT_INLINE void general_matrix_vector_product<Index,QUInt8,LhsMapper,ColMajor,ConjugateLhs,QInt8,RhsMapper,ConjugateRhs,Version>::run(
|
typename RhsMapper, bool ConjugateRhs, int Version>
|
||||||
Index rows, Index cols,
|
EIGEN_DONT_INLINE void general_matrix_vector_product<
|
||||||
const LhsMapper& lhs,
|
Index, QUInt8, LhsMapper, ColMajor, ConjugateLhs, QInt8, RhsMapper,
|
||||||
const RhsMapper& rhs,
|
ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs,
|
||||||
QInt32* res, Index resIncr,
|
const RhsMapper& rhs, QInt32* res,
|
||||||
QInt8 alpha)
|
Index resIncr, QInt8 alpha) {
|
||||||
{
|
|
||||||
eigen_assert(alpha.value == 1);
|
eigen_assert(alpha.value == 1);
|
||||||
eigen_assert(resIncr == 1);
|
eigen_assert(resIncr == 1);
|
||||||
eigen_assert(rows > 0);
|
eigen_assert(rows > 0);
|
||||||
|
@ -8,24 +8,20 @@
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
inline int _mm256_extract_epi16_N0(const __m256i X)
|
inline int _mm256_extract_epi16_N0(const __m256i X) {
|
||||||
{
|
return _mm_extract_epi16(_mm256_extractf128_si256(X, 0 >> 3), 0 % 8);
|
||||||
return _mm_extract_epi16(_mm256_extractf128_si256(X, 0 >> 3), 0 % 8);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int _mm256_extract_epi16_N1(const __m256i X)
|
inline int _mm256_extract_epi16_N1(const __m256i X) {
|
||||||
{
|
return _mm_extract_epi16(_mm256_extractf128_si256(X, 1 >> 3), 1 % 8);
|
||||||
return _mm_extract_epi16(_mm256_extractf128_si256(X, 1 >> 3), 1 % 8);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int _mm256_extract_epi8_N0(const __m256i X)
|
inline int _mm256_extract_epi8_N0(const __m256i X) {
|
||||||
{
|
return _mm_extract_epi8(_mm256_extractf128_si256((X), 0 >> 4), 0 % 16);
|
||||||
return _mm_extract_epi8(_mm256_extractf128_si256((X), 0 >> 4), 0 % 16);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int _mm256_extract_epi8_N1(const __m256i X)
|
inline int _mm256_extract_epi8_N1(const __m256i X) {
|
||||||
{
|
return _mm_extract_epi8(_mm256_extractf128_si256((X), 1 >> 4), 1 % 16);
|
||||||
return _mm_extract_epi8(_mm256_extractf128_si256((X), 1 >> 4), 1 % 16);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
@ -34,56 +30,56 @@ namespace internal {
|
|||||||
typedef struct Packet32q8i {
|
typedef struct Packet32q8i {
|
||||||
__m256i val;
|
__m256i val;
|
||||||
operator __m256i() const { return val; }
|
operator __m256i() const { return val; }
|
||||||
Packet32q8i();
|
Packet32q8i() : val(_mm256_setzero_si256()){};
|
||||||
Packet32q8i(__m256i val) : val(val) {}
|
Packet32q8i(__m256i val) : val(val) {}
|
||||||
} Packet32q8i;
|
} Packet32q8i;
|
||||||
|
|
||||||
typedef struct Packet16q16i {
|
typedef struct Packet16q16i {
|
||||||
__m256i val;
|
__m256i val;
|
||||||
operator __m256i() const { return val; }
|
operator __m256i() const { return val; }
|
||||||
Packet16q16i();
|
Packet16q16i() : val(_mm256_setzero_si256()){};
|
||||||
Packet16q16i(__m256i val) : val(val) {}
|
Packet16q16i(__m256i val) : val(val) {}
|
||||||
} Packet16q16i;
|
} Packet16q16i;
|
||||||
|
|
||||||
typedef struct Packet32q8u {
|
typedef struct Packet32q8u {
|
||||||
__m256i val;
|
__m256i val;
|
||||||
operator __m256i() const { return val; }
|
operator __m256i() const { return val; }
|
||||||
Packet32q8u();
|
Packet32q8u() : val(_mm256_setzero_si256()){};
|
||||||
Packet32q8u(__m256i val) : val(val) {}
|
Packet32q8u(__m256i val) : val(val) {}
|
||||||
} Packet32q8u;
|
} Packet32q8u;
|
||||||
|
|
||||||
typedef struct Packet16q8i {
|
typedef struct Packet16q8i {
|
||||||
__m128i val;
|
__m128i val;
|
||||||
operator __m128i() const { return val; }
|
operator __m128i() const { return val; }
|
||||||
Packet16q8i();
|
Packet16q8i() : val(_mm_setzero_si128()) {}
|
||||||
Packet16q8i(__m128i val) : val(val) {}
|
Packet16q8i(__m128i val) : val(val) {}
|
||||||
} Packet16q8i;
|
} Packet16q8i;
|
||||||
|
|
||||||
typedef struct Packet16q8u {
|
typedef struct Packet16q8u {
|
||||||
__m128i val;
|
__m128i val;
|
||||||
operator __m128i() const { return val; }
|
operator __m128i() const { return val; }
|
||||||
Packet16q8u();
|
Packet16q8u() : val(_mm_setzero_si128()) {}
|
||||||
Packet16q8u(__m128i val) : val(val) {}
|
Packet16q8u(__m128i val) : val(val) {}
|
||||||
} Packet16q8u;
|
} Packet16q8u;
|
||||||
|
|
||||||
typedef struct Packet8q16i {
|
typedef struct Packet8q16i {
|
||||||
__m128i val;
|
__m128i val;
|
||||||
operator __m128i() const { return val; }
|
operator __m128i() const { return val; }
|
||||||
Packet8q16i();
|
Packet8q16i() : val(_mm_setzero_si128()) {}
|
||||||
Packet8q16i(__m128i val) : val(val) {}
|
Packet8q16i(__m128i val) : val(val) {}
|
||||||
} Packet8q16i;
|
} Packet8q16i;
|
||||||
|
|
||||||
typedef struct Packet8q32i {
|
typedef struct Packet8q32i {
|
||||||
__m256i val;
|
__m256i val;
|
||||||
operator __m256i() const { return val; }
|
operator __m256i() const { return val; }
|
||||||
Packet8q32i();
|
Packet8q32i() : val(_mm256_setzero_si256()){};
|
||||||
Packet8q32i(__m256i val) : val(val) {}
|
Packet8q32i(__m256i val) : val(val) {}
|
||||||
} Packet8q32i;
|
} Packet8q32i;
|
||||||
|
|
||||||
typedef struct Packet4q32i {
|
typedef struct Packet4q32i {
|
||||||
__m128i val;
|
__m128i val;
|
||||||
operator __m128i() const { return val; }
|
operator __m128i() const { return val; }
|
||||||
Packet4q32i();
|
Packet4q32i() : val(_mm_setzero_si128()) {}
|
||||||
Packet4q32i(__m128i val) : val(val) {}
|
Packet4q32i(__m128i val) : val(val) {}
|
||||||
} Packet4q32i;
|
} Packet4q32i;
|
||||||
|
|
||||||
@ -182,25 +178,25 @@ template <>
|
|||||||
struct unpacket_traits<Packet32q8i> {
|
struct unpacket_traits<Packet32q8i> {
|
||||||
typedef QInt8 type;
|
typedef QInt8 type;
|
||||||
typedef Packet16q8i half;
|
typedef Packet16q8i half;
|
||||||
enum { size = 32, alignment=Aligned32 };
|
enum { size = 32, alignment = Aligned32 };
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct unpacket_traits<Packet16q16i> {
|
struct unpacket_traits<Packet16q16i> {
|
||||||
typedef QInt16 type;
|
typedef QInt16 type;
|
||||||
typedef Packet8q16i half;
|
typedef Packet8q16i half;
|
||||||
enum { size = 16, alignment=Aligned32 };
|
enum { size = 16, alignment = Aligned32 };
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct unpacket_traits<Packet32q8u> {
|
struct unpacket_traits<Packet32q8u> {
|
||||||
typedef QUInt8 type;
|
typedef QUInt8 type;
|
||||||
typedef Packet16q8u half;
|
typedef Packet16q8u half;
|
||||||
enum { size = 32, alignment=Aligned32 };
|
enum { size = 32, alignment = Aligned32 };
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct unpacket_traits<Packet8q32i> {
|
struct unpacket_traits<Packet8q32i> {
|
||||||
typedef QInt32 type;
|
typedef QInt32 type;
|
||||||
typedef Packet4q32i half;
|
typedef Packet4q32i half;
|
||||||
enum { size = 8, alignment=Aligned32 };
|
enum { size = 8, alignment = Aligned32 };
|
||||||
};
|
};
|
||||||
|
|
||||||
// Unaligned load
|
// Unaligned load
|
||||||
@ -455,40 +451,47 @@ EIGEN_STRONG_INLINE QUInt8 predux_max<Packet32q8u>(const Packet32q8u& a) {
|
|||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QInt8 predux_min<Packet32q8i>(const Packet32q8i& a) {
|
EIGEN_STRONG_INLINE QInt8 predux_min<Packet32q8i>(const Packet32q8i& a) {
|
||||||
__m256i tmp = _mm256_min_epi8(a, _mm256_permute2f128_si256(a, a, 1));
|
__m256i tmp = _mm256_min_epi8(a, _mm256_permute2f128_si256(a, a, 1));
|
||||||
tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
|
tmp =
|
||||||
|
_mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
|
||||||
tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, 1));
|
tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, 1));
|
||||||
tmp = _mm256_min_epi8(tmp, _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
|
tmp = _mm256_min_epi8(tmp,
|
||||||
|
_mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
|
||||||
return std::min(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp));
|
return std::min(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp));
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QInt8 predux_max<Packet32q8i>(const Packet32q8i& a) {
|
EIGEN_STRONG_INLINE QInt8 predux_max<Packet32q8i>(const Packet32q8i& a) {
|
||||||
__m256i tmp = _mm256_max_epi8(a, _mm256_permute2f128_si256(a, a, 1));
|
__m256i tmp = _mm256_max_epi8(a, _mm256_permute2f128_si256(a, a, 1));
|
||||||
tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
|
tmp =
|
||||||
|
_mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
|
||||||
tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, 1));
|
tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, 1));
|
||||||
tmp = _mm256_max_epi8(tmp, _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
|
tmp = _mm256_max_epi8(tmp,
|
||||||
|
_mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2)));
|
||||||
return std::max(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp));
|
return std::max(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vectorized scaling of Packet32q8i by float.
|
// Vectorized scaling of Packet32q8i by float.
|
||||||
template<>
|
template <>
|
||||||
struct scalar_product_op<QInt32, double> : binary_op_base<QInt32, double> {
|
struct scalar_product_op<QInt32, double> : binary_op_base<QInt32, double> {
|
||||||
typedef typename ScalarBinaryOpTraits<QInt32, double>::ReturnType result_type;
|
typedef typename ScalarBinaryOpTraits<QInt32, double>::ReturnType result_type;
|
||||||
#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
|
#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op)
|
||||||
#else
|
#else
|
||||||
scalar_product_op() {
|
scalar_product_op() { EIGEN_SCALAR_BINARY_OP_PLUGIN }
|
||||||
EIGEN_SCALAR_BINARY_OP_PLUGIN
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const QInt32& a, const double& b) const { return a * b; }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type
|
||||||
|
operator()(const QInt32& a, const double& b) const {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE const Packet8q32i packetOp(const Packet8q32i& a, const double& b) const {
|
EIGEN_STRONG_INLINE const Packet8q32i packetOp(const Packet8q32i& a,
|
||||||
|
const double& b) const {
|
||||||
__m256d scale = _mm256_set1_pd(b);
|
__m256d scale = _mm256_set1_pd(b);
|
||||||
__m256d a_lo = _mm256_cvtepi32_pd(_mm256_castsi256_si128(a));
|
__m256d a_lo = _mm256_cvtepi32_pd(_mm256_castsi256_si128(a));
|
||||||
__m128i result_lo = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_lo));
|
__m128i result_lo = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_lo));
|
||||||
__m256d a_hi = _mm256_cvtepi32_pd(_mm256_extracti128_si256(a, 1));
|
__m256d a_hi = _mm256_cvtepi32_pd(_mm256_extracti128_si256(a, 1));
|
||||||
__m128i result_hi = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_hi));
|
__m128i result_hi = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_hi));
|
||||||
return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
|
return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi,
|
||||||
|
1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -127,25 +127,25 @@ template <>
|
|||||||
struct unpacket_traits<Packet64q8i> {
|
struct unpacket_traits<Packet64q8i> {
|
||||||
typedef QInt8 type;
|
typedef QInt8 type;
|
||||||
typedef Packet32q8i half;
|
typedef Packet32q8i half;
|
||||||
enum { size = 64, alignment=Aligned64 };
|
enum { size = 64, alignment = Aligned64 };
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct unpacket_traits<Packet32q16i> {
|
struct unpacket_traits<Packet32q16i> {
|
||||||
typedef QInt16 type;
|
typedef QInt16 type;
|
||||||
typedef Packet16q16i half;
|
typedef Packet16q16i half;
|
||||||
enum { size = 32, alignment=Aligned64 };
|
enum { size = 32, alignment = Aligned64 };
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct unpacket_traits<Packet64q8u> {
|
struct unpacket_traits<Packet64q8u> {
|
||||||
typedef QUInt8 type;
|
typedef QUInt8 type;
|
||||||
typedef Packet32q8u half;
|
typedef Packet32q8u half;
|
||||||
enum { size = 64, alignment=Aligned64 };
|
enum { size = 64, alignment = Aligned64 };
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct unpacket_traits<Packet16q32i> {
|
struct unpacket_traits<Packet16q32i> {
|
||||||
typedef QInt32 type;
|
typedef QInt32 type;
|
||||||
typedef Packet8q32i half;
|
typedef Packet8q32i half;
|
||||||
enum { size = 16, alignment=Aligned64 };
|
enum { size = 16, alignment = Aligned64 };
|
||||||
};
|
};
|
||||||
|
|
||||||
// Unaligned load
|
// Unaligned load
|
||||||
@ -244,7 +244,7 @@ EIGEN_STRONG_INLINE QInt32 pfirst<Packet16q32i>(const Packet16q32i& a) {
|
|||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QUInt8 pfirst<Packet64q8u>(const Packet64q8u& a) {
|
EIGEN_STRONG_INLINE QUInt8 pfirst<Packet64q8u>(const Packet64q8u& a) {
|
||||||
return static_cast<uint8_t>(
|
return static_cast<uint8_t>(
|
||||||
_mm_extract_epi8(_mm512_extracti32x4_epi32(a.val, 0), 0));
|
_mm_extract_epi8(_mm512_extracti32x4_epi32(a.val, 0), 0));
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QInt8 pfirst<Packet64q8i>(const Packet64q8i& a) {
|
EIGEN_STRONG_INLINE QInt8 pfirst<Packet64q8i>(const Packet64q8i& a) {
|
||||||
@ -410,9 +410,7 @@ EIGEN_STRONG_INLINE QInt32 predux_min<Packet16q32i>(const Packet16q32i& a) {
|
|||||||
_mm_min_epi32(_mm_min_epi32(lane0, lane1), _mm_min_epi32(lane2, lane3));
|
_mm_min_epi32(_mm_min_epi32(lane0, lane1), _mm_min_epi32(lane2, lane3));
|
||||||
res = _mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
res = _mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
||||||
return pfirst(
|
return pfirst(
|
||||||
_mm_min_epi32(
|
_mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
||||||
res,
|
|
||||||
_mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QInt32 predux_max<Packet16q32i>(const Packet16q32i& a) {
|
EIGEN_STRONG_INLINE QInt32 predux_max<Packet16q32i>(const Packet16q32i& a) {
|
||||||
@ -424,9 +422,7 @@ EIGEN_STRONG_INLINE QInt32 predux_max<Packet16q32i>(const Packet16q32i& a) {
|
|||||||
_mm_max_epi32(_mm_max_epi32(lane0, lane1), _mm_max_epi32(lane2, lane3));
|
_mm_max_epi32(_mm_max_epi32(lane0, lane1), _mm_max_epi32(lane2, lane3));
|
||||||
res = _mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
res = _mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
||||||
return pfirst(
|
return pfirst(
|
||||||
_mm_max_epi32(
|
_mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
||||||
res,
|
|
||||||
_mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QInt16 predux_min<Packet32q16i>(const Packet32q16i& a) {
|
EIGEN_STRONG_INLINE QInt16 predux_min<Packet32q16i>(const Packet32q16i& a) {
|
||||||
@ -437,13 +433,10 @@ EIGEN_STRONG_INLINE QInt16 predux_min<Packet32q16i>(const Packet32q16i& a) {
|
|||||||
Packet4i res =
|
Packet4i res =
|
||||||
_mm_min_epi16(_mm_min_epi16(lane0, lane1), _mm_min_epi16(lane2, lane3));
|
_mm_min_epi16(_mm_min_epi16(lane0, lane1), _mm_min_epi16(lane2, lane3));
|
||||||
res = _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
res = _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
||||||
std::uint32_t w =
|
std::uint32_t w = pfirst(
|
||||||
pfirst(
|
_mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
||||||
_mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
return std::min(
|
||||||
return std::min({
|
{static_cast<std::int16_t>(w >> 16), static_cast<std::int16_t>(w)});
|
||||||
static_cast<std::int16_t>(w >> 16),
|
|
||||||
static_cast<std::int16_t>(w)
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QInt16 predux_max<Packet32q16i>(const Packet32q16i& a) {
|
EIGEN_STRONG_INLINE QInt16 predux_max<Packet32q16i>(const Packet32q16i& a) {
|
||||||
@ -454,13 +447,10 @@ EIGEN_STRONG_INLINE QInt16 predux_max<Packet32q16i>(const Packet32q16i& a) {
|
|||||||
Packet4i res =
|
Packet4i res =
|
||||||
_mm_max_epi16(_mm_max_epi16(lane0, lane1), _mm_max_epi16(lane2, lane3));
|
_mm_max_epi16(_mm_max_epi16(lane0, lane1), _mm_max_epi16(lane2, lane3));
|
||||||
res = _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
res = _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
||||||
std::uint32_t w =
|
std::uint32_t w = pfirst(
|
||||||
pfirst(
|
_mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
||||||
_mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
return std::max(
|
||||||
return std::max({
|
{static_cast<std::int16_t>(w >> 16), static_cast<std::int16_t>(w)});
|
||||||
static_cast<std::int16_t>(w >> 16),
|
|
||||||
static_cast<std::int16_t>(w)
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QUInt8 predux_min<Packet64q8u>(const Packet64q8u& a) {
|
EIGEN_STRONG_INLINE QUInt8 predux_min<Packet64q8u>(const Packet64q8u& a) {
|
||||||
@ -471,15 +461,11 @@ EIGEN_STRONG_INLINE QUInt8 predux_min<Packet64q8u>(const Packet64q8u& a) {
|
|||||||
Packet4i res =
|
Packet4i res =
|
||||||
_mm_min_epu8(_mm_min_epu8(lane0, lane1), _mm_min_epu8(lane2, lane3));
|
_mm_min_epu8(_mm_min_epu8(lane0, lane1), _mm_min_epu8(lane2, lane3));
|
||||||
res = _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
res = _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
||||||
std::uint32_t w =
|
std::uint32_t w = pfirst(
|
||||||
pfirst(
|
_mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
||||||
_mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
return std::min(
|
||||||
return std::min({
|
{static_cast<std::uint8_t>(w >> 24), static_cast<std::uint8_t>(w >> 16),
|
||||||
static_cast<std::uint8_t>(w >> 24),
|
static_cast<std::uint8_t>(w >> 8), static_cast<std::uint8_t>(w)});
|
||||||
static_cast<std::uint8_t>(w >> 16),
|
|
||||||
static_cast<std::uint8_t>(w >> 8),
|
|
||||||
static_cast<std::uint8_t>(w)
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QUInt8 predux_max<Packet64q8u>(const Packet64q8u& a) {
|
EIGEN_STRONG_INLINE QUInt8 predux_max<Packet64q8u>(const Packet64q8u& a) {
|
||||||
@ -490,15 +476,11 @@ EIGEN_STRONG_INLINE QUInt8 predux_max<Packet64q8u>(const Packet64q8u& a) {
|
|||||||
Packet4i res =
|
Packet4i res =
|
||||||
_mm_max_epu8(_mm_max_epu8(lane0, lane1), _mm_max_epu8(lane2, lane3));
|
_mm_max_epu8(_mm_max_epu8(lane0, lane1), _mm_max_epu8(lane2, lane3));
|
||||||
res = _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
res = _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
||||||
std::uint32_t w =
|
std::uint32_t w = pfirst(
|
||||||
pfirst(
|
_mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
||||||
_mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
return std::max(
|
||||||
return std::max({
|
{static_cast<std::uint8_t>(w >> 24), static_cast<std::uint8_t>(w >> 16),
|
||||||
static_cast<std::uint8_t>(w >> 24),
|
static_cast<std::uint8_t>(w >> 8), static_cast<std::uint8_t>(w)});
|
||||||
static_cast<std::uint8_t>(w >> 16),
|
|
||||||
static_cast<std::uint8_t>(w >> 8),
|
|
||||||
static_cast<std::uint8_t>(w)
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QInt8 predux_min<Packet64q8i>(const Packet64q8i& a) {
|
EIGEN_STRONG_INLINE QInt8 predux_min<Packet64q8i>(const Packet64q8i& a) {
|
||||||
@ -509,15 +491,11 @@ EIGEN_STRONG_INLINE QInt8 predux_min<Packet64q8i>(const Packet64q8i& a) {
|
|||||||
Packet4i res =
|
Packet4i res =
|
||||||
_mm_min_epi8(_mm_min_epi8(lane0, lane1), _mm_min_epi8(lane2, lane3));
|
_mm_min_epi8(_mm_min_epi8(lane0, lane1), _mm_min_epi8(lane2, lane3));
|
||||||
res = _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
res = _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
||||||
std::uint32_t w =
|
std::uint32_t w = pfirst(
|
||||||
pfirst(
|
_mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
||||||
_mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
return std::min(
|
||||||
return std::min({
|
{static_cast<std::int8_t>(w >> 24), static_cast<std::int8_t>(w >> 16),
|
||||||
static_cast<std::int8_t>(w >> 24),
|
static_cast<std::int8_t>(w >> 8), static_cast<std::int8_t>(w)});
|
||||||
static_cast<std::int8_t>(w >> 16),
|
|
||||||
static_cast<std::int8_t>(w >> 8),
|
|
||||||
static_cast<std::int8_t>(w)
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE QInt8 predux_max<Packet64q8i>(const Packet64q8i& a) {
|
EIGEN_STRONG_INLINE QInt8 predux_max<Packet64q8i>(const Packet64q8i& a) {
|
||||||
@ -528,15 +506,11 @@ EIGEN_STRONG_INLINE QInt8 predux_max<Packet64q8i>(const Packet64q8i& a) {
|
|||||||
Packet4i res =
|
Packet4i res =
|
||||||
_mm_max_epi8(_mm_max_epi8(lane0, lane1), _mm_max_epi8(lane2, lane3));
|
_mm_max_epi8(_mm_max_epi8(lane0, lane1), _mm_max_epi8(lane2, lane3));
|
||||||
res = _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
res = _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2)));
|
||||||
std::uint32_t w =
|
std::uint32_t w = pfirst(
|
||||||
pfirst(
|
_mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
||||||
_mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))));
|
return std::min(
|
||||||
return std::min({
|
{static_cast<std::int8_t>(w >> 24), static_cast<std::int8_t>(w >> 16),
|
||||||
static_cast<std::int8_t>(w >> 24),
|
static_cast<std::int8_t>(w >> 8), static_cast<std::int8_t>(w)});
|
||||||
static_cast<std::int8_t>(w >> 16),
|
|
||||||
static_cast<std::int8_t>(w >> 8),
|
|
||||||
static_cast<std::int8_t>(w)
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
@ -33,28 +33,23 @@ struct type_casting_traits<float, QInt16> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE Packet32q16i
|
EIGEN_STRONG_INLINE Packet32q16i pcast<Packet16f>(const Packet16f& a,
|
||||||
pcast<Packet16f>(const Packet16f& a, const Packet16f& b) {
|
const Packet16f& b) {
|
||||||
Packet16i a_int = _mm512_cvtps_epi32(a);
|
Packet16i a_int = _mm512_cvtps_epi32(a);
|
||||||
Packet16i b_int = _mm512_cvtps_epi32(b);
|
Packet16i b_int = _mm512_cvtps_epi32(b);
|
||||||
#ifdef EIGEN_VECTORIZE_AVX512BW
|
#ifdef EIGEN_VECTORIZE_AVX512BW
|
||||||
return _mm512_packs_epi32(a_int, b_int);
|
return _mm512_packs_epi32(a_int, b_int);
|
||||||
#else
|
#else
|
||||||
Packet8i ab_int16_low =
|
Packet8i ab_int16_low = _mm256_permute4x64_epi64(
|
||||||
_mm256_permute4x64_epi64(
|
_mm256_packs_epi32(_mm512_castsi512_si256(a_int),
|
||||||
_mm256_packs_epi32(
|
_mm512_castsi512_si256(b_int)),
|
||||||
_mm512_castsi512_si256(a_int),
|
_MM_SHUFFLE(0, 2, 1, 3));
|
||||||
_mm512_castsi512_si256(b_int)),
|
Packet8i ab_int16_high = _mm256_permute4x64_epi64(
|
||||||
_MM_SHUFFLE(0, 2, 1, 3));
|
_mm256_packs_epi32(_mm512_extracti32x8_epi32(a_int, 1),
|
||||||
Packet8i ab_int16_high =
|
_mm512_extracti32x8_epi32(b_int, 1)),
|
||||||
_mm256_permute4x64_epi64(
|
_MM_SHUFFLE(0, 2, 1, 3));
|
||||||
_mm256_packs_epi32(
|
return _mm512_inserti32x8(_mm512_castsi256_si512(ab_int16_low), ab_int16_high,
|
||||||
_mm512_extracti32x8_epi32(a_int, 1),
|
1);
|
||||||
_mm512_extracti32x8_epi32(b_int, 1)),
|
|
||||||
_MM_SHUFFLE(0, 2, 1, 3));
|
|
||||||
return _mm512_inserti32x8(
|
|
||||||
_mm512_castsi256_si512(ab_int16_low),
|
|
||||||
ab_int16_high, 1);
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,55 +59,41 @@ struct type_casting_traits<float, QInt8> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE Packet64q8i
|
EIGEN_STRONG_INLINE Packet64q8i pcast<Packet16f>(const Packet16f& a,
|
||||||
pcast<Packet16f>(const Packet16f& a,
|
const Packet16f& b,
|
||||||
const Packet16f& b,
|
const Packet16f& c,
|
||||||
const Packet16f& c,
|
const Packet16f& d) {
|
||||||
const Packet16f& d) {
|
|
||||||
Packet16i a_int = _mm512_cvtps_epi32(a);
|
Packet16i a_int = _mm512_cvtps_epi32(a);
|
||||||
Packet16i b_int = _mm512_cvtps_epi32(b);
|
Packet16i b_int = _mm512_cvtps_epi32(b);
|
||||||
Packet16i c_int = _mm512_cvtps_epi32(c);
|
Packet16i c_int = _mm512_cvtps_epi32(c);
|
||||||
Packet16i d_int = _mm512_cvtps_epi32(d);
|
Packet16i d_int = _mm512_cvtps_epi32(d);
|
||||||
#ifdef EIGEN_VECTORIZE_AVX512BW
|
#ifdef EIGEN_VECTORIZE_AVX512BW
|
||||||
return _mm512_packs_epi16(
|
return _mm512_packs_epi16(_mm512_packs_epi32(a_int, b_int),
|
||||||
_mm512_packs_epi32(a_int, b_int),
|
_mm512_packs_epi32(c_int, d_int));
|
||||||
_mm512_packs_epi32(c_int, d_int));
|
|
||||||
#else
|
#else
|
||||||
Packet8i ab_int16_low =
|
Packet8i ab_int16_low = _mm256_permute4x64_epi64(
|
||||||
_mm256_permute4x64_epi64(
|
_mm256_packs_epi32(_mm512_castsi512_si256(a_int),
|
||||||
_mm256_packs_epi32(
|
_mm512_castsi512_si256(b_int)),
|
||||||
_mm512_castsi512_si256(a_int),
|
_MM_SHUFFLE(0, 2, 1, 3));
|
||||||
_mm512_castsi512_si256(b_int)),
|
Packet8i cd_int16_low = _mm256_permute4x64_epi64(
|
||||||
_MM_SHUFFLE(0, 2, 1, 3));
|
_mm256_packs_epi32(_mm512_castsi512_si256(c_int),
|
||||||
Packet8i cd_int16_low =
|
_mm512_castsi512_si256(d_int)),
|
||||||
_mm256_permute4x64_epi64(
|
_MM_SHUFFLE(0, 2, 1, 3));
|
||||||
_mm256_packs_epi32(
|
Packet8i ab_int16_high = _mm256_permute4x64_epi64(
|
||||||
_mm512_castsi512_si256(c_int),
|
_mm256_packs_epi32(_mm512_extracti32x8_epi32(a_int, 1),
|
||||||
_mm512_castsi512_si256(d_int)),
|
_mm512_extracti32x8_epi32(b_int, 1)),
|
||||||
_MM_SHUFFLE(0, 2, 1, 3));
|
_MM_SHUFFLE(0, 2, 1, 3));
|
||||||
Packet8i ab_int16_high =
|
Packet8i cd_int16_high = _mm256_permute4x64_epi64(
|
||||||
_mm256_permute4x64_epi64(
|
_mm256_packs_epi32(_mm512_extracti32x8_epi32(c_int, 1),
|
||||||
_mm256_packs_epi32(
|
_mm512_extracti32x8_epi32(d_int, 1)),
|
||||||
_mm512_extracti32x8_epi32(a_int, 1),
|
_MM_SHUFFLE(0, 2, 1, 3));
|
||||||
_mm512_extracti32x8_epi32(b_int, 1)),
|
Packet8i abcd_int8_low = _mm256_permute4x64_epi64(
|
||||||
_MM_SHUFFLE(0, 2, 1, 3));
|
_mm256_packs_epi16(ab_int16_low, cd_int16_low), _MM_SHUFFLE(0, 2, 1, 3));
|
||||||
Packet8i cd_int16_high =
|
|
||||||
_mm256_permute4x64_epi64(
|
|
||||||
_mm256_packs_epi32(
|
|
||||||
_mm512_extracti32x8_epi32(c_int, 1),
|
|
||||||
_mm512_extracti32x8_epi32(d_int, 1)),
|
|
||||||
_MM_SHUFFLE(0, 2, 1, 3));
|
|
||||||
Packet8i abcd_int8_low =
|
|
||||||
_mm256_permute4x64_epi64(
|
|
||||||
_mm256_packs_epi16(ab_int16_low, cd_int16_low),
|
|
||||||
_MM_SHUFFLE(0, 2, 1, 3));
|
|
||||||
Packet8i abcd_int8_high =
|
Packet8i abcd_int8_high =
|
||||||
_mm256_permute4x64_epi64(
|
_mm256_permute4x64_epi64(_mm256_packs_epi16(ab_int16_high, cd_int16_high),
|
||||||
_mm256_packs_epi16(ab_int16_high, cd_int16_high),
|
_MM_SHUFFLE(0, 2, 1, 3));
|
||||||
_MM_SHUFFLE(0, 2, 1, 3));
|
return _mm512_inserti32x8(_mm512_castsi256_si512(abcd_int8_low),
|
||||||
return _mm512_inserti32x8(
|
abcd_int8_high, 1);
|
||||||
_mm512_castsi256_si512(abcd_int8_low),
|
|
||||||
abcd_int8_high, 1);
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,10 +109,8 @@ struct type_casting_traits<QInt32, QInt16> {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE Packet64q8i
|
EIGEN_STRONG_INLINE Packet64q8i
|
||||||
pcast<Packet16q32i, Packet64q8i>(const Packet16q32i& a,
|
pcast<Packet16q32i, Packet64q8i>(const Packet16q32i& a, const Packet16q32i& b,
|
||||||
const Packet16q32i& b,
|
const Packet16q32i& c, const Packet16q32i& d) {
|
||||||
const Packet16q32i& c,
|
|
||||||
const Packet16q32i& d) {
|
|
||||||
__m128i a_part = _mm512_cvtsepi32_epi8(a);
|
__m128i a_part = _mm512_cvtsepi32_epi8(a);
|
||||||
__m128i b_part = _mm512_cvtsepi32_epi8(b);
|
__m128i b_part = _mm512_cvtsepi32_epi8(b);
|
||||||
__m128i c_part = _mm512_cvtsepi32_epi8(c);
|
__m128i c_part = _mm512_cvtsepi32_epi8(c);
|
||||||
@ -145,9 +124,8 @@ pcast<Packet16q32i, Packet64q8i>(const Packet16q32i& a,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE Packet32q16i
|
EIGEN_STRONG_INLINE Packet32q16i pcast<Packet16q32i, Packet32q16i>(
|
||||||
pcast<Packet16q32i, Packet32q16i>(const Packet16q32i& a,
|
const Packet16q32i& a, const Packet16q32i& b) {
|
||||||
const Packet16q32i& b) {
|
|
||||||
__m256i a_part = _mm512_cvtsepi32_epi16(a);
|
__m256i a_part = _mm512_cvtsepi32_epi16(a);
|
||||||
__m256i b_part = _mm512_cvtsepi32_epi16(b);
|
__m256i b_part = _mm512_cvtsepi32_epi16(b);
|
||||||
__m512i converted =
|
__m512i converted =
|
||||||
|
48
third_party/eigen_reshaped.patch
vendored
Normal file
48
third_party/eigen_reshaped.patch
vendored
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
--- a/Eigen/src/Core/util/ReshapedHelper.h (date 1541195478000)
|
||||||
|
+++ b/Eigen/src/Core/util/ReshapedHelper.h (date 1541195478000)
|
||||||
|
@@ -39,6 +39,11 @@
|
||||||
|
return total/other;
|
||||||
|
}
|
||||||
|
|
||||||
|
+template<int Flags, int Order>
|
||||||
|
+struct get_compiletime_reshape_order {
|
||||||
|
+ enum { value = Order == AutoOrder ? Flags & RowMajorBit : Order };
|
||||||
|
+};
|
||||||
|
+
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace Eigen
|
||||||
|
--- a/Eigen/src/plugins/ReshapedMethods.h (date 1541195254000)
|
||||||
|
+++ b/Eigen/src/plugins/ReshapedMethods.h (date 1541195254000)
|
||||||
|
@@ -105,13 +105,13 @@
|
||||||
|
inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
|
||||||
|
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
|
||||||
|
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
|
||||||
|
- (Order==AutoOrder?Flags&RowMajorBit:Order)>
|
||||||
|
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
|
||||||
|
reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST
|
||||||
|
{
|
||||||
|
return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
|
||||||
|
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
|
||||||
|
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
|
||||||
|
- (Order==AutoOrder?Flags&RowMajorBit:Order)>
|
||||||
|
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
|
||||||
|
(derived(),
|
||||||
|
internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
|
||||||
|
internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
|
||||||
|
@@ -128,11 +128,13 @@
|
||||||
|
|
||||||
|
template<int Order>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
-inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, SizeAtCompileTime, 1, (Order==AutoOrder?Flags&RowMajorBit:Order)>
|
||||||
|
+inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, SizeAtCompileTime, 1,
|
||||||
|
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
|
||||||
|
reshaped() EIGEN_RESHAPED_METHOD_CONST
|
||||||
|
{
|
||||||
|
EIGEN_STATIC_ASSERT(Order==RowMajor || Order==ColMajor || Order==AutoOrder, INVALID_TEMPLATE_PARAMETER);
|
||||||
|
- return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, SizeAtCompileTime, 1, (Order==AutoOrder?Flags&RowMajorBit:Order)>
|
||||||
|
+ return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, SizeAtCompileTime, 1,
|
||||||
|
+ internal::get_compiletime_reshape_order<Flags,Order>::value>
|
||||||
|
(derived(), size(), 1);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user