Ruy: Introduce x86 (AVX-512) code.
PiperOrigin-RevId: 260750932
This commit is contained in:
parent
338de90a40
commit
72867cc5bc
@ -264,11 +264,13 @@ cc_library(
|
||||
srcs = [
|
||||
"kernel_arm32.cc",
|
||||
"kernel_arm64.cc",
|
||||
"kernel_avx512.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"kernel.h",
|
||||
],
|
||||
deps = [
|
||||
":check_macros",
|
||||
":common",
|
||||
":internal_matrix",
|
||||
":opt_set",
|
||||
@ -287,6 +289,7 @@ cc_library(
|
||||
name = "pack",
|
||||
srcs = [
|
||||
"pack_arm.cc",
|
||||
"pack_avx512.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"pack.h",
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_COMMON_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
@ -28,6 +29,11 @@ limitations under the License.
|
||||
#include "tensorflow/lite/experimental/ruy/path.h"
|
||||
#include "tensorflow/lite/experimental/ruy/platform.h"
|
||||
|
||||
// TODO(b/138449463): also guard by RUY_OPT_ENABLED(RUY_OPT_INTRINSICS).
|
||||
#if RUY_PLATFORM(AVX512)
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
#if (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32))
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
@ -211,12 +211,18 @@ struct Kernel<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar, Spec> {
|
||||
: Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec>(tuning) {} \
|
||||
};
|
||||
|
||||
#if RUY_PLATFORM(NEON)
|
||||
RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon)
|
||||
RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
|
||||
#elif RUY_PLATFORM(AVX512)
|
||||
RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx512)
|
||||
#endif
|
||||
|
||||
// KernelParams are shared across 32-bit and 64-bit NEON code.
|
||||
#if (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \
|
||||
(RUY_OPT_ENABLED(RUY_OPT_ASM))
|
||||
// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 AVX-512
|
||||
// code.
|
||||
#if (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32) || \
|
||||
RUY_PLATFORM(AVX512)) && \
|
||||
RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
|
||||
#define RUY_ASM_FLAG_HAS_BIAS 0x1
|
||||
#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
|
||||
@ -362,10 +368,12 @@ void MakeKernelParams8bit(const PackedMatrix<std::int8_t>& lhs,
|
||||
dst->data.get() + start_col * dst->layout.stride + start_row;
|
||||
}
|
||||
|
||||
#if RUY_PLATFORM(NEON)
|
||||
void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params);
|
||||
void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params);
|
||||
void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params);
|
||||
void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params);
|
||||
#endif
|
||||
|
||||
#if RUY_PLATFORM(NEON_64)
|
||||
template <typename DstScalar>
|
||||
@ -482,10 +490,12 @@ inline void MakeKernelParamsFloat(const PackedMatrix<float>& lhs,
|
||||
RUY_DCHECK_LT(params->last_col, params->dst_cols);
|
||||
}
|
||||
|
||||
#if RUY_PLATFORM(NEON)
|
||||
void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params);
|
||||
void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params);
|
||||
void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params);
|
||||
void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params);
|
||||
#endif
|
||||
|
||||
#if RUY_PLATFORM(NEON_64)
|
||||
// A Float kernel for ARM64 Neon.
|
||||
@ -531,6 +541,7 @@ struct Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>> {
|
||||
};
|
||||
#endif
|
||||
|
||||
#if RUY_PLATFORM(NEON)
|
||||
// While the dotprod NEON extension does not concern floating-point arithmetic,
|
||||
// its presence allows us to distinguish, in the in-order tuning case, between
|
||||
// A53 and A55r1. TODO: should this be folded into tuning?
|
||||
@ -556,9 +567,53 @@ struct Kernel<Path::kNeonDotprod, float, float, float,
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#if RUY_PLATFORM(AVX512)
|
||||
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
|
||||
|
||||
template <typename DstScalar>
|
||||
struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, DstScalar,
|
||||
BasicSpec<std::int32_t, DstScalar>> {
|
||||
Tuning tuning = Tuning::kAuto;
|
||||
using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
|
||||
using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
|
||||
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
|
||||
void Run(const PackedMatrix<std::int8_t>& lhs,
|
||||
const PackedMatrix<std::int8_t>& rhs,
|
||||
const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
|
||||
int start_col, int end_row, int end_col,
|
||||
Matrix<DstScalar>* dst) const {
|
||||
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
|
||||
MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
|
||||
dst, ¶ms);
|
||||
Kernel8bitAvx512(params);
|
||||
}
|
||||
};
|
||||
|
||||
void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
|
||||
|
||||
template <>
|
||||
struct Kernel<Path::kAvx512, float, float, float, BasicSpec<float, float>> {
|
||||
Tuning tuning = Tuning::kAuto;
|
||||
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
|
||||
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
|
||||
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
|
||||
void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
|
||||
const BasicSpec<float, float>& spec, int start_row, int start_col,
|
||||
int end_row, int end_col, Matrix<float>* dst) const {
|
||||
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
|
||||
MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
|
||||
end_col, dst, ¶ms);
|
||||
KernelFloatAvx512(params);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32) || \
|
||||
// RUY_PLATFORM(AVX512)) && \
|
||||
// RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
|
||||
#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) &&
|
||||
// (RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
} // namespace ruy
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_
|
||||
|
805
tensorflow/lite/experimental/ruy/kernel_avx512.cc
Normal file
805
tensorflow/lite/experimental/ruy/kernel_avx512.cc
Normal file
@ -0,0 +1,805 @@
|
||||
/* Copyright 2019 Google LLC. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "profiling/instrumentation.h"
|
||||
#include "tensorflow/lite/experimental/ruy/check_macros.h"
|
||||
#include "tensorflow/lite/experimental/ruy/kernel.h"
|
||||
#include "tensorflow/lite/experimental/ruy/platform.h"
|
||||
|
||||
namespace ruy {
|
||||
|
||||
#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
|
||||
inline std::int32_t mm512_get1_epi32(const __m512i v, int i) {
|
||||
__m256i a =
|
||||
i < 8 ? _mm512_extracti32x8_epi32(v, 0) : _mm512_extracti32x8_epi32(v, 1);
|
||||
switch (i & ~8) {
|
||||
case 0:
|
||||
return _mm256_extract_epi32(a, 0);
|
||||
case 1:
|
||||
return _mm256_extract_epi32(a, 1);
|
||||
case 2:
|
||||
return _mm256_extract_epi32(a, 2);
|
||||
case 3:
|
||||
return _mm256_extract_epi32(a, 3);
|
||||
case 4:
|
||||
return _mm256_extract_epi32(a, 4);
|
||||
case 5:
|
||||
return _mm256_extract_epi32(a, 5);
|
||||
case 6:
|
||||
return _mm256_extract_epi32(a, 6);
|
||||
case 7:
|
||||
return _mm256_extract_epi32(a, 7);
|
||||
default:
|
||||
RUY_DCHECK(i < 16);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
inline __m512i mm512_set1_epi32(__m512i* v, int i, std::int32_t x) {
|
||||
return *v = _mm512_mask_set1_epi32(*v, 1 << i, x);
|
||||
}
|
||||
|
||||
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
|
||||
gemmlowp::ScopedProfilingLabel label("Kernel kAvx512");
|
||||
|
||||
std::int32_t dst_stride;
|
||||
if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
|
||||
(params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
|
||||
dst_stride = params.dst_stride;
|
||||
} else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
|
||||
dst_stride = params.dst_stride / sizeof(std::int16_t);
|
||||
} else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
|
||||
dst_stride = params.dst_stride / sizeof(std::int32_t);
|
||||
} else {
|
||||
RUY_DCHECK(false);
|
||||
}
|
||||
|
||||
int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
|
||||
|
||||
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
|
||||
void* dst_col_ptr = params.dst_base_ptr;
|
||||
const std::int32_t* bias_col_ptr = params.bias;
|
||||
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
|
||||
bias_col_ptr += params.start_row;
|
||||
}
|
||||
|
||||
for (int col = params.start_col; col <= params.last_col; col += 16) {
|
||||
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
|
||||
void* dst_ptr = dst_col_ptr;
|
||||
const std::int32_t* bias_ptr = bias_col_ptr;
|
||||
|
||||
for (int row = params.start_row; row <= params.last_row; row += 16) {
|
||||
const int residual_rows = std::min(params.dst_rows - row, 16);
|
||||
const int residual_cols = std::min(params.dst_cols - col, 16);
|
||||
|
||||
__m512i accum_data_v[16];
|
||||
__m512i accum_data_v_low[16];
|
||||
__m512i accum_data_v_high[16];
|
||||
|
||||
// Initialize with bias.
|
||||
const __mmask16 row_mask =
|
||||
(static_cast<std::uint32_t>(1) << residual_rows) - 1;
|
||||
const __m512i initial_accum_data =
|
||||
_mm512_maskz_loadu_epi32(row_mask, bias_ptr);
|
||||
__m512i initial_accum_data_low = initial_accum_data;
|
||||
__m512i initial_accum_data_high = _mm512_setzero_epi32();
|
||||
bias_ptr += bias_ptr_block_increment;
|
||||
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
accum_data_v_low[j] = initial_accum_data_low;
|
||||
accum_data_v_high[j] = initial_accum_data_high;
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
const std::int8_t* lhs_ptr = lhs_col_ptr;
|
||||
const std::int8_t* rhs_ptr = rhs_col_ptr;
|
||||
for (int d = 0; d < params.depth; d += 4) {
|
||||
const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr);
|
||||
__m512i rhs_data = _mm512_loadu_epi8(rhs_ptr);
|
||||
|
||||
// Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
|
||||
__m512i lhs_16_bit_low =
|
||||
_mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
|
||||
// Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
|
||||
__m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
|
||||
_mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
|
||||
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
// Mask that drops the 0th element.
|
||||
static constexpr std::uint16_t shift_mask = 0xfffe;
|
||||
const __m256i dup_rhs_element_low =
|
||||
_mm256_broadcastw_epi16(_mm512_castsi512_si128(rhs_data));
|
||||
// Shift rhs_data, moving next element into 0 position.
|
||||
const __m256i dup_rhs_element_high = _mm256_set1_epi16(
|
||||
_mm_extract_epi16(_mm512_castsi512_si128(rhs_data), 1));
|
||||
// Shift rhs_data, moving next element into 0 position.
|
||||
rhs_data = _mm512_maskz_compress_epi32(shift_mask, rhs_data);
|
||||
|
||||
__m512i rhs_16_bit_dup_low =
|
||||
_mm512_cvtepi8_epi16(dup_rhs_element_low);
|
||||
__m512i rhs_16_bit_dup_high =
|
||||
_mm512_cvtepi8_epi16(dup_rhs_element_high);
|
||||
|
||||
accum_data_v_low[j] = _mm512_add_epi32(
|
||||
accum_data_v_low[j],
|
||||
_mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
|
||||
accum_data_v_high[j] = _mm512_add_epi32(
|
||||
accum_data_v_high[j],
|
||||
_mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
|
||||
}
|
||||
|
||||
lhs_ptr += 16 * 4;
|
||||
rhs_ptr += 16 * 4;
|
||||
}
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
accum_data_v[j] =
|
||||
_mm512_add_epi32(accum_data_v_low[j], accum_data_v_high[j]);
|
||||
}
|
||||
|
||||
// Move most of this up to bias, or even outside row loop.
|
||||
|
||||
const std::int32_t lhs_zero_point = params.lhs_zero_point;
|
||||
const std::int32_t rhs_zero_point = params.rhs_zero_point;
|
||||
const std::int32_t prod_zp_depth = params.prod_zp_depth;
|
||||
if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
|
||||
const __m512i lhs_sums_offset =
|
||||
_mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
|
||||
_mm512_loadu_epi32(¶ms.lhs_sums[row]));
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
accum_data_v[j] = _mm512_sub_epi32(accum_data_v[j], lhs_sums_offset);
|
||||
}
|
||||
}
|
||||
if (((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point) ||
|
||||
prod_zp_depth) {
|
||||
__m512i non_lhs_sums_offset =
|
||||
_mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
|
||||
_mm512_loadu_epi32(¶ms.rhs_sums[col]));
|
||||
non_lhs_sums_offset = _mm512_sub_epi32(
|
||||
non_lhs_sums_offset, _mm512_set1_epi32(prod_zp_depth));
|
||||
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
accum_data_v[j] = _mm512_sub_epi32(
|
||||
accum_data_v[j],
|
||||
_mm512_set1_epi32(mm512_get1_epi32(non_lhs_sums_offset, j)));
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
|
||||
__m512i m_vector;
|
||||
__m512i e_vector;
|
||||
// Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
|
||||
if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
|
||||
m_vector = _mm512_maskz_loadu_epi32(
|
||||
row_mask, ¶ms.multiplier_fixedpoint[row]);
|
||||
e_vector = _mm512_maskz_loadu_epi32(row_mask,
|
||||
¶ms.multiplier_exponent[row]);
|
||||
} else {
|
||||
// These arrays have size LhsCols, and are pre-filled.
|
||||
m_vector =
|
||||
_mm512_maskz_loadu_epi32(row_mask, params.multiplier_fixedpoint);
|
||||
e_vector =
|
||||
_mm512_maskz_loadu_epi32(row_mask, params.multiplier_exponent);
|
||||
}
|
||||
|
||||
const __m512i m_64bit_low =
|
||||
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
|
||||
const __m512i m_64bit_high =
|
||||
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
|
||||
|
||||
const __m512i zero_vector = _mm512_setzero_epi32();
|
||||
const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
|
||||
const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
|
||||
const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
|
||||
const __m512i final_right_shift =
|
||||
_mm512_add_epi32(right_shift, _mm512_set1_epi32(31));
|
||||
const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
|
||||
_mm512_extracti32x8_epi32(final_right_shift, 0));
|
||||
const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
|
||||
_mm512_extracti32x8_epi32(final_right_shift, 1));
|
||||
|
||||
const __m512i offset_vector =
|
||||
_mm512_slli_epi64(_mm512_set1_epi64(1), 30);
|
||||
// Really these should be shifted by neg_e_vector, but tests pass when
|
||||
// using right_shift.
|
||||
const __m512i offset_vector_low = _mm512_sllv_epi64(
|
||||
offset_vector,
|
||||
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)));
|
||||
const __m512i offset_vector_high = _mm512_sllv_epi64(
|
||||
offset_vector,
|
||||
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)));
|
||||
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
accum_data_v[j] = _mm512_sllv_epi32(accum_data_v[j], left_shift);
|
||||
// Apply the fixed-point part of the multiplier.
|
||||
__m512i scaled_v_low =
|
||||
_mm512_mul_epi32(_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(
|
||||
accum_data_v[j], 0)),
|
||||
m_64bit_low);
|
||||
__m512i scaled_v_high =
|
||||
_mm512_mul_epi32(_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(
|
||||
accum_data_v[j], 1)),
|
||||
m_64bit_high);
|
||||
|
||||
scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
|
||||
scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
|
||||
|
||||
scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
|
||||
scaled_v_high =
|
||||
_mm512_srav_epi64(scaled_v_high, final_right_shift_high);
|
||||
|
||||
accum_data_v[j] =
|
||||
_mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
|
||||
accum_data_v[j] = _mm512_inserti32x8(
|
||||
accum_data_v[j], _mm512_cvtepi64_epi32(scaled_v_high), 1);
|
||||
|
||||
#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
|
||||
RUY_DCHECK(false);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (params.dst_zero_point) {
|
||||
__m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
accum_data_v[j] = _mm512_add_epi32(accum_data_v[j], dst_zero_point);
|
||||
}
|
||||
}
|
||||
__m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
|
||||
__m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
accum_data_v[j] = _mm512_min_epi32(accum_data_v[j], clamp_max_v);
|
||||
accum_data_v[j] = _mm512_max_epi32(accum_data_v[j], clamp_min_v);
|
||||
}
|
||||
}
|
||||
const bool store_full_block =
|
||||
(residual_rows == 16) && (residual_cols == 16);
|
||||
|
||||
if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
|
||||
std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
|
||||
const int block_col_offset = dst_stride;
|
||||
if (store_full_block) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
_mm_storeu_epi8(tmp_ptr, _mm512_cvtepi32_epi8(accum_data_v[j]));
|
||||
tmp_ptr += block_col_offset;
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < residual_cols; ++j) {
|
||||
_mm_mask_storeu_epi8(tmp_ptr, row_mask,
|
||||
_mm512_cvtepi32_epi8(accum_data_v[j]));
|
||||
tmp_ptr += block_col_offset;
|
||||
}
|
||||
}
|
||||
dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
|
||||
} else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
|
||||
std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
|
||||
const int block_col_offset = dst_stride;
|
||||
if (store_full_block) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
_mm_storeu_epi8(tmp_ptr, _mm512_cvtepi32_epi8(accum_data_v[j]));
|
||||
tmp_ptr += block_col_offset;
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < residual_cols; ++j) {
|
||||
_mm_mask_storeu_epi8(tmp_ptr, row_mask,
|
||||
_mm512_cvtepi32_epi8(accum_data_v[j]));
|
||||
tmp_ptr += block_col_offset;
|
||||
}
|
||||
}
|
||||
dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
|
||||
} else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
|
||||
std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
|
||||
const int block_col_offset = dst_stride;
|
||||
if (store_full_block) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
_mm256_storeu_epi16(tmp_ptr,
|
||||
_mm512_cvtepi32_epi16(accum_data_v[j]));
|
||||
tmp_ptr += block_col_offset;
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < residual_cols; ++j) {
|
||||
_mm256_mask_storeu_epi16(tmp_ptr, row_mask,
|
||||
_mm512_cvtepi32_epi16(accum_data_v[j]));
|
||||
tmp_ptr += block_col_offset;
|
||||
}
|
||||
}
|
||||
dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
|
||||
} else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
|
||||
if (store_full_block) {
|
||||
std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
|
||||
const int block_col_offset = dst_stride;
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
_mm512_storeu_epi32(tmp_ptr, accum_data_v[j]);
|
||||
tmp_ptr += block_col_offset;
|
||||
}
|
||||
} else {
|
||||
std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
|
||||
for (int j = 0; j < residual_cols; ++j) {
|
||||
_mm512_mask_storeu_epi32(dst_block_ptr, row_mask, accum_data_v[j]);
|
||||
dst_block_ptr += dst_stride;
|
||||
}
|
||||
}
|
||||
dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
|
||||
} else {
|
||||
RUY_DCHECK(false);
|
||||
}
|
||||
|
||||
lhs_col_ptr += 16 * params.lhs_stride;
|
||||
} // End row-block loop.
|
||||
|
||||
dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
|
||||
16 * params.dst_stride);
|
||||
rhs_col_ptr += 16 * params.rhs_stride;
|
||||
} // End col-block loop.
|
||||
}
|
||||
|
||||
void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
|
||||
gemmlowp::ScopedProfilingLabel label("Kernel kAvx512");
|
||||
RUY_DCHECK_EQ(16, 16);
|
||||
|
||||
// As parameters are defined, we need to scale by sizeof(float).
|
||||
const std::int64_t lhs_stride = params.lhs_stride >> 2;
|
||||
const std::int64_t dst_stride = params.dst_stride >> 2;
|
||||
const std::int64_t rhs_stride = params.rhs_stride >> 2;
|
||||
|
||||
int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
|
||||
const int end_row = std::min(params.dst_rows, params.last_row + 16);
|
||||
const int end_col = std::min(params.dst_cols, params.last_col + 16);
|
||||
|
||||
const float* adj_rhs_col_ptr =
|
||||
params.rhs_base_ptr - params.start_col * rhs_stride;
|
||||
float* adj_dst_col_ptr =
|
||||
params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
|
||||
const float* adj_lhs_col_ptr =
|
||||
params.lhs_base_ptr - params.start_row * lhs_stride;
|
||||
const float* bias_col_ptr = params.bias;
|
||||
|
||||
const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
|
||||
const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
|
||||
|
||||
int col = params.start_col;
|
||||
for (; col <= end_col - 16; col += 16) {
|
||||
const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
|
||||
float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
|
||||
|
||||
int row = params.start_row;
|
||||
for (; row <= end_row - 16; row += 16) {
|
||||
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
|
||||
float* dst_ptr = dst_col_ptr + row;
|
||||
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
|
||||
|
||||
// Initialize with bias.
|
||||
const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr);
|
||||
|
||||
// Process block in two halves, split by columns.
|
||||
{
|
||||
constexpr int mmm = 0;
|
||||
|
||||
__m512 accum_data_v0 = initial_accum_data;
|
||||
__m512 accum_data_v1 = initial_accum_data;
|
||||
__m512 accum_data_v2 = initial_accum_data;
|
||||
__m512 accum_data_v3 = initial_accum_data;
|
||||
__m512 accum_data_v4 = initial_accum_data;
|
||||
__m512 accum_data_v5 = initial_accum_data;
|
||||
__m512 accum_data_v6 = initial_accum_data;
|
||||
__m512 accum_data_v7 = initial_accum_data;
|
||||
|
||||
const float* lhs_ptr = lhs_col_ptr;
|
||||
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
|
||||
for (int d = 0; d < (params.depth - 1); ++d) {
|
||||
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
|
||||
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
|
||||
lhs_ptr += 16;
|
||||
rhs_ptr += 16;
|
||||
|
||||
{
|
||||
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
|
||||
accum_data_v0 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
|
||||
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
|
||||
accum_data_v1 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
|
||||
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
|
||||
accum_data_v2 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
|
||||
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
|
||||
accum_data_v3 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
|
||||
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
|
||||
accum_data_v4 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
|
||||
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
|
||||
accum_data_v5 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
|
||||
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
|
||||
accum_data_v6 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
|
||||
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
|
||||
accum_data_v7 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
|
||||
}
|
||||
}
|
||||
{
|
||||
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
|
||||
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
|
||||
{
|
||||
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
|
||||
accum_data_v0 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
|
||||
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
|
||||
accum_data_v1 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
|
||||
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
|
||||
accum_data_v2 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
|
||||
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
|
||||
accum_data_v3 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
|
||||
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
|
||||
accum_data_v4 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
|
||||
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
|
||||
accum_data_v5 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
|
||||
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
|
||||
accum_data_v6 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
|
||||
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
|
||||
accum_data_v7 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
|
||||
}
|
||||
{
|
||||
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
|
||||
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
|
||||
accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
|
||||
accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
|
||||
accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
|
||||
accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
|
||||
accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
|
||||
accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
|
||||
accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
|
||||
accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
|
||||
accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
|
||||
accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
|
||||
accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
|
||||
accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
|
||||
accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
|
||||
accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
|
||||
accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
|
||||
}
|
||||
}
|
||||
} // Inner half-block loop, unrolled, first iteration.
|
||||
{
|
||||
constexpr int mmm = 1;
|
||||
|
||||
__m512 accum_data_v0 = initial_accum_data;
|
||||
__m512 accum_data_v1 = initial_accum_data;
|
||||
__m512 accum_data_v2 = initial_accum_data;
|
||||
__m512 accum_data_v3 = initial_accum_data;
|
||||
__m512 accum_data_v4 = initial_accum_data;
|
||||
__m512 accum_data_v5 = initial_accum_data;
|
||||
__m512 accum_data_v6 = initial_accum_data;
|
||||
__m512 accum_data_v7 = initial_accum_data;
|
||||
|
||||
const float* lhs_ptr = lhs_col_ptr;
|
||||
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
|
||||
for (int d = 0; d < (params.depth - 1); ++d) {
|
||||
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
|
||||
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
|
||||
lhs_ptr += 16;
|
||||
rhs_ptr += 16;
|
||||
{
|
||||
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
|
||||
accum_data_v0 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
|
||||
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
|
||||
accum_data_v1 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
|
||||
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
|
||||
accum_data_v2 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
|
||||
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
|
||||
accum_data_v3 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
|
||||
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
|
||||
accum_data_v4 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
|
||||
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
|
||||
accum_data_v5 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
|
||||
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
|
||||
accum_data_v6 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
|
||||
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
|
||||
accum_data_v7 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
|
||||
}
|
||||
}
|
||||
{
|
||||
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
|
||||
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
|
||||
{
|
||||
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
|
||||
accum_data_v0 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
|
||||
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
|
||||
accum_data_v1 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
|
||||
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
|
||||
accum_data_v2 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
|
||||
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
|
||||
accum_data_v3 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
|
||||
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
|
||||
accum_data_v4 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
|
||||
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
|
||||
accum_data_v5 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
|
||||
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
|
||||
accum_data_v6 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
|
||||
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
|
||||
accum_data_v7 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
|
||||
}
|
||||
{
|
||||
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
|
||||
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
|
||||
accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
|
||||
accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
|
||||
accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
|
||||
accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
|
||||
accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
|
||||
accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
|
||||
accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
|
||||
accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
|
||||
accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
|
||||
accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
|
||||
accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
|
||||
accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
|
||||
accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
|
||||
accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
|
||||
accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
|
||||
}
|
||||
}
|
||||
} // Inner half-block loop, unrolled, second iteration.
|
||||
} // End row-block loop.
|
||||
|
||||
// The unrolling within this conditional may be somewhat pointless. It
|
||||
// depends on the kinds of models.
|
||||
if (row < end_row) {
|
||||
const int residual_rows = end_row - row;
|
||||
|
||||
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
|
||||
float* dst_ptr = dst_col_ptr + row;
|
||||
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
|
||||
|
||||
// Initialize with bias.
|
||||
const __mmask16 row_mask =
|
||||
(static_cast<std::uint32_t>(1) << residual_rows) - 1;
|
||||
const __m512 initial_accum_data =
|
||||
_mm512_maskz_loadu_ps(row_mask, bias_ptr);
|
||||
|
||||
// Process block in two halves, split by columns.
|
||||
for (int mmm = 0; mmm < 2; ++mmm) {
|
||||
__m512 accum_data_v0 = initial_accum_data;
|
||||
__m512 accum_data_v1 = initial_accum_data;
|
||||
__m512 accum_data_v2 = initial_accum_data;
|
||||
__m512 accum_data_v3 = initial_accum_data;
|
||||
__m512 accum_data_v4 = initial_accum_data;
|
||||
__m512 accum_data_v5 = initial_accum_data;
|
||||
__m512 accum_data_v6 = initial_accum_data;
|
||||
__m512 accum_data_v7 = initial_accum_data;
|
||||
|
||||
const float* lhs_ptr = lhs_col_ptr;
|
||||
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
|
||||
for (int d = 0; d < (params.depth - 1); ++d) {
|
||||
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
|
||||
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
|
||||
lhs_ptr += 16;
|
||||
rhs_ptr += 16;
|
||||
{
|
||||
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
|
||||
accum_data_v0 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
|
||||
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
|
||||
accum_data_v1 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
|
||||
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
|
||||
accum_data_v2 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
|
||||
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
|
||||
accum_data_v3 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
|
||||
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
|
||||
accum_data_v4 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
|
||||
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
|
||||
accum_data_v5 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
|
||||
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
|
||||
accum_data_v6 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
|
||||
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
|
||||
accum_data_v7 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
|
||||
}
|
||||
}
|
||||
{
|
||||
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
|
||||
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
|
||||
{
|
||||
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
|
||||
accum_data_v0 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
|
||||
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
|
||||
accum_data_v1 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
|
||||
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
|
||||
accum_data_v2 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
|
||||
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
|
||||
accum_data_v3 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
|
||||
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
|
||||
accum_data_v4 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
|
||||
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
|
||||
accum_data_v5 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
|
||||
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
|
||||
accum_data_v6 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
|
||||
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
|
||||
accum_data_v7 =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
|
||||
}
|
||||
{
|
||||
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
|
||||
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
|
||||
accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
|
||||
_mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
|
||||
accum_data_v0);
|
||||
accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
|
||||
accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
|
||||
_mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
|
||||
accum_data_v1);
|
||||
accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
|
||||
accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
|
||||
_mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
|
||||
accum_data_v2);
|
||||
accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
|
||||
accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
|
||||
_mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
|
||||
accum_data_v3);
|
||||
accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
|
||||
accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
|
||||
_mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
|
||||
accum_data_v4);
|
||||
accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
|
||||
accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
|
||||
_mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
|
||||
accum_data_v5);
|
||||
accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
|
||||
accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
|
||||
_mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
|
||||
accum_data_v6);
|
||||
accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
|
||||
accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
|
||||
_mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
|
||||
accum_data_v7);
|
||||
}
|
||||
}
|
||||
} // Inner half-block loop.
|
||||
} // Residual rows, main col-block loop.
|
||||
} // End col-block loop.
|
||||
|
||||
if (col < end_col) {
|
||||
RUY_DCHECK_GE(end_col - col, 0);
|
||||
RUY_DCHECK_LT(end_col - col, 16);
|
||||
|
||||
__m512 accum_data_v[8];
|
||||
|
||||
const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
|
||||
float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
|
||||
|
||||
for (int row = params.start_row; row < end_row; row += 16) {
|
||||
const int residual_rows = std::min(end_row - row, 16);
|
||||
|
||||
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
|
||||
float* dst_ptr = dst_col_ptr + row;
|
||||
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
|
||||
|
||||
// Initialize with bias.
|
||||
const __mmask16 row_mask =
|
||||
(static_cast<std::uint32_t>(1) << residual_rows) - 1;
|
||||
const __m512 initial_accum_data =
|
||||
_mm512_maskz_loadu_ps(row_mask, bias_ptr);
|
||||
|
||||
// Process block in two halves, split by columns.
|
||||
for (int mmm = 0; mmm < 2; ++mmm) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
accum_data_v[j] = initial_accum_data;
|
||||
}
|
||||
|
||||
const float* lhs_ptr = lhs_col_ptr;
|
||||
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
|
||||
for (int d = 0; d < params.depth; ++d) {
|
||||
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
|
||||
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
|
||||
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
|
||||
accum_data_v[j] =
|
||||
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
|
||||
}
|
||||
lhs_ptr += 16;
|
||||
rhs_ptr += 16;
|
||||
}
|
||||
|
||||
const int residual_cols = std::min(end_col - col - 8 * mmm, 8);
|
||||
|
||||
if (residual_rows == 16) {
|
||||
if (residual_cols == 8) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
|
||||
accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
|
||||
accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr, accum_data_v[j]);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < residual_cols; ++j) {
|
||||
float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
|
||||
accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
|
||||
accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
|
||||
_mm512_storeu_ps(block_ptr, accum_data_v[j]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < residual_cols; ++j) {
|
||||
float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
|
||||
accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
|
||||
accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
|
||||
_mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
|
||||
}
|
||||
}
|
||||
} // Inner half-block loop.
|
||||
} // End row-block loop.
|
||||
} // Residual cols.
|
||||
}
|
||||
|
||||
#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
|
||||
} // namespace ruy
|
@ -98,6 +98,7 @@ struct PackedTypeImpl {
|
||||
using Type = Scalar;
|
||||
};
|
||||
|
||||
#if RUY_PLATFORM(NEON)
|
||||
template <>
|
||||
struct PackedTypeImpl<Path::kNeon, std::uint8_t> {
|
||||
using Type = std::int8_t;
|
||||
@ -106,6 +107,12 @@ template <>
|
||||
struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> {
|
||||
using Type = std::int8_t;
|
||||
};
|
||||
#elif RUY_PLATFORM(AVX512)
|
||||
template <>
|
||||
struct PackedTypeImpl<Path::kAvx512, std::uint8_t> {
|
||||
using Type = std::int8_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <Path ThePath, typename Scalar>
|
||||
using PackedType = typename PackedTypeImpl<ThePath, Scalar>::Type;
|
||||
@ -155,10 +162,14 @@ struct PackImpl<Path::kStandardCpp, FixedKernelLayout, Scalar, PackedScalar,
|
||||
}
|
||||
};
|
||||
|
||||
#if RUY_PLATFORM(NEON)
|
||||
RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon)
|
||||
#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod)
|
||||
#endif
|
||||
#elif RUY_PLATFORM(AVX512)
|
||||
RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx512)
|
||||
#endif
|
||||
|
||||
#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1,
|
||||
@ -478,6 +489,92 @@ struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float,
|
||||
#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \
|
||||
// RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
|
||||
#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
// Note that source and zero buffers can be uint8 type, but in the packing
|
||||
// function are reinterpreted as int8, and are XOR-ed with input_xor.
|
||||
void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor,
|
||||
const std::int8_t* zerobuf, int src_stride,
|
||||
int remaining_src_cols, int src_rows,
|
||||
std::int8_t* packed_ptr, std::int32_t* sums_ptr);
|
||||
|
||||
template <typename Scalar>
|
||||
struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
|
||||
Scalar, std::int8_t, std::int32_t> {
|
||||
static_assert(std::is_same<Scalar, std::int8_t>::value ||
|
||||
std::is_same<Scalar, std::uint8_t>::value,
|
||||
"");
|
||||
using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
|
||||
static constexpr int kHalfLayoutCols =
|
||||
8; // Half the number of cols in a block.
|
||||
static constexpr std::int8_t kInputXor =
|
||||
std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
|
||||
|
||||
static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
|
||||
PackedMatrix<std::int8_t>* packed_matrix, int start_col,
|
||||
int end_col) {
|
||||
gemmlowp::ScopedProfilingLabel label("Pack (AVX-512)");
|
||||
|
||||
RUY_DCHECK(IsColMajor(src_matrix.layout));
|
||||
RUY_DCHECK(IsColMajor(packed_matrix->layout));
|
||||
RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
|
||||
RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
|
||||
RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
|
||||
std::int32_t* sums = packed_matrix->sums;
|
||||
Scalar zerobuf[kHalfLayoutCols * Layout::kRows];
|
||||
memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
|
||||
kHalfLayoutCols * Layout::kRows * sizeof(Scalar));
|
||||
for (int block_col = start_col; block_col < end_col;
|
||||
block_col += Layout::kCols) {
|
||||
std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
|
||||
int src_stride = src_matrix.layout.stride;
|
||||
const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
|
||||
int remaining_src_cols = src_matrix.layout.cols - block_col;
|
||||
|
||||
static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
|
||||
std::int8_t* packed_ptr =
|
||||
packed_matrix->data +
|
||||
packed_matrix->layout.stride * (block_col & block_col_mask);
|
||||
Pack8bitAvx512(reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
|
||||
reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
|
||||
remaining_src_cols, src_matrix.layout.rows, packed_ptr,
|
||||
sums_ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride,
|
||||
int remaining_src_cols, int src_rows, float* packed_ptr);
|
||||
|
||||
template <>
|
||||
struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>,
|
||||
float, float, float> {
|
||||
static void Run(Tuning, const Matrix<float>& src_matrix,
|
||||
PackedMatrix<float>* packed_matrix, int start_col,
|
||||
int end_col) {
|
||||
using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
|
||||
RUY_DCHECK(IsColMajor(src_matrix.layout));
|
||||
RUY_DCHECK(IsColMajor(packed_matrix->layout));
|
||||
RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
|
||||
RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
|
||||
const float zerobuf[Layout::kCols] = {
|
||||
0.0f}; // Remainder default inits to 0.0f.
|
||||
for (int block_col = start_col; block_col < end_col;
|
||||
block_col += Layout::kCols) {
|
||||
int src_stride = src_matrix.layout.stride;
|
||||
const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
|
||||
int remaining_src_cols = src_matrix.layout.cols - block_col;
|
||||
|
||||
static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
|
||||
float* packed_ptr =
|
||||
packed_matrix->data +
|
||||
packed_matrix->layout.stride * (block_col & block_col_mask);
|
||||
PackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
|
||||
src_matrix.layout.rows, packed_ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
|
||||
// Main entry point for packing.
|
||||
template <Path ThePath, typename FixedKernelLayout, typename Scalar,
|
||||
typename PackedScalar>
|
||||
|
531
tensorflow/lite/experimental/ruy/pack_avx512.cc
Normal file
531
tensorflow/lite/experimental/ruy/pack_avx512.cc
Normal file
@ -0,0 +1,531 @@
|
||||
/* Copyright 2019 Google LLC. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/ruy/pack.h"
|
||||
#include "tensorflow/lite/experimental/ruy/platform.h"
|
||||
|
||||
namespace ruy {
|
||||
|
||||
#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
|
||||
// The first int8_t template parameter is arbitrary: this routine is common to
|
||||
// all 8-bit source matrix types.
|
||||
using PackImpl8bitAvx512 =
|
||||
PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
|
||||
std::int8_t, std::int8_t, std::int32_t>;
|
||||
|
||||
namespace {
|
||||
|
||||
inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point,
|
||||
std::int8_t* packed_ptr) {
|
||||
using Layout = PackImpl8bitAvx512::Layout;
|
||||
static constexpr int kHalfLayoutCols =
|
||||
PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
|
||||
// block.
|
||||
RUY_DCHECK_EQ(kHalfLayoutCols, 8);
|
||||
RUY_DCHECK_EQ(Layout::kCols, 16);
|
||||
RUY_DCHECK_EQ(Layout::kRows, 4);
|
||||
|
||||
const int non_trailing_blocks = (src_rows & ~31) >> 2;
|
||||
// This routine fills half blocks, and typically fills the second halves.
|
||||
// Thus packed_ptr is already offset by 8 * 4.
|
||||
for (int k = 0; k < non_trailing_blocks; ++k) {
|
||||
for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) {
|
||||
packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
|
||||
std::int8_t input_xor,
|
||||
const std::int8_t* zerobuf, int src_stride,
|
||||
int remaining_src_cols, int src_rows,
|
||||
std::int8_t* packed_ptr, std::int32_t* sums_ptr,
|
||||
std::int8_t* trailing_buf) {
|
||||
using Layout = PackImpl8bitAvx512::Layout;
|
||||
static constexpr int kHalfLayoutCols =
|
||||
PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
|
||||
// block.
|
||||
RUY_DCHECK_EQ(Layout::kCols, 16);
|
||||
RUY_DCHECK_EQ(Layout::kRows, 4);
|
||||
RUY_DCHECK_EQ(kHalfLayoutCols, 8);
|
||||
|
||||
std::int8_t in_data[kHalfLayoutCols][kHalfLayoutCols][Layout::kCols];
|
||||
|
||||
const std::int8_t* src_ptr0 = src_ptr;
|
||||
const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
|
||||
const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
|
||||
const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
|
||||
const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
|
||||
const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
|
||||
const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
|
||||
const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
|
||||
// Each Layout::Rows is 4 contiguous input, contiguous packed elements.
|
||||
// We process 8 of these chunks at a time, padding short input chunks.
|
||||
constexpr int kNumRowChunks = 8;
|
||||
constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
|
||||
std::int64_t src_inc0 = kNumChunkedSrcRows;
|
||||
std::int64_t src_inc1 = kNumChunkedSrcRows;
|
||||
std::int64_t src_inc2 = kNumChunkedSrcRows;
|
||||
std::int64_t src_inc3 = kNumChunkedSrcRows;
|
||||
std::int64_t src_inc4 = kNumChunkedSrcRows;
|
||||
std::int64_t src_inc5 = kNumChunkedSrcRows;
|
||||
std::int64_t src_inc6 = kNumChunkedSrcRows;
|
||||
std::int64_t src_inc7 = kNumChunkedSrcRows;
|
||||
// Handle cases where source does not have kHalfLayoutCols (8) columns.
|
||||
if (remaining_src_cols < 8) {
|
||||
if (remaining_src_cols <= 0) {
|
||||
src_ptr0 = zerobuf;
|
||||
src_inc0 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 1) {
|
||||
src_ptr1 = zerobuf;
|
||||
src_inc1 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 2) {
|
||||
src_ptr2 = zerobuf;
|
||||
src_inc2 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 3) {
|
||||
src_ptr3 = zerobuf;
|
||||
src_inc3 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 4) {
|
||||
src_ptr4 = zerobuf;
|
||||
src_inc4 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 5) {
|
||||
src_ptr5 = zerobuf;
|
||||
src_inc5 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 6) {
|
||||
src_ptr6 = zerobuf;
|
||||
src_inc6 = 0;
|
||||
}
|
||||
src_ptr7 = zerobuf;
|
||||
src_inc7 = 0;
|
||||
}
|
||||
|
||||
const std::int8_t zero_point = zerobuf[0];
|
||||
|
||||
if (sums_ptr) {
|
||||
// i: kHalfLayoutCols.
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
sums_ptr[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// The overall packing effectively pads the source rows to
|
||||
// (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
|
||||
// only pack for (src_rows + 31) & ~31. When there is an incomplete
|
||||
// destination block, this is stored into trailing_buf instead of packed_ptr.
|
||||
for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
|
||||
// m: {0, 1} for 2 chunks of rows.
|
||||
for (int m = 0; m < 2; ++m) {
|
||||
// Available source rows.
|
||||
// If this is less than 0 (for m=1), we skip, having filled trailing
|
||||
// buffer for m=0. Also, if source rows is zero on m=1, then we filled
|
||||
// exactly to the end of the column in the packed buffer.
|
||||
const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
|
||||
// Effectively,
|
||||
// available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m));
|
||||
// treat each case separately.
|
||||
if (available_src_rows >= kNumChunkedSrcRows) {
|
||||
// i: chunks, s: Layout::Rows.
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
for (int s = 0; s < 4; ++s) {
|
||||
in_data[0][i][s] = src_ptr0[i * 4 + s];
|
||||
in_data[1][i][s] = src_ptr1[i * 4 + s];
|
||||
in_data[2][i][s] = src_ptr2[i * 4 + s];
|
||||
in_data[3][i][s] = src_ptr3[i * 4 + s];
|
||||
in_data[4][i][s] = src_ptr4[i * 4 + s];
|
||||
in_data[5][i][s] = src_ptr5[i * 4 + s];
|
||||
in_data[6][i][s] = src_ptr6[i * 4 + s];
|
||||
in_data[7][i][s] = src_ptr7[i * 4 + s];
|
||||
}
|
||||
}
|
||||
// i: chunks, j: kHalfLayoutCols, s: Layout::Rows.
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (int s = 0; s < 4; ++s) {
|
||||
// 16 * 4 * i is offset for each block, that is
|
||||
// (Layout::kCols * Layout::kRows * i)
|
||||
packed_ptr[(16 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor;
|
||||
}
|
||||
if (sums_ptr) {
|
||||
for (int s = 0; s < 4; ++s) {
|
||||
sums_ptr[j] += in_data[j][i][s] ^ input_xor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (available_src_rows > 0) {
|
||||
RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
|
||||
int i = 0;
|
||||
// Consume chunks of 4 rows that are complete.
|
||||
for (; i < (available_src_rows >> 2); ++i) {
|
||||
for (int s = 0; s < 4; ++s) {
|
||||
in_data[0][i][s] = src_ptr0[i * 4 + s];
|
||||
in_data[1][i][s] = src_ptr1[i * 4 + s];
|
||||
in_data[2][i][s] = src_ptr2[i * 4 + s];
|
||||
in_data[3][i][s] = src_ptr3[i * 4 + s];
|
||||
in_data[4][i][s] = src_ptr4[i * 4 + s];
|
||||
in_data[5][i][s] = src_ptr5[i * 4 + s];
|
||||
in_data[6][i][s] = src_ptr6[i * 4 + s];
|
||||
in_data[7][i][s] = src_ptr7[i * 4 + s];
|
||||
}
|
||||
}
|
||||
// Consume any incomplete chunk.
|
||||
if (i < ((available_src_rows + 3) >> 2)) {
|
||||
int s = 0;
|
||||
for (; s < (available_src_rows & 3); ++s) {
|
||||
in_data[0][i][s] = src_ptr0[i * 4 + s];
|
||||
in_data[1][i][s] = src_ptr1[i * 4 + s];
|
||||
in_data[2][i][s] = src_ptr2[i * 4 + s];
|
||||
in_data[3][i][s] = src_ptr3[i * 4 + s];
|
||||
in_data[4][i][s] = src_ptr4[i * 4 + s];
|
||||
in_data[5][i][s] = src_ptr5[i * 4 + s];
|
||||
in_data[6][i][s] = src_ptr6[i * 4 + s];
|
||||
in_data[7][i][s] = src_ptr7[i * 4 + s];
|
||||
}
|
||||
RUY_DCHECK_LE(s, 4);
|
||||
for (; s < 4; ++s) {
|
||||
// j: kHalfLayoutCols.
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
in_data[j][i][s] = zero_point;
|
||||
}
|
||||
}
|
||||
++i;
|
||||
}
|
||||
// We do not care what goes into the trailing buffer, but we want
|
||||
// in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
|
||||
//
|
||||
// It might prove better in optimized code to pad uniformly with
|
||||
// zero_point, and compensate by initializing the summations with the
|
||||
// compensating offset, effectively
|
||||
// ((input_xor - zero_point) ^ input_xor) *
|
||||
// 4 * (8 - ((available_src_rows + 3) >> 2)).
|
||||
for (; i < 8; ++i) {
|
||||
for (int s = 0; s < 4; ++s) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
in_data[j][i][s] = input_xor;
|
||||
}
|
||||
}
|
||||
}
|
||||
// We loop through [0, 8) rather than
|
||||
// [0, (available_src_rows + 3) >> 2), since that emulates what we might
|
||||
// do in fully-optimized code.
|
||||
//
|
||||
// i: chunks, j: kHalfLayoutCols, s: Layout::Rows.
|
||||
if (sums_ptr) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (int s = 0; s < 4; ++s) {
|
||||
trailing_buf[(16 * i + j) * 4 + s] =
|
||||
in_data[j][i][s] ^ input_xor;
|
||||
sums_ptr[j] = sums_ptr[j] + (in_data[j][i][s] ^ input_xor);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (int s = 0; s < 4; ++s) {
|
||||
trailing_buf[(16 * i + j) * 4 + s] =
|
||||
in_data[j][i][s] ^ input_xor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
packed_ptr += 16 * kNumChunkedSrcRows;
|
||||
src_ptr0 += src_inc0;
|
||||
src_ptr1 += src_inc1;
|
||||
src_ptr2 += src_inc2;
|
||||
src_ptr3 += src_inc3;
|
||||
src_ptr4 += src_inc4;
|
||||
src_ptr5 += src_inc5;
|
||||
src_ptr6 += src_inc6;
|
||||
src_ptr7 += src_inc7;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
|
||||
__m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
|
||||
return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
|
||||
}
|
||||
|
||||
inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo,
|
||||
const float* addr_hi) {
|
||||
__m512 lower_filled =
|
||||
_mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo));
|
||||
return _mm512_insertf32x8(lower_filled,
|
||||
_mm256_maskz_loadu_ps(row_mask, addr_hi), 1);
|
||||
}
|
||||
|
||||
inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf,
|
||||
int src_stride, int remaining_src_cols,
|
||||
int src_rows, float* packed_ptr,
|
||||
float* trailing_buf) {
|
||||
const float* src_ptr0 = src_ptr;
|
||||
const float* src_ptr1 = src_ptr0 + src_stride;
|
||||
const float* src_ptr2 = src_ptr1 + src_stride;
|
||||
const float* src_ptr3 = src_ptr2 + src_stride;
|
||||
const float* src_ptr4 = src_ptr3 + src_stride;
|
||||
const float* src_ptr5 = src_ptr4 + src_stride;
|
||||
const float* src_ptr6 = src_ptr5 + src_stride;
|
||||
const float* src_ptr7 = src_ptr6 + src_stride;
|
||||
std::int64_t src_inc0 = 8;
|
||||
std::int64_t src_inc1 = 8;
|
||||
std::int64_t src_inc2 = 8;
|
||||
std::int64_t src_inc3 = 8;
|
||||
std::int64_t src_inc4 = 8;
|
||||
std::int64_t src_inc5 = 8;
|
||||
std::int64_t src_inc6 = 8;
|
||||
std::int64_t src_inc7 = 8;
|
||||
if (remaining_src_cols < 8) {
|
||||
if (remaining_src_cols <= 0) {
|
||||
src_ptr0 = zerobuf;
|
||||
src_inc0 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 1) {
|
||||
src_ptr1 = zerobuf;
|
||||
src_inc1 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 2) {
|
||||
src_ptr2 = zerobuf;
|
||||
src_inc2 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 3) {
|
||||
src_ptr3 = zerobuf;
|
||||
src_inc3 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 4) {
|
||||
src_ptr4 = zerobuf;
|
||||
src_inc4 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 5) {
|
||||
src_ptr5 = zerobuf;
|
||||
src_inc5 = 0;
|
||||
}
|
||||
if (remaining_src_cols <= 6) {
|
||||
src_ptr6 = zerobuf;
|
||||
src_inc6 = 0;
|
||||
}
|
||||
src_ptr7 = zerobuf;
|
||||
src_inc7 = 0;
|
||||
}
|
||||
|
||||
for (int k = 0; k < src_rows; k += 16) {
|
||||
for (int m = 0; m < 2; ++m) {
|
||||
const int available_src_rows = src_rows - k - 8 * m;
|
||||
// Effectively,
|
||||
// available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
|
||||
// but treat each case separately.
|
||||
if (available_src_rows > 7) {
|
||||
__m512i t0, t1, t2, t3;
|
||||
__m512i r0, r1, r2, r3;
|
||||
|
||||
t0 = LoaduTwo(src_ptr0, src_ptr4);
|
||||
t1 = LoaduTwo(src_ptr1, src_ptr5);
|
||||
t2 = LoaduTwo(src_ptr2, src_ptr6);
|
||||
t3 = LoaduTwo(src_ptr3, src_ptr7);
|
||||
|
||||
r0 = _mm512_unpacklo_epi32(t0, t1);
|
||||
r2 = _mm512_unpackhi_epi32(t0, t1);
|
||||
r1 = _mm512_unpacklo_epi32(t2, t3);
|
||||
r3 = _mm512_unpackhi_epi32(t2, t3);
|
||||
|
||||
t0 = _mm512_unpacklo_epi64(r0, r1);
|
||||
t2 = _mm512_unpackhi_epi64(r0, r1);
|
||||
t1 = _mm512_unpacklo_epi64(r2, r3);
|
||||
t3 = _mm512_unpackhi_epi64(r2, r3);
|
||||
|
||||
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
|
||||
r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
|
||||
r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
|
||||
r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
|
||||
|
||||
_mm256_storeu_epi32(packed_ptr + 0 * 16, _mm512_castsi512_si256(r0));
|
||||
_mm256_storeu_epi32(packed_ptr + 2 * 16,
|
||||
_mm512_extracti64x4_epi64(r0, 1));
|
||||
_mm256_storeu_epi32(packed_ptr + 4 * 16, _mm512_castsi512_si256(r1));
|
||||
_mm256_storeu_epi32(packed_ptr + 6 * 16,
|
||||
_mm512_extracti64x4_epi64(r1, 1));
|
||||
_mm256_storeu_epi32(packed_ptr + 1 * 16, _mm512_castsi512_si256(r2));
|
||||
_mm256_storeu_epi32(packed_ptr + 3 * 16,
|
||||
_mm512_extracti64x4_epi64(r2, 1));
|
||||
_mm256_storeu_epi32(packed_ptr + 5 * 16, _mm512_castsi512_si256(r3));
|
||||
_mm256_storeu_epi32(packed_ptr + 7 * 16,
|
||||
_mm512_extracti64x4_epi64(r3, 1));
|
||||
} else if (available_src_rows > 0) {
|
||||
const __mmask8 row_mask =
|
||||
(static_cast<std::uint32_t>(1) << available_src_rows) - 1;
|
||||
|
||||
__m512i t0, t1, t2, t3;
|
||||
__m512i r0, r1, r2, r3;
|
||||
|
||||
t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4);
|
||||
t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5);
|
||||
t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6);
|
||||
t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7);
|
||||
|
||||
r0 = _mm512_unpacklo_epi32(t0, t1);
|
||||
r2 = _mm512_unpackhi_epi32(t0, t1);
|
||||
r1 = _mm512_unpacklo_epi32(t2, t3);
|
||||
r3 = _mm512_unpackhi_epi32(t2, t3);
|
||||
|
||||
t0 = _mm512_unpacklo_epi64(r0, r1);
|
||||
t2 = _mm512_unpackhi_epi64(r0, r1);
|
||||
t1 = _mm512_unpacklo_epi64(r2, r3);
|
||||
t3 = _mm512_unpackhi_epi64(r2, r3);
|
||||
|
||||
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
|
||||
r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
|
||||
r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
|
||||
r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
|
||||
|
||||
_mm256_storeu_epi32(trailing_buf + 0 * 16, _mm512_castsi512_si256(r0));
|
||||
_mm256_storeu_epi32(trailing_buf + 2 * 16,
|
||||
_mm512_extracti64x4_epi64(r0, 1));
|
||||
_mm256_storeu_epi32(trailing_buf + 4 * 16, _mm512_castsi512_si256(r1));
|
||||
_mm256_storeu_epi32(trailing_buf + 6 * 16,
|
||||
_mm512_extracti64x4_epi64(r1, 1));
|
||||
_mm256_storeu_epi32(trailing_buf + 1 * 16, _mm512_castsi512_si256(r2));
|
||||
_mm256_storeu_epi32(trailing_buf + 3 * 16,
|
||||
_mm512_extracti64x4_epi64(r2, 1));
|
||||
_mm256_storeu_epi32(trailing_buf + 5 * 16, _mm512_castsi512_si256(r3));
|
||||
// Do not store _mm512_extracti64x4_epi64(r3, 1).
|
||||
}
|
||||
|
||||
packed_ptr += 16 * 8;
|
||||
src_ptr0 += src_inc0;
|
||||
src_ptr1 += src_inc1;
|
||||
src_ptr2 += src_inc2;
|
||||
src_ptr3 += src_inc3;
|
||||
src_ptr4 += src_inc4;
|
||||
src_ptr5 += src_inc5;
|
||||
src_ptr6 += src_inc6;
|
||||
src_ptr7 += src_inc7;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) {
|
||||
const int non_trailing_rows = src_rows & ~7;
|
||||
for (int k = 0; k < non_trailing_rows; ++k) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
packed_ptr[j] = 0.0f;
|
||||
}
|
||||
packed_ptr += 16;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace.
|
||||
|
||||
void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor,
|
||||
const std::int8_t* zerobuf, int src_stride,
|
||||
int remaining_src_cols, int src_rows,
|
||||
std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
|
||||
gemmlowp::ScopedProfilingLabel label("Pack kAvx512 8bit");
|
||||
|
||||
using Layout = PackImpl8bitAvx512::Layout;
|
||||
constexpr int kHalfBlockOffset = 32;
|
||||
RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kRows);
|
||||
static constexpr int kHalfLayoutCols =
|
||||
PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
|
||||
// block.
|
||||
RUY_DCHECK_EQ(kHalfLayoutCols, 8);
|
||||
RUY_DCHECK_EQ(Layout::kCols, 16);
|
||||
RUY_DCHECK_EQ(Layout::kRows, 4);
|
||||
|
||||
// Each Layout::Rows is 4 contiguous input, contiguous packed elements.
|
||||
// We process 8 of these chunks at a time, padding short input chunks.
|
||||
constexpr int kNumRowChunks = 8;
|
||||
|
||||
// Each packed block is 4*16, and there are normally 8. The trailing block is
|
||||
// only slightly shorter.
|
||||
constexpr int kTrailingBufSize =
|
||||
kNumRowChunks * Layout::kCols * Layout::kRows;
|
||||
std::int8_t trailing_buf[kTrailingBufSize];
|
||||
memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
|
||||
|
||||
std::int32_t* second_sums_ptr =
|
||||
sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
|
||||
if (remaining_src_cols > kHalfLayoutCols) {
|
||||
HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
|
||||
remaining_src_cols, src_rows, packed_ptr, sums_ptr,
|
||||
trailing_buf);
|
||||
HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor,
|
||||
zerobuf, src_stride,
|
||||
remaining_src_cols - kHalfLayoutCols, src_rows,
|
||||
packed_ptr + kHalfBlockOffset, second_sums_ptr,
|
||||
trailing_buf + kHalfBlockOffset);
|
||||
} else {
|
||||
HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
|
||||
remaining_src_cols, src_rows, packed_ptr, sums_ptr,
|
||||
trailing_buf);
|
||||
ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor,
|
||||
packed_ptr + kHalfBlockOffset);
|
||||
// The kernel may not need the second half-blocks sums to be set.
|
||||
if (second_sums_ptr) {
|
||||
for (int i = 0; i < kHalfLayoutCols; ++i) {
|
||||
second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
|
||||
}
|
||||
}
|
||||
}
|
||||
constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
|
||||
const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
|
||||
// If the number of source rows is not a multiple of kChunkedRowMask, there
|
||||
// will be data in the trailing buffer,
|
||||
if (trailing_data > 0) {
|
||||
const int non_trailing_rows = src_rows & ~kChunkedRowMask;
|
||||
// Destination "rows" are padded to next highest multiple of Layout::kRows.
|
||||
const int dst_rows = (src_rows + 3) & ~3;
|
||||
const int trailing_rows = dst_rows - non_trailing_rows;
|
||||
memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
|
||||
Layout::kCols * trailing_rows * sizeof(std::int8_t));
|
||||
}
|
||||
}
|
||||
|
||||
void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride,
|
||||
int remaining_src_cols, int src_rows, float* packed_ptr) {
|
||||
gemmlowp::ScopedProfilingLabel label("Pack kAvx512 float");
|
||||
float trailing_buf[7 * 16];
|
||||
if (remaining_src_cols > 8) {
|
||||
HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
|
||||
src_rows, packed_ptr, trailing_buf);
|
||||
HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride,
|
||||
remaining_src_cols - 8, src_rows, packed_ptr + 8,
|
||||
trailing_buf + 8);
|
||||
} else {
|
||||
memset(trailing_buf, 0, sizeof(trailing_buf));
|
||||
HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
|
||||
src_rows, packed_ptr, trailing_buf);
|
||||
ZeroHalfFloatAvx512(src_rows, packed_ptr + 8);
|
||||
}
|
||||
const int trailing_rows = src_rows & 7;
|
||||
if (trailing_rows > 0) {
|
||||
const int non_trailing_rows = src_rows & ~7;
|
||||
memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
|
||||
16 * trailing_rows * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
|
||||
} // namespace ruy
|
@ -51,6 +51,11 @@ namespace ruy {
|
||||
// given base architecture (such as ARM). Higher values of this enum correspond
|
||||
// to "better" code paths within a given base architecture for which Ruy has
|
||||
// optimized code paths.
|
||||
//
|
||||
// Values are reused across architectures.
|
||||
// Rationale: Scale better to N architectures, it is good to have small values
|
||||
// both for the compile-time logic to select paths, and when manually spelling
|
||||
// out Path values, such as when invoking a test or benchmark.
|
||||
enum class Path : std::uint8_t {
|
||||
// This is a special null value, representing the absence of any path.
|
||||
kNone = 0,
|
||||
@ -66,11 +71,19 @@ enum class Path : std::uint8_t {
|
||||
//
|
||||
// This is intended for testing/development.
|
||||
kStandardCpp = 0x2,
|
||||
//
|
||||
// ARM architectures.
|
||||
//
|
||||
// Optimized path using a widely available subset of ARM NEON instructions.
|
||||
kNeon = 0x4,
|
||||
// Optimized path making use of ARM NEON dot product instructions that are
|
||||
// available on newer ARM cores.
|
||||
kNeonDotprod = 0x8,
|
||||
//
|
||||
// x86 architectures.
|
||||
//
|
||||
// Optimized for AVX-512.
|
||||
kAvx512 = 0x4,
|
||||
};
|
||||
|
||||
inline constexpr Path operator|(Path p, Path q) {
|
||||
@ -104,6 +117,11 @@ constexpr Path kAllPaths =
|
||||
Path::kReference | Path::kStandardCpp | Path::kNeon | Path::kNeonDotprod;
|
||||
#elif RUY_PLATFORM(NEON_32)
|
||||
constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon;
|
||||
#elif RUY_PLATFORM(AVX512)
|
||||
// TODO(b/138433137): kAllPaths should always contain kAvx512 regardless of
|
||||
// whether AVX-512 is enabled in the translation unit #including this header.
|
||||
constexpr Path kAllPaths =
|
||||
Path::kReference | Path::kStandardCpp | Path::kAvx512;
|
||||
#else
|
||||
constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp;
|
||||
#endif
|
||||
@ -111,6 +129,9 @@ constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp;
|
||||
// We don't know how to do runtime dotprod detection outside of linux for now.
|
||||
#if RUY_PLATFORM(NEON)
|
||||
constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon;
|
||||
#elif RUY_PLATFORM(AVX512)
|
||||
constexpr Path kAllPaths =
|
||||
Path::kReference | Path::kStandardCpp | Path::kAvx512;
|
||||
#else
|
||||
constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp;
|
||||
#endif
|
||||
|
@ -49,6 +49,17 @@ limitations under the License.
|
||||
#define RUY_DONOTUSEDIRECTLY_NEON_64 \
|
||||
(RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_64)
|
||||
|
||||
// These CPU capabilities will all be true when Skylake is enabled during
|
||||
// compilation.
|
||||
//
|
||||
// TODO(b/138433137) Select AVX-512 at runtime rather than via compile options.
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__) && defined(__AVX512CD__) && \
|
||||
defined(__AVX512BW__) && defined(__AVX512VL__)
|
||||
#define RUY_DONOTUSEDIRECTLY_AVX512 1
|
||||
#else
|
||||
#define RUY_DONOTUSEDIRECTLY_AVX512 0
|
||||
#endif
|
||||
|
||||
// Detect APPLE
|
||||
#ifdef __APPLE__
|
||||
#define RUY_DONOTUSEDIRECTLY_APPLE 1
|
||||
|
@ -66,8 +66,12 @@ const char* PathName(Path path) {
|
||||
switch (path) {
|
||||
RUY_PATHNAME_CASE(kReference)
|
||||
RUY_PATHNAME_CASE(kStandardCpp)
|
||||
#if RUY_PLATFORM(NEON)
|
||||
RUY_PATHNAME_CASE(kNeon)
|
||||
RUY_PATHNAME_CASE(kNeonDotprod)
|
||||
#elif RUY_PLATFORM(AVX512)
|
||||
RUY_PATHNAME_CASE(kAvx512)
|
||||
#endif
|
||||
default:
|
||||
RUY_CHECK(false);
|
||||
return nullptr;
|
||||
@ -245,7 +249,7 @@ struct RandomRangeBounds<Scalar, false> {
|
||||
inline std::default_random_engine& global_random_engine() {
|
||||
static std::default_random_engine engine;
|
||||
return engine;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
struct UniformRandomDistribution {
|
||||
@ -660,7 +664,7 @@ void EvalGemmlowp(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
|
||||
LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
||||
&GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
|
||||
-lhs.zero_point, -rhs.zero_point, output_pipeline);
|
||||
} else
|
||||
} else // NOLINT[readability/braces]
|
||||
#endif
|
||||
{
|
||||
const auto& output_pipeline =
|
||||
@ -680,7 +684,7 @@ void EvalGemmlowp(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
|
||||
LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
||||
&GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
|
||||
-lhs.zero_point, -rhs.zero_point, output_pipeline);
|
||||
} else
|
||||
} else // NOLINT[readability/braces]
|
||||
#endif
|
||||
{
|
||||
const auto& output_pipeline = std::make_tuple(
|
||||
|
Loading…
Reference in New Issue
Block a user