Ruy: Introduce x86 (AVX-512) code.

PiperOrigin-RevId: 260750932
This commit is contained in:
A. Unique TensorFlower 2019-07-30 10:54:43 -07:00 committed by TensorFlower Gardener
parent 338de90a40
commit 72867cc5bc
9 changed files with 1541 additions and 8 deletions

View File

@ -264,11 +264,13 @@ cc_library(
srcs = [
"kernel_arm32.cc",
"kernel_arm64.cc",
"kernel_avx512.cc",
],
hdrs = [
"kernel.h",
],
deps = [
":check_macros",
":common",
":internal_matrix",
":opt_set",
@ -287,6 +289,7 @@ cc_library(
name = "pack",
srcs = [
"pack_arm.cc",
"pack_avx512.cc",
],
hdrs = [
"pack.h",

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_COMMON_H_
#include <atomic>
#include <cstdint>
#include <limits>
#include <type_traits>
@ -28,6 +29,11 @@ limitations under the License.
#include "tensorflow/lite/experimental/ruy/path.h"
#include "tensorflow/lite/experimental/ruy/platform.h"
// TODO(b/138449463): also guard by RUY_OPT_ENABLED(RUY_OPT_INTRINSICS).
#if RUY_PLATFORM(AVX512)
#include <immintrin.h>
#endif
#if (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32))
#include <arm_neon.h>
#endif

View File

@ -211,12 +211,18 @@ struct Kernel<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar, Spec> {
: Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec>(tuning) {} \
};
#if RUY_PLATFORM(NEON)
RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon)
RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
#elif RUY_PLATFORM(AVX512)
RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx512)
#endif
// KernelParams are shared across 32-bit and 64-bit NEON code.
#if (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \
(RUY_OPT_ENABLED(RUY_OPT_ASM))
// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 AVX-512
// code.
#if (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32) || \
RUY_PLATFORM(AVX512)) && \
RUY_OPT_ENABLED(RUY_OPT_ASM)
#define RUY_ASM_FLAG_HAS_BIAS 0x1
#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
@ -362,10 +368,12 @@ void MakeKernelParams8bit(const PackedMatrix<std::int8_t>& lhs,
dst->data.get() + start_col * dst->layout.stride + start_row;
}
#if RUY_PLATFORM(NEON)
void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params);
void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params);
void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params);
void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params);
#endif
#if RUY_PLATFORM(NEON_64)
template <typename DstScalar>
@ -482,10 +490,12 @@ inline void MakeKernelParamsFloat(const PackedMatrix<float>& lhs,
RUY_DCHECK_LT(params->last_col, params->dst_cols);
}
#if RUY_PLATFORM(NEON)
void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params);
void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params);
void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params);
void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params);
#endif
#if RUY_PLATFORM(NEON_64)
// A Float kernel for ARM64 Neon.
@ -531,6 +541,7 @@ struct Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>> {
};
#endif
#if RUY_PLATFORM(NEON)
// While the dotprod NEON extension does not concern floating-point arithmetic,
// its presence allows us to distinguish, in the in-order tuning case, between
// A53 and A55r1. TODO: should this be folded into tuning?
@ -556,9 +567,53 @@ struct Kernel<Path::kNeonDotprod, float, float, float,
}
}
};
#endif
#if RUY_PLATFORM(AVX512)
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
template <typename DstScalar>
struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, DstScalar,
BasicSpec<std::int32_t, DstScalar>> {
Tuning tuning = Tuning::kAuto;
using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
void Run(const PackedMatrix<std::int8_t>& lhs,
const PackedMatrix<std::int8_t>& rhs,
const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
int start_col, int end_row, int end_col,
Matrix<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
dst, &params);
Kernel8bitAvx512(params);
}
};
void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
template <>
struct Kernel<Path::kAvx512, float, float, float, BasicSpec<float, float>> {
Tuning tuning = Tuning::kAuto;
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
const BasicSpec<float, float>& spec, int start_row, int start_col,
int end_row, int end_col, Matrix<float>* dst) const {
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
end_col, dst, &params);
KernelFloatAvx512(params);
}
};
#endif
#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32) || \
// RUY_PLATFORM(AVX512)) && \
// RUY_OPT_ENABLED(RUY_OPT_ASM)
#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) &&
// (RUY_OPT_ENABLED(RUY_OPT_ASM)
} // namespace ruy
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_

View File

