diff --git a/tensorflow/core/kernels/scan_ops.h b/tensorflow/core/kernels/scan_ops.h index 1fd98f6656d..8afcac86c3f 100644 --- a/tensorflow/core/kernels/scan_ops.h +++ b/tensorflow/core/kernels/scan_ops.h @@ -24,6 +24,7 @@ namespace functor { typedef Eigen::Index Index; +// TODO(b/154339590): Needs to be vectorized. template struct Scan { void operator()(const Device& d, typename TTypes::ConstTensor in, @@ -44,18 +45,33 @@ template struct LogSumExp { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, const T& b) const { - Eigen::internal::scalar_sum_op sum_op; - Eigen::internal::scalar_exp_op exp_op; - Eigen::internal::scalar_log_op log_op; - Eigen::internal::scalar_max_op max_op; - Eigen::internal::scalar_min_op min_op; - Eigen::internal::scalar_log1p_op log1p_op; - Eigen::internal::scalar_difference_op diff_op; + auto mi = Eigen::internal::scalar_min_op()(a, b); + auto ma = Eigen::internal::scalar_max_op()(a, b); - auto mi = min_op(a, b); - auto ma = max_op(a, b); + auto sub = Eigen::internal::scalar_difference_op(); + auto add = Eigen::internal::scalar_sum_op(); + auto exp = Eigen::internal::scalar_exp_op(); + auto log1p = Eigen::internal::scalar_log1p_op(); + auto cmp_lt = + Eigen::internal::scalar_cmp_op(); - return sum_op(log1p_op(exp_op(diff_op(mi, ma))), ma); + auto logsumexp = add(log1p(exp(sub(mi, ma))), ma); + return cmp_lt(ma, Eigen::NumTraits::lowest()) ? ma : logsumexp; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a, + const T& b) const { + auto mi = Eigen::internal::pmin(a, b); + auto ma = Eigen::internal::pmax(a, b); + using Eigen::internal::padd; + using Eigen::internal::pcmp_lt; + using Eigen::internal::pexp; + using Eigen::internal::plog1p; + using Eigen::internal::pset1; + using Eigen::internal::psub; + + auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma); + return pselect(pcmp_lt(ma, pset1(Eigen::NumTraits::lowest())), ma, + logsumexp); } }; @@ -66,13 +82,58 @@ struct LogSumExpReducer { *accum = logsumexp(*accum, t); } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, + Packet* accum) const { + LogSumExp logsumexp; + *accum = logsumexp.packetOp(*accum, p); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { - return Eigen::NumTraits::lowest(); + return -Eigen::NumTraits::infinity(); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { + return Eigen::internal::pset1(initialize()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { return accum; } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet + finalizePacket(const Packet& vaccum) const { + return vaccum; + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T + finalizeBoth(const T saccum, const Packet& vaccum) const { + auto max_reducer = Eigen::internal::MaxReducer(); + auto sum_reducer = Eigen::internal::SumReducer(); + auto exp = Eigen::internal::scalar_exp_op(); + auto cmp_lt = + Eigen::internal::scalar_cmp_op(); + auto log = Eigen::internal::scalar_log_op(); + auto add = Eigen::internal::scalar_sum_op(); + + using Eigen::internal::pexp; + using Eigen::internal::psub; + + // `ma = max(x1, ..., xn)` + // If the max of all of the `xi` is `-infinity` then the result is + // -infinity. If the max is larger than `-infinity` then it's safe to use + // for normalization even if the other elements are `-infinity`. + // + // `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))` + auto ma = max_reducer.finalizeBoth(saccum, vaccum); + auto logsumexp = add(log(sum_reducer.finalizeBoth( + exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))), + ma); + return cmp_lt(ma, Eigen::NumTraits::lowest()) ? initialize() : logsumexp; + } }; } // namespace functor diff --git a/tensorflow/python/kernel_tests/cumulative_logsumexp_test.py b/tensorflow/python/kernel_tests/cumulative_logsumexp_test.py index aae624f6605..2b0309f26c4 100644 --- a/tensorflow/python/kernel_tests/cumulative_logsumexp_test.py +++ b/tensorflow/python/kernel_tests/cumulative_logsumexp_test.py @@ -55,6 +55,11 @@ class CumulativeLogsumexpTest(test.TestCase): reverse=reverse, exclusive=exclusive, axis=axis) + def testMinusInfinity(self): + x = np.log([0., 0., 1., 1., 1., 1., 0., 0.]) + self._testLogSumExpAllArgs(x, use_gpu=False) + self._testLogSumExpAllArgs(x, use_gpu=True) + def test1D(self): x = np.arange(10) / 10.0 - 0.5 self._testLogSumExpAllArgs(x, use_gpu=False)