[TF] Fix CTC compilation on MacOS: make kLogZero a templated constexpr function.

PiperOrigin-RevId: 264729008
This commit is contained in:
Eugene Brevdo 2019-08-21 17:20:56 -07:00 committed by TensorFlower Gardener
parent 6fd82802c1
commit b34c9ad405
4 changed files with 18 additions and 20 deletions

View File

@ -44,13 +44,11 @@ struct EmptyBeamState {};
template <typename T> template <typename T>
struct BeamProbability { struct BeamProbability {
BeamProbability() BeamProbability()
: total(kLogZero<T>::val), : total(kLogZero<T>()), blank(kLogZero<T>()), label(kLogZero<T>()) {}
blank(kLogZero<T>::val),
label(kLogZero<T>::val) {}
void Reset() { void Reset() {
total = kLogZero<T>::val; total = kLogZero<T>();
blank = kLogZero<T>::val; blank = kLogZero<T>();
label = kLogZero<T>::val; label = kLogZero<T>();
} }
T total; T total;
T blank; T blank;
@ -65,7 +63,7 @@ struct BeamEntry {
// BeamRoot<CTCBeamState>::AddEntry() serves as the factory method. // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
friend BeamEntry<T, CTCBeamState>* BeamRoot<T, CTCBeamState>::AddEntry( friend BeamEntry<T, CTCBeamState>* BeamRoot<T, CTCBeamState>::AddEntry(
BeamEntry<T, CTCBeamState>* p, int l); BeamEntry<T, CTCBeamState>* p, int l);
inline bool Active() const { return newp.total != kLogZero<T>::val; } inline bool Active() const { return newp.total != kLogZero<T>(); }
// Return the child at the given index, or construct a new one in-place if // Return the child at the given index, or construct a new one in-place if
// none was found. // none was found.
BeamEntry<T, CTCBeamState>& GetChild(int ind) { BeamEntry<T, CTCBeamState>& GetChild(int ind) {

View File

@ -327,7 +327,7 @@ void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Step(
// isn't full, or the lowest probability entry in the beam has a // isn't full, or the lowest probability entry in the beam has a
// lower probability than the leaf. // lower probability than the leaf.
auto is_candidate = [this](const BeamProbability& prob) { auto is_candidate = [this](const BeamProbability& prob) {
return (prob.total > kLogZero<T>::val && return (prob.total > kLogZero<T>() &&
(leaves_.size() < beam_width_ || (leaves_.size() < beam_width_ ||
prob.total > leaves_.peek_bottom()->newp.total)); prob.total > leaves_.peek_bottom()->newp.total));
}; };
@ -349,7 +349,7 @@ void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Step(
BeamEntry& c = b->GetChild(label); BeamEntry& c = b->GetChild(label);
if (!c.Active()) { if (!c.Active()) {
// Pblank(l=abcd @ t=6) = 0 // Pblank(l=abcd @ t=6) = 0
c.newp.blank = kLogZero<T>::val; c.newp.blank = kLogZero<T>();
// If new child label is identical to beam label: // If new child label is identical to beam label:
// Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6) // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
// Otherwise: // Otherwise:

View File

@ -230,7 +230,7 @@ Status CTCLossCalculator<T>::CalculateLoss(
// The loss is computed as the log(p(z|x)) between the target and // The loss is computed as the log(p(z|x)) between the target and
// prediction. Do lazy evaluation of log_prob here. // prediction. Do lazy evaluation of log_prob here.
T log_p_z_x = kLogZero<T>::val; T log_p_z_x = kLogZero<T>();
for (int u = 0; u < l_prime.size(); ++u) { for (int u = 0; u < l_prime.size(); ++u) {
// (GravesTh) Eq 7.26, sum over all paths for t = 0. // (GravesTh) Eq 7.26, sum over all paths for t = 0.
log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0)); log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
@ -377,7 +377,7 @@ void CTCLossCalculator<TT>::CalculateForwardVariables(
// Number of cols is the number of time steps = number of cols in target // Number of cols is the number of time steps = number of cols in target
// after the output delay. // after the output delay.
log_alpha->setConstant(kLogZero<TT>::val); log_alpha->setConstant(kLogZero<TT>());
int U = l_prime.size(); int U = l_prime.size();
int T = log_alpha->cols(); int T = log_alpha->cols();
@ -398,7 +398,7 @@ void CTCLossCalculator<TT>::CalculateForwardVariables(
++u) { ++u) {
// Begin (GravesTh) Eq 7.9 // Begin (GravesTh) Eq 7.9
// Add in the u, t - 1 term. // Add in the u, t - 1 term.
auto sum_log_alpha = kLogZero<TT>::val; auto sum_log_alpha = kLogZero<TT>();
if (ctc_merge_repeated || l_prime[u] == blank_index_) { if (ctc_merge_repeated || l_prime[u] == blank_index_) {
sum_log_alpha = log_alpha->coeff(u, t - 1); sum_log_alpha = log_alpha->coeff(u, t - 1);
} }
@ -436,7 +436,7 @@ void CTCLossCalculator<TT>::CalculateBackwardVariables(
// kLogZero); // kLogZero);
using Eigen::numext::log; using Eigen::numext::log;
log_beta->setConstant(kLogZero<TT>::val); log_beta->setConstant(kLogZero<TT>());
int T = log_beta->cols(); int T = log_beta->cols();
int U = l_prime.size(); int U = l_prime.size();
CHECK_EQ(U, log_beta->rows()); CHECK_EQ(U, log_beta->rows());
@ -495,7 +495,7 @@ void CTCLossCalculator<TT>::CalculateGradient(const std::vector<int>& l_prime,
// It is possible that no valid path is found if the activations for the // It is possible that no valid path is found if the activations for the
// targets are zero. // targets are zero.
if (log_p_z_x == kLogZero<TT>::val) { if (log_p_z_x == kLogZero<TT>()) {
LOG(WARNING) << "No valid path found."; LOG(WARNING) << "No valid path found.";
dy_b = y; dy_b = y;
return; return;
@ -507,7 +507,7 @@ void CTCLossCalculator<TT>::CalculateGradient(const std::vector<int>& l_prime,
for (int t = 0; t < T - output_delay_; ++t) { for (int t = 0; t < T - output_delay_; ++t) {
Array prob_sum(L); Array prob_sum(L);
prob_sum.setConstant(kLogZero<TT>::val); prob_sum.setConstant(kLogZero<TT>());
for (int u = 0; u < U; ++u) { for (int u = 0; u < U; ++u) {
int l = l_prime[u]; int l = l_prime[u];

View File

@ -24,9 +24,9 @@ namespace tensorflow {
namespace ctc { namespace ctc {
template <class T> template <class T>
struct kLogZero { constexpr T kLogZero() {
static constexpr T val = -std::numeric_limits<T>::infinity(); // NOLINT return -std::numeric_limits<T>::infinity(); // NOLINT
}; }
// Add logarithmic probabilities using: // Add logarithmic probabilities using:
// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a))) // ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a)))
@ -37,9 +37,9 @@ inline T LogSumExp(T log_prob_1, T log_prob_2) {
// const T kLogZero = -std::numeric_limits<T>::infinity(); // const T kLogZero = -std::numeric_limits<T>::infinity();
// Always have 'b' be the smaller number to avoid the exponential from // Always have 'b' be the smaller number to avoid the exponential from
// blowing up. // blowing up.
if (log_prob_1 == kLogZero<T>::val) { if (log_prob_1 == kLogZero<T>()) {
return log_prob_2; return log_prob_2;
} else if (log_prob_2 == kLogZero<T>::val) { } else if (log_prob_2 == kLogZero<T>()) {
return log_prob_1; return log_prob_1;
} else { } else {
return (log_prob_1 > log_prob_2) return (log_prob_1 > log_prob_2)