Ruy: Rough optimization of x86 AVX2 float kernel.

PiperOrigin-RevId: 270284629
This commit is contained in:
Alex Stark 2019-09-20 09:25:24 -07:00 committed by TensorFlower Gardener
parent 97cfacbca3
commit 52084c979b

View File

@ -46,6 +46,82 @@ static constexpr int kAvxFloatBlockSize = 8;
static constexpr int kAvx8bitBlockSize = 8;
static constexpr int kAvx8bitInnerSize = 4;
namespace {
inline float mm256_get1_ps(const __m256 a, int i) {
__m256i ai = _mm256_castps_si256(a);
int float_val_as_int;
switch (i) {
case 0:
float_val_as_int = _mm256_extract_epi32(ai, 0);
break;
case 1:
float_val_as_int = _mm256_extract_epi32(ai, 1);
break;
case 2:
float_val_as_int = _mm256_extract_epi32(ai, 2);
break;
case 3:
float_val_as_int = _mm256_extract_epi32(ai, 3);
break;
case 4:
float_val_as_int = _mm256_extract_epi32(ai, 4);
break;
case 5:
float_val_as_int = _mm256_extract_epi32(ai, 5);
break;
case 6:
float_val_as_int = _mm256_extract_epi32(ai, 6);
break;
case 7:
float_val_as_int = _mm256_extract_epi32(ai, 7);
break;
default:
RUY_DCHECK_LT(i, 8);
return .0f;
}
return reinterpret_cast<float&>(float_val_as_int);
}
inline __m256 mm256_n_loadu_ps(int i, const float* src) {
switch (i) {
case 0:
return _mm256_setzero_ps();
case 1:
return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f),
_mm_setzero_ps());
case 2:
return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f),
_mm_setzero_ps());
case 3:
return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f),
_mm_setzero_ps());
case 4:
return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]),
_mm_setzero_ps());
case 5:
return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f,
.0f);
case 6:
return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f,
.0f);
case 7:
return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5],
src[6], .0f);
case 8:
return _mm256_loadu_ps(src);
default:
RUY_DCHECK(i < 9);
return _mm256_setzero_ps();
}
}
inline void _mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
for (int i = 0; i < residual_rows; ++i) {
dst[i] = mm256_get1_ps(v, i);
}
}
} // namespace
void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
gemmlowp::ScopedProfilingLabel label("Kernel kAvx2");
@ -110,8 +186,6 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
}
//
if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) {
for (int j = 0; j < kAvx8bitBlockSize; ++j) {
for (int i = 0; i < kAvx8bitBlockSize; ++i) {
@ -172,8 +246,6 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
}
}
//
for (int j = 0; j < kAvx8bitBlockSize; ++j) {
for (int i = 0; i < kAvx8bitBlockSize; ++i) {
accum_data[j][i] =
@ -316,105 +388,132 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
gemmlowp::ScopedProfilingLabel label("Kernel kAvx2");
float lhs_data[kAvxFloatBlockSize];
float rhs_data[kAvxFloatBlockSize];
float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize];
int bias_ptr_block_increment =
params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0;
const float* rhs_col_ptr = params.rhs_base_ptr;
float* dst_col_ptr = params.dst_base_ptr;
// As parameters are defined, we need to scale by sizeof(float).
const std::int64_t lhs_stride = params.lhs_stride >> 2;
const std::int64_t dst_stride = params.dst_stride >> 2;
const std::int64_t rhs_stride = params.rhs_stride >> 2;
//
int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
// kAvxFloatBlockSize = 8.
const int end_row = std::min(params.dst_rows, params.last_row + 8);
const int end_col = std::min(params.dst_cols, params.last_col + 8);
//
const float* adj_rhs_col_ptr =
params.rhs_base_ptr - params.start_col * rhs_stride;
float* adj_dst_col_ptr =
params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
const float* adj_lhs_col_ptr =
params.lhs_base_ptr - params.start_row * lhs_stride;
const float* bias_col_ptr = params.bias;
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
bias_col_ptr += params.start_row;
}
for (int col = params.start_col; col <= params.last_col;
col += kAvxFloatBlockSize) {
const float* lhs_col_ptr = params.lhs_base_ptr;
float* dst_ptr = dst_col_ptr;
const float* bias_ptr = bias_col_ptr;
const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
for (int row = params.start_row; row <= params.last_row;
row += kAvxFloatBlockSize) {
const int residual_rows =
std::min(params.dst_rows - row, kAvxFloatBlockSize);
const int residual_cols =
std::min(params.dst_cols - col, kAvxFloatBlockSize);
int col = params.start_col;
// Loop through cols by kAvxFloatBlockSize, leaving incomplete remainder
for (; col <= end_col - 8; col += 8) {
__m256 accum_data_v[8];
const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
for (int row = params.start_row; row < end_row; row += 8) {
const int residual_rows = std::min(end_row - row, 8);
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
float initial_accum_data[kAvxFloatBlockSize];
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
initial_accum_data[i] = 0.0f;
const __m256 initial_accum_data =
mm256_n_loadu_ps(residual_rows, bias_ptr);
for (int j = 0; j < 8; ++j) {
accum_data_v[j] = initial_accum_data;
}
for (int i = 0; i < residual_rows; ++i) {
initial_accum_data[i] = bias_ptr[i];
}
for (int j = 0; j < kAvxFloatBlockSize; ++j) {
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
accum_data[j][i] = initial_accum_data[i];
}
}
bias_ptr += bias_ptr_block_increment;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; ++d) {
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
lhs_data[i] = lhs_ptr[i];
rhs_data[i] = rhs_ptr[i];
const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
for (int j = 0; j < 8; ++j) {
const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
accum_data_v[j] =
_mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
}
for (int j = 0; j < kAvxFloatBlockSize; ++j) {
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
accum_data[j][i] += lhs_data[i] * rhs_data[j];
}
}
lhs_ptr += kAvxFloatBlockSize;
rhs_ptr += kAvxFloatBlockSize;
lhs_ptr += 8;
rhs_ptr += 8;
}
for (int j = 0; j < kAvxFloatBlockSize; ++j) {
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
accum_data[j][i] =
std::min<float>(accum_data[j][i], params.clamp_max);
accum_data[j][i] =
std::max<float>(accum_data[j][i], params.clamp_min);
if (residual_rows == 8) {
for (int j = 0; j < 8; ++j) {
float* block_ptr = dst_ptr + j * dst_stride;
accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
_mm256_storeu_ps(block_ptr, accum_data_v[j]);
}
} else {
for (int j = 0; j < 8; ++j) {
float* block_ptr = dst_ptr + j * dst_stride;
accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
_mm256_n_storeu_ps(block_ptr, residual_rows, accum_data_v[j]);
}
}
const bool store_full_block = (residual_rows == kAvxFloatBlockSize) &&
(residual_cols == kAvxFloatBlockSize);
{
float* block_ptr =
store_full_block ? dst_ptr : const_cast<float*>(params.dst_tmp_buf);
const int block_col_offset = store_full_block
? params.dst_stride / sizeof(float)
: kAvxFloatBlockSize;
for (int j = 0; j < kAvxFloatBlockSize; ++j) {
for (int i = 0; i < kAvxFloatBlockSize; ++i) {
block_ptr[i] = accum_data[j][i];
}
block_ptr += block_col_offset;
}
}
if (!store_full_block) {
const float* block_ptr = params.dst_tmp_buf;
for (int j = 0; j < residual_cols; ++j) {
for (int i = 0; i < residual_rows; ++i) {
dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i];
}
block_ptr += kAvxFloatBlockSize;
}
}
lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float);
dst_ptr += kAvxFloatBlockSize;
} // End row-block loop.
} // End col-block loop.
dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float);
rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float);
} // End col-block loop.
if (col < end_col) {
// Remaining cols in [0, kAvxFloatBlockSize).
RUY_DCHECK_GE(end_col - col, 0);
RUY_DCHECK_LT(end_col - col, 8);
__m256 accum_data_v[8];
const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
const int residual_cols = std::min(end_col - col, 8);
for (int row = params.start_row; row < end_row; row += 8) {
const int residual_rows = std::min(end_row - row, 8);
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
const __m256 initial_accum_data =
mm256_n_loadu_ps(residual_rows, bias_ptr);
for (int j = 0; j < 8; ++j) {
accum_data_v[j] = initial_accum_data;
}
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; ++d) {
const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
for (int j = 0; j < 8; ++j) {
const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
accum_data_v[j] =
_mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
}
lhs_ptr += 8;
rhs_ptr += 8;
}
for (int j = 0; j < residual_cols; ++j) {
float* block_ptr = dst_ptr + j * dst_stride;
accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
_mm256_n_storeu_ps(block_ptr, residual_rows, accum_data_v[j]);
}
} // End row-block loop.
} // End col-block terminal conditional.
}
#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)