Added additional test for regime of parameterized_truncated_normal sampler.
In the case where the stddev is small relative to the bounds, the sampler is currently generating some outliers. Added warning logging if the sampler terminates before sampling successfully. Increased kMaxIterations from 100 to 1000 (this may not be a good long-term solution but, in general, users probably prefer additional computation to spurious outliers). PiperOrigin-RevId: 209583394
This commit is contained in:
parent
fd833e87f1
commit
eac93feb41
@ -47,7 +47,7 @@ using random::PhiloxRandom;
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct TruncatedNormalFunctor<CPUDevice, T> {
|
struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||||
static const int kMaxIterations = 100;
|
static const int kMaxIterations = 1000;
|
||||||
|
|
||||||
void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
|
void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
|
||||||
int64 samples_per_batch, int64 num_elements,
|
int64 samples_per_batch, int64 num_elements,
|
||||||
@ -124,6 +124,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
(normMin * (normMin - sqrtFactor)) / T(4)) /
|
(normMin * (normMin - sqrtFactor)) / T(4)) /
|
||||||
(normMin + sqrtFactor);
|
(normMin + sqrtFactor);
|
||||||
const T diff = normMax - normMin;
|
const T diff = normMax - normMin;
|
||||||
|
|
||||||
if (diff < cutoff) {
|
if (diff < cutoff) {
|
||||||
// Sample from a uniform distribution on [normMin, normMax].
|
// Sample from a uniform distribution on [normMin, normMax].
|
||||||
|
|
||||||
@ -143,15 +144,20 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
|
|
||||||
const auto u = dist(&gen_copy);
|
const auto u = dist(&gen_copy);
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
if (u[i] <= Eigen::numext::exp(g[i]) ||
|
auto accept = u[i] <= Eigen::numext::exp(g[i]);
|
||||||
numIterations + 1 >= kMaxIterations) {
|
if (accept || numIterations + 1 >= kMaxIterations) {
|
||||||
// Accept the sample z.
|
// Accept the sample z.
|
||||||
// If we run out of iterations, just use the current uniform
|
// If we run out of iterations, just use the current uniform
|
||||||
// sample. Emperically, the probability of accepting each sample
|
// sample, but emit a warning.
|
||||||
// is at least 50% for typical inputs, so we will always accept
|
// TODO(jjhunt) For small entropies (relative to the bounds),
|
||||||
// by 100 iterations.
|
// this sampler is poor and may take many iterations since
|
||||||
// This introduces a slight inaccuracy when at least one bound
|
// the proposal distribution is the uniform distribution
|
||||||
// is large, minval is negative and maxval is positive.
|
// U(lower_bound, upper_bound).
|
||||||
|
if (!accept) {
|
||||||
|
LOG(WARNING) << "TruncatedNormal uniform rejection sampler "
|
||||||
|
<< "exceeded max iterations. Sample may contain "
|
||||||
|
<< "outliers.";
|
||||||
|
}
|
||||||
output(sample) = z[i] * stddev + mean;
|
output(sample) = z[i] * stddev + mean;
|
||||||
sample++;
|
sample++;
|
||||||
if (sample >= limit_sample) {
|
if (sample >= limit_sample) {
|
||||||
@ -181,13 +187,15 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
const T g = Eigen::numext::exp(-x * x / T(2.0));
|
const T g = Eigen::numext::exp(-x * x / T(2.0));
|
||||||
const T u = rand[i];
|
const T u = rand[i];
|
||||||
i++;
|
i++;
|
||||||
if ((u <= g && z < normMax) ||
|
auto accept = (u <= g && z < normMax);
|
||||||
numIterations + 1 >= kMaxIterations) {
|
if (accept || numIterations + 1 >= kMaxIterations) {
|
||||||
|
if (!accept) {
|
||||||
|
LOG(WARNING) << "TruncatedNormal exponential distribution "
|
||||||
|
<< "rejection sampler exceeds max iterations. "
|
||||||
|
<< "Sample may contain outliers.";
|
||||||
|
}
|
||||||
output(sample) = z * stddev + mean;
|
output(sample) = z * stddev + mean;
|
||||||
sample++;
|
sample++;
|
||||||
if (sample >= limit_sample) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
numIterations = 0;
|
numIterations = 0;
|
||||||
} else {
|
} else {
|
||||||
numIterations++;
|
numIterations++;
|
||||||
|
@ -190,7 +190,7 @@ __global__ void __launch_bounds__(1024)
|
|||||||
// Partial specialization for GPU
|
// Partial specialization for GPU
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct TruncatedNormalFunctor<GPUDevice, T> {
|
struct TruncatedNormalFunctor<GPUDevice, T> {
|
||||||
static const int kMaxIterations = 100;
|
static const int kMaxIterations = 1000;
|
||||||
|
|
||||||
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
|
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
|
||||||
int64 samples_per_batch, int64 num_elements,
|
int64 samples_per_batch, int64 num_elements,
|
||||||
|
@ -664,7 +664,7 @@ cuda_py_test(
|
|||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "parameterized_truncated_normal_op_test",
|
name = "parameterized_truncated_normal_op_test",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["parameterized_truncated_normal_op_test.py"],
|
srcs = ["parameterized_truncated_normal_op_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
@ -182,6 +182,19 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
|
|||||||
def testSmallStddev(self):
|
def testSmallStddev(self):
|
||||||
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
|
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
|
||||||
|
|
||||||
|
def testSamplingWithSmallStdDevFarFromBound(self):
|
||||||
|
sample_op = random_ops.parameterized_truncated_normal(
|
||||||
|
shape=(int(1e5),), means=0.8, stddevs=0.05, minvals=-1., maxvals=1.)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
samples = sess.run(sample_op)
|
||||||
|
# 0. is more than 16 standard deviations from the mean, and
|
||||||
|
# should have a likelihood < 1e-57.
|
||||||
|
# TODO(jjhunt) Sampler is still numerically unstable in this case,
|
||||||
|
# numbers less than 0 should never observed.
|
||||||
|
no_neg_samples = np.sum(samples < 0.)
|
||||||
|
self.assertLess(no_neg_samples, 2.)
|
||||||
|
|
||||||
|
|
||||||
# Benchmarking code
|
# Benchmarking code
|
||||||
def parameterized_vs_naive(shape, num_iters, use_gpu=False):
|
def parameterized_vs_naive(shape, num_iters, use_gpu=False):
|
||||||
|
Loading…
Reference in New Issue
Block a user