Added additional test for regime of parameterized_truncated_normal sampler.
In the case where the stddev is small relative to the bounds, the sampler is currently generating some outliers. Added warning logging if the sampler terminates before sampling successfully. Increased kMaxIterations from 100 to 1000 (this may not be a good long-term solution but, in general, users probably prefer additional computation to spurious outliers). PiperOrigin-RevId: 209583394
This commit is contained in:
parent
fd833e87f1
commit
eac93feb41
@ -47,7 +47,7 @@ using random::PhiloxRandom;
|
||||
|
||||
template <typename T>
|
||||
struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
static const int kMaxIterations = 100;
|
||||
static const int kMaxIterations = 1000;
|
||||
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
|
||||
int64 samples_per_batch, int64 num_elements,
|
||||
@ -124,6 +124,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
(normMin * (normMin - sqrtFactor)) / T(4)) /
|
||||
(normMin + sqrtFactor);
|
||||
const T diff = normMax - normMin;
|
||||
|
||||
if (diff < cutoff) {
|
||||
// Sample from a uniform distribution on [normMin, normMax].
|
||||
|
||||
@ -143,15 +144,20 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
|
||||
const auto u = dist(&gen_copy);
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (u[i] <= Eigen::numext::exp(g[i]) ||
|
||||
numIterations + 1 >= kMaxIterations) {
|
||||
auto accept = u[i] <= Eigen::numext::exp(g[i]);
|
||||
if (accept || numIterations + 1 >= kMaxIterations) {
|
||||
// Accept the sample z.
|
||||
// If we run out of iterations, just use the current uniform
|
||||
// sample. Emperically, the probability of accepting each sample
|
||||
// is at least 50% for typical inputs, so we will always accept
|
||||
// by 100 iterations.
|
||||
// This introduces a slight inaccuracy when at least one bound
|
||||
// is large, minval is negative and maxval is positive.
|
||||
// sample, but emit a warning.
|
||||
// TODO(jjhunt) For small entropies (relative to the bounds),
|
||||
// this sampler is poor and may take many iterations since
|
||||
// the proposal distribution is the uniform distribution
|
||||
// U(lower_bound, upper_bound).
|
||||
if (!accept) {
|
||||
LOG(WARNING) << "TruncatedNormal uniform rejection sampler "
|
||||
<< "exceeded max iterations. Sample may contain "
|
||||
<< "outliers.";
|
||||
}
|
||||
output(sample) = z[i] * stddev + mean;
|
||||
sample++;
|
||||
if (sample >= limit_sample) {
|
||||
@ -181,13 +187,15 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
const T g = Eigen::numext::exp(-x * x / T(2.0));
|
||||
const T u = rand[i];
|
||||
i++;
|
||||
if ((u <= g && z < normMax) ||
|
||||
numIterations + 1 >= kMaxIterations) {
|
||||
auto accept = (u <= g && z < normMax);
|
||||
if (accept || numIterations + 1 >= kMaxIterations) {
|
||||
if (!accept) {
|
||||
LOG(WARNING) << "TruncatedNormal exponential distribution "
|
||||
<< "rejection sampler exceeds max iterations. "
|
||||
<< "Sample may contain outliers.";
|
||||
}
|
||||
output(sample) = z * stddev + mean;
|
||||
sample++;
|
||||
if (sample >= limit_sample) {
|
||||
break;
|
||||
}
|
||||
numIterations = 0;
|
||||
} else {
|
||||
numIterations++;
|
||||
|
@ -190,7 +190,7 @@ __global__ void __launch_bounds__(1024)
|
||||
// Partial specialization for GPU
|
||||
template <typename T>
|
||||
struct TruncatedNormalFunctor<GPUDevice, T> {
|
||||
static const int kMaxIterations = 100;
|
||||
static const int kMaxIterations = 1000;
|
||||
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
|
||||
int64 samples_per_batch, int64 num_elements,
|
||||
|
@ -664,7 +664,7 @@ cuda_py_test(
|
||||
|
||||
cuda_py_test(
|
||||
name = "parameterized_truncated_normal_op_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["parameterized_truncated_normal_op_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
|
@ -182,6 +182,19 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
|
||||
def testSmallStddev(self):
|
||||
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
|
||||
|
||||
def testSamplingWithSmallStdDevFarFromBound(self):
|
||||
sample_op = random_ops.parameterized_truncated_normal(
|
||||
shape=(int(1e5),), means=0.8, stddevs=0.05, minvals=-1., maxvals=1.)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
samples = sess.run(sample_op)
|
||||
# 0. is more than 16 standard deviations from the mean, and
|
||||
# should have a likelihood < 1e-57.
|
||||
# TODO(jjhunt) Sampler is still numerically unstable in this case,
|
||||
# numbers less than 0 should never observed.
|
||||
no_neg_samples = np.sum(samples < 0.)
|
||||
self.assertLess(no_neg_samples, 2.)
|
||||
|
||||
|
||||
# Benchmarking code
|
||||
def parameterized_vs_naive(shape, num_iters, use_gpu=False):
|
||||
|
Loading…
Reference in New Issue
Block a user