Registering documentation for contrib.distributions.
Change: 121899093
This commit is contained in:
parent
6d2623d03c
commit
506671c7b3
tensorflow
@ -12,17 +12,38 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Ops for representing statistical distributions.
|
"""Classes representing statistical distributions. Ops for working with them.
|
||||||
|
|
||||||
## This package provides classes for statistical distributions.
|
## Classes for statistical distributions.
|
||||||
|
|
||||||
|
Classes that represent batches of statistical distributions. Each class is
|
||||||
|
initialized with parameters that define the distributions.
|
||||||
|
|
||||||
|
### Univariate (scalar) distributions
|
||||||
|
|
||||||
|
@@Gaussian
|
||||||
|
|
||||||
|
### Multivariate distributions
|
||||||
|
|
||||||
|
@@MultivariateNormal
|
||||||
|
@@DirichletMultinomial
|
||||||
|
|
||||||
|
## Posterior inference with conjugate priors.
|
||||||
|
|
||||||
|
Functions that transform conjugate prior/likelihood pairs to distributions
|
||||||
|
representing the posterior or posterior predictive.
|
||||||
|
|
||||||
|
### Gaussian likelihood with conjugate prior.
|
||||||
|
|
||||||
|
@@gaussian_conjugates_known_sigma_posterior
|
||||||
|
@@gaussian_congugates_known_sigma_predictive
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||||
from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors
|
|
||||||
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
|
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
|
||||||
from tensorflow.contrib.distributions.python.ops.gaussian import *
|
from tensorflow.contrib.distributions.python.ops.gaussian import *
|
||||||
|
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
|
||||||
from tensorflow.contrib.distributions.python.ops.mvn import *
|
from tensorflow.contrib.distributions.python.ops.mvn import *
|
||||||
|
@ -22,7 +22,7 @@ import math
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
gaussian_conjugate_posteriors = tf.contrib.distributions.gaussian_conjugate_posteriors # pylint: disable=line-too-long
|
distributions = tf.contrib.distributions
|
||||||
|
|
||||||
|
|
||||||
class GaussianTest(tf.test.TestCase):
|
class GaussianTest(tf.test.TestCase):
|
||||||
@ -35,12 +35,12 @@ class GaussianTest(tf.test.TestCase):
|
|||||||
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
|
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
|
||||||
s = tf.reduce_sum(x)
|
s = tf.reduce_sum(x)
|
||||||
n = tf.size(x)
|
n = tf.size(x)
|
||||||
prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
|
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
|
||||||
posterior = gaussian_conjugate_posteriors.known_sigma_posterior(
|
posterior = distributions.gaussian_conjugates_known_sigma_posterior(
|
||||||
prior=prior, sigma=sigma, s=s, n=n)
|
prior=prior, sigma=sigma, s=s, n=n)
|
||||||
|
|
||||||
# Smoke test
|
# Smoke test
|
||||||
self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian))
|
self.assertTrue(isinstance(posterior, distributions.Gaussian))
|
||||||
posterior_log_pdf = posterior.log_pdf(x).eval()
|
posterior_log_pdf = posterior.log_pdf(x).eval()
|
||||||
self.assertEqual(posterior_log_pdf.shape, (6,))
|
self.assertEqual(posterior_log_pdf.shape, (6,))
|
||||||
|
|
||||||
@ -54,12 +54,12 @@ class GaussianTest(tf.test.TestCase):
|
|||||||
tf.constant([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=tf.float32))
|
tf.constant([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=tf.float32))
|
||||||
s = tf.reduce_sum(x)
|
s = tf.reduce_sum(x)
|
||||||
n = tf.size(x)
|
n = tf.size(x)
|
||||||
prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
|
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
|
||||||
posterior = gaussian_conjugate_posteriors.known_sigma_posterior(
|
posterior = distributions.gaussian_conjugates_known_sigma_posterior(
|
||||||
prior=prior, sigma=sigma, s=s, n=n)
|
prior=prior, sigma=sigma, s=s, n=n)
|
||||||
|
|
||||||
# Smoke test
|
# Smoke test
|
||||||
self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian))
|
self.assertTrue(isinstance(posterior, distributions.Gaussian))
|
||||||
posterior_log_pdf = posterior.log_pdf(x).eval()
|
posterior_log_pdf = posterior.log_pdf(x).eval()
|
||||||
self.assertEqual(posterior_log_pdf.shape, (6, 2))
|
self.assertEqual(posterior_log_pdf.shape, (6, 2))
|
||||||
|
|
||||||
@ -75,12 +75,12 @@ class GaussianTest(tf.test.TestCase):
|
|||||||
s = tf.reduce_sum(x, reduction_indices=[1])
|
s = tf.reduce_sum(x, reduction_indices=[1])
|
||||||
x = tf.transpose(x) # Reshape to shape (6, 2)
|
x = tf.transpose(x) # Reshape to shape (6, 2)
|
||||||
n = tf.constant([6] * 2)
|
n = tf.constant([6] * 2)
|
||||||
prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
|
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
|
||||||
posterior = gaussian_conjugate_posteriors.known_sigma_posterior(
|
posterior = distributions.gaussian_conjugates_known_sigma_posterior(
|
||||||
prior=prior, sigma=sigma, s=s, n=n)
|
prior=prior, sigma=sigma, s=s, n=n)
|
||||||
|
|
||||||
# Smoke test
|
# Smoke test
|
||||||
self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian))
|
self.assertTrue(isinstance(posterior, distributions.Gaussian))
|
||||||
|
|
||||||
# Calculate log_pdf under the 2 models
|
# Calculate log_pdf under the 2 models
|
||||||
posterior_log_pdf = posterior.log_pdf(x)
|
posterior_log_pdf = posterior.log_pdf(x)
|
||||||
@ -96,14 +96,15 @@ class GaussianTest(tf.test.TestCase):
|
|||||||
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
|
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
|
||||||
s = tf.reduce_sum(x)
|
s = tf.reduce_sum(x)
|
||||||
n = tf.size(x)
|
n = tf.size(x)
|
||||||
prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
|
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
|
||||||
predictive = gaussian_conjugate_posteriors.known_sigma_predictive(
|
predictive = distributions.gaussian_congugates_known_sigma_predictive(
|
||||||
prior=prior, sigma=sigma, s=s, n=n)
|
prior=prior, sigma=sigma, s=s, n=n)
|
||||||
|
|
||||||
# Smoke test
|
# Smoke test
|
||||||
self.assertTrue(isinstance(predictive, tf.contrib.distributions.Gaussian))
|
self.assertTrue(isinstance(predictive, distributions.Gaussian))
|
||||||
predictive_log_pdf = predictive.log_pdf(x).eval()
|
predictive_log_pdf = predictive.log_pdf(x).eval()
|
||||||
self.assertEqual(predictive_log_pdf.shape, (6,))
|
self.assertEqual(predictive_log_pdf.shape, (6,))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -12,10 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The Dirichlet Multinomial distribution class.
|
"""The Dirichlet Multinomial distribution class."""
|
||||||
|
|
||||||
@@DirichletMultinomial
|
|
||||||
"""
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -61,6 +58,11 @@ def _log_combinations(counts, name='log_combinations'):
|
|||||||
class DirichletMultinomial(object):
|
class DirichletMultinomial(object):
|
||||||
"""DirichletMultinomial mixture distribution.
|
"""DirichletMultinomial mixture distribution.
|
||||||
|
|
||||||
|
This distribution is parameterized by a vector `alpha` of concentration
|
||||||
|
parameters for `k` classes.
|
||||||
|
|
||||||
|
#### Mathematical details
|
||||||
|
|
||||||
The Dirichlet Multinomial is a distribution over k-class count data, meaning
|
The Dirichlet Multinomial is a distribution over k-class count data, meaning
|
||||||
for each k-tuple of non-negative integer `counts = [c_1,...,c_k]`, we have a
|
for each k-tuple of non-negative integer `counts = [c_1,...,c_k]`, we have a
|
||||||
probability of these draws being made from the distribution. The distribution
|
probability of these draws being made from the distribution. The distribution
|
||||||
@ -85,7 +87,7 @@ class DirichletMultinomial(object):
|
|||||||
same shape (if possible). In all cases, the last dimension of alpha/counts
|
same shape (if possible). In all cases, the last dimension of alpha/counts
|
||||||
represents single Dirichlet Multinomial distributions.
|
represents single Dirichlet Multinomial distributions.
|
||||||
|
|
||||||
Examples:
|
#### Examples
|
||||||
|
|
||||||
```python
|
```python
|
||||||
alpha = [1, 2, 3]
|
alpha = [1, 2, 3]
|
||||||
|
@ -12,10 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The Normal (Gaussian) distribution class.
|
"""The Normal (Gaussian) distribution class."""
|
||||||
|
|
||||||
@@Gaussian
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -44,10 +41,47 @@ def _assert_all_positive(x):
|
|||||||
class Gaussian(object):
|
class Gaussian(object):
|
||||||
"""The scalar Gaussian distribution with mean and stddev parameters mu, sigma.
|
"""The scalar Gaussian distribution with mean and stddev parameters mu, sigma.
|
||||||
|
|
||||||
|
#### Mathematical details
|
||||||
|
|
||||||
The PDF of this distribution is:
|
The PDF of this distribution is:
|
||||||
|
|
||||||
```f(x) = sqrt(1/(2*pi*sigma^2)) exp(-(x-mu)^2/(2*sigma^2))```
|
```f(x) = sqrt(1/(2*pi*sigma^2)) exp(-(x-mu)^2/(2*sigma^2))```
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
Examples of initialization of one or a batch of distributions.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define a single scalar Gaussian distribution.
|
||||||
|
dist = tf.contrib.Gaussian(mu=0, sigma=3)
|
||||||
|
|
||||||
|
# Evaluate the cdf at 1, returning a scalar.
|
||||||
|
dist.cdf(1)
|
||||||
|
|
||||||
|
# Define a batch of two scalar valued Gaussians.
|
||||||
|
# The first has mean 1 and standard deviation 11, the second 2 and 22.
|
||||||
|
dist = tf.contrib.distributions.Gaussian(mu=[1, 2.], sigma=[11, 22.])
|
||||||
|
|
||||||
|
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
|
||||||
|
# returning a length two tensor.
|
||||||
|
dist.pdf([0, 1.5])
|
||||||
|
|
||||||
|
# Get 3 samples, returning a 3 x 2 tensor.
|
||||||
|
dist.sample(3)
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments are broadcast when possible.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define a batch of two scalar valued Gaussians.
|
||||||
|
# Both have mean 1, but different standard deviations.
|
||||||
|
dist = tf.contrib.distributions.Gaussian(mu=1, sigma=[11, 22.])
|
||||||
|
|
||||||
|
# Evaluate the pdf of both distributions on the same point, 3.0,
|
||||||
|
# returning a length 2 tensor.
|
||||||
|
dist.pdf(3.0)
|
||||||
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mu, sigma, name=None):
|
def __init__(self, mu, sigma, name=None):
|
||||||
|
@ -12,11 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The Gaussian distribution: conjugate posterior closed form calculations.
|
"""The Gaussian distribution: conjugate posterior closed form calculations."""
|
||||||
|
|
||||||
@@known_sigma_posterior
|
|
||||||
@@known_sigma_predictive
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -27,7 +23,7 @@ from tensorflow.contrib.distributions.python.ops.gaussian import Gaussian # pyl
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
def known_sigma_posterior(prior, sigma, s, n):
|
def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n):
|
||||||
"""Posterior Gaussian distribution with conjugate prior on the mean.
|
"""Posterior Gaussian distribution with conjugate prior on the mean.
|
||||||
|
|
||||||
This model assumes that `n` observations (with sum `s`) come from a
|
This model assumes that `n` observations (with sum `s`) come from a
|
||||||
@ -86,7 +82,7 @@ def known_sigma_posterior(prior, sigma, s, n):
|
|||||||
sigma=math_ops.sqrt(sigmap_2))
|
sigma=math_ops.sqrt(sigmap_2))
|
||||||
|
|
||||||
|
|
||||||
def known_sigma_predictive(prior, sigma, s, n):
|
def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n):
|
||||||
"""Posterior predictive Gaussian distribution w. conjugate prior on the mean.
|
"""Posterior predictive Gaussian distribution w. conjugate prior on the mean.
|
||||||
|
|
||||||
This model assumes that `n` observations (with sum `s`) come from a
|
This model assumes that `n` observations (with sum `s`) come from a
|
||||||
|
@ -12,10 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The Multivariate Normal distribution class.
|
"""The Multivariate Normal distribution class."""
|
||||||
|
|
||||||
@@MultivariateNormal
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -101,6 +98,8 @@ class MultivariateNormal(object):
|
|||||||
or alternatively mean `mu` and factored covariance (cholesky decomposed
|
or alternatively mean `mu` and factored covariance (cholesky decomposed
|
||||||
`sigma`) called `sigma_chol`.
|
`sigma`) called `sigma_chol`.
|
||||||
|
|
||||||
|
#### Mathematical details
|
||||||
|
|
||||||
The PDF of this distribution is:
|
The PDF of this distribution is:
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -123,6 +122,34 @@ class MultivariateNormal(object):
|
|||||||
```
|
```
|
||||||
|
|
||||||
where `tri_solve()` solves a triangular system of equations.
|
where `tri_solve()` solves a triangular system of equations.
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
A single multi-variate Gaussian distribution is defined by a vector of means
|
||||||
|
of length `k`, and a covariance matrix of shape `k x k`.
|
||||||
|
|
||||||
|
Extra leading dimensions, if provided, allow for batches.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Initialize a single 3-variate Gaussian with diagonal covariance.
|
||||||
|
mu = [1, 2, 3]
|
||||||
|
sigma = [[1, 0, 0], [0, 3, 0], [0, 0, 2]]
|
||||||
|
dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||||
|
|
||||||
|
# Evaluate this on an observation in R^3, returning a scalar.
|
||||||
|
dist.pdf([-1, 0, 1])
|
||||||
|
|
||||||
|
# Initialize a batch of two 3-variate Gaussians.
|
||||||
|
mu = [[1, 2, 3], [11, 22, 33]]
|
||||||
|
sigma = ... # shape 2 x 3 x 3
|
||||||
|
dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||||
|
|
||||||
|
# Evaluate this on a two observations, each in R^3, returning a length two
|
||||||
|
# tensor.
|
||||||
|
x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3.
|
||||||
|
dist.pdf(x)
|
||||||
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
|
def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
|
||||||
|
@ -50,6 +50,7 @@ def get_module_to_name():
|
|||||||
tf.train: "tf.train",
|
tf.train: "tf.train",
|
||||||
tf.python_io: "tf.python_io",
|
tf.python_io: "tf.python_io",
|
||||||
tf.test: "tf.test",
|
tf.test: "tf.test",
|
||||||
|
tf.contrib.distributions: "tf.contrib.distributions",
|
||||||
tf.contrib.layers: "tf.contrib.layers",
|
tf.contrib.layers: "tf.contrib.layers",
|
||||||
tf.contrib.learn: "tf.contrib.learn",
|
tf.contrib.learn: "tf.contrib.learn",
|
||||||
tf.contrib.util: "tf.contrib.util",
|
tf.contrib.util: "tf.contrib.util",
|
||||||
@ -125,6 +126,8 @@ def all_libraries(module_to_name, members, documented):
|
|||||||
"RankingExample", "SequenceExample"]),
|
"RankingExample", "SequenceExample"]),
|
||||||
library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT),
|
library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT),
|
||||||
library("test", "Testing", tf.test),
|
library("test", "Testing", tf.test),
|
||||||
|
library("contrib.distributions", "Statistical distributions (contrib)",
|
||||||
|
tf.contrib.distributions),
|
||||||
library("contrib.layers", "Layers (contrib)", tf.contrib.layers),
|
library("contrib.layers", "Layers (contrib)", tf.contrib.layers),
|
||||||
library("contrib.learn", "Learn (contrib)", tf.contrib.learn),
|
library("contrib.learn", "Learn (contrib)", tf.contrib.learn),
|
||||||
library("contrib.util", "Utilities (contrib)", tf.contrib.util),
|
library("contrib.util", "Utilities (contrib)", tf.contrib.util),
|
||||||
|
Loading…
Reference in New Issue
Block a user