@ -0,0 +1,805 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "profiling/instrumentation.h"
#include "tensorflow/lite/experimental/ruy/check_macros.h"
#include "tensorflow/lite/experimental/ruy/kernel.h"
#include "tensorflow/lite/experimental/ruy/platform.h"
namespace ruy {
#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
inline std::int32_t mm512_get1_epi32(const __m512i v, int i) {
__m256i a =
i < 8 ? _mm512_extracti32x8_epi32(v, 0) : _mm512_extracti32x8_epi32(v, 1);
switch (i & ~8) {
case 0:
return _mm256_extract_epi32(a, 0);
case 1:
return _mm256_extract_epi32(a, 1);
case 2:
return _mm256_extract_epi32(a, 2);
case 3:
return _mm256_extract_epi32(a, 3);
case 4:
return _mm256_extract_epi32(a, 4);
case 5:
return _mm256_extract_epi32(a, 5);
case 6:
return _mm256_extract_epi32(a, 6);
case 7:
return _mm256_extract_epi32(a, 7);
default:
RUY_DCHECK(i < 16);
return 0;
}
}
inline __m512i mm512_set1_epi32(__m512i* v, int i, std::int32_t x) {
return *v = _mm512_mask_set1_epi32(*v, 1 << i, x);
}
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
gemmlowp::ScopedProfilingLabel label("Kernel kAvx512");
std::int32_t dst_stride;
if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
(params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
dst_stride = params.dst_stride;
} else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
dst_stride = params.dst_stride / sizeof(std::int16_t);
} else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
dst_stride = params.dst_stride / sizeof(std::int32_t);
} else {
RUY_DCHECK(false);
}
int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
void* dst_col_ptr = params.dst_base_ptr;
const std::int32_t* bias_col_ptr = params.bias;
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
bias_col_ptr += params.start_row;
}
for (int col = params.start_col; col <= params.last_col; col += 16) {
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
void* dst_ptr = dst_col_ptr;
const std::int32_t* bias_ptr = bias_col_ptr;
for (int row = params.start_row; row <= params.last_row; row += 16) {
const int residual_rows = std::min(params.dst_rows - row, 16);
const int residual_cols = std::min(params.dst_cols - col, 16);
__m512i accum_data_v[16];
__m512i accum_data_v_low[16];
__m512i accum_data_v_high[16];
// Initialize with bias.
const __mmask16 row_mask =
(static_cast<std::uint32_t>(1) << residual_rows) - 1;
const __m512i initial_accum_data =
_mm512_maskz_loadu_epi32(row_mask, bias_ptr);
__m512i initial_accum_data_low = initial_accum_data;
__m512i initial_accum_data_high = _mm512_setzero_epi32();
bias_ptr += bias_ptr_block_increment;
for (int j = 0; j < 16; ++j) {
accum_data_v_low[j] = initial_accum_data_low;
accum_data_v_high[j] = initial_accum_data_high;
}
//
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; d += 4) {
const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr);
__m512i rhs_data = _mm512_loadu_epi8(rhs_ptr);
// Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
__m512i lhs_16_bit_low =
_mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
// Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
__m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
_mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
for (int j = 0; j < 16; ++j) {
// Mask that drops the 0th element.
static constexpr std::uint16_t shift_mask = 0xfffe;
const __m256i dup_rhs_element_low =
_mm256_broadcastw_epi16(_mm512_castsi512_si128(rhs_data));
// Shift rhs_data, moving next element into 0 position.
const __m256i dup_rhs_element_high = _mm256_set1_epi16(
_mm_extract_epi16(_mm512_castsi512_si128(rhs_data), 1));
// Shift rhs_data, moving next element into 0 position.
rhs_data = _mm512_maskz_compress_epi32(shift_mask, rhs_data);
__m512i rhs_16_bit_dup_low =
_mm512_cvtepi8_epi16(dup_rhs_element_low);
__m512i rhs_16_bit_dup_high =
_mm512_cvtepi8_epi16(dup_rhs_element_high);
accum_data_v_low[j] = _mm512_add_epi32(
accum_data_v_low[j],
_mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
accum_data_v_high[j] = _mm512_add_epi32(
accum_data_v_high[j],
_mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
}
lhs_ptr += 16 * 4;
rhs_ptr += 16 * 4;
}
for (int j = 0; j < 16; ++j) {
accum_data_v[j] =
_mm512_add_epi32(accum_data_v_low[j], accum_data_v_high[j]);
}
// Move most of this up to bias, or even outside row loop.
const std::int32_t lhs_zero_point = params.lhs_zero_point;
const std::int32_t rhs_zero_point = params.rhs_zero_point;
const std::int32_t prod_zp_depth = params.prod_zp_depth;
if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
const __m512i lhs_sums_offset =
_mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
_mm512_loadu_epi32(&params.lhs_sums[row]));
for (int j = 0; j < 16; ++j) {
accum_data_v[j] = _mm512_sub_epi32(accum_data_v[j], lhs_sums_offset);
}
}
if (((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point) ||
prod_zp_depth) {
__m512i non_lhs_sums_offset =
_mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
_mm512_loadu_epi32(&params.rhs_sums[col]));
non_lhs_sums_offset = _mm512_sub_epi32(
non_lhs_sums_offset, _mm512_set1_epi32(prod_zp_depth));
for (int j = 0; j < 16; ++j) {
accum_data_v[j] = _mm512_sub_epi32(
accum_data_v[j],
_mm512_set1_epi32(mm512_get1_epi32(non_lhs_sums_offset, j)));
}
}
//
if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
__m512i m_vector;
__m512i e_vector;
// Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
m_vector = _mm512_maskz_loadu_epi32(
row_mask, &params.multiplier_fixedpoint[row]);
e_vector = _mm512_maskz_loadu_epi32(row_mask,
&params.multiplier_exponent[row]);
} else {
// These arrays have size LhsCols, and are pre-filled.
m_vector =
_mm512_maskz_loadu_epi32(row_mask, params.multiplier_fixedpoint);
e_vector =
_mm512_maskz_loadu_epi32(row_mask, params.multiplier_exponent);
}
const __m512i m_64bit_low =
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
const __m512i m_64bit_high =
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
const __m512i zero_vector = _mm512_setzero_epi32();
const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
const __m512i final_right_shift =
_mm512_add_epi32(right_shift, _mm512_set1_epi32(31));
const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
_mm512_extracti32x8_epi32(final_right_shift, 0));
const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
_mm512_extracti32x8_epi32(final_right_shift, 1));
const __m512i offset_vector =
_mm512_slli_epi64(_mm512_set1_epi64(1), 30);
// Really these should be shifted by neg_e_vector, but tests pass when
// using right_shift.
const __m512i offset_vector_low = _mm512_sllv_epi64(
offset_vector,
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)));
const __m512i offset_vector_high = _mm512_sllv_epi64(
offset_vector,
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)));
for (int j = 0; j < 16; ++j) {
accum_data_v[j] = _mm512_sllv_epi32(accum_data_v[j], left_shift);
// Apply the fixed-point part of the multiplier.
__m512i scaled_v_low =
_mm512_mul_epi32(_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(
accum_data_v[j], 0)),
m_64bit_low);
__m512i scaled_v_high =
_mm512_mul_epi32(_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(
accum_data_v[j], 1)),
m_64bit_high);
scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
scaled_v_high =
_mm512_srav_epi64(scaled_v_high, final_right_shift_high);
accum_data_v[j] =
_mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
accum_data_v[j] = _mm512_inserti32x8(
accum_data_v[j], _mm512_cvtepi64_epi32(scaled_v_high), 1);
#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
RUY_DCHECK(false);
#endif
}
if (params.dst_zero_point) {
__m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
for (int j = 0; j < 16; ++j) {
accum_data_v[j] = _mm512_add_epi32(accum_data_v[j], dst_zero_point);
}
}
__m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
__m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
for (int j = 0; j < 16; ++j) {
accum_data_v[j] = _mm512_min_epi32(accum_data_v[j], clamp_max_v);
accum_data_v[j] = _mm512_max_epi32(accum_data_v[j], clamp_min_v);
}
}
const bool store_full_block =
(residual_rows == 16) && (residual_cols == 16);
if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
const int block_col_offset = dst_stride;
if (store_full_block) {
for (int j = 0; j < 16; ++j) {
_mm_storeu_epi8(tmp_ptr, _mm512_cvtepi32_epi8(accum_data_v[j]));
tmp_ptr += block_col_offset;
}
} else {
for (int j = 0; j < residual_cols; ++j) {
_mm_mask_storeu_epi8(tmp_ptr, row_mask,
_mm512_cvtepi32_epi8(accum_data_v[j]));
tmp_ptr += block_col_offset;
}
}
dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
} else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
const int block_col_offset = dst_stride;
if (store_full_block) {
for (int j = 0; j < 16; ++j) {
_mm_storeu_epi8(tmp_ptr, _mm512_cvtepi32_epi8(accum_data_v[j]));
tmp_ptr += block_col_offset;
}
} else {
for (int j = 0; j < residual_cols; ++j) {
_mm_mask_storeu_epi8(tmp_ptr, row_mask,
_mm512_cvtepi32_epi8(accum_data_v[j]));
tmp_ptr += block_col_offset;
}
}
dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
} else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
const int block_col_offset = dst_stride;
if (store_full_block) {
for (int j = 0; j < 16; ++j) {
_mm256_storeu_epi16(tmp_ptr,
_mm512_cvtepi32_epi16(accum_data_v[j]));
tmp_ptr += block_col_offset;
}
} else {
for (int j = 0; j < residual_cols; ++j) {
_mm256_mask_storeu_epi16(tmp_ptr, row_mask,
_mm512_cvtepi32_epi16(accum_data_v[j]));
tmp_ptr += block_col_offset;
}
}
dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
} else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
if (store_full_block) {
std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
const int block_col_offset = dst_stride;
for (int j = 0; j < 16; ++j) {
_mm512_storeu_epi32(tmp_ptr, accum_data_v[j]);
tmp_ptr += block_col_offset;
}
} else {
std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
for (int j = 0; j < residual_cols; ++j) {
_mm512_mask_storeu_epi32(dst_block_ptr, row_mask, accum_data_v[j]);
dst_block_ptr += dst_stride;
}
}
dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
} else {
RUY_DCHECK(false);
}
lhs_col_ptr += 16 * params.lhs_stride;
} // End row-block loop.
dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
16 * params.dst_stride);
rhs_col_ptr += 16 * params.rhs_stride;
} // End col-block loop.
}
void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
gemmlowp::ScopedProfilingLabel label("Kernel kAvx512");
RUY_DCHECK_EQ(16, 16);
// As parameters are defined, we need to scale by sizeof(float).
const std::int64_t lhs_stride = params.lhs_stride >> 2;
const std::int64_t dst_stride = params.dst_stride >> 2;
const std::int64_t rhs_stride = params.rhs_stride >> 2;
int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
const int end_row = std::min(params.dst_rows, params.last_row + 16);
const int end_col = std::min(params.dst_cols, params.last_col + 16);
const float* adj_rhs_col_ptr =
params.rhs_base_ptr - params.start_col * rhs_stride;
float* adj_dst_col_ptr =
params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
const float* adj_lhs_col_ptr =
params.lhs_base_ptr - params.start_row * lhs_stride;
const float* bias_col_ptr = params.bias;
const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
int col = params.start_col;
for (; col <= end_col - 16; col += 16) {
const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
int row = params.start_row;
for (; row <= end_row - 16; row += 16) {
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr);
// Process block in two halves, split by columns.
{
constexpr int mmm = 0;
__m512 accum_data_v0 = initial_accum_data;
__m512 accum_data_v1 = initial_accum_data;
__m512 accum_data_v2 = initial_accum_data;
__m512 accum_data_v3 = initial_accum_data;
__m512 accum_data_v4 = initial_accum_data;
__m512 accum_data_v5 = initial_accum_data;
__m512 accum_data_v6 = initial_accum_data;
__m512 accum_data_v7 = initial_accum_data;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
for (int d = 0; d < (params.depth - 1); ++d) {
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
lhs_ptr += 16;
rhs_ptr += 16;
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
}
{
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
{
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
_mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
_mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
_mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
_mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
_mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
_mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
_mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
_mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
}
}
} // Inner half-block loop, unrolled, first iteration.
{
constexpr int mmm = 1;
__m512 accum_data_v0 = initial_accum_data;
__m512 accum_data_v1 = initial_accum_data;
__m512 accum_data_v2 = initial_accum_data;
__m512 accum_data_v3 = initial_accum_data;
__m512 accum_data_v4 = initial_accum_data;
__m512 accum_data_v5 = initial_accum_data;
__m512 accum_data_v6 = initial_accum_data;
__m512 accum_data_v7 = initial_accum_data;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
for (int d = 0; d < (params.depth - 1); ++d) {
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
lhs_ptr += 16;
rhs_ptr += 16;
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
}
{
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
{
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
_mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
_mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
_mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
_mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
_mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
_mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
_mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
_mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
}
}
} // Inner half-block loop, unrolled, second iteration.
} // End row-block loop.
// The unrolling within this conditional may be somewhat pointless. It
// depends on the kinds of models.
if (row < end_row) {
const int residual_rows = end_row - row;
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
const __mmask16 row_mask =
(static_cast<std::uint32_t>(1) << residual_rows) - 1;
const __m512 initial_accum_data =
_mm512_maskz_loadu_ps(row_mask, bias_ptr);
// Process block in two halves, split by columns.
for (int mmm = 0; mmm < 2; ++mmm) {
__m512 accum_data_v0 = initial_accum_data;
__m512 accum_data_v1 = initial_accum_data;
__m512 accum_data_v2 = initial_accum_data;
__m512 accum_data_v3 = initial_accum_data;
__m512 accum_data_v4 = initial_accum_data;
__m512 accum_data_v5 = initial_accum_data;
__m512 accum_data_v6 = initial_accum_data;
__m512 accum_data_v7 = initial_accum_data;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
for (int d = 0; d < (params.depth - 1); ++d) {
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
lhs_ptr += 16;
rhs_ptr += 16;
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
}
{
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
{
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
accum_data_v0);
accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
accum_data_v1);
accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
accum_data_v2);
accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
accum_data_v3);
accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
accum_data_v4);
accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
accum_data_v5);
accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
accum_data_v6);
accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
accum_data_v7);
}
}
} // Inner half-block loop.
} // Residual rows, main col-block loop.
} // End col-block loop.
if (col < end_col) {
RUY_DCHECK_GE(end_col - col, 0);
RUY_DCHECK_LT(end_col - col, 16);
__m512 accum_data_v[8];
const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
for (int row = params.start_row; row < end_row; row += 16) {
const int residual_rows = std::min(end_row - row, 16);
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
const __mmask16 row_mask =
(static_cast<std::uint32_t>(1) << residual_rows) - 1;
const __m512 initial_accum_data =
_mm512_maskz_loadu_ps(row_mask, bias_ptr);
// Process block in two halves, split by columns.
for (int mmm = 0; mmm < 2; ++mmm) {
for (int j = 0; j < 8; ++j) {
accum_data_v[j] = initial_accum_data;
}
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
for (int d = 0; d < params.depth; ++d) {
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
for (int j = 0; j < 8; ++j) {
const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
accum_data_v[j] =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
}
lhs_ptr += 16;
rhs_ptr += 16;
}
const int residual_cols = std::min(end_col - col - 8 * mmm, 8);
if (residual_rows == 16) {
if (residual_cols == 8) {
for (int j = 0; j < 8; ++j) {
float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
_mm512_storeu_ps(block_ptr, accum_data_v[j]);
}
} else {
for (int j = 0; j < residual_cols; ++j) {
float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
_mm512_storeu_ps(block_ptr, accum_data_v[j]);
}
}
} else {
for (int j = 0; j < residual_cols; ++j) {
float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
_mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
}
}
} // Inner half-block loop.
} // End row-block loop.
} // Residual cols.
}
#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
} // namespace ruy

