Modified `cumulative_logsumexp` to improve handling of `-inf` elements.
Fixes b/153928926 PiperOrigin-RevId: 307475953 Change-Id: I2d59e0076a08d04e88bac6d66dfad8b227078f42
This commit is contained in:
parent
a29adb7b49
commit
e885af2941
|
@ -24,6 +24,7 @@ namespace functor {
|
||||||
|
|
||||||
typedef Eigen::Index Index;
|
typedef Eigen::Index Index;
|
||||||
|
|
||||||
|
// TODO(b/154339590): Needs to be vectorized.
|
||||||
template <typename Device, typename Reducer, typename T>
|
template <typename Device, typename Reducer, typename T>
|
||||||
struct Scan {
|
struct Scan {
|
||||||
void operator()(const Device& d, typename TTypes<T, 3>::ConstTensor in,
|
void operator()(const Device& d, typename TTypes<T, 3>::ConstTensor in,
|
||||||
|
@ -44,18 +45,33 @@ template <typename T>
|
||||||
struct LogSumExp {
|
struct LogSumExp {
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
|
||||||
const T& b) const {
|
const T& b) const {
|
||||||
Eigen::internal::scalar_sum_op<T> sum_op;
|
auto mi = Eigen::internal::scalar_min_op<T>()(a, b);
|
||||||
Eigen::internal::scalar_exp_op<T> exp_op;
|
auto ma = Eigen::internal::scalar_max_op<T>()(a, b);
|
||||||
Eigen::internal::scalar_log_op<T> log_op;
|
|
||||||
Eigen::internal::scalar_max_op<T> max_op;
|
|
||||||
Eigen::internal::scalar_min_op<T> min_op;
|
|
||||||
Eigen::internal::scalar_log1p_op<T> log1p_op;
|
|
||||||
Eigen::internal::scalar_difference_op<T> diff_op;
|
|
||||||
|
|
||||||
auto mi = min_op(a, b);
|
auto sub = Eigen::internal::scalar_difference_op<T>();
|
||||||
auto ma = max_op(a, b);
|
auto add = Eigen::internal::scalar_sum_op<T>();
|
||||||
|
auto exp = Eigen::internal::scalar_exp_op<T>();
|
||||||
|
auto log1p = Eigen::internal::scalar_log1p_op<T>();
|
||||||
|
auto cmp_lt =
|
||||||
|
Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
|
||||||
|
|
||||||
return sum_op(log1p_op(exp_op(diff_op(mi, ma))), ma);
|
auto logsumexp = add(log1p(exp(sub(mi, ma))), ma);
|
||||||
|
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? ma : logsumexp;
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a,
|
||||||
|
const T& b) const {
|
||||||
|
auto mi = Eigen::internal::pmin(a, b);
|
||||||
|
auto ma = Eigen::internal::pmax(a, b);
|
||||||
|
using Eigen::internal::padd;
|
||||||
|
using Eigen::internal::pcmp_lt;
|
||||||
|
using Eigen::internal::pexp;
|
||||||
|
using Eigen::internal::plog1p;
|
||||||
|
using Eigen::internal::pset1;
|
||||||
|
using Eigen::internal::psub;
|
||||||
|
|
||||||
|
auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma);
|
||||||
|
return pselect(pcmp_lt(ma, pset1(Eigen::NumTraits<T>::lowest())), ma,
|
||||||
|
logsumexp);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -66,13 +82,58 @@ struct LogSumExpReducer {
|
||||||
*accum = logsumexp(*accum, t);
|
*accum = logsumexp(*accum, t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p,
|
||||||
|
Packet* accum) const {
|
||||||
|
LogSumExp<T> logsumexp;
|
||||||
|
*accum = logsumexp.packetOp(*accum, p);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
||||||
return Eigen::NumTraits<T>::lowest();
|
return -Eigen::NumTraits<T>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
|
||||||
|
return Eigen::internal::pset1(initialize());
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
||||||
return accum;
|
return accum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
|
||||||
|
finalizePacket(const Packet& vaccum) const {
|
||||||
|
return vaccum;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T
|
||||||
|
finalizeBoth(const T saccum, const Packet& vaccum) const {
|
||||||
|
auto max_reducer = Eigen::internal::MaxReducer<T>();
|
||||||
|
auto sum_reducer = Eigen::internal::SumReducer<T>();
|
||||||
|
auto exp = Eigen::internal::scalar_exp_op<T>();
|
||||||
|
auto cmp_lt =
|
||||||
|
Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
|
||||||
|
auto log = Eigen::internal::scalar_log_op<T>();
|
||||||
|
auto add = Eigen::internal::scalar_sum_op<T>();
|
||||||
|
|
||||||
|
using Eigen::internal::pexp;
|
||||||
|
using Eigen::internal::psub;
|
||||||
|
|
||||||
|
// `ma = max(x1, ..., xn)`
|
||||||
|
// If the max of all of the `xi` is `-infinity` then the result is
|
||||||
|
// -infinity. If the max is larger than `-infinity` then it's safe to use
|
||||||
|
// for normalization even if the other elements are `-infinity`.
|
||||||
|
//
|
||||||
|
// `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))`
|
||||||
|
auto ma = max_reducer.finalizeBoth(saccum, vaccum);
|
||||||
|
auto logsumexp = add(log(sum_reducer.finalizeBoth(
|
||||||
|
exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))),
|
||||||
|
ma);
|
||||||
|
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? initialize() : logsumexp;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
|
@ -55,6 +55,11 @@ class CumulativeLogsumexpTest(test.TestCase):
|
||||||
reverse=reverse, exclusive=exclusive,
|
reverse=reverse, exclusive=exclusive,
|
||||||
axis=axis)
|
axis=axis)
|
||||||
|
|
||||||
|
def testMinusInfinity(self):
|
||||||
|
x = np.log([0., 0., 1., 1., 1., 1., 0., 0.])
|
||||||
|
self._testLogSumExpAllArgs(x, use_gpu=False)
|
||||||
|
self._testLogSumExpAllArgs(x, use_gpu=True)
|
||||||
|
|
||||||
def test1D(self):
|
def test1D(self):
|
||||||
x = np.arange(10) / 10.0 - 0.5
|
x = np.arange(10) / 10.0 - 0.5
|
||||||
self._testLogSumExpAllArgs(x, use_gpu=False)
|
self._testLogSumExpAllArgs(x, use_gpu=False)
|
||||||
|
|
Loading…
Reference in New Issue