Ruy: Rough optimization of x86 AVX2 float kernel.
PiperOrigin-RevId: 270284629
This commit is contained in:
parent
97cfacbca3
commit
52084c979b
@ -46,6 +46,82 @@ static constexpr int kAvxFloatBlockSize = 8;
|
||||
static constexpr int kAvx8bitBlockSize = 8;
|
||||
static constexpr int kAvx8bitInnerSize = 4;
|
||||
|
||||
namespace {
|
||||
inline float mm256_get1_ps(const __m256 a, int i) {
|
||||
__m256i ai = _mm256_castps_si256(a);
|
||||
int float_val_as_int;
|
||||
switch (i) {
|
||||
case 0:
|
||||
float_val_as_int = _mm256_extract_epi32(ai, 0);
|
||||
break;
|
||||
case 1:
|
||||
float_val_as_int = _mm256_extract_epi32(ai, 1);
|
||||
break;
|
||||
case 2:
|
||||
float_val_as_int = _mm256_extract_epi32(ai, 2);
|
||||
break;
|
||||
case 3:
|
||||
float_val_as_int = _mm256_extract_epi32(ai, 3);
|
||||
break;
|
||||
case 4:
|
||||
float_val_as_int = _mm256_extract_epi32(ai, 4);
|
||||
break;
|
||||
case 5:
|
||||
float_val_as_int = _mm256_extract_epi32(ai, 5);
|
||||
break;
|
||||
case 6:
|
||||
float_val_as_int = _mm256_extract_epi32(ai, 6);
|
||||
break;
|
||||
case 7:
|
||||
float_val_as_int = _mm256_extract_epi32(ai, 7);
|
||||
break;
|
||||
default:
|
||||
RUY_DCHECK_LT(i, 8);
|
||||
return .0f;
|
||||
}
|
||||
return reinterpret_cast<float&>(float_val_as_int);
|
||||
}
|
||||
|
||||
inline __m256 mm256_n_loadu_ps(int i, const float* src) {
|
||||
switch (i) {
|
||||
case 0:
|
||||
return _mm256_setzero_ps();
|
||||
case 1:
|
||||
return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f),
|
||||
_mm_setzero_ps());
|
||||
case 2:
|
||||
return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f),
|
||||
_mm_setzero_ps());
|
||||
case 3:
|
||||
return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f),
|
||||
_mm_setzero_ps());
|
||||
case 4:
|
||||
return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]),
|
||||
_mm_setzero_ps());
|
||||
case 5:
|
||||
return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f,
|
||||
.0f);
|
||||
case 6:
|
||||
return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f,
|
||||
.0f);
|
||||
case 7:
|
||||
return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5],
|
||||
src[6], .0f);
|
||||
case 8:
|
||||
return _mm256_loadu_ps(src);
|
||||
default:
|
||||
RUY_DCHECK(i < 9);
|
||||
return _mm256_setzero_ps();
|
||||
}
|
||||
}
|
||||
|
||||
inline void _mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
|
||||
for (int i = 0; i < residual_rows; ++i) {
|
||||
dst[i] = mm256_get1_ps(v, i);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
|
||||
gemmlowp::ScopedProfilingLabel label("Kernel kAvx2");
|
||||
|
||||
@ -110,8 +186,6 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
|
||||
rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) {
|
||||
for (int j = 0; j < kAvx8bitBlockSize; ++j) {
|
||||
for (int i = 0; i < kAvx8bitBlockSize; ++i) {
|
||||
@ -172,8 +246,6 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
for (int j = 0; j < kAvx8bitBlockSize; ++j) {
|
||||
for (int i = 0; i < kAvx8bitBlockSize; ++i) {
|
||||
accum_data[j][i] =
|
||||
@ -316,105 +388,132 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
|
||||
void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
|
||||
gemmlowp::ScopedProfilingLabel label("Kernel kAvx2");
|
||||
|
||||
float lhs_data[kAvxFloatBlockSize];
|
||||
float rhs_data[kAvxFloatBlockSize];
|
||||
float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize];
|
||||
int bias_ptr_block_increment =
|
||||
params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0;
|
||||
|
||||
const float* rhs_col_ptr = params.rhs_base_ptr;
|
||||
float* dst_col_ptr = params.dst_base_ptr;
|
||||
// As parameters are defined, we need to scale by sizeof(float).
|
||||
const std::int64_t lhs_stride = params.lhs_stride >> 2;
|
||||
const std::int64_t dst_stride = params.dst_stride >> 2;
|
||||
const std::int64_t rhs_stride = params.rhs_stride >> 2;
|
||||
//
|
||||
int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
|
||||
// kAvxFloatBlockSize = 8.
|
||||
const int end_row = std::min(params.dst_rows, params.last_row + 8);
|
||||
const int end_col = std::min(params.dst_cols, params.last_col + 8);
|
||||
//
|
||||
const float* adj_rhs_col_ptr =
|
||||
params.rhs_base_ptr - params.start_col * rhs_stride;
|
||||
float* adj_dst_col_ptr =
|
||||
params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
|
||||
const float* adj_lhs_col_ptr =
|
||||
params.lhs_base_ptr - params.start_row * lhs_stride;
|
||||
const float* bias_col_ptr = params.bias;
|
||||
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
|
||||
bias_col_ptr += params.start_row;
|
||||
}
|
||||
|
||||
for (int col = params.start_col; col <= params.last_col;
|
||||
col += kAvxFloatBlockSize) {
|
||||
const float* lhs_col_ptr = params.lhs_base_ptr;
|
||||
float* dst_ptr = dst_col_ptr;
|
||||
const float* bias_ptr = bias_col_ptr;
|
||||
const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
|
||||
const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
|
||||
|
||||
for (int row = params.start_row; row <= params.last_row;
|
||||
row += kAvxFloatBlockSize) {
|
||||
const int residual_rows =
|
||||
std::min(params.dst_rows - row, kAvxFloatBlockSize);
|
||||
const int residual_cols =
|
||||
std::min(params.dst_cols - col, kAvxFloatBlockSize);
|
||||
int col = params.start_col;
|
||||
// Loop through cols by kAvxFloatBlockSize, leaving incomplete remainder
|
||||
for (; col <= end_col - 8; col += 8) {
|
||||
__m256 accum_data_v[8];
|
||||
|
||||
const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
|
||||
float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
|
||||
|
||||
for (int row = params.start_row; row < end_row; row += 8) {
|
||||
const int residual_rows = std::min(end_row - row, 8);
|
||||
|
||||
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
|
||||
float* dst_ptr = dst_col_ptr + row;
|
||||
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
|
||||
|
||||
// Initialize with bias.
|
||||
float initial_accum_data[kAvxFloatBlockSize];
|
||||
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
|
||||
initial_accum_data[i] = 0.0f;
|
||||
const __m256 initial_accum_data =
|
||||
mm256_n_loadu_ps(residual_rows, bias_ptr);
|
||||
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
accum_data_v[j] = initial_accum_data;
|
||||
}
|
||||
for (int i = 0; i < residual_rows; ++i) {
|
||||
initial_accum_data[i] = bias_ptr[i];
|
||||
}
|
||||
for (int j = 0; j < kAvxFloatBlockSize; ++j) {
|
||||
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
|
||||
accum_data[j][i] = initial_accum_data[i];
|
||||
}
|
||||
}
|
||||
bias_ptr += bias_ptr_block_increment;
|
||||
|
||||
const float* lhs_ptr = lhs_col_ptr;
|
||||
const float* rhs_ptr = rhs_col_ptr;
|
||||
for (int d = 0; d < params.depth; ++d) {
|
||||
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
|
||||
lhs_data[i] = lhs_ptr[i];
|
||||
rhs_data[i] = rhs_ptr[i];
|
||||
const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
|
||||
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
|
||||
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
|
||||
accum_data_v[j] =
|
||||
_mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
|
||||
}
|
||||
for (int j = 0; j < kAvxFloatBlockSize; ++j) {
|
||||
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
|
||||
accum_data[j][i] += lhs_data[i] * rhs_data[j];
|
||||
}
|
||||
}
|
||||
lhs_ptr += kAvxFloatBlockSize;
|
||||
rhs_ptr += kAvxFloatBlockSize;
|
||||
lhs_ptr += 8;
|
||||
rhs_ptr += 8;
|
||||
}
|
||||
|
||||
for (int j = 0; j < kAvxFloatBlockSize; ++j) {
|
||||
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
|
||||
accum_data[j][i] =
|
||||
std::min<float>(accum_data[j][i], params.clamp_max);
|
||||
accum_data[j][i] =
|
||||
std::max<float>(accum_data[j][i], params.clamp_min);
|
||||
if (residual_rows == 8) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
float* block_ptr = dst_ptr + j * dst_stride;
|
||||
accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
|
||||
accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
|
||||
_mm256_storeu_ps(block_ptr, accum_data_v[j]);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
float* block_ptr = dst_ptr + j * dst_stride;
|
||||
accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
|
||||
accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
|
||||
_mm256_n_storeu_ps(block_ptr, residual_rows, accum_data_v[j]);
|
||||
}
|
||||
}
|
||||
|
||||
const bool store_full_block = (residual_rows == kAvxFloatBlockSize) &&
|
||||
(residual_cols == kAvxFloatBlockSize);
|
||||
|
||||
{
|
||||
float* block_ptr =
|
||||
store_full_block ? dst_ptr : const_cast<float*>(params.dst_tmp_buf);
|
||||
const int block_col_offset = store_full_block
|
||||
? params.dst_stride / sizeof(float)
|
||||
: kAvxFloatBlockSize;
|
||||
for (int j = 0; j < kAvxFloatBlockSize; ++j) {
|
||||
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
|
||||
block_ptr[i] = accum_data[j][i];
|
||||
}
|
||||
block_ptr += block_col_offset;
|
||||
}
|
||||
}
|
||||
if (!store_full_block) {
|
||||
const float* block_ptr = params.dst_tmp_buf;
|
||||
for (int j = 0; j < residual_cols; ++j) {
|
||||
for (int i = 0; i < residual_rows; ++i) {
|
||||
dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i];
|
||||
}
|
||||
block_ptr += kAvxFloatBlockSize;
|
||||
}
|
||||
}
|
||||
|
||||
lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float);
|
||||
dst_ptr += kAvxFloatBlockSize;
|
||||
} // End row-block loop.
|
||||
} // End col-block loop.
|
||||
|
||||
dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float);
|
||||
rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float);
|
||||
} // End col-block loop.
|
||||
if (col < end_col) {
|
||||
// Remaining cols in [0, kAvxFloatBlockSize).
|
||||
RUY_DCHECK_GE(end_col - col, 0);
|
||||
RUY_DCHECK_LT(end_col - col, 8);
|
||||
|
||||
__m256 accum_data_v[8];
|
||||
|
||||
const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
|
||||
float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
|
||||
const int residual_cols = std::min(end_col - col, 8);
|
||||
|
||||
for (int row = params.start_row; row < end_row; row += 8) {
|
||||
const int residual_rows = std::min(end_row - row, 8);
|
||||
|
||||
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
|
||||
float* dst_ptr = dst_col_ptr + row;
|
||||
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
|
||||
|
||||
// Initialize with bias.
|
||||
const __m256 initial_accum_data =
|
||||
mm256_n_loadu_ps(residual_rows, bias_ptr);
|
||||
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
accum_data_v[j] = initial_accum_data;
|
||||
}
|
||||
|
||||
const float* lhs_ptr = lhs_col_ptr;
|
||||
const float* rhs_ptr = rhs_col_ptr;
|
||||
for (int d = 0; d < params.depth; ++d) {
|
||||
const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
|
||||
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
|
||||
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
|
||||
accum_data_v[j] =
|
||||
_mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
|
||||
}
|
||||
lhs_ptr += 8;
|
||||
rhs_ptr += 8;
|
||||
}
|
||||
|
||||
for (int j = 0; j < residual_cols; ++j) {
|
||||
float* block_ptr = dst_ptr + j * dst_stride;
|
||||
accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
|
||||
accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
|
||||
_mm256_n_storeu_ps(block_ptr, residual_rows, accum_data_v[j]);
|
||||
}
|
||||
} // End row-block loop.
|
||||
} // End col-block terminal conditional.
|
||||
}
|
||||
|
||||
#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)
|
||||
|
Loading…
x
Reference in New Issue
Block a user