View File

@ -98,6 +98,7 @@ struct PackedTypeImpl {
using Type = Scalar;
};
#if RUY_PLATFORM(NEON)
template <>
struct PackedTypeImpl<Path::kNeon, std::uint8_t> {
using Type = std::int8_t;
@ -106,6 +107,12 @@ template <>
struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> {
using Type = std::int8_t;
};
#elif RUY_PLATFORM(AVX512)
template <>
struct PackedTypeImpl<Path::kAvx512, std::uint8_t> {
using Type = std::int8_t;
};
#endif
template <Path ThePath, typename Scalar>
using PackedType = typename PackedTypeImpl<ThePath, Scalar>::Type;
@ -155,10 +162,14 @@ struct PackImpl<Path::kStandardCpp, FixedKernelLayout, Scalar, PackedScalar,
}
};
#if RUY_PLATFORM(NEON)
RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon)
#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod)
#endif
#elif RUY_PLATFORM(AVX512)
RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx512)
#endif
#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1,
@ -478,6 +489,92 @@ struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float,
#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \
// RUY_OPT_ENABLED(RUY_OPT_ASM)
#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
// Note that source and zero buffers can be uint8 type, but in the packing
// function are reinterpreted as int8, and are XOR-ed with input_xor.
void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor,
const std::int8_t* zerobuf, int src_stride,
int remaining_src_cols, int src_rows,
std::int8_t* packed_ptr, std::int32_t* sums_ptr);
template <typename Scalar>
struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
Scalar, std::int8_t, std::int32_t> {
static_assert(std::is_same<Scalar, std::int8_t>::value ||
std::is_same<Scalar, std::uint8_t>::value,
"");
using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
static constexpr int kHalfLayoutCols =
8; // Half the number of cols in a block.
static constexpr std::int8_t kInputXor =
std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
PackedMatrix<std::int8_t>* packed_matrix, int start_col,
int end_col) {
gemmlowp::ScopedProfilingLabel label("Pack (AVX-512)");
RUY_DCHECK(IsColMajor(src_matrix.layout));
RUY_DCHECK(IsColMajor(packed_matrix->layout));
RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
std::int32_t* sums = packed_matrix->sums;
Scalar zerobuf[kHalfLayoutCols * Layout::kRows];
memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
kHalfLayoutCols * Layout::kRows * sizeof(Scalar));
for (int block_col = start_col; block_col < end_col;
block_col += Layout::kCols) {
std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
int src_stride = src_matrix.layout.stride;
const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
int remaining_src_cols = src_matrix.layout.cols - block_col;
static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
std::int8_t* packed_ptr =
packed_matrix->data +
packed_matrix->layout.stride * (block_col & block_col_mask);
Pack8bitAvx512(reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
remaining_src_cols, src_matrix.layout.rows, packed_ptr,
sums_ptr);
}
}
};
void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride,
int remaining_src_cols, int src_rows, float* packed_ptr);
template <>
struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>,
float, float, float> {
static void Run(Tuning, const Matrix<float>& src_matrix,
PackedMatrix<float>* packed_matrix, int start_col,
int end_col) {
using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
RUY_DCHECK(IsColMajor(src_matrix.layout));
RUY_DCHECK(IsColMajor(packed_matrix->layout));
RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
const float zerobuf[Layout::kCols] = {
0.0f}; // Remainder default inits to 0.0f.
for (int block_col = start_col; block_col < end_col;
block_col += Layout::kCols) {
int src_stride = src_matrix.layout.stride;
const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
int remaining_src_cols = src_matrix.layout.cols - block_col;
static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
float* packed_ptr =
packed_matrix->data +
packed_matrix->layout.stride * (block_col & block_col_mask);
PackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
src_matrix.layout.rows, packed_ptr);
}
}
};
#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
// Main entry point for packing.
template <Path ThePath, typename FixedKernelLayout, typename Scalar,
typename PackedScalar>

