Change log
to log1p
in Random Gamma sampler.
PiperOrigin-RevId: 307853432 Change-Id: Id5c7ffbf30c45a62c22a30086fe253c59e5fffc0
This commit is contained in:
parent
85d40e74ac
commit
0493a020d4
@ -208,6 +208,7 @@ class RandomGammaOp : public OpKernel {
|
|||||||
alpha_flat](int start_output, int limit_output) {
|
alpha_flat](int start_output, int limit_output) {
|
||||||
using Eigen::numext::exp;
|
using Eigen::numext::exp;
|
||||||
using Eigen::numext::log;
|
using Eigen::numext::log;
|
||||||
|
using Eigen::numext::log1p;
|
||||||
using Eigen::numext::pow;
|
using Eigen::numext::pow;
|
||||||
|
|
||||||
// Capturing "rng" by-value would only make a copy for the _shared_
|
// Capturing "rng" by-value would only make a copy for the _shared_
|
||||||
@ -241,7 +242,7 @@ class RandomGammaOp : public OpKernel {
|
|||||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||||
int16 uniform_remaining = 0;
|
int16 uniform_remaining = 0;
|
||||||
UNIFORM(u);
|
UNIFORM(u);
|
||||||
const double res = -log(1.0 - u);
|
const double res = -log1p(-u);
|
||||||
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
||||||
} // for (sample_idx)
|
} // for (sample_idx)
|
||||||
} else { // if alpha != 1.0
|
} else { // if alpha != 1.0
|
||||||
|
@ -259,6 +259,7 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase {
|
|||||||
|
|
||||||
using Eigen::numext::exp;
|
using Eigen::numext::exp;
|
||||||
using Eigen::numext::log;
|
using Eigen::numext::log;
|
||||||
|
using Eigen::numext::log1p;
|
||||||
using Eigen::numext::pow;
|
using Eigen::numext::pow;
|
||||||
|
|
||||||
Normal normal;
|
Normal normal;
|
||||||
@ -288,7 +289,7 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase {
|
|||||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||||
int16 uniform_remaining = 0;
|
int16 uniform_remaining = 0;
|
||||||
UNIFORM(u);
|
UNIFORM(u);
|
||||||
const double res = -log(1.0 - u);
|
const double res = -log1p(-u);
|
||||||
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
||||||
} // for (sample_idx)
|
} // for (sample_idx)
|
||||||
} else { // if alpha != 1.0
|
} else { // if alpha != 1.0
|
||||||
|
Loading…
Reference in New Issue
Block a user