Registering documentation for contrib.distributions.

Change: 121899093
This commit is contained in:
A. Unique TensorFlower 2016-05-09 16:54:05 -08:00 committed by TensorFlower Gardener
parent 6d2623d03c
commit 506671c7b3
7 changed files with 120 additions and 36 deletions

View File

@ -12,17 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for representing statistical distributions.
"""Classes representing statistical distributions. Ops for working with them.
## This package provides classes for statistical distributions.
## Classes for statistical distributions.
Classes that represent batches of statistical distributions. Each class is
initialized with parameters that define the distributions.
### Univariate (scalar) distributions
@@Gaussian
### Multivariate distributions
@@MultivariateNormal
@@DirichletMultinomial
## Posterior inference with conjugate priors.
Functions that transform conjugate prior/likelihood pairs to distributions
representing the posterior or posterior predictive.
### Gaussian likelihood with conjugate prior.
@@gaussian_conjugates_known_sigma_posterior
@@gaussian_congugates_known_sigma_predictive
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.gaussian import *
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.mvn import *

View File

@ -22,7 +22,7 @@ import math
import tensorflow as tf
gaussian_conjugate_posteriors = tf.contrib.distributions.gaussian_conjugate_posteriors # pylint: disable=line-too-long
distributions = tf.contrib.distributions
class GaussianTest(tf.test.TestCase):
@ -35,12 +35,12 @@ class GaussianTest(tf.test.TestCase):
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
s = tf.reduce_sum(x)
n = tf.size(x)
prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
posterior = gaussian_conjugate_posteriors.known_sigma_posterior(
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
posterior = distributions.gaussian_conjugates_known_sigma_posterior(
prior=prior, sigma=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian))
self.assertTrue(isinstance(posterior, distributions.Gaussian))
posterior_log_pdf = posterior.log_pdf(x).eval()
self.assertEqual(posterior_log_pdf.shape, (6,))
@ -54,12 +54,12 @@ class GaussianTest(tf.test.TestCase):
tf.constant([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=tf.float32))
s = tf.reduce_sum(x)
n = tf.size(x)
prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
posterior = gaussian_conjugate_posteriors.known_sigma_posterior(
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
posterior = distributions.gaussian_conjugates_known_sigma_posterior(
prior=prior, sigma=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian))
self.assertTrue(isinstance(posterior, distributions.Gaussian))
posterior_log_pdf = posterior.log_pdf(x).eval()
self.assertEqual(posterior_log_pdf.shape, (6, 2))
@ -75,12 +75,12 @@ class GaussianTest(tf.test.TestCase):
s = tf.reduce_sum(x, reduction_indices=[1])
x = tf.transpose(x) # Reshape to shape (6, 2)
n = tf.constant([6] * 2)
prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
posterior = gaussian_conjugate_posteriors.known_sigma_posterior(
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
posterior = distributions.gaussian_conjugates_known_sigma_posterior(
prior=prior, sigma=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian))
self.assertTrue(isinstance(posterior, distributions.Gaussian))
# Calculate log_pdf under the 2 models
posterior_log_pdf = posterior.log_pdf(x)
@ -96,14 +96,15 @@ class GaussianTest(tf.test.TestCase):
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
s = tf.reduce_sum(x)
n = tf.size(x)
prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
predictive = gaussian_conjugate_posteriors.known_sigma_predictive(
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
predictive = distributions.gaussian_congugates_known_sigma_predictive(
prior=prior, sigma=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(predictive, tf.contrib.distributions.Gaussian))
self.assertTrue(isinstance(predictive, distributions.Gaussian))
predictive_log_pdf = predictive.log_pdf(x).eval()
self.assertEqual(predictive_log_pdf.shape, (6,))
if __name__ == '__main__':
tf.test.main()

View File

@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Dirichlet Multinomial distribution class.
@@DirichletMultinomial
"""
"""The Dirichlet Multinomial distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -61,6 +58,11 @@ def _log_combinations(counts, name='log_combinations'):
class DirichletMultinomial(object):
"""DirichletMultinomial mixture distribution.
This distribution is parameterized by a vector `alpha` of concentration
parameters for `k` classes.
#### Mathematical details
The Dirichlet Multinomial is a distribution over k-class count data, meaning
for each k-tuple of non-negative integer `counts = [c_1,...,c_k]`, we have a
probability of these draws being made from the distribution. The distribution
@ -85,7 +87,7 @@ class DirichletMultinomial(object):
same shape (if possible). In all cases, the last dimension of alpha/counts
represents single Dirichlet Multinomial distributions.
Examples:
#### Examples
```python
alpha = [1, 2, 3]

View File

@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Normal (Gaussian) distribution class.
@@Gaussian
"""
"""The Normal (Gaussian) distribution class."""
from __future__ import absolute_import
from __future__ import division
@ -44,10 +41,47 @@ def _assert_all_positive(x):
class Gaussian(object):
"""The scalar Gaussian distribution with mean and stddev parameters mu, sigma.
#### Mathematical details
The PDF of this distribution is:
```f(x) = sqrt(1/(2*pi*sigma^2)) exp(-(x-mu)^2/(2*sigma^2))```
#### Examples
Examples of initialization of one or a batch of distributions.
```python
# Define a single scalar Gaussian distribution.
dist = tf.contrib.Gaussian(mu=0, sigma=3)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1)
# Define a batch of two scalar valued Gaussians.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
dist = tf.contrib.distributions.Gaussian(mu=[1, 2.], sigma=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
dist.pdf([0, 1.5])
# Get 3 samples, returning a 3 x 2 tensor.
dist.sample(3)
```
Arguments are broadcast when possible.
```python
# Define a batch of two scalar valued Gaussians.
# Both have mean 1, but different standard deviations.
dist = tf.contrib.distributions.Gaussian(mu=1, sigma=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
dist.pdf(3.0)
```
"""
def __init__(self, mu, sigma, name=None):

View File

@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Gaussian distribution: conjugate posterior closed form calculations.
@@known_sigma_posterior
@@known_sigma_predictive
"""
"""The Gaussian distribution: conjugate posterior closed form calculations."""
from __future__ import absolute_import
from __future__ import division
@ -27,7 +23,7 @@ from tensorflow.contrib.distributions.python.ops.gaussian import Gaussian # pyl
from tensorflow.python.ops import math_ops
def known_sigma_posterior(prior, sigma, s, n):
def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n):
"""Posterior Gaussian distribution with conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a
@ -86,7 +82,7 @@ def known_sigma_posterior(prior, sigma, s, n):
sigma=math_ops.sqrt(sigmap_2))
def known_sigma_predictive(prior, sigma, s, n):
def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n):
"""Posterior predictive Gaussian distribution w. conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a

View File

@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Multivariate Normal distribution class.
@@MultivariateNormal
"""
"""The Multivariate Normal distribution class."""
from __future__ import absolute_import
from __future__ import division
@ -101,6 +98,8 @@ class MultivariateNormal(object):
or alternatively mean `mu` and factored covariance (cholesky decomposed
`sigma`) called `sigma_chol`.
#### Mathematical details
The PDF of this distribution is:
```
@ -123,6 +122,34 @@ class MultivariateNormal(object):
```
where `tri_solve()` solves a triangular system of equations.
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
of length `k`, and a covariance matrix of shape `k x k`.
Extra leading dimensions, if provided, allow for batches.
```python
# Initialize a single 3-variate Gaussian with diagonal covariance.
mu = [1, 2, 3]
sigma = [[1, 0, 0], [0, 3, 0], [0, 0, 2]]
dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
# Evaluate this on an observation in R^3, returning a scalar.
dist.pdf([-1, 0, 1])
# Initialize a batch of two 3-variate Gaussians.
mu = [[1, 2, 3], [11, 22, 33]]
sigma = ... # shape 2 x 3 x 3
dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
# Evaluate this on a two observations, each in R^3, returning a length two
# tensor.
x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3.
dist.pdf(x)
```
"""
def __init__(self, mu, sigma=None, sigma_chol=None, name=None):

View File

@ -50,6 +50,7 @@ def get_module_to_name():
tf.train: "tf.train",
tf.python_io: "tf.python_io",
tf.test: "tf.test",
tf.contrib.distributions: "tf.contrib.distributions",
tf.contrib.layers: "tf.contrib.layers",
tf.contrib.learn: "tf.contrib.learn",
tf.contrib.util: "tf.contrib.util",
@ -125,6 +126,8 @@ def all_libraries(module_to_name, members, documented):
"RankingExample", "SequenceExample"]),
library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT),
library("test", "Testing", tf.test),
library("contrib.distributions", "Statistical distributions (contrib)",
tf.contrib.distributions),
library("contrib.layers", "Layers (contrib)", tf.contrib.layers),
library("contrib.learn", "Learn (contrib)", tf.contrib.learn),
library("contrib.util", "Utilities (contrib)", tf.contrib.util),