View File

@ -0,0 +1,531 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/ruy/pack.h"
#include "tensorflow/lite/experimental/ruy/platform.h"
namespace ruy {
#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
// The first int8_t template parameter is arbitrary: this routine is common to
// all 8-bit source matrix types.
using PackImpl8bitAvx512 =
PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
std::int8_t, std::int8_t, std::int32_t>;
namespace {
inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point,
std::int8_t* packed_ptr) {
using Layout = PackImpl8bitAvx512::Layout;
static constexpr int kHalfLayoutCols =
PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
// block.
RUY_DCHECK_EQ(kHalfLayoutCols, 8);
RUY_DCHECK_EQ(Layout::kCols, 16);
RUY_DCHECK_EQ(Layout::kRows, 4);
const int non_trailing_blocks = (src_rows & ~31) >> 2;
// This routine fills half blocks, and typically fills the second halves.
// Thus packed_ptr is already offset by 8 * 4.
for (int k = 0; k < non_trailing_blocks; ++k) {
for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) {
packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point;
}
}
}
inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
std::int8_t input_xor,
const std::int8_t* zerobuf, int src_stride,
int remaining_src_cols, int src_rows,
std::int8_t* packed_ptr, std::int32_t* sums_ptr,
std::int8_t* trailing_buf) {
using Layout = PackImpl8bitAvx512::Layout;
static constexpr int kHalfLayoutCols =
PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
// block.
RUY_DCHECK_EQ(Layout::kCols, 16);
RUY_DCHECK_EQ(Layout::kRows, 4);
RUY_DCHECK_EQ(kHalfLayoutCols, 8);
std::int8_t in_data[kHalfLayoutCols][kHalfLayoutCols][Layout::kCols];
const std::int8_t* src_ptr0 = src_ptr;
const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
// Each Layout::Rows is 4 contiguous input, contiguous packed elements.
// We process 8 of these chunks at a time, padding short input chunks.
constexpr int kNumRowChunks = 8;
constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
std::int64_t src_inc0 = kNumChunkedSrcRows;
std::int64_t src_inc1 = kNumChunkedSrcRows;
std::int64_t src_inc2 = kNumChunkedSrcRows;
std::int64_t src_inc3 = kNumChunkedSrcRows;
std::int64_t src_inc4 = kNumChunkedSrcRows;
std::int64_t src_inc5 = kNumChunkedSrcRows;
std::int64_t src_inc6 = kNumChunkedSrcRows;
std::int64_t src_inc7 = kNumChunkedSrcRows;
// Handle cases where source does not have kHalfLayoutCols (8) columns.
if (remaining_src_cols < 8) {
if (remaining_src_cols <= 0) {
src_ptr0 = zerobuf;
src_inc0 = 0;
}
if (remaining_src_cols <= 1) {
src_ptr1 = zerobuf;
src_inc1 = 0;
}
if (remaining_src_cols <= 2) {
src_ptr2 = zerobuf;
src_inc2 = 0;
}
if (remaining_src_cols <= 3) {
src_ptr3 = zerobuf;
src_inc3 = 0;
}
if (remaining_src_cols <= 4) {
src_ptr4 = zerobuf;
src_inc4 = 0;
}
if (remaining_src_cols <= 5) {
src_ptr5 = zerobuf;
src_inc5 = 0;
}
if (remaining_src_cols <= 6) {
src_ptr6 = zerobuf;
src_inc6 = 0;
}
src_ptr7 = zerobuf;
src_inc7 = 0;
}
const std::int8_t zero_point = zerobuf[0];
if (sums_ptr) {
// i: kHalfLayoutCols.
for (int i = 0; i < 8; ++i) {
sums_ptr[i] = 0;
}
}
// The overall packing effectively pads the source rows to
// (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
// only pack for (src_rows + 31) & ~31. When there is an incomplete
// destination block, this is stored into trailing_buf instead of packed_ptr.
for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
// m: {0, 1} for 2 chunks of rows.
for (int m = 0; m < 2; ++m) {
// Available source rows.
// If this is less than 0 (for m=1), we skip, having filled trailing
// buffer for m=0. Also, if source rows is zero on m=1, then we filled
// exactly to the end of the column in the packed buffer.
const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
// Effectively,
// available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m));
// treat each case separately.
if (available_src_rows >= kNumChunkedSrcRows) {
// i: chunks, s: Layout::Rows.
for (int i = 0; i < 8; ++i) {
for (int s = 0; s < 4; ++s) {
in_data[0][i][s] = src_ptr0[i * 4 + s];
in_data[1][i][s] = src_ptr1[i * 4 + s];
in_data[2][i][s] = src_ptr2[i * 4 + s];
in_data[3][i][s] = src_ptr3[i * 4 + s];
in_data[4][i][s] = src_ptr4[i * 4 + s];
in_data[5][i][s] = src_ptr5[i * 4 + s];
in_data[6][i][s] = src_ptr6[i * 4 + s];
in_data[7][i][s] = src_ptr7[i * 4 + s];
}
}
// i: chunks, j: kHalfLayoutCols, s: Layout::Rows.
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
for (int s = 0; s < 4; ++s) {
// 16 * 4 * i is offset for each block, that is
// (Layout::kCols * Layout::kRows * i)
packed_ptr[(16 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor;
}
if (sums_ptr) {
for (int s = 0; s < 4; ++s) {
sums_ptr[j] += in_data[j][i][s] ^ input_xor;
}
}
}
}
} else if (available_src_rows > 0) {
RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
int i = 0;
// Consume chunks of 4 rows that are complete.
for (; i < (available_src_rows >> 2); ++i) {
for (int s = 0; s < 4; ++s) {
in_data[0][i][s] = src_ptr0[i * 4 + s];
in_data[1][i][s] = src_ptr1[i * 4 + s];
in_data[2][i][s] = src_ptr2[i * 4 + s];
in_data[3][i][s] = src_ptr3[i * 4 + s];
in_data[4][i][s] = src_ptr4[i * 4 + s];
in_data[5][i][s] = src_ptr5[i * 4 + s];
in_data[6][i][s] = src_ptr6[i * 4 + s];
in_data[7][i][s] = src_ptr7[i * 4 + s];
}
}
// Consume any incomplete chunk.
if (i < ((available_src_rows + 3) >> 2)) {
int s = 0;
for (; s < (available_src_rows & 3); ++s) {
in_data[0][i][s] = src_ptr0[i * 4 + s];
in_data[1][i][s] = src_ptr1[i * 4 + s];
in_data[2][i][s] = src_ptr2[i * 4 + s];
in_data[3][i][s] = src_ptr3[i * 4 + s];
in_data[4][i][s] = src_ptr4[i * 4 + s];
in_data[5][i][s] = src_ptr5[i * 4 + s];
in_data[6][i][s] = src_ptr6[i * 4 + s];
in_data[7][i][s] = src_ptr7[i * 4 + s];
}
RUY_DCHECK_LE(s, 4);
for (; s < 4; ++s) {
// j: kHalfLayoutCols.
for (int j = 0; j < 8; ++j) {
in_data[j][i][s] = zero_point;
}
}
++i;
}
// We do not care what goes into the trailing buffer, but we want
// in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
//
// It might prove better in optimized code to pad uniformly with
// zero_point, and compensate by initializing the summations with the
// compensating offset, effectively
// ((input_xor - zero_point) ^ input_xor) *
// 4 * (8 - ((available_src_rows + 3) >> 2)).
for (; i < 8; ++i) {
for (int s = 0; s < 4; ++s) {
for (int j = 0; j < 8; ++j) {
in_data[j][i][s] = input_xor;
}
}
}
// We loop through [0, 8) rather than
// [0, (available_src_rows + 3) >> 2), since that emulates what we might
// do in fully-optimized code.
//
// i: chunks, j: kHalfLayoutCols, s: Layout::Rows.
if (sums_ptr) {
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
for (int s = 0; s < 4; ++s) {
trailing_buf[(16 * i + j) * 4 + s] =
in_data[j][i][s] ^ input_xor;
sums_ptr[j] = sums_ptr[j] + (in_data[j][i][s] ^ input_xor);
}
}
}
} else {
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
for (int s = 0; s < 4; ++s) {
trailing_buf[(16 * i + j) * 4 + s] =
in_data[j][i][s] ^ input_xor;
}
}
}
}
}
packed_ptr += 16 * kNumChunkedSrcRows;
src_ptr0 += src_inc0;
src_ptr1 += src_inc1;
src_ptr2 += src_inc2;
src_ptr3 += src_inc3;
src_ptr4 += src_inc4;
src_ptr5 += src_inc5;
src_ptr6 += src_inc6;
src_ptr7 += src_inc7;
}
}
}
inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
__m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
}
inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo,
const float* addr_hi) {
__m512 lower_filled =
_mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo));
return _mm512_insertf32x8(lower_filled,
_mm256_maskz_loadu_ps(row_mask, addr_hi), 1);
}
inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf,
int src_stride, int remaining_src_cols,
int src_rows, float* packed_ptr,
float* trailing_buf) {
const float* src_ptr0 = src_ptr;
const float* src_ptr1 = src_ptr0 + src_stride;
const float* src_ptr2 = src_ptr1 + src_stride;
const float* src_ptr3 = src_ptr2 + src_stride;
const float* src_ptr4 = src_ptr3 + src_stride;
const float* src_ptr5 = src_ptr4 + src_stride;
const float* src_ptr6 = src_ptr5 + src_stride;
const float* src_ptr7 = src_ptr6 + src_stride;
std::int64_t src_inc0 = 8;
std::int64_t src_inc1 = 8;
std::int64_t src_inc2 = 8;
std::int64_t src_inc3 = 8;
std::int64_t src_inc4 = 8;
std::int64_t src_inc5 = 8;
std::int64_t src_inc6 = 8;
std::int64_t src_inc7 = 8;
if (remaining_src_cols < 8) {
if (remaining_src_cols <= 0) {
src_ptr0 = zerobuf;
src_inc0 = 0;
}
if (remaining_src_cols <= 1) {
src_ptr1 = zerobuf;
src_inc1 = 0;
}
if (remaining_src_cols <= 2) {
src_ptr2 = zerobuf;
src_inc2 = 0;
}
if (remaining_src_cols <= 3) {
src_ptr3 = zerobuf;
src_inc3 = 0;
}
if (remaining_src_cols <= 4) {
src_ptr4 = zerobuf;
src_inc4 = 0;
}
if (remaining_src_cols <= 5) {
src_ptr5 = zerobuf;
src_inc5 = 0;
}
if (remaining_src_cols <= 6) {
src_ptr6 = zerobuf;
src_inc6 = 0;
}
src_ptr7 = zerobuf;
src_inc7 = 0;
}
for (int k = 0; k < src_rows; k += 16) {
for (int m = 0; m < 2; ++m) {
const int available_src_rows = src_rows - k - 8 * m;
// Effectively,
// available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
// but treat each case separately.
if (available_src_rows > 7) {
__m512i t0, t1, t2, t3;
__m512i r0, r1, r2, r3;
t0 = LoaduTwo(src_ptr0, src_ptr4);
t1 = LoaduTwo(src_ptr1, src_ptr5);
t2 = LoaduTwo(src_ptr2, src_ptr6);
t3 = LoaduTwo(src_ptr3, src_ptr7);
r0 = _mm512_unpacklo_epi32(t0, t1);
r2 = _mm512_unpackhi_epi32(t0, t1);
r1 = _mm512_unpacklo_epi32(t2, t3);
r3 = _mm512_unpackhi_epi32(t2, t3);
t0 = _mm512_unpacklo_epi64(r0, r1);
t2 = _mm512_unpackhi_epi64(r0, r1);
t1 = _mm512_unpacklo_epi64(r2, r3);
t3 = _mm512_unpackhi_epi64(r2, r3);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
_mm256_storeu_epi32(packed_ptr + 0 * 16, _mm512_castsi512_si256(r0));
_mm256_storeu_epi32(packed_ptr + 2 * 16,
_mm512_extracti64x4_epi64(r0, 1));
_mm256_storeu_epi32(packed_ptr + 4 * 16, _mm512_castsi512_si256(r1));
_mm256_storeu_epi32(packed_ptr + 6 * 16,
_mm512_extracti64x4_epi64(r1, 1));
_mm256_storeu_epi32(packed_ptr + 1 * 16, _mm512_castsi512_si256(r2));
_mm256_storeu_epi32(packed_ptr + 3 * 16,
_mm512_extracti64x4_epi64(r2, 1));
_mm256_storeu_epi32(packed_ptr + 5 * 16, _mm512_castsi512_si256(r3));
_mm256_storeu_epi32(packed_ptr + 7 * 16,
_mm512_extracti64x4_epi64(r3, 1));
} else if (available_src_rows > 0) {
const __mmask8 row_mask =
(static_cast<std::uint32_t>(1) << available_src_rows) - 1;
__m512i t0, t1, t2, t3;
__m512i r0, r1, r2, r3;
t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4);
t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5);
t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6);
t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7);
r0 = _mm512_unpacklo_epi32(t0, t1);
r2 = _mm512_unpackhi_epi32(t0, t1);
r1 = _mm512_unpacklo_epi32(t2, t3);
r3 = _mm512_unpackhi_epi32(t2, t3);
t0 = _mm512_unpacklo_epi64(r0, r1);
t2 = _mm512_unpackhi_epi64(r0, r1);
t1 = _mm512_unpacklo_epi64(r2, r3);
t3 = _mm512_unpackhi_epi64(r2, r3);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
_mm256_storeu_epi32(trailing_buf + 0 * 16, _mm512_castsi512_si256(r0));
_mm256_storeu_epi32(trailing_buf + 2 * 16,
_mm512_extracti64x4_epi64(r0, 1));
_mm256_storeu_epi32(trailing_buf + 4 * 16, _mm512_castsi512_si256(r1));
_mm256_storeu_epi32(trailing_buf + 6 * 16,
_mm512_extracti64x4_epi64(r1, 1));
_mm256_storeu_epi32(trailing_buf + 1 * 16, _mm512_castsi512_si256(r2));
_mm256_storeu_epi32(trailing_buf + 3 * 16,
_mm512_extracti64x4_epi64(r2, 1));
_mm256_storeu_epi32(trailing_buf + 5 * 16, _mm512_castsi512_si256(r3));
// Do not store _mm512_extracti64x4_epi64(r3, 1).
}
packed_ptr += 16 * 8;
src_ptr0 += src_inc0;
src_ptr1 += src_inc1;
src_ptr2 += src_inc2;
src_ptr3 += src_inc3;
src_ptr4 += src_inc4;
src_ptr5 += src_inc5;
src_ptr6 += src_inc6;
src_ptr7 += src_inc7;
}
}
}
inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) {
const int non_trailing_rows = src_rows & ~7;
for (int k = 0; k < non_trailing_rows; ++k) {
for (int j = 0; j < 8; ++j) {
packed_ptr[j] = 0.0f;
}
packed_ptr += 16;
}
}
} // namespace.
void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor,
const std::int8_t* zerobuf, int src_stride,
int remaining_src_cols, int src_rows,
std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
gemmlowp::ScopedProfilingLabel label("Pack kAvx512 8bit");
using Layout = PackImpl8bitAvx512::Layout;
constexpr int kHalfBlockOffset = 32;
RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kRows);
static constexpr int kHalfLayoutCols =
PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
// block.
RUY_DCHECK_EQ(kHalfLayoutCols, 8);
RUY_DCHECK_EQ(Layout::kCols, 16);
RUY_DCHECK_EQ(Layout::kRows, 4);
// Each Layout::Rows is 4 contiguous input, contiguous packed elements.
// We process 8 of these chunks at a time, padding short input chunks.
constexpr int kNumRowChunks = 8;
// Each packed block is 4*16, and there are normally 8. The trailing block is
// only slightly shorter.
constexpr int kTrailingBufSize =
kNumRowChunks * Layout::kCols * Layout::kRows;
std::int8_t trailing_buf[kTrailingBufSize];
memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
std::int32_t* second_sums_ptr =
sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
if (remaining_src_cols > kHalfLayoutCols) {
HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
remaining_src_cols, src_rows, packed_ptr, sums_ptr,
trailing_buf);
HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor,
zerobuf, src_stride,
remaining_src_cols - kHalfLayoutCols, src_rows,
packed_ptr + kHalfBlockOffset, second_sums_ptr,
trailing_buf + kHalfBlockOffset);
} else {
HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
remaining_src_cols, src_rows, packed_ptr, sums_ptr,
trailing_buf);
ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor,
packed_ptr + kHalfBlockOffset);
// The kernel may not need the second half-blocks sums to be set.
if (second_sums_ptr) {
for (int i = 0; i < kHalfLayoutCols; ++i) {
second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
}
}
}
constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
// If the number of source rows is not a multiple of kChunkedRowMask, there
// will be data in the trailing buffer,
if (trailing_data > 0) {
const int non_trailing_rows = src_rows & ~kChunkedRowMask;
// Destination "rows" are padded to next highest multiple of Layout::kRows.
const int dst_rows = (src_rows + 3) & ~3;
const int trailing_rows = dst_rows - non_trailing_rows;
memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
Layout::kCols * trailing_rows * sizeof(std::int8_t));
}
}
void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride,
int remaining_src_cols, int src_rows, float* packed_ptr) {
gemmlowp::ScopedProfilingLabel label("Pack kAvx512 float");
float trailing_buf[7 * 16];
if (remaining_src_cols > 8) {
HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
src_rows, packed_ptr, trailing_buf);
HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride,
remaining_src_cols - 8, src_rows, packed_ptr + 8,
trailing_buf + 8);
} else {
memset(trailing_buf, 0, sizeof(trailing_buf));
HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
src_rows, packed_ptr, trailing_buf);
ZeroHalfFloatAvx512(src_rows, packed_ptr + 8);
}
const int trailing_rows = src_rows & 7;
if (trailing_rows > 0) {
const int non_trailing_rows = src_rows & ~7;
memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
16 * trailing_rows * sizeof(float));
}
}
#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
} // namespace ruy

View File

@ -51,6 +51,11 @@ namespace ruy {
// given base architecture (such as ARM). Higher values of this enum correspond
// to "better" code paths within a given base architecture for which Ruy has
// optimized code paths.
//
// Values are reused across architectures.
// Rationale: Scale better to N architectures, it is good to have small values
// both for the compile-time logic to select paths, and when manually spelling
// out Path values, such as when invoking a test or benchmark.
enum class Path : std::uint8_t {
// This is a special null value, representing the absence of any path.
kNone = 0,
@ -66,11 +71,19 @@ enum class Path : std::uint8_t {
//
// This is intended for testing/development.
kStandardCpp = 0x2,
//
// ARM architectures.
//
// Optimized path using a widely available subset of ARM NEON instructions.
kNeon = 0x4,
// Optimized path making use of ARM NEON dot product instructions that are
// available on newer ARM cores.
kNeonDotprod = 0x8,
//
// x86 architectures.
//
// Optimized for AVX-512.
kAvx512 = 0x4,
};
inline constexpr Path operator|(Path p, Path q) {
@ -104,6 +117,11 @@ constexpr Path kAllPaths =
Path::kReference | Path::kStandardCpp | Path::kNeon | Path::kNeonDotprod;
#elif RUY_PLATFORM(NEON_32)
constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon;
#elif RUY_PLATFORM(AVX512)
// TODO(b/138433137): kAllPaths should always contain kAvx512 regardless of
// whether AVX-512 is enabled in the translation unit #including this header.
constexpr Path kAllPaths =
Path::kReference | Path::kStandardCpp | Path::kAvx512;
#else
constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp;
#endif
@ -111,6 +129,9 @@ constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp;
// We don't know how to do runtime dotprod detection outside of linux for now.
#if RUY_PLATFORM(NEON)
constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon;
#elif RUY_PLATFORM(AVX512)
constexpr Path kAllPaths =
Path::kReference | Path::kStandardCpp | Path::kAvx512;
#else
constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp;
#endif

View File

@ -49,6 +49,17 @@ limitations under the License.
#define RUY_DONOTUSEDIRECTLY_NEON_64 \
(RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_64)
// These CPU capabilities will all be true when Skylake is enabled during
// compilation.
//
// TODO(b/138433137) Select AVX-512 at runtime rather than via compile options.
#if defined(__AVX512F__) && defined(__AVX512DQ__) && defined(__AVX512CD__) && \
defined(__AVX512BW__) && defined(__AVX512VL__)
#define RUY_DONOTUSEDIRECTLY_AVX512 1
#else
#define RUY_DONOTUSEDIRECTLY_AVX512 0
#endif
// Detect APPLE
#ifdef __APPLE__
#define RUY_DONOTUSEDIRECTLY_APPLE 1

View File

@ -66,8 +66,12 @@ const char* PathName(Path path) {
switch (path) {
RUY_PATHNAME_CASE(kReference)
RUY_PATHNAME_CASE(kStandardCpp)
#if RUY_PLATFORM(NEON)
RUY_PATHNAME_CASE(kNeon)
RUY_PATHNAME_CASE(kNeonDotprod)
#elif RUY_PLATFORM(AVX512)
RUY_PATHNAME_CASE(kAvx512)
#endif
default:
RUY_CHECK(false);
return nullptr;
@ -245,7 +249,7 @@ struct RandomRangeBounds<Scalar, false> {
inline std::default_random_engine& global_random_engine() {
static std::default_random_engine engine;
return engine;
};
}
template <typename Scalar>
struct UniformRandomDistribution {
@ -660,7 +664,7 @@ void EvalGemmlowp(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
&GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
-lhs.zero_point, -rhs.zero_point, output_pipeline);
} else
} else // NOLINT[readability/braces]
#endif
{
const auto& output_pipeline =
@ -680,7 +684,7 @@ void EvalGemmlowp(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
&GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
-lhs.zero_point, -rhs.zero_point, output_pipeline);
} else
} else // NOLINT[readability/braces]
#endif
{
const auto& output_pipeline = std::make_tuple(