Add SSE4 path for Tanh and Logistic.

PiperOrigin-RevId: 311320167
Change-Id: Ie62fd09adf8e41827796d2102c5f1d505429a139
This commit is contained in:
T.J. Alumbaugh 2020-05-13 06:58:18 -07:00 committed by TensorFlower Gardener
parent 5530521a57
commit e4702e19bb
2 changed files with 105 additions and 4 deletions

View File

@ -4332,6 +4332,41 @@ inline void Logistic(const LogisticParams& params,
}
}
#endif
#ifdef GEMMLOWP_SSE4
{
// F0 uses 0 integer bits, range [-1, 1].
// This is the return type of math functions such as tanh, logistic,
// whose range is in [-1, 1].
using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
// F3 uses 3 integer bits, range [-8, 8], the input range expected here.
using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
for (; c <= flat_size - 16; c += 16) {
F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
F0 output0 = gemmlowp::logistic(input0);
F0 output1 = gemmlowp::logistic(input1);
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
output0.raw().v);
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
output1.raw().v);
input_data_ptr += 16;
output_data_ptr += 16;
}
for (; c <= flat_size - 8; c += 8) {
F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
F0 output = gemmlowp::logistic(input);
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
output.raw().v);
input_data_ptr += 8;
output_data_ptr += 8;
}
}
#endif
{
// F0 uses 0 integer bits, range [-1, 1].
// This is the return type of math functions such as tanh, logistic,
@ -4438,6 +4473,72 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
#endif
#ifdef GEMMLOWP_SSE4
{
// F0 uses 0 integer bits, range [-1, 1].
// This is the return type of math functions such as tanh, logistic,
// whose range is in [-1, 1].
using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
// F3 uses 3 integer bits, range [-8, 8], the input range expected here.
using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
if (input_left_shift == 0) {
for (; c <= flat_size - 16; c += 16) {
F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
F0 output0 = gemmlowp::tanh(input0);
F0 output1 = gemmlowp::tanh(input1);
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
output0.raw().v);
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
output1.raw().v);
input_data_ptr += 16;
output_data_ptr += 16;
}
for (; c <= flat_size - 8; c += 8) {
F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
F0 output = gemmlowp::tanh(input);
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
output.raw().v);
input_data_ptr += 8;
output_data_ptr += 8;
}
} else {
for (; c <= flat_size - 16; c += 16) {
F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
reinterpret_cast<const __m128i*>(input_data_ptr)))));
F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
reinterpret_cast<const __m128i*>(input_data_ptr + 8)))));
F0 output0 = gemmlowp::tanh(input0);
F0 output1 = gemmlowp::tanh(input1);
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
output0.raw().v);
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
output1.raw().v);
input_data_ptr += 16;
output_data_ptr += 16;
}
for (; c <= flat_size - 8; c += 8) {
F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
reinterpret_cast<const __m128i*>(input_data_ptr)))));
F0 output = gemmlowp::tanh(input);
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
output.raw().v);
input_data_ptr += 8;
output_data_ptr += 8;
}
}
}
#endif
{
// F0 uses 0 integer bits, range [-1, 1].
// This is the return type of math functions such as tanh, logistic,

View File

@ -354,11 +354,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "gemmlowp",
sha256 = "6678b484d929f2d0d3229d8ac4e3b815a950c86bb9f17851471d143f6d4f7834", # SHARED_GEMMLOWP_SHA
strip_prefix = "gemmlowp-12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3",
sha256 = "43146e6f56cb5218a8caaab6b5d1601a083f1f31c06ff474a4378a7d35be9cfb", # SHARED_GEMMLOWP_SHA
strip_prefix = "gemmlowp-fda83bdc38b118cc6b56753bd540caa49e570745",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip",
"https://github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/gemmlowp/archive/fda83bdc38b118cc6b56753bd540caa49e570745.zip",
"https://github.com/google/gemmlowp/archive/fda83bdc38b118cc6b56753bd540caa49e570745.zip",
],
)