[TF] Fix CTC compilation on MacOS: make kLogZero a templated constexpr function.
PiperOrigin-RevId: 264729008
This commit is contained in:
parent
6fd82802c1
commit
b34c9ad405
@ -44,13 +44,11 @@ struct EmptyBeamState {};
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct BeamProbability {
|
struct BeamProbability {
|
||||||
BeamProbability()
|
BeamProbability()
|
||||||
: total(kLogZero<T>::val),
|
: total(kLogZero<T>()), blank(kLogZero<T>()), label(kLogZero<T>()) {}
|
||||||
blank(kLogZero<T>::val),
|
|
||||||
label(kLogZero<T>::val) {}
|
|
||||||
void Reset() {
|
void Reset() {
|
||||||
total = kLogZero<T>::val;
|
total = kLogZero<T>();
|
||||||
blank = kLogZero<T>::val;
|
blank = kLogZero<T>();
|
||||||
label = kLogZero<T>::val;
|
label = kLogZero<T>();
|
||||||
}
|
}
|
||||||
T total;
|
T total;
|
||||||
T blank;
|
T blank;
|
||||||
@ -65,7 +63,7 @@ struct BeamEntry {
|
|||||||
// BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
|
// BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
|
||||||
friend BeamEntry<T, CTCBeamState>* BeamRoot<T, CTCBeamState>::AddEntry(
|
friend BeamEntry<T, CTCBeamState>* BeamRoot<T, CTCBeamState>::AddEntry(
|
||||||
BeamEntry<T, CTCBeamState>* p, int l);
|
BeamEntry<T, CTCBeamState>* p, int l);
|
||||||
inline bool Active() const { return newp.total != kLogZero<T>::val; }
|
inline bool Active() const { return newp.total != kLogZero<T>(); }
|
||||||
// Return the child at the given index, or construct a new one in-place if
|
// Return the child at the given index, or construct a new one in-place if
|
||||||
// none was found.
|
// none was found.
|
||||||
BeamEntry<T, CTCBeamState>& GetChild(int ind) {
|
BeamEntry<T, CTCBeamState>& GetChild(int ind) {
|
||||||
|
|||||||
@ -327,7 +327,7 @@ void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Step(
|
|||||||
// isn't full, or the lowest probability entry in the beam has a
|
// isn't full, or the lowest probability entry in the beam has a
|
||||||
// lower probability than the leaf.
|
// lower probability than the leaf.
|
||||||
auto is_candidate = [this](const BeamProbability& prob) {
|
auto is_candidate = [this](const BeamProbability& prob) {
|
||||||
return (prob.total > kLogZero<T>::val &&
|
return (prob.total > kLogZero<T>() &&
|
||||||
(leaves_.size() < beam_width_ ||
|
(leaves_.size() < beam_width_ ||
|
||||||
prob.total > leaves_.peek_bottom()->newp.total));
|
prob.total > leaves_.peek_bottom()->newp.total));
|
||||||
};
|
};
|
||||||
@ -349,7 +349,7 @@ void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Step(
|
|||||||
BeamEntry& c = b->GetChild(label);
|
BeamEntry& c = b->GetChild(label);
|
||||||
if (!c.Active()) {
|
if (!c.Active()) {
|
||||||
// Pblank(l=abcd @ t=6) = 0
|
// Pblank(l=abcd @ t=6) = 0
|
||||||
c.newp.blank = kLogZero<T>::val;
|
c.newp.blank = kLogZero<T>();
|
||||||
// If new child label is identical to beam label:
|
// If new child label is identical to beam label:
|
||||||
// Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
|
// Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
|
||||||
// Otherwise:
|
// Otherwise:
|
||||||
|
|||||||
@ -230,7 +230,7 @@ Status CTCLossCalculator<T>::CalculateLoss(
|
|||||||
|
|
||||||
// The loss is computed as the log(p(z|x)) between the target and
|
// The loss is computed as the log(p(z|x)) between the target and
|
||||||
// prediction. Do lazy evaluation of log_prob here.
|
// prediction. Do lazy evaluation of log_prob here.
|
||||||
T log_p_z_x = kLogZero<T>::val;
|
T log_p_z_x = kLogZero<T>();
|
||||||
for (int u = 0; u < l_prime.size(); ++u) {
|
for (int u = 0; u < l_prime.size(); ++u) {
|
||||||
// (GravesTh) Eq 7.26, sum over all paths for t = 0.
|
// (GravesTh) Eq 7.26, sum over all paths for t = 0.
|
||||||
log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
|
log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
|
||||||
@ -377,7 +377,7 @@ void CTCLossCalculator<TT>::CalculateForwardVariables(
|
|||||||
|
|
||||||
// Number of cols is the number of time steps = number of cols in target
|
// Number of cols is the number of time steps = number of cols in target
|
||||||
// after the output delay.
|
// after the output delay.
|
||||||
log_alpha->setConstant(kLogZero<TT>::val);
|
log_alpha->setConstant(kLogZero<TT>());
|
||||||
|
|
||||||
int U = l_prime.size();
|
int U = l_prime.size();
|
||||||
int T = log_alpha->cols();
|
int T = log_alpha->cols();
|
||||||
@ -398,7 +398,7 @@ void CTCLossCalculator<TT>::CalculateForwardVariables(
|
|||||||
++u) {
|
++u) {
|
||||||
// Begin (GravesTh) Eq 7.9
|
// Begin (GravesTh) Eq 7.9
|
||||||
// Add in the u, t - 1 term.
|
// Add in the u, t - 1 term.
|
||||||
auto sum_log_alpha = kLogZero<TT>::val;
|
auto sum_log_alpha = kLogZero<TT>();
|
||||||
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
|
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
|
||||||
sum_log_alpha = log_alpha->coeff(u, t - 1);
|
sum_log_alpha = log_alpha->coeff(u, t - 1);
|
||||||
}
|
}
|
||||||
@ -436,7 +436,7 @@ void CTCLossCalculator<TT>::CalculateBackwardVariables(
|
|||||||
// kLogZero);
|
// kLogZero);
|
||||||
using Eigen::numext::log;
|
using Eigen::numext::log;
|
||||||
|
|
||||||
log_beta->setConstant(kLogZero<TT>::val);
|
log_beta->setConstant(kLogZero<TT>());
|
||||||
int T = log_beta->cols();
|
int T = log_beta->cols();
|
||||||
int U = l_prime.size();
|
int U = l_prime.size();
|
||||||
CHECK_EQ(U, log_beta->rows());
|
CHECK_EQ(U, log_beta->rows());
|
||||||
@ -495,7 +495,7 @@ void CTCLossCalculator<TT>::CalculateGradient(const std::vector<int>& l_prime,
|
|||||||
|
|
||||||
// It is possible that no valid path is found if the activations for the
|
// It is possible that no valid path is found if the activations for the
|
||||||
// targets are zero.
|
// targets are zero.
|
||||||
if (log_p_z_x == kLogZero<TT>::val) {
|
if (log_p_z_x == kLogZero<TT>()) {
|
||||||
LOG(WARNING) << "No valid path found.";
|
LOG(WARNING) << "No valid path found.";
|
||||||
dy_b = y;
|
dy_b = y;
|
||||||
return;
|
return;
|
||||||
@ -507,7 +507,7 @@ void CTCLossCalculator<TT>::CalculateGradient(const std::vector<int>& l_prime,
|
|||||||
|
|
||||||
for (int t = 0; t < T - output_delay_; ++t) {
|
for (int t = 0; t < T - output_delay_; ++t) {
|
||||||
Array prob_sum(L);
|
Array prob_sum(L);
|
||||||
prob_sum.setConstant(kLogZero<TT>::val);
|
prob_sum.setConstant(kLogZero<TT>());
|
||||||
|
|
||||||
for (int u = 0; u < U; ++u) {
|
for (int u = 0; u < U; ++u) {
|
||||||
int l = l_prime[u];
|
int l = l_prime[u];
|
||||||
|
|||||||
@ -24,9 +24,9 @@ namespace tensorflow {
|
|||||||
namespace ctc {
|
namespace ctc {
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
struct kLogZero {
|
constexpr T kLogZero() {
|
||||||
static constexpr T val = -std::numeric_limits<T>::infinity(); // NOLINT
|
return -std::numeric_limits<T>::infinity(); // NOLINT
|
||||||
};
|
}
|
||||||
|
|
||||||
// Add logarithmic probabilities using:
|
// Add logarithmic probabilities using:
|
||||||
// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a)))
|
// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a)))
|
||||||
@ -37,9 +37,9 @@ inline T LogSumExp(T log_prob_1, T log_prob_2) {
|
|||||||
// const T kLogZero = -std::numeric_limits<T>::infinity();
|
// const T kLogZero = -std::numeric_limits<T>::infinity();
|
||||||
// Always have 'b' be the smaller number to avoid the exponential from
|
// Always have 'b' be the smaller number to avoid the exponential from
|
||||||
// blowing up.
|
// blowing up.
|
||||||
if (log_prob_1 == kLogZero<T>::val) {
|
if (log_prob_1 == kLogZero<T>()) {
|
||||||
return log_prob_2;
|
return log_prob_2;
|
||||||
} else if (log_prob_2 == kLogZero<T>::val) {
|
} else if (log_prob_2 == kLogZero<T>()) {
|
||||||
return log_prob_1;
|
return log_prob_1;
|
||||||
} else {
|
} else {
|
||||||
return (log_prob_1 > log_prob_2)
|
return (log_prob_1 > log_prob_2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user