Adds Student's t distribution to tf.contrib.distributions
Leaving out the cdf for now, as it requires incomplete beta, not available in eigen at the moment. Change: 122673237
This commit is contained in:
parent
996f797746
commit
0eb9af8148
tensorflow/contrib/distributions
@ -65,6 +65,17 @@ cuda_py_tests(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "student_t_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/student_t_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "uniform_test",
|
||||
size = "small",
|
||||
|
@ -31,6 +31,7 @@ initialized with parameters that define the distributions.
|
||||
@@Exponential
|
||||
@@Gamma
|
||||
@@Gaussian
|
||||
@@StudentT
|
||||
@@Uniform
|
||||
|
||||
### Multivariate distributions
|
||||
@ -62,4 +63,5 @@ from tensorflow.contrib.distributions.python.ops.gamma import *
|
||||
from tensorflow.contrib.distributions.python.ops.gaussian import *
|
||||
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
|
||||
from tensorflow.contrib.distributions.python.ops.mvn import *
|
||||
from tensorflow.contrib.distributions.python.ops.student_t import *
|
||||
from tensorflow.contrib.distributions.python.ops.uniform import *
|
||||
|
@ -0,0 +1,321 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Student t distribution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class StudentTTest(tf.test.TestCase):
|
||||
|
||||
def testStudentPDFAndLogPDF(self):
|
||||
with tf.Session():
|
||||
batch_size = 6
|
||||
df = tf.constant([3.0] * batch_size)
|
||||
mu = tf.constant([7.0] * batch_size)
|
||||
sigma = tf.constant([8.0] * batch_size)
|
||||
df_v = 3.0
|
||||
mu_v = 7.0
|
||||
sigma_v = 8.0
|
||||
t = np.array([-2.5, 2.5, 8.0, 0.0, -1.0, 2.0], dtype=np.float32)
|
||||
student = tf.contrib.distributions.StudentT(df, mu=mu, sigma=sigma)
|
||||
|
||||
log_pdf = student.log_pdf(t)
|
||||
self.assertEquals(log_pdf.get_shape(), (6,))
|
||||
log_pdf_values = log_pdf.eval()
|
||||
pdf = student.pdf(t)
|
||||
self.assertEquals(pdf.get_shape(), (6,))
|
||||
pdf_values = pdf.eval()
|
||||
|
||||
expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
|
||||
expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
|
||||
self.assertAllClose(expected_log_pdf, log_pdf_values)
|
||||
self.assertAllClose(np.log(expected_pdf), log_pdf_values)
|
||||
self.assertAllClose(expected_pdf, pdf_values)
|
||||
self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
|
||||
|
||||
def testStudentLogPDFMultidimensional(self):
|
||||
with tf.Session():
|
||||
batch_size = 6
|
||||
df = tf.constant([[1.5, 7.2]] * batch_size)
|
||||
mu = tf.constant([[3.0, -3.0]] * batch_size)
|
||||
sigma = tf.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
|
||||
df_v = np.array([1.5, 7.2])
|
||||
mu_v = np.array([3.0, -3.0])
|
||||
sigma_v = np.array([np.sqrt(10.0), np.sqrt(15.0)])
|
||||
t = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
|
||||
student = tf.contrib.distributions.StudentT(df, mu=mu, sigma=sigma)
|
||||
log_pdf = student.log_pdf(t)
|
||||
log_pdf_values = log_pdf.eval()
|
||||
self.assertEqual(log_pdf.get_shape(), (6, 2))
|
||||
pdf = student.pdf(t)
|
||||
pdf_values = pdf.eval()
|
||||
self.assertEqual(pdf.get_shape(), (6, 2))
|
||||
expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
|
||||
expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
|
||||
self.assertAllClose(expected_log_pdf, log_pdf_values)
|
||||
self.assertAllClose(np.log(expected_pdf), log_pdf_values)
|
||||
self.assertAllClose(expected_pdf, pdf_values)
|
||||
self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
|
||||
|
||||
def testStudentEntropy(self):
|
||||
df_v = np.array([[2., 3., 7.]]) # 1x3
|
||||
mu_v = np.array([[1., -1, 0]]) # 1x3
|
||||
sigma_v = np.array([[1., 2., 3.]]).T # transposed => 3x1
|
||||
with tf.Session():
|
||||
student = tf.contrib.distributions.StudentT(df=df_v,
|
||||
mu=mu_v,
|
||||
sigma=sigma_v)
|
||||
ent = student.entropy()
|
||||
ent_values = ent.eval()
|
||||
|
||||
# Help scipy broadcast to 3x3
|
||||
ones = np.array([[1, 1, 1]])
|
||||
sigma_bc = sigma_v * ones
|
||||
mu_bc = ones.T * mu_v
|
||||
df_bc = ones.T * df_v
|
||||
expected_entropy = stats.t.entropy(
|
||||
np.reshape(df_bc, [-1]),
|
||||
loc=np.reshape(mu_bc, [-1]),
|
||||
scale=np.reshape(sigma_bc, [-1]))
|
||||
expected_entropy = np.reshape(expected_entropy, df_bc.shape)
|
||||
self.assertAllClose(expected_entropy, ent_values)
|
||||
|
||||
def testStudentSample(self):
|
||||
with tf.Session():
|
||||
df = tf.constant(4.0)
|
||||
mu = tf.constant(3.0)
|
||||
sigma = tf.constant(math.sqrt(10.0))
|
||||
df_v = 4.0
|
||||
mu_v = 3.0
|
||||
sigma_v = np.sqrt(10.0)
|
||||
n = tf.constant(100000)
|
||||
student = tf.contrib.distributions.StudentT(df=df, mu=mu, sigma=sigma)
|
||||
samples = student.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
n = 100000
|
||||
self.assertEqual(sample_values.shape, (n,))
|
||||
self.assertAllClose(sample_values.mean(), mu_v, atol=1e-2)
|
||||
self.assertAllClose(sample_values.var(),
|
||||
sigma_v**2 * df_v / (df_v - 2),
|
||||
atol=.25)
|
||||
self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
|
||||
|
||||
def testStudentSampleMultiDimensional(self):
|
||||
with tf.Session():
|
||||
batch_size = 7
|
||||
df = tf.constant([[3.0, 7.0]] * batch_size)
|
||||
mu = tf.constant([[3.0, -3.0]] * batch_size)
|
||||
sigma = tf.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
|
||||
df_v = [3.0, 7.0]
|
||||
mu_v = [3.0, -3.0]
|
||||
sigma_v = [np.sqrt(10.0), np.sqrt(15.0)]
|
||||
n = tf.constant(100000)
|
||||
student = tf.contrib.distributions.StudentT(df=df, mu=mu, sigma=sigma)
|
||||
samples = student.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
|
||||
self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=.15)
|
||||
self.assertAllClose(sample_values[:, 0, 0].var(),
|
||||
sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
|
||||
atol=1)
|
||||
self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
|
||||
self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=.01)
|
||||
self.assertAllClose(sample_values[:, 0, 1].var(),
|
||||
sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
|
||||
atol=.25)
|
||||
self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1])
|
||||
|
||||
def _checkKLApprox(self, df, mu, sigma, samples):
|
||||
n = samples.size
|
||||
np.random.seed(137)
|
||||
sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n)
|
||||
covg = 0.99
|
||||
r = stats.t.interval(covg, df, loc=mu, scale=sigma)
|
||||
bins = 100
|
||||
hist, _ = np.histogram(samples, bins=bins, range=r)
|
||||
hist_scipy, _ = np.histogram(sample_scipy, bins=bins, range=r)
|
||||
self.assertGreater(hist.sum(), n * (covg - .01))
|
||||
self.assertGreater(hist_scipy.sum(), n * (covg - .01))
|
||||
hist_min1 = hist + 1. # put at least one item in each bucket
|
||||
hist_norm = hist_min1 / hist_min1.sum()
|
||||
hist_scipy_min1 = hist_scipy + 1. # put at least one item in each bucket
|
||||
hist_scipy_norm = hist_scipy_min1 / hist_scipy_min1.sum()
|
||||
kl_appx = np.sum(np.log(hist_scipy_norm / hist_norm) * hist_scipy_norm)
|
||||
self.assertLess(kl_appx, 1)
|
||||
|
||||
def testBroadcastingParams(self):
|
||||
|
||||
def _check(student):
|
||||
self.assertEqual(student.mean.get_shape(), (3,))
|
||||
self.assertEqual(student.variance.get_shape(), (3,))
|
||||
self.assertEqual(student.entropy().get_shape(), (3,))
|
||||
self.assertEqual(student.log_pdf(2.).get_shape(), (3,))
|
||||
self.assertEqual(student.pdf(2.).get_shape(), (3,))
|
||||
self.assertEqual(student.sample(37).get_shape(), (37, 3,))
|
||||
|
||||
_check(tf.contrib.distributions.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.))
|
||||
_check(tf.contrib.distributions.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.))
|
||||
_check(tf.contrib.distributions.StudentT(df=7., mu=3., sigma=[2., 3., 4.,]))
|
||||
|
||||
def testBroadcastingPdfArgs(self):
|
||||
|
||||
def _assert_shape(student, arg, shape):
|
||||
self.assertEqual(student.log_pdf(arg).get_shape(), shape)
|
||||
self.assertEqual(student.pdf(arg).get_shape(), shape)
|
||||
|
||||
def _check(student):
|
||||
_assert_shape(student, 2., (3,))
|
||||
xs = np.array([2., 3., 4.], dtype=np.float32)
|
||||
_assert_shape(student, xs, (3,))
|
||||
xs = np.array([xs])
|
||||
_assert_shape(student, xs, (1, 3))
|
||||
xs = xs.T
|
||||
_assert_shape(student, xs, (3, 3))
|
||||
|
||||
_check(tf.contrib.distributions.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.))
|
||||
_check(tf.contrib.distributions.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.))
|
||||
_check(tf.contrib.distributions.StudentT(df=7., mu=3., sigma=[2., 3., 4.,]))
|
||||
|
||||
def _check2d(student):
|
||||
_assert_shape(student, 2., (1, 3))
|
||||
xs = np.array([2., 3., 4.], dtype=np.float32)
|
||||
_assert_shape(student, xs, (1, 3))
|
||||
xs = np.array([xs])
|
||||
_assert_shape(student, xs, (1, 3))
|
||||
xs = xs.T
|
||||
_assert_shape(student, xs, (3, 3))
|
||||
|
||||
_check2d(tf.contrib.distributions.StudentT(
|
||||
df=[[2., 3., 4.,]], mu=2., sigma=1.))
|
||||
_check2d(tf.contrib.distributions.StudentT(
|
||||
df=7., mu=[[2., 3., 4.,]], sigma=1.))
|
||||
_check2d(tf.contrib.distributions.StudentT(
|
||||
df=7., mu=3., sigma=[[2., 3., 4.,]]))
|
||||
|
||||
def _check2d_rows(student):
|
||||
_assert_shape(student, 2., (3, 1))
|
||||
xs = np.array([2., 3., 4.], dtype=np.float32) # (3,)
|
||||
_assert_shape(student, xs, (3, 3))
|
||||
xs = np.array([xs]) # (1,3)
|
||||
_assert_shape(student, xs, (3, 3))
|
||||
xs = xs.T # (3,1)
|
||||
_assert_shape(student, xs, (3, 1))
|
||||
|
||||
_check2d_rows(tf.contrib.distributions.StudentT(
|
||||
df=[[2.], [3.], [4.]], mu=2., sigma=1.))
|
||||
_check2d_rows(tf.contrib.distributions.StudentT(
|
||||
df=7., mu=[[2.], [3.], [4.]], sigma=1.))
|
||||
_check2d_rows(tf.contrib.distributions.StudentT(
|
||||
df=7., mu=3., sigma=[[2.], [3.], [4.]]))
|
||||
|
||||
def testMeanVar(self):
|
||||
with tf.Session():
|
||||
student = tf.contrib.distributions.StudentT(
|
||||
df=[1., 2., 3., 5., 7.],
|
||||
mu=np.exp(1, dtype=np.float32),
|
||||
sigma=[5., 4., 3., 2., 1.])
|
||||
# Test broadcast of mu across shape of df/sigma
|
||||
mean = student.mean.eval()
|
||||
self.assertAllClose([np.exp(1, dtype=np.float32)] * 5, mean)
|
||||
var = student.variance.eval()
|
||||
# loc does not effect variance, so we use 0.
|
||||
self.assertAllClose([stats.t.var(1., loc=0., scale=5.),
|
||||
stats.t.var(2., loc=0., scale=4.),
|
||||
stats.t.var(3., loc=0., scale=3.),
|
||||
stats.t.var(5., loc=0., scale=2.),
|
||||
stats.t.var(7., loc=0., scale=1.)], var)
|
||||
|
||||
def testPdfOfSample(self):
|
||||
with tf.Session() as sess:
|
||||
student = tf.contrib.distributions.StudentT(df=3., mu=np.pi, sigma=1.)
|
||||
num = 20000
|
||||
samples = student.sample(num, seed=137)
|
||||
pdfs = student.pdf(samples)
|
||||
mean = student.mean
|
||||
mean_pdf = student.pdf(student.mean)
|
||||
sample_vals, pdf_vals, mean_val, mean_pdf_val = sess.run(
|
||||
[samples, pdfs, student.mean, mean_pdf])
|
||||
self.assertEqual(samples.get_shape(), (num,))
|
||||
self.assertEqual(pdfs.get_shape(), (num,))
|
||||
self.assertEqual(mean.get_shape(), ())
|
||||
self.assertNear(np.pi, np.mean(sample_vals), err=0.02)
|
||||
self.assertNear(np.pi, mean_val, err=1e-6)
|
||||
self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
|
||||
# Verify integral over sample*pdf ~= 1.
|
||||
self._assertIntegral(sample_vals, pdf_vals)
|
||||
|
||||
def testPdfOfSampleMultiDims(self):
|
||||
with tf.Session() as sess:
|
||||
student = tf.contrib.distributions.StudentT(df=[7., 11.],
|
||||
mu=[[5.], [6.]],
|
||||
sigma=3.)
|
||||
num = 50000
|
||||
samples = student.sample(num, seed=137)
|
||||
pdfs = student.pdf(samples)
|
||||
sample_vals, pdf_vals = sess.run([samples, pdfs])
|
||||
self.assertEqual(samples.get_shape(), (num, 2, 2))
|
||||
self.assertEqual(pdfs.get_shape(), (num, 2, 2))
|
||||
self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03)
|
||||
self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03)
|
||||
self.assertNear(stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var
|
||||
np.var(sample_vals[:, :, 0]),
|
||||
err=.25)
|
||||
self.assertNear(stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var
|
||||
np.var(sample_vals[:, :, 1]),
|
||||
err=.25)
|
||||
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
|
||||
|
||||
def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
|
||||
s_p = zip(sample_vals, pdf_vals)
|
||||
prev = (sample_vals.min() - 1000, 0)
|
||||
total = 0
|
||||
for k in sorted(s_p, key=lambda x: x[0]):
|
||||
pair_pdf = (k[1] + prev[1]) / 2
|
||||
total += (k[0] - prev[0]) * pair_pdf
|
||||
prev = k
|
||||
self.assertNear(1., total, err=err)
|
||||
|
||||
def testNegativeDofFails(self):
|
||||
with tf.Session():
|
||||
student = tf.contrib.distributions.StudentT(df=[2, -5.],
|
||||
mu=0.,
|
||||
sigma=1.,
|
||||
name='S')
|
||||
with self.assertRaisesOpError(r'Condition x > 0 did not hold'):
|
||||
student.mean.eval()
|
||||
|
||||
def testNegativeScaleFails(self):
|
||||
with tf.Session():
|
||||
student = tf.contrib.distributions.StudentT(df=[5.],
|
||||
mu=0.,
|
||||
sigma=[[3.], [-2.]],
|
||||
name='S')
|
||||
with self.assertRaisesOpError(r'Condition x > 0 did not hold'):
|
||||
student.mean.eval()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
284
tensorflow/contrib/distributions/python/ops/student_t.py
Normal file
284
tensorflow/contrib/distributions/python/ops/student_t.py
Normal file
@ -0,0 +1,284 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Student's t distribution class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import constant_op
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import special_math_ops
|
||||
|
||||
|
||||
class StudentT(ContinuousDistribution):
|
||||
"""Student's t distribution with degree-of-freedom parameter df.
|
||||
|
||||
#### Mathematical details
|
||||
|
||||
The PDF of this distribution is:
|
||||
|
||||
`f(t) = gamma((df+1)/2)/sqrt(df*pi)/gamma(df/2)*(1+t^2/df)^(-(df+1)/2)`
|
||||
|
||||
#### Examples
|
||||
|
||||
Examples of initialization of one or a batch of distributions.
|
||||
|
||||
```python
|
||||
# Define a single scalar Student t distribution.
|
||||
single_dist = tf.contrib.distributions.StudentT(df=3)
|
||||
|
||||
# Evaluate the pdf at 1, returning a scalar Tensor.
|
||||
single_dist.pdf(1.)
|
||||
|
||||
# Define a batch of two scalar valued Student t's.
|
||||
# The first has degrees of freedom 2, mean 1, and scale 11.
|
||||
# The second 3, 2 and 22.
|
||||
multi_dist = tf.contrib.distributions.StudentT(df=[2, 3],
|
||||
mu=[1, 2.],
|
||||
sigma=[11, 22.])
|
||||
|
||||
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
|
||||
# returning a length two tensor.
|
||||
multi_dist.pdf([0, 1.5])
|
||||
|
||||
# Get 3 samples, returning a 3 x 2 tensor.
|
||||
multi_dist.sample(3)
|
||||
```
|
||||
|
||||
Arguments are broadcast when possible.
|
||||
|
||||
```python
|
||||
# Define a batch of two Student's t distributions.
|
||||
# Both have df 2 and mean 1, but different scales.
|
||||
dist = tf.contrib.distributions.StudentT(df=2, mu=1, sigma=[11, 22.])
|
||||
|
||||
# Evaluate the pdf of both distributions on the same point, 3.0,
|
||||
# returning a length 2 tensor.
|
||||
dist.pdf(3.0)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, df, mu, sigma, name="StudentT"):
|
||||
"""Construct Student's t distributions.
|
||||
|
||||
The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`.
|
||||
|
||||
The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports
|
||||
broadcasting (e.g. `df + mu + sigma` is a valid operation).
|
||||
|
||||
Args:
|
||||
df: `float` or `double` tensor, the degrees of freedom of the
|
||||
distribution(s). `df` must contain only positive values.
|
||||
mu: `float` or `double` tensor, the means of the distribution(s).
|
||||
sigma: `float` or `double` tensor, the scaling factor for the
|
||||
distribution(s). `sigma` must contain only positive values.
|
||||
Note that `sigma` is not the standard deviation of this distribution.
|
||||
name: The name to give Ops created by the initializer.
|
||||
|
||||
Raises:
|
||||
TypeError: if mu and sigma are different dtypes.
|
||||
"""
|
||||
super(StudentT, self).__init__()
|
||||
with ops.op_scope([df, mu, sigma], name) as scope:
|
||||
with ops.control_dependencies([check_ops.assert_positive(df),
|
||||
check_ops.assert_positive(sigma)]):
|
||||
self._df = ops.convert_to_tensor(df, name="df")
|
||||
self._mu = ops.convert_to_tensor(mu, name="mu")
|
||||
self._sigma = ops.convert_to_tensor(sigma, name="sigma")
|
||||
contrib_tensor_util.assert_same_float_dtype(
|
||||
(self._df, self._mu, self._sigma))
|
||||
self._name = scope
|
||||
self._batch_shape = self._ones().get_shape()
|
||||
self._event_shape = tensor_shape.TensorShape([])
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._df.dtype
|
||||
|
||||
@property
|
||||
def df(self):
|
||||
"""Degrees of freedom in these Student's t distribution(s)."""
|
||||
return self._df
|
||||
|
||||
@property
|
||||
def mu(self):
|
||||
"""Locations of these Student's t distribution(s)."""
|
||||
return self._mu
|
||||
|
||||
@property
|
||||
def sigma(self):
|
||||
"""Scaling factors of these Student's t distribution(s)."""
|
||||
return self._sigma
|
||||
|
||||
@property
|
||||
def mean(self, name="mean"):
|
||||
with ops.name_scope(self.name):
|
||||
return math_ops.mul(self._mu, self._ones(), name=name)
|
||||
|
||||
@property
|
||||
def variance(self, name="var"):
|
||||
with ops.name_scope(self.name):
|
||||
return math_ops.select(
|
||||
(self._zeros() + self._df > 2),
|
||||
self._zeros() + math_ops.square(self._sigma) * self._df /
|
||||
(self._df - 2),
|
||||
self._zeros() + np.inf,
|
||||
name=name)
|
||||
|
||||
def batch_shape(self, name="batch_shape"):
|
||||
with ops.name_scope(self.name):
|
||||
return array_ops.shape(self._ones(), name=name)
|
||||
|
||||
def get_batch_shape(self):
|
||||
return self._batch_shape
|
||||
|
||||
def event_shape(self, name="event_shape"):
|
||||
with ops.name_scope(self.name):
|
||||
return constant_op.constant(1, name=name)
|
||||
|
||||
def get_event_shape(self):
|
||||
return self._event_shape
|
||||
|
||||
def log_pdf(self, x, name="log_pdf"):
|
||||
"""Log pdf of observations in `x` under these Student's t-distribution(s).
|
||||
|
||||
Args:
|
||||
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `df`.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
|
||||
"""
|
||||
with ops.op_scope([self._df, self._mu, self._sigma, x], self.name):
|
||||
with ops.name_scope(name):
|
||||
x = ops.convert_to_tensor(x)
|
||||
if x.dtype != self.dtype:
|
||||
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
|
||||
(x.dtype, self.dtype))
|
||||
df_2 = self._df / 2
|
||||
log_beta = (math_ops.lgamma(0.5) + math_ops.lgamma(df_2) -
|
||||
math_ops.lgamma(0.5 + df_2))
|
||||
return (-math_ops.log(self._df) / 2 - log_beta - (self._df + 1) / 2 *
|
||||
math_ops.log(1 + math_ops.square((x - self._mu) / self._sigma) /
|
||||
self._df) - math_ops.log(self._sigma))
|
||||
|
||||
def pdf(self, x, name="pdf"):
|
||||
"""The PDF of observations in `x` under these Student's t distribution(s).
|
||||
|
||||
Args:
|
||||
x: tensor of dtype `dtype`, must be broadcastable with `df`, `mu`, and
|
||||
`sigma`.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
pdf: tensor of dtype `dtype`, the pdf values of `x`.
|
||||
"""
|
||||
with ops.op_scope([self._df, self._mu, self._sigma, x], self.name):
|
||||
with ops.name_scope(name):
|
||||
x = ops.convert_to_tensor(x)
|
||||
if x.dtype != self.dtype:
|
||||
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
|
||||
(x.dtype, self.dtype))
|
||||
reloc_scaled = (x - self._mu) / self._sigma
|
||||
return (math_ops.exp(math_ops.lgamma((self._df + 1) / 2) -
|
||||
math_ops.lgamma(self._df / 2)) /
|
||||
math_ops.sqrt(self._df) / math.sqrt(np.pi) *
|
||||
math_ops.pow(1 + math_ops.square(reloc_scaled) / self._df,
|
||||
-(self._df + 1) / 2) / self.sigma)
|
||||
|
||||
def entropy(self, name="entropy"):
|
||||
"""The entropy of Student t distribution(s).
|
||||
|
||||
Args:
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
entropy: tensor of dtype `dtype`, the entropy.
|
||||
"""
|
||||
with ops.op_scope([self._df, self._sigma], self.name):
|
||||
with ops.name_scope(name):
|
||||
u = array_ops.expand_dims(self._df + self._zeros(), -1)
|
||||
v = array_ops.expand_dims(self._ones(), -1)
|
||||
beta_arg = array_ops.concat(len(u.get_shape()) - 1, [u, v]) / 2
|
||||
return ((self._df + 1) / 2 * (math_ops.digamma((self._df + 1) / 2) -
|
||||
math_ops.digamma(self._df / 2)) +
|
||||
math_ops.log(self._df) / 2 +
|
||||
special_math_ops.lbeta(beta_arg) +
|
||||
math_ops.log(self._sigma))
|
||||
|
||||
def sample(self, n, seed=None, name="sample"):
|
||||
"""Sample `n` observations from the Student t Distributions.
|
||||
|
||||
Args:
|
||||
n: `Scalar`, type int32, the number of observations to sample.
|
||||
seed: Python integer, the random seed.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
|
||||
with values of type `self.dtype`.
|
||||
"""
|
||||
with ops.op_scope([self._df, self._mu, self._sigma, n], self.name):
|
||||
with ops.name_scope(name):
|
||||
n = ops.convert_to_tensor(n, name="n")
|
||||
n_val = tensor_util.constant_value(n)
|
||||
|
||||
# We use 2 uniform random floats to generate polar random variates.
|
||||
# http://dl.acm.org/citation.cfm?id=179631
|
||||
# Theorem 2. Let G, H be iid variates, uniformly distributed on [0,1].
|
||||
# Let theta = 2*pi*H, let R = sqrt(df*(G^(-2/df) - 1)) for df > 0.
|
||||
# Let X = R*cos(theta), and let Y = R*sin(theta).
|
||||
# Then X ~ t_df and Y ~ t_df.
|
||||
# The variates X and Y are not independent.
|
||||
shape = array_ops.concat(0, [array_ops.pack([2, n]),
|
||||
self.batch_shape()])
|
||||
uniform = random_ops.random_uniform(shape=shape,
|
||||
dtype=self.dtype,
|
||||
seed=seed)
|
||||
samples_g, samples_h = array_ops.unpack(uniform, num=2)
|
||||
theta = (2 * np.pi) * samples_h
|
||||
r = math_ops.sqrt(self._df *
|
||||
(math_ops.pow(samples_g, -2 / self._df) - 1))
|
||||
samples = r * math_ops.cos(theta)
|
||||
|
||||
# Provide some hints to shape inference
|
||||
inferred_shape = tensor_shape.vector(n_val).concatenate(
|
||||
self.get_batch_shape())
|
||||
samples.set_shape(inferred_shape)
|
||||
|
||||
return samples * self._sigma + self._mu
|
||||
|
||||
def _ones(self):
|
||||
return array_ops.ones_like(self._df + self._mu + self._sigma)
|
||||
|
||||
def _zeros(self):
|
||||
return array_ops.zeros_like(self._df + self._mu + self._sigma)
|
Loading…
Reference in New Issue
Block a user