Work around an ODR violation. This template specialization is not always seen

and the compiler/linker can choose methods and/or struct instantiations from
either this specialization or the primary template.

1. The `blocks` method is marked static in one place and not in the other
meaning the arguments passed were shifted by 8 bytes due to the implicit
this-pointer.
2. The sizes of the structs in the specialization versus primary template were
different and so reading the values were reading uninitialized memory passed
the end of the struct.

Note: THIS DOES NOT FIX THE UNDERLYING ODR VIOLATION.
PiperOrigin-RevId: 321077691
Change-Id: I3998c19ed8983438001b6621719a18d032f4fafe
This commit is contained in:
A. Unique TensorFlower 2020-07-13 18:53:57 -07:00 committed by TensorFlower Gardener
parent 84ce94e5ea
commit 07daafc869

View File

@ -524,14 +524,7 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
RhsMapper> { \
TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \
StorageIndex bm, StorageIndex bk, StorageIndex bn) \
: m(m), \
k(k), \
n(n), \
bm(bm), \
bk(bk), \
bn(bn), \
nm0(bm > 0 ? divup(m, bm) : 0), \
nn0(bn > 0 ? divup(n, bn) : 0) {} \
: m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} \
\
enum { HasBeta = true }; \
\
@ -616,7 +609,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
} \
\
template <typename Device> \
EIGEN_DEVICE_FUNC void deallocate(Device& d, BlockMemHandle handle) { \
EIGEN_DEVICE_FUNC static void deallocate(Device& d, \
BlockMemHandle handle) { \
BlockMemAllocator::deallocate(d, handle); \
} \
\
@ -626,7 +620,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
if (UseCustomContractionKernels()) { \
const bool is_direct_access = \
DirectLhsAccess::value && \
DirectLhsAccess::block(data_mapper, rows, depth, nn0, lhsBlock); \
DirectLhsAccess::block(data_mapper, rows, depth, \
bn > 0 ? divup(n, bn) : 0, lhsBlock); \
\
if (!is_direct_access) { \
lhsBlock->is_direct_access = false; \
@ -645,7 +640,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
if (UseCustomContractionKernels()) { \
const bool is_direct_access = \
DirectRhsAccess::value && \
DirectRhsAccess::block(data_mapper, depth, cols, nm0, rhsBlock); \
DirectRhsAccess::block(data_mapper, depth, cols, \
bm > 0 ? divup(m, bm) : 0, rhsBlock); \
\
if (!is_direct_access) { \
rhsBlock->is_direct_access = false; \
@ -723,9 +719,6 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
const StorageIndex bm; \
const StorageIndex bk; \
const StorageIndex bn; \
/* Number of kernels for each dimension. */ \
const StorageIndex nm0; \
const StorageIndex nn0; \
}
// Tensor contraction kernel that do not fallback on Eigen. Currently not all
@ -740,14 +733,7 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
RhsMapper> { \
TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \
StorageIndex bm, StorageIndex bk, StorageIndex bn) \
: m(m), \
k(k), \
n(n), \
bm(bm), \
bk(bk), \
bn(bn), \
nm0(bm > 0 ? divup(m, bm) : 0), \
nn0(bn > 0 ? divup(n, bn) : 0) {} \
: m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} \
\
enum { HasBeta = true }; \
\
@ -818,7 +804,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
} \
\
template <typename Device> \
EIGEN_DEVICE_FUNC void deallocate(Device& d, BlockMemHandle handle) { \
EIGEN_DEVICE_FUNC static void deallocate(Device& d, \
BlockMemHandle handle) { \
BlockMemAllocator::deallocate(d, handle); \
} \
\
@ -827,7 +814,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
const StorageIndex depth, const StorageIndex rows) { \
const bool is_direct_access = \
DirectLhsAccess::value && \
DirectLhsAccess::block(data_mapper, rows, depth, nn0, lhsBlock); \
DirectLhsAccess::block(data_mapper, rows, depth, \
bn > 0 ? divup(n, bn) : 0, lhsBlock); \
\
if (!is_direct_access) { \
lhsBlock->is_direct_access = false; \
@ -840,7 +828,8 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
const StorageIndex depth, const StorageIndex cols) { \
const bool is_direct_access = \
DirectRhsAccess::value && \
DirectRhsAccess::block(data_mapper, depth, cols, nm0, rhsBlock); \
DirectRhsAccess::block(data_mapper, depth, cols, \
bm > 0 ? divup(m, bm) : 0, rhsBlock); \
\
if (!is_direct_access) { \
rhsBlock->is_direct_access = false; \
@ -890,9 +879,6 @@ struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
const StorageIndex bm; \
const StorageIndex bk; \
const StorageIndex bn; \
/* Number of kernels for each dimension. */ \
const StorageIndex nm0; \
const StorageIndex nn0; \
}
REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK(float, float, float);