Modified `cumulative_logsumexp` to improve handling of `-inf` elements.

Fixes b/153928926

PiperOrigin-RevId: 307475953
Change-Id: I2d59e0076a08d04e88bac6d66dfad8b227078f42
This commit is contained in:
A. Unique TensorFlower 2020-04-20 14:23:26 -07:00 committed by TensorFlower Gardener
parent a29adb7b49
commit e885af2941
2 changed files with 77 additions and 11 deletions

View File

@ -24,6 +24,7 @@ namespace functor {
typedef Eigen::Index Index; typedef Eigen::Index Index;
// TODO(b/154339590): Needs to be vectorized.
template <typename Device, typename Reducer, typename T> template <typename Device, typename Reducer, typename T>
struct Scan { struct Scan {
void operator()(const Device& d, typename TTypes<T, 3>::ConstTensor in, void operator()(const Device& d, typename TTypes<T, 3>::ConstTensor in,
@ -44,18 +45,33 @@ template <typename T>
struct LogSumExp { struct LogSumExp {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
const T& b) const { const T& b) const {
Eigen::internal::scalar_sum_op<T> sum_op; auto mi = Eigen::internal::scalar_min_op<T>()(a, b);
Eigen::internal::scalar_exp_op<T> exp_op; auto ma = Eigen::internal::scalar_max_op<T>()(a, b);
Eigen::internal::scalar_log_op<T> log_op;
Eigen::internal::scalar_max_op<T> max_op;
Eigen::internal::scalar_min_op<T> min_op;
Eigen::internal::scalar_log1p_op<T> log1p_op;
Eigen::internal::scalar_difference_op<T> diff_op;
auto mi = min_op(a, b); auto sub = Eigen::internal::scalar_difference_op<T>();
auto ma = max_op(a, b); auto add = Eigen::internal::scalar_sum_op<T>();
auto exp = Eigen::internal::scalar_exp_op<T>();
auto log1p = Eigen::internal::scalar_log1p_op<T>();
auto cmp_lt =
Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
return sum_op(log1p_op(exp_op(diff_op(mi, ma))), ma); auto logsumexp = add(log1p(exp(sub(mi, ma))), ma);
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? ma : logsumexp;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a,
const T& b) const {
auto mi = Eigen::internal::pmin(a, b);
auto ma = Eigen::internal::pmax(a, b);
using Eigen::internal::padd;
using Eigen::internal::pcmp_lt;
using Eigen::internal::pexp;
using Eigen::internal::plog1p;
using Eigen::internal::pset1;
using Eigen::internal::psub;
auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma);
return pselect(pcmp_lt(ma, pset1(Eigen::NumTraits<T>::lowest())), ma,
logsumexp);
} }
}; };
@ -66,13 +82,58 @@ struct LogSumExpReducer {
*accum = logsumexp(*accum, t); *accum = logsumexp(*accum, t);
} }
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p,
Packet* accum) const {
LogSumExp<T> logsumexp;
*accum = logsumexp.packetOp(*accum, p);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return Eigen::NumTraits<T>::lowest(); return -Eigen::NumTraits<T>::infinity();
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
return Eigen::internal::pset1(initialize());
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
return accum; return accum;
} }
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
finalizePacket(const Packet& vaccum) const {
return vaccum;
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T
finalizeBoth(const T saccum, const Packet& vaccum) const {
auto max_reducer = Eigen::internal::MaxReducer<T>();
auto sum_reducer = Eigen::internal::SumReducer<T>();
auto exp = Eigen::internal::scalar_exp_op<T>();
auto cmp_lt =
Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
auto log = Eigen::internal::scalar_log_op<T>();
auto add = Eigen::internal::scalar_sum_op<T>();
using Eigen::internal::pexp;
using Eigen::internal::psub;
// `ma = max(x1, ..., xn)`
// If the max of all of the `xi` is `-infinity` then the result is
// -infinity. If the max is larger than `-infinity` then it's safe to use
// for normalization even if the other elements are `-infinity`.
//
// `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))`
auto ma = max_reducer.finalizeBoth(saccum, vaccum);
auto logsumexp = add(log(sum_reducer.finalizeBoth(
exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))),
ma);
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? initialize() : logsumexp;
}
}; };
} // namespace functor } // namespace functor

View File

@ -55,6 +55,11 @@ class CumulativeLogsumexpTest(test.TestCase):
reverse=reverse, exclusive=exclusive, reverse=reverse, exclusive=exclusive,
axis=axis) axis=axis)
def testMinusInfinity(self):
x = np.log([0., 0., 1., 1., 1., 1., 0., 0.])
self._testLogSumExpAllArgs(x, use_gpu=False)
self._testLogSumExpAllArgs(x, use_gpu=True)
def test1D(self): def test1D(self):
x = np.arange(10) / 10.0 - 0.5 x = np.arange(10) / 10.0 - 0.5
self._testLogSumExpAllArgs(x, use_gpu=False) self._testLogSumExpAllArgs(x, use_gpu=False)