From 52084c979baf3ee45ddb30a25c1753311198434a Mon Sep 17 00:00:00 2001 From: Alex Stark Date: Fri, 20 Sep 2019 09:25:24 -0700 Subject: [PATCH] Ruy: Rough optimization of x86 AVX2 float kernel. PiperOrigin-RevId: 270284629 --- .../lite/experimental/ruy/kernel_avx2.cc | 271 ++++++++++++------ 1 file changed, 185 insertions(+), 86 deletions(-) diff --git a/tensorflow/lite/experimental/ruy/kernel_avx2.cc b/tensorflow/lite/experimental/ruy/kernel_avx2.cc index eb38addc725..b97f723d42f 100644 --- a/tensorflow/lite/experimental/ruy/kernel_avx2.cc +++ b/tensorflow/lite/experimental/ruy/kernel_avx2.cc @@ -46,6 +46,82 @@ static constexpr int kAvxFloatBlockSize = 8; static constexpr int kAvx8bitBlockSize = 8; static constexpr int kAvx8bitInnerSize = 4; +namespace { +inline float mm256_get1_ps(const __m256 a, int i) { + __m256i ai = _mm256_castps_si256(a); + int float_val_as_int; + switch (i) { + case 0: + float_val_as_int = _mm256_extract_epi32(ai, 0); + break; + case 1: + float_val_as_int = _mm256_extract_epi32(ai, 1); + break; + case 2: + float_val_as_int = _mm256_extract_epi32(ai, 2); + break; + case 3: + float_val_as_int = _mm256_extract_epi32(ai, 3); + break; + case 4: + float_val_as_int = _mm256_extract_epi32(ai, 4); + break; + case 5: + float_val_as_int = _mm256_extract_epi32(ai, 5); + break; + case 6: + float_val_as_int = _mm256_extract_epi32(ai, 6); + break; + case 7: + float_val_as_int = _mm256_extract_epi32(ai, 7); + break; + default: + RUY_DCHECK_LT(i, 8); + return .0f; + } + return reinterpret_cast(float_val_as_int); +} + +inline __m256 mm256_n_loadu_ps(int i, const float* src) { + switch (i) { + case 0: + return _mm256_setzero_ps(); + case 1: + return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f), + _mm_setzero_ps()); + case 2: + return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f), + _mm_setzero_ps()); + case 3: + return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f), + _mm_setzero_ps()); + case 4: + return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]), + _mm_setzero_ps()); + case 5: + return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f, + .0f); + case 6: + return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f, + .0f); + case 7: + return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], + src[6], .0f); + case 8: + return _mm256_loadu_ps(src); + default: + RUY_DCHECK(i < 9); + return _mm256_setzero_ps(); + } +} + +inline void _mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { + for (int i = 0; i < residual_rows; ++i) { + dst[i] = mm256_get1_ps(v, i); + } +} +} // namespace + void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { gemmlowp::ScopedProfilingLabel label("Kernel kAvx2"); @@ -110,8 +186,6 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; } - // - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { for (int j = 0; j < kAvx8bitBlockSize; ++j) { for (int i = 0; i < kAvx8bitBlockSize; ++i) { @@ -172,8 +246,6 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { } } - // - for (int j = 0; j < kAvx8bitBlockSize; ++j) { for (int i = 0; i < kAvx8bitBlockSize; ++i) { accum_data[j][i] = @@ -316,105 +388,132 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { gemmlowp::ScopedProfilingLabel label("Kernel kAvx2"); - float lhs_data[kAvxFloatBlockSize]; - float rhs_data[kAvxFloatBlockSize]; - float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = params.dst_base_ptr; + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + const std::int64_t dst_stride = params.dst_stride >> 2; + const std::int64_t rhs_stride = params.rhs_stride >> 2; + // + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + // kAvxFloatBlockSize = 8. + const int end_row = std::min(params.dst_rows, params.last_row + 8); + const int end_col = std::min(params.dst_cols, params.last_col + 8); + // + const float* adj_rhs_col_ptr = + params.rhs_base_ptr - params.start_col * rhs_stride; + float* adj_dst_col_ptr = + params.dst_base_ptr - params.start_col * dst_stride - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; const float* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - for (int col = params.start_col; col <= params.last_col; - col += kAvxFloatBlockSize) { - const float* lhs_col_ptr = params.lhs_base_ptr; - float* dst_ptr = dst_col_ptr; - const float* bias_ptr = bias_col_ptr; + const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); + const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); - for (int row = params.start_row; row <= params.last_row; - row += kAvxFloatBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvxFloatBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvxFloatBlockSize); + int col = params.start_col; + // Loop through cols by kAvxFloatBlockSize, leaving incomplete remainder + for (; col <= end_col - 8; col += 8) { + __m256 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + for (int row = params.start_row; row < end_row; row += 8) { + const int residual_rows = std::min(end_row - row, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; // Initialize with bias. - float initial_accum_data[kAvxFloatBlockSize]; - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - initial_accum_data[i] = 0.0f; + const __m256 initial_accum_data = + mm256_n_loadu_ps(residual_rows, bias_ptr); + + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; const float* lhs_ptr = lhs_col_ptr; const float* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; ++d) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - lhs_data[i] = lhs_ptr[i]; - rhs_data[i] = rhs_ptr[i]; + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr); + + for (int j = 0; j < 8; ++j) { + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] += lhs_data[i] * rhs_data[j]; - } - } - lhs_ptr += kAvxFloatBlockSize; - rhs_ptr += kAvxFloatBlockSize; + lhs_ptr += 8; + rhs_ptr += 8; } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); + if (residual_rows == 8) { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + _mm256_storeu_ps(block_ptr, accum_data_v[j]); + } + } else { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + _mm256_n_storeu_ps(block_ptr, residual_rows, accum_data_v[j]); } } - - const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && - (residual_cols == kAvxFloatBlockSize); - - { - float* block_ptr = - store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); - const int block_col_offset = store_full_block - ? params.dst_stride / sizeof(float) - : kAvxFloatBlockSize; - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - block_ptr[i] = accum_data[j][i]; - } - block_ptr += block_col_offset; - } - } - if (!store_full_block) { - const float* block_ptr = params.dst_tmp_buf; - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; - } - block_ptr += kAvxFloatBlockSize; - } - } - - lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); - dst_ptr += kAvxFloatBlockSize; } // End row-block loop. + } // End col-block loop. - dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); - rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); - } // End col-block loop. + if (col < end_col) { + // Remaining cols in [0, kAvxFloatBlockSize). + RUY_DCHECK_GE(end_col - col, 0); + RUY_DCHECK_LT(end_col - col, 8); + + __m256 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + const int residual_cols = std::min(end_col - col, 8); + + for (int row = params.start_row; row < end_row; row += 8) { + const int residual_rows = std::min(end_row - row, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __m256 initial_accum_data = + mm256_n_loadu_ps(residual_rows, bias_ptr); + + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr); + + for (int j = 0; j < 8; ++j) { + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); + } + lhs_ptr += 8; + rhs_ptr += 8; + } + + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + _mm256_n_storeu_ps(block_ptr, residual_rows, accum_data_v[j]); + } + } // End row-block loop. + } // End col-block terminal conditional. } #endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)