Add SSE4 path for Tanh and Logistic.
PiperOrigin-RevId: 311320167 Change-Id: Ie62fd09adf8e41827796d2102c5f1d505429a139
This commit is contained in:
parent
5530521a57
commit
e4702e19bb
@ -4332,6 +4332,41 @@ inline void Logistic(const LogisticParams& params,
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef GEMMLOWP_SSE4
|
||||
{
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
// This is the return type of math functions such as tanh, logistic,
|
||||
// whose range is in [-1, 1].
|
||||
using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
|
||||
// F3 uses 3 integer bits, range [-8, 8], the input range expected here.
|
||||
using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
|
||||
|
||||
for (; c <= flat_size - 16; c += 16) {
|
||||
F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
|
||||
F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
|
||||
reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
|
||||
F0 output0 = gemmlowp::logistic(input0);
|
||||
F0 output1 = gemmlowp::logistic(input1);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
|
||||
output0.raw().v);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
|
||||
output1.raw().v);
|
||||
input_data_ptr += 16;
|
||||
output_data_ptr += 16;
|
||||
}
|
||||
for (; c <= flat_size - 8; c += 8) {
|
||||
F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
|
||||
F0 output = gemmlowp::logistic(input);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
|
||||
output.raw().v);
|
||||
input_data_ptr += 8;
|
||||
output_data_ptr += 8;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
{
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
// This is the return type of math functions such as tanh, logistic,
|
||||
@ -4438,6 +4473,72 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef GEMMLOWP_SSE4
|
||||
{
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
// This is the return type of math functions such as tanh, logistic,
|
||||
// whose range is in [-1, 1].
|
||||
using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
|
||||
// F3 uses 3 integer bits, range [-8, 8], the input range expected here.
|
||||
using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
|
||||
|
||||
if (input_left_shift == 0) {
|
||||
for (; c <= flat_size - 16; c += 16) {
|
||||
F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
|
||||
F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
|
||||
reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
|
||||
F0 output0 = gemmlowp::tanh(input0);
|
||||
F0 output1 = gemmlowp::tanh(input1);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
|
||||
output0.raw().v);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
|
||||
output1.raw().v);
|
||||
|
||||
input_data_ptr += 16;
|
||||
output_data_ptr += 16;
|
||||
}
|
||||
for (; c <= flat_size - 8; c += 8) {
|
||||
F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
|
||||
F0 output = gemmlowp::tanh(input);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
|
||||
output.raw().v);
|
||||
input_data_ptr += 8;
|
||||
output_data_ptr += 8;
|
||||
}
|
||||
} else {
|
||||
for (; c <= flat_size - 16; c += 16) {
|
||||
F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
|
||||
gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
|
||||
reinterpret_cast<const __m128i*>(input_data_ptr)))));
|
||||
F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
|
||||
gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
|
||||
reinterpret_cast<const __m128i*>(input_data_ptr + 8)))));
|
||||
F0 output0 = gemmlowp::tanh(input0);
|
||||
F0 output1 = gemmlowp::tanh(input1);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
|
||||
output0.raw().v);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
|
||||
output1.raw().v);
|
||||
|
||||
input_data_ptr += 16;
|
||||
output_data_ptr += 16;
|
||||
}
|
||||
for (; c <= flat_size - 8; c += 8) {
|
||||
F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
|
||||
gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
|
||||
reinterpret_cast<const __m128i*>(input_data_ptr)))));
|
||||
F0 output = gemmlowp::tanh(input);
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
|
||||
output.raw().v);
|
||||
input_data_ptr += 8;
|
||||
output_data_ptr += 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
{
|
||||
// F0 uses 0 integer bits, range [-1, 1].
|
||||
// This is the return type of math functions such as tanh, logistic,
|
||||
|
@ -354,11 +354,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
|
||||
tf_http_archive(
|
||||
name = "gemmlowp",
|
||||
sha256 = "6678b484d929f2d0d3229d8ac4e3b815a950c86bb9f17851471d143f6d4f7834", # SHARED_GEMMLOWP_SHA
|
||||
strip_prefix = "gemmlowp-12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3",
|
||||
sha256 = "43146e6f56cb5218a8caaab6b5d1601a083f1f31c06ff474a4378a7d35be9cfb", # SHARED_GEMMLOWP_SHA
|
||||
strip_prefix = "gemmlowp-fda83bdc38b118cc6b56753bd540caa49e570745",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip",
|
||||
"https://github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/gemmlowp/archive/fda83bdc38b118cc6b56753bd540caa49e570745.zip",
|
||||
"https://github.com/google/gemmlowp/archive/fda83bdc38b118cc6b56753bd540caa49e570745.zip",
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user