Replace log(1 + x) with numerically more stable log1p(x)

This commit is contained in:
Lukas Geiger 2019-02-27 00:35:53 +00:00
parent cf0f741491
commit f42d9846f6
5 changed files with 6 additions and 6 deletions

View File

@ -69,12 +69,12 @@ class LogisticLossUpdater : public DualLossUpdater {
if (y_wx > 0) { if (y_wx > 0) {
// 0 + log(e^(0) + e^(-ywx - 0)) // 0 + log(e^(0) + e^(-ywx - 0))
// log(1 + e^(-ywx)) // log(1 + e^(-ywx))
return log(1 + exp(-y_wx)) * example_weight; return log1p(exp(-y_wx)) * example_weight;
} }
// -ywx + log(e^(ywx) + e^(-ywx + ywx)) // -ywx + log(e^(ywx) + e^(-ywx + ywx))
// log(e^(ywx) + e^(0)) - ywx // log(e^(ywx) + e^(0)) - ywx
// log(1 + e^(ywx)) - ywx // log(1 + e^(ywx)) - ywx
return (log(1 + exp(y_wx)) - y_wx) * example_weight; return (log1p(exp(y_wx)) - y_wx) * example_weight;
} }
// Derivative of logistic loss // Derivative of logistic loss

View File

@ -196,7 +196,7 @@ void MfccMelFilterbank::Compute(const std::vector<double> &input,
} }
double MfccMelFilterbank::FreqToMel(double freq) const { double MfccMelFilterbank::FreqToMel(double freq) const {
return 1127.0 * log(1.0 + (freq / 700.0)); return 1127.0 * log1p(freq / 700.0);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -154,7 +154,7 @@ int64 UniformSampler::Sample(random::SimplePhilox* rnd) const {
float UniformSampler::Probability(int64 value) const { return inv_range_; } float UniformSampler::Probability(int64 value) const { return inv_range_; }
LogUniformSampler::LogUniformSampler(int64 range) LogUniformSampler::LogUniformSampler(int64 range)
: RangeSampler(range), log_range_(log(range + 1)) {} : RangeSampler(range), log_range_(log1p(range)) {}
int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const { int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
const int64 value = const int64 value =

View File

@ -29,7 +29,7 @@ void FilterbankFillConfigWithDefaults(struct FilterbankConfig* config) {
} }
static float FreqToMel(float freq) { static float FreqToMel(float freq) {
return 1127.0 * log(1.0 + (freq / 700.0)); return 1127.0 * log1p(freq / 700.0);
} }
static void CalculateCenterFrequencies(const int num_channels, static void CalculateCenterFrequencies(const int num_channels,

View File

@ -197,7 +197,7 @@ void MfccMelFilterbank::Compute(const std::vector<double> &input,
} }
double MfccMelFilterbank::FreqToMel(double freq) const { double MfccMelFilterbank::FreqToMel(double freq) const {
return 1127.0 * log(1.0 + (freq / 700.0)); return 1127.0 * log1p(freq / 700.0);
} }
} // namespace internal } // namespace internal