Work around an ODR violation. This template specialization is not always seen
and the compiler/linker can choose methods and/or struct instantiations from either this specialization or the primary template. 1. The `blocks` method is marked static in one place and not in the other meaning the arguments passed were shifted by 8 bytes due to the implicit this-pointer. 2. The sizes of the structs in the specialization versus primary template were different and so reading the values were reading uninitialized memory passed the end of the struct. Note: THIS DOES NOT FIX THE UNDERLYING ODR VIOLATION. PiperOrigin-RevId: 321077691 Change-Id: I3998c19ed8983438001b6621719a18d032f4fafe
This commit is contained in:
parent
84ce94e5ea
commit
07daafc869
@ -524,14 +524,7 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
RhsMapper> { \
|
||||
TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \
|
||||
StorageIndex bm, StorageIndex bk, StorageIndex bn) \
|
||||
: m(m), \
|
||||
k(k), \
|
||||
n(n), \
|
||||
bm(bm), \
|
||||
bk(bk), \
|
||||
bn(bn), \
|
||||
nm0(bm > 0 ? divup(m, bm) : 0), \
|
||||
nn0(bn > 0 ? divup(n, bn) : 0) {} \
|
||||
: m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} \
|
||||
\
|
||||
enum { HasBeta = true }; \
|
||||
\
|
||||
@ -616,7 +609,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
} \
|
||||
\
|
||||
template <typename Device> \
|
||||
EIGEN_DEVICE_FUNC void deallocate(Device& d, BlockMemHandle handle) { \
|
||||
EIGEN_DEVICE_FUNC static void deallocate(Device& d, \
|
||||
BlockMemHandle handle) { \
|
||||
BlockMemAllocator::deallocate(d, handle); \
|
||||
} \
|
||||
\
|
||||
@ -626,7 +620,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
if (UseCustomContractionKernels()) { \
|
||||
const bool is_direct_access = \
|
||||
DirectLhsAccess::value && \
|
||||
DirectLhsAccess::block(data_mapper, rows, depth, nn0, lhsBlock); \
|
||||
DirectLhsAccess::block(data_mapper, rows, depth, \
|
||||
bn > 0 ? divup(n, bn) : 0, lhsBlock); \
|
||||
\
|
||||
if (!is_direct_access) { \
|
||||
lhsBlock->is_direct_access = false; \
|
||||
@ -645,7 +640,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
if (UseCustomContractionKernels()) { \
|
||||
const bool is_direct_access = \
|
||||
DirectRhsAccess::value && \
|
||||
DirectRhsAccess::block(data_mapper, depth, cols, nm0, rhsBlock); \
|
||||
DirectRhsAccess::block(data_mapper, depth, cols, \
|
||||
bm > 0 ? divup(m, bm) : 0, rhsBlock); \
|
||||
\
|
||||
if (!is_direct_access) { \
|
||||
rhsBlock->is_direct_access = false; \
|
||||
@ -723,9 +719,6 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
const StorageIndex bm; \
|
||||
const StorageIndex bk; \
|
||||
const StorageIndex bn; \
|
||||
/* Number of kernels for each dimension. */ \
|
||||
const StorageIndex nm0; \
|
||||
const StorageIndex nn0; \
|
||||
}
|
||||
|
||||
// Tensor contraction kernel that do not fallback on Eigen. Currently not all
|
||||
@ -740,14 +733,7 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
RhsMapper> { \
|
||||
TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \
|
||||
StorageIndex bm, StorageIndex bk, StorageIndex bn) \
|
||||
: m(m), \
|
||||
k(k), \
|
||||
n(n), \
|
||||
bm(bm), \
|
||||
bk(bk), \
|
||||
bn(bn), \
|
||||
nm0(bm > 0 ? divup(m, bm) : 0), \
|
||||
nn0(bn > 0 ? divup(n, bn) : 0) {} \
|
||||
: m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} \
|
||||
\
|
||||
enum { HasBeta = true }; \
|
||||
\
|
||||
@ -818,7 +804,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
} \
|
||||
\
|
||||
template <typename Device> \
|
||||
EIGEN_DEVICE_FUNC void deallocate(Device& d, BlockMemHandle handle) { \
|
||||
EIGEN_DEVICE_FUNC static void deallocate(Device& d, \
|
||||
BlockMemHandle handle) { \
|
||||
BlockMemAllocator::deallocate(d, handle); \
|
||||
} \
|
||||
\
|
||||
@ -827,7 +814,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
const StorageIndex depth, const StorageIndex rows) { \
|
||||
const bool is_direct_access = \
|
||||
DirectLhsAccess::value && \
|
||||
DirectLhsAccess::block(data_mapper, rows, depth, nn0, lhsBlock); \
|
||||
DirectLhsAccess::block(data_mapper, rows, depth, \
|
||||
bn > 0 ? divup(n, bn) : 0, lhsBlock); \
|
||||
\
|
||||
if (!is_direct_access) { \
|
||||
lhsBlock->is_direct_access = false; \
|
||||
@ -840,7 +828,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
const StorageIndex depth, const StorageIndex cols) { \
|
||||
const bool is_direct_access = \
|
||||
DirectRhsAccess::value && \
|
||||
DirectRhsAccess::block(data_mapper, depth, cols, nm0, rhsBlock); \
|
||||
DirectRhsAccess::block(data_mapper, depth, cols, \
|
||||
bm > 0 ? divup(m, bm) : 0, rhsBlock); \
|
||||
\
|
||||
if (!is_direct_access) { \
|
||||
rhsBlock->is_direct_access = false; \
|
||||
@ -890,9 +879,6 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
|
||||
const StorageIndex bm; \
|
||||
const StorageIndex bk; \
|
||||
const StorageIndex bn; \
|
||||
/* Number of kernels for each dimension. */ \
|
||||
const StorageIndex nm0; \
|
||||
const StorageIndex nn0; \
|
||||
}
|
||||
|
||||
REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK(float, float, float);
|
||||
|
Loading…
Reference in New Issue
Block a user