Adds Student's t distribution to tf.contrib.distributions

Leaving out the cdf for now, as it requires incomplete beta, not available in eigen at the moment.
Change: 122673237
This commit is contained in:
A. Unique TensorFlower 2016-05-18 14:12:52 -08:00 committed by TensorFlower Gardener
parent 996f797746
commit 0eb9af8148
4 changed files with 618 additions and 0 deletions
tensorflow/contrib/distributions

View File

@ -65,6 +65,17 @@ cuda_py_tests(
],
)
cuda_py_tests(
name = "student_t_test",
size = "small",
srcs = ["python/kernel_tests/student_t_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "uniform_test",
size = "small",

View File

@ -31,6 +31,7 @@ initialized with parameters that define the distributions.
@@Exponential
@@Gamma
@@Gaussian
@@StudentT
@@Uniform
### Multivariate distributions
@ -62,4 +63,5 @@ from tensorflow.contrib.distributions.python.ops.gamma import *
from tensorflow.contrib.distributions.python.ops.gaussian import *
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.mvn import *
from tensorflow.contrib.distributions.python.ops.student_t import *
from tensorflow.contrib.distributions.python.ops.uniform import *

View File

@ -0,0 +1,321 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Student t distribution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
from scipy import stats
import tensorflow as tf
class StudentTTest(tf.test.TestCase):
def testStudentPDFAndLogPDF(self):
with tf.Session():
batch_size = 6
df = tf.constant([3.0] * batch_size)
mu = tf.constant([7.0] * batch_size)
sigma = tf.constant([8.0] * batch_size)
df_v = 3.0
mu_v = 7.0
sigma_v = 8.0
t = np.array([-2.5, 2.5, 8.0, 0.0, -1.0, 2.0], dtype=np.float32)
student = tf.contrib.distributions.StudentT(df, mu=mu, sigma=sigma)
log_pdf = student.log_pdf(t)
self.assertEquals(log_pdf.get_shape(), (6,))
log_pdf_values = log_pdf.eval()
pdf = student.pdf(t)
self.assertEquals(pdf.get_shape(), (6,))
pdf_values = pdf.eval()
expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
self.assertAllClose(expected_log_pdf, log_pdf_values)
self.assertAllClose(np.log(expected_pdf), log_pdf_values)
self.assertAllClose(expected_pdf, pdf_values)
self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testStudentLogPDFMultidimensional(self):
with tf.Session():
batch_size = 6
df = tf.constant([[1.5, 7.2]] * batch_size)
mu = tf.constant([[3.0, -3.0]] * batch_size)
sigma = tf.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
df_v = np.array([1.5, 7.2])
mu_v = np.array([3.0, -3.0])
sigma_v = np.array([np.sqrt(10.0), np.sqrt(15.0)])
t = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
student = tf.contrib.distributions.StudentT(df, mu=mu, sigma=sigma)
log_pdf = student.log_pdf(t)
log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2))
pdf = student.pdf(t)
pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2))
expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
self.assertAllClose(expected_log_pdf, log_pdf_values)
self.assertAllClose(np.log(expected_pdf), log_pdf_values)
self.assertAllClose(expected_pdf, pdf_values)
self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testStudentEntropy(self):
df_v = np.array([[2., 3., 7.]]) # 1x3
mu_v = np.array([[1., -1, 0]]) # 1x3
sigma_v = np.array([[1., 2., 3.]]).T # transposed => 3x1
with tf.Session():
student = tf.contrib.distributions.StudentT(df=df_v,
mu=mu_v,
sigma=sigma_v)
ent = student.entropy()
ent_values = ent.eval()
# Help scipy broadcast to 3x3
ones = np.array([[1, 1, 1]])
sigma_bc = sigma_v * ones
mu_bc = ones.T * mu_v
df_bc = ones.T * df_v
expected_entropy = stats.t.entropy(
np.reshape(df_bc, [-1]),
loc=np.reshape(mu_bc, [-1]),
scale=np.reshape(sigma_bc, [-1]))
expected_entropy = np.reshape(expected_entropy, df_bc.shape)
self.assertAllClose(expected_entropy, ent_values)
def testStudentSample(self):
with tf.Session():
df = tf.constant(4.0)
mu = tf.constant(3.0)
sigma = tf.constant(math.sqrt(10.0))
df_v = 4.0
mu_v = 3.0
sigma_v = np.sqrt(10.0)
n = tf.constant(100000)
student = tf.contrib.distributions.StudentT(df=df, mu=mu, sigma=sigma)
samples = student.sample(n, seed=137)
sample_values = samples.eval()
n = 100000
self.assertEqual(sample_values.shape, (n,))
self.assertAllClose(sample_values.mean(), mu_v, atol=1e-2)
self.assertAllClose(sample_values.var(),
sigma_v**2 * df_v / (df_v - 2),
atol=.25)
self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
def testStudentSampleMultiDimensional(self):
with tf.Session():
batch_size = 7
df = tf.constant([[3.0, 7.0]] * batch_size)
mu = tf.constant([[3.0, -3.0]] * batch_size)
sigma = tf.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
df_v = [3.0, 7.0]
mu_v = [3.0, -3.0]
sigma_v = [np.sqrt(10.0), np.sqrt(15.0)]
n = tf.constant(100000)
student = tf.contrib.distributions.StudentT(df=df, mu=mu, sigma=sigma)
samples = student.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=.15)
self.assertAllClose(sample_values[:, 0, 0].var(),
sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
atol=1)
self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=.01)
self.assertAllClose(sample_values[:, 0, 1].var(),
sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
atol=.25)
self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1])
def _checkKLApprox(self, df, mu, sigma, samples):
n = samples.size
np.random.seed(137)
sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n)
covg = 0.99
r = stats.t.interval(covg, df, loc=mu, scale=sigma)
bins = 100
hist, _ = np.histogram(samples, bins=bins, range=r)
hist_scipy, _ = np.histogram(sample_scipy, bins=bins, range=r)
self.assertGreater(hist.sum(), n * (covg - .01))
self.assertGreater(hist_scipy.sum(), n * (covg - .01))
hist_min1 = hist + 1. # put at least one item in each bucket
hist_norm = hist_min1 / hist_min1.sum()
hist_scipy_min1 = hist_scipy + 1. # put at least one item in each bucket
hist_scipy_norm = hist_scipy_min1 / hist_scipy_min1.sum()
kl_appx = np.sum(np.log(hist_scipy_norm / hist_norm) * hist_scipy_norm)
self.assertLess(kl_appx, 1)
def testBroadcastingParams(self):
def _check(student):
self.assertEqual(student.mean.get_shape(), (3,))
self.assertEqual(student.variance.get_shape(), (3,))
self.assertEqual(student.entropy().get_shape(), (3,))
self.assertEqual(student.log_pdf(2.).get_shape(), (3,))
self.assertEqual(student.pdf(2.).get_shape(), (3,))
self.assertEqual(student.sample(37).get_shape(), (37, 3,))
_check(tf.contrib.distributions.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.))
_check(tf.contrib.distributions.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.))
_check(tf.contrib.distributions.StudentT(df=7., mu=3., sigma=[2., 3., 4.,]))
def testBroadcastingPdfArgs(self):
def _assert_shape(student, arg, shape):
self.assertEqual(student.log_pdf(arg).get_shape(), shape)
self.assertEqual(student.pdf(arg).get_shape(), shape)
def _check(student):
_assert_shape(student, 2., (3,))
xs = np.array([2., 3., 4.], dtype=np.float32)
_assert_shape(student, xs, (3,))
xs = np.array([xs])
_assert_shape(student, xs, (1, 3))
xs = xs.T
_assert_shape(student, xs, (3, 3))
_check(tf.contrib.distributions.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.))
_check(tf.contrib.distributions.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.))
_check(tf.contrib.distributions.StudentT(df=7., mu=3., sigma=[2., 3., 4.,]))
def _check2d(student):
_assert_shape(student, 2., (1, 3))
xs = np.array([2., 3., 4.], dtype=np.float32)
_assert_shape(student, xs, (1, 3))
xs = np.array([xs])
_assert_shape(student, xs, (1, 3))
xs = xs.T
_assert_shape(student, xs, (3, 3))
_check2d(tf.contrib.distributions.StudentT(
df=[[2., 3., 4.,]], mu=2., sigma=1.))
_check2d(tf.contrib.distributions.StudentT(
df=7., mu=[[2., 3., 4.,]], sigma=1.))
_check2d(tf.contrib.distributions.StudentT(
df=7., mu=3., sigma=[[2., 3., 4.,]]))
def _check2d_rows(student):
_assert_shape(student, 2., (3, 1))
xs = np.array([2., 3., 4.], dtype=np.float32) # (3,)
_assert_shape(student, xs, (3, 3))
xs = np.array([xs]) # (1,3)
_assert_shape(student, xs, (3, 3))
xs = xs.T # (3,1)
_assert_shape(student, xs, (3, 1))
_check2d_rows(tf.contrib.distributions.StudentT(
df=[[2.], [3.], [4.]], mu=2., sigma=1.))
_check2d_rows(tf.contrib.distributions.StudentT(
df=7., mu=[[2.], [3.], [4.]], sigma=1.))
_check2d_rows(tf.contrib.distributions.StudentT(
df=7., mu=3., sigma=[[2.], [3.], [4.]]))
def testMeanVar(self):
with tf.Session():
student = tf.contrib.distributions.StudentT(
df=[1., 2., 3., 5., 7.],
mu=np.exp(1, dtype=np.float32),
sigma=[5., 4., 3., 2., 1.])
# Test broadcast of mu across shape of df/sigma
mean = student.mean.eval()
self.assertAllClose([np.exp(1, dtype=np.float32)] * 5, mean)
var = student.variance.eval()
# loc does not effect variance, so we use 0.
self.assertAllClose([stats.t.var(1., loc=0., scale=5.),
stats.t.var(2., loc=0., scale=4.),
stats.t.var(3., loc=0., scale=3.),
stats.t.var(5., loc=0., scale=2.),
stats.t.var(7., loc=0., scale=1.)], var)
def testPdfOfSample(self):
with tf.Session() as sess:
student = tf.contrib.distributions.StudentT(df=3., mu=np.pi, sigma=1.)
num = 20000
samples = student.sample(num, seed=137)
pdfs = student.pdf(samples)
mean = student.mean
mean_pdf = student.pdf(student.mean)
sample_vals, pdf_vals, mean_val, mean_pdf_val = sess.run(
[samples, pdfs, student.mean, mean_pdf])
self.assertEqual(samples.get_shape(), (num,))
self.assertEqual(pdfs.get_shape(), (num,))
self.assertEqual(mean.get_shape(), ())
self.assertNear(np.pi, np.mean(sample_vals), err=0.02)
self.assertNear(np.pi, mean_val, err=1e-6)
self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
# Verify integral over sample*pdf ~= 1.
self._assertIntegral(sample_vals, pdf_vals)
def testPdfOfSampleMultiDims(self):
with tf.Session() as sess:
student = tf.contrib.distributions.StudentT(df=[7., 11.],
mu=[[5.], [6.]],
sigma=3.)
num = 50000
samples = student.sample(num, seed=137)
pdfs = student.pdf(samples)
sample_vals, pdf_vals = sess.run([samples, pdfs])
self.assertEqual(samples.get_shape(), (num, 2, 2))
self.assertEqual(pdfs.get_shape(), (num, 2, 2))
self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03)
self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03)
self.assertNear(stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var
np.var(sample_vals[:, :, 0]),
err=.25)
self.assertNear(stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var
np.var(sample_vals[:, :, 1]),
err=.25)
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
s_p = zip(sample_vals, pdf_vals)
prev = (sample_vals.min() - 1000, 0)
total = 0
for k in sorted(s_p, key=lambda x: x[0]):
pair_pdf = (k[1] + prev[1]) / 2
total += (k[0] - prev[0]) * pair_pdf
prev = k
self.assertNear(1., total, err=err)
def testNegativeDofFails(self):
with tf.Session():
student = tf.contrib.distributions.StudentT(df=[2, -5.],
mu=0.,
sigma=1.,
name='S')
with self.assertRaisesOpError(r'Condition x > 0 did not hold'):
student.mean.eval()
def testNegativeScaleFails(self):
with tf.Session():
student = tf.contrib.distributions.StudentT(df=[5.],
mu=0.,
sigma=[[3.], [-2.]],
name='S')
with self.assertRaisesOpError(r'Condition x > 0 did not hold'):
student.mean.eval()
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,284 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Student's t distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
class StudentT(ContinuousDistribution):
"""Student's t distribution with degree-of-freedom parameter df.
#### Mathematical details
The PDF of this distribution is:
`f(t) = gamma((df+1)/2)/sqrt(df*pi)/gamma(df/2)*(1+t^2/df)^(-(df+1)/2)`
#### Examples
Examples of initialization of one or a batch of distributions.
```python
# Define a single scalar Student t distribution.
single_dist = tf.contrib.distributions.StudentT(df=3)
# Evaluate the pdf at 1, returning a scalar Tensor.
single_dist.pdf(1.)
# Define a batch of two scalar valued Student t's.
# The first has degrees of freedom 2, mean 1, and scale 11.
# The second 3, 2 and 22.
multi_dist = tf.contrib.distributions.StudentT(df=[2, 3],
mu=[1, 2.],
sigma=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
multi_dist.pdf([0, 1.5])
# Get 3 samples, returning a 3 x 2 tensor.
multi_dist.sample(3)
```
Arguments are broadcast when possible.
```python
# Define a batch of two Student's t distributions.
# Both have df 2 and mean 1, but different scales.
dist = tf.contrib.distributions.StudentT(df=2, mu=1, sigma=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
dist.pdf(3.0)
```
"""
def __init__(self, df, mu, sigma, name="StudentT"):
"""Construct Student's t distributions.
The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`.
The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports
broadcasting (e.g. `df + mu + sigma` is a valid operation).
Args:
df: `float` or `double` tensor, the degrees of freedom of the
distribution(s). `df` must contain only positive values.
mu: `float` or `double` tensor, the means of the distribution(s).
sigma: `float` or `double` tensor, the scaling factor for the
distribution(s). `sigma` must contain only positive values.
Note that `sigma` is not the standard deviation of this distribution.
name: The name to give Ops created by the initializer.
Raises:
TypeError: if mu and sigma are different dtypes.
"""
super(StudentT, self).__init__()
with ops.op_scope([df, mu, sigma], name) as scope:
with ops.control_dependencies([check_ops.assert_positive(df),
check_ops.assert_positive(sigma)]):
self._df = ops.convert_to_tensor(df, name="df")
self._mu = ops.convert_to_tensor(mu, name="mu")
self._sigma = ops.convert_to_tensor(sigma, name="sigma")
contrib_tensor_util.assert_same_float_dtype(
(self._df, self._mu, self._sigma))
self._name = scope
self._batch_shape = self._ones().get_shape()
self._event_shape = tensor_shape.TensorShape([])
@property
def name(self):
return self._name
@property
def dtype(self):
return self._df.dtype
@property
def df(self):
"""Degrees of freedom in these Student's t distribution(s)."""
return self._df
@property
def mu(self):
"""Locations of these Student's t distribution(s)."""
return self._mu
@property
def sigma(self):
"""Scaling factors of these Student's t distribution(s)."""
return self._sigma
@property
def mean(self, name="mean"):
with ops.name_scope(self.name):
return math_ops.mul(self._mu, self._ones(), name=name)
@property
def variance(self, name="var"):
with ops.name_scope(self.name):
return math_ops.select(
(self._zeros() + self._df > 2),
self._zeros() + math_ops.square(self._sigma) * self._df /
(self._df - 2),
self._zeros() + np.inf,
name=name)
def batch_shape(self, name="batch_shape"):
with ops.name_scope(self.name):
return array_ops.shape(self._ones(), name=name)
def get_batch_shape(self):
return self._batch_shape
def event_shape(self, name="event_shape"):
with ops.name_scope(self.name):
return constant_op.constant(1, name=name)
def get_event_shape(self):
return self._event_shape
def log_pdf(self, x, name="log_pdf"):
"""Log pdf of observations in `x` under these Student's t-distribution(s).
Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `df`.
name: The name to give this op.
Returns:
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
"""
with ops.op_scope([self._df, self._mu, self._sigma, x], self.name):
with ops.name_scope(name):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
(x.dtype, self.dtype))
df_2 = self._df / 2
log_beta = (math_ops.lgamma(0.5) + math_ops.lgamma(df_2) -
math_ops.lgamma(0.5 + df_2))
return (-math_ops.log(self._df) / 2 - log_beta - (self._df + 1) / 2 *
math_ops.log(1 + math_ops.square((x - self._mu) / self._sigma) /
self._df) - math_ops.log(self._sigma))
def pdf(self, x, name="pdf"):
"""The PDF of observations in `x` under these Student's t distribution(s).
Args:
x: tensor of dtype `dtype`, must be broadcastable with `df`, `mu`, and
`sigma`.
name: The name to give this op.
Returns:
pdf: tensor of dtype `dtype`, the pdf values of `x`.
"""
with ops.op_scope([self._df, self._mu, self._sigma, x], self.name):
with ops.name_scope(name):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
(x.dtype, self.dtype))
reloc_scaled = (x - self._mu) / self._sigma
return (math_ops.exp(math_ops.lgamma((self._df + 1) / 2) -
math_ops.lgamma(self._df / 2)) /
math_ops.sqrt(self._df) / math.sqrt(np.pi) *
math_ops.pow(1 + math_ops.square(reloc_scaled) / self._df,
-(self._df + 1) / 2) / self.sigma)
def entropy(self, name="entropy"):
"""The entropy of Student t distribution(s).
Args:
name: The name to give this op.
Returns:
entropy: tensor of dtype `dtype`, the entropy.
"""
with ops.op_scope([self._df, self._sigma], self.name):
with ops.name_scope(name):
u = array_ops.expand_dims(self._df + self._zeros(), -1)
v = array_ops.expand_dims(self._ones(), -1)
beta_arg = array_ops.concat(len(u.get_shape()) - 1, [u, v]) / 2
return ((self._df + 1) / 2 * (math_ops.digamma((self._df + 1) / 2) -
math_ops.digamma(self._df / 2)) +
math_ops.log(self._df) / 2 +
special_math_ops.lbeta(beta_arg) +
math_ops.log(self._sigma))
def sample(self, n, seed=None, name="sample"):
"""Sample `n` observations from the Student t Distributions.
Args:
n: `Scalar`, type int32, the number of observations to sample.
seed: Python integer, the random seed.
name: The name to give this op.
Returns:
samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
with values of type `self.dtype`.
"""
with ops.op_scope([self._df, self._mu, self._sigma, n], self.name):
with ops.name_scope(name):
n = ops.convert_to_tensor(n, name="n")
n_val = tensor_util.constant_value(n)
# We use 2 uniform random floats to generate polar random variates.
# http://dl.acm.org/citation.cfm?id=179631
# Theorem 2. Let G, H be iid variates, uniformly distributed on [0,1].
# Let theta = 2*pi*H, let R = sqrt(df*(G^(-2/df) - 1)) for df > 0.
# Let X = R*cos(theta), and let Y = R*sin(theta).
# Then X ~ t_df and Y ~ t_df.
# The variates X and Y are not independent.
shape = array_ops.concat(0, [array_ops.pack([2, n]),
self.batch_shape()])
uniform = random_ops.random_uniform(shape=shape,
dtype=self.dtype,
seed=seed)
samples_g, samples_h = array_ops.unpack(uniform, num=2)
theta = (2 * np.pi) * samples_h
r = math_ops.sqrt(self._df *
(math_ops.pow(samples_g, -2 / self._df) - 1))
samples = r * math_ops.cos(theta)
# Provide some hints to shape inference
inferred_shape = tensor_shape.vector(n_val).concatenate(
self.get_batch_shape())
samples.set_shape(inferred_shape)
return samples * self._sigma + self._mu
def _ones(self):
return array_ops.ones_like(self._df + self._mu + self._sigma)
def _zeros(self):
return array_ops.zeros_like(self._df + self._mu + self._sigma)