Change log to log1p in Random Gamma sampler.

PiperOrigin-RevId: 307853432
Change-Id: Id5c7ffbf30c45a62c22a30086fe253c59e5fffc0
This commit is contained in:
Srinivas Vasudevan 2020-04-22 11:04:12 -07:00 committed by TensorFlower Gardener
parent 85d40e74ac
commit 0493a020d4
2 changed files with 4 additions and 2 deletions

View File

@ -208,6 +208,7 @@ class RandomGammaOp : public OpKernel {
alpha_flat](int start_output, int limit_output) { alpha_flat](int start_output, int limit_output) {
using Eigen::numext::exp; using Eigen::numext::exp;
using Eigen::numext::log; using Eigen::numext::log;
using Eigen::numext::log1p;
using Eigen::numext::pow; using Eigen::numext::pow;
// Capturing "rng" by-value would only make a copy for the _shared_ // Capturing "rng" by-value would only make a copy for the _shared_
@ -241,7 +242,7 @@ class RandomGammaOp : public OpKernel {
gen.Skip(kReservedSamplesPerOutput * output_idx); gen.Skip(kReservedSamplesPerOutput * output_idx);
int16 uniform_remaining = 0; int16 uniform_remaining = 0;
UNIFORM(u); UNIFORM(u);
const double res = -log(1.0 - u); const double res = -log1p(-u);
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res); samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
} // for (sample_idx) } // for (sample_idx)
} else { // if alpha != 1.0 } else { // if alpha != 1.0

View File

@ -259,6 +259,7 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase {
using Eigen::numext::exp; using Eigen::numext::exp;
using Eigen::numext::log; using Eigen::numext::log;
using Eigen::numext::log1p;
using Eigen::numext::pow; using Eigen::numext::pow;
Normal normal; Normal normal;
@ -288,7 +289,7 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase {
gen.Skip(kReservedSamplesPerOutput * output_idx); gen.Skip(kReservedSamplesPerOutput * output_idx);
int16 uniform_remaining = 0; int16 uniform_remaining = 0;
UNIFORM(u); UNIFORM(u);
const double res = -log(1.0 - u); const double res = -log1p(-u);
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res); samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
} // for (sample_idx) } // for (sample_idx)
} else { // if alpha != 1.0 } else { // if alpha != 1.0