From 506671c7b33f0705aa4fbe5527cb1d3264a8b9b6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 9 May 2016 16:54:05 -0800 Subject: [PATCH] Registering documentation for contrib.distributions. Change: 121899093 --- tensorflow/contrib/distributions/__init__.py | 27 ++++++++++-- .../gaussian_conjugate_posteriors_test.py | 27 ++++++------ .../python/ops/dirichlet_multinomial.py | 12 +++--- .../distributions/python/ops/gaussian.py | 42 +++++++++++++++++-- .../ops/gaussian_conjugate_posteriors.py | 10 ++--- .../contrib/distributions/python/ops/mvn.py | 35 ++++++++++++++-- .../python/framework/gen_docs_combined.py | 3 ++ 7 files changed, 120 insertions(+), 36 deletions(-) diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 54607a7379e..f3263ff7858 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -12,17 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Ops for representing statistical distributions. +"""Classes representing statistical distributions. Ops for working with them. -## This package provides classes for statistical distributions. +## Classes for statistical distributions. +Classes that represent batches of statistical distributions. Each class is +initialized with parameters that define the distributions. + +### Univariate (scalar) distributions + +@@Gaussian + +### Multivariate distributions + +@@MultivariateNormal +@@DirichletMultinomial + +## Posterior inference with conjugate priors. + +Functions that transform conjugate prior/likelihood pairs to distributions +representing the posterior or posterior predictive. + +### Gaussian likelihood with conjugate prior. + +@@gaussian_conjugates_known_sigma_posterior +@@gaussian_congugates_known_sigma_predictive """ from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long -from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * from tensorflow.contrib.distributions.python.ops.gaussian import * +from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import * from tensorflow.contrib.distributions.python.ops.mvn import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gaussian_conjugate_posteriors_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gaussian_conjugate_posteriors_test.py index ef15f8316bf..c3a2464b5bd 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/gaussian_conjugate_posteriors_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/gaussian_conjugate_posteriors_test.py @@ -22,7 +22,7 @@ import math import tensorflow as tf -gaussian_conjugate_posteriors = tf.contrib.distributions.gaussian_conjugate_posteriors # pylint: disable=line-too-long +distributions = tf.contrib.distributions class GaussianTest(tf.test.TestCase): @@ -35,12 +35,12 @@ class GaussianTest(tf.test.TestCase): x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]) s = tf.reduce_sum(x) n = tf.size(x) - prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0) - posterior = gaussian_conjugate_posteriors.known_sigma_posterior( + prior = distributions.Gaussian(mu=mu0, sigma=sigma0) + posterior = distributions.gaussian_conjugates_known_sigma_posterior( prior=prior, sigma=sigma, s=s, n=n) # Smoke test - self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian)) + self.assertTrue(isinstance(posterior, distributions.Gaussian)) posterior_log_pdf = posterior.log_pdf(x).eval() self.assertEqual(posterior_log_pdf.shape, (6,)) @@ -54,12 +54,12 @@ class GaussianTest(tf.test.TestCase): tf.constant([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=tf.float32)) s = tf.reduce_sum(x) n = tf.size(x) - prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0) - posterior = gaussian_conjugate_posteriors.known_sigma_posterior( + prior = distributions.Gaussian(mu=mu0, sigma=sigma0) + posterior = distributions.gaussian_conjugates_known_sigma_posterior( prior=prior, sigma=sigma, s=s, n=n) # Smoke test - self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian)) + self.assertTrue(isinstance(posterior, distributions.Gaussian)) posterior_log_pdf = posterior.log_pdf(x).eval() self.assertEqual(posterior_log_pdf.shape, (6, 2)) @@ -75,12 +75,12 @@ class GaussianTest(tf.test.TestCase): s = tf.reduce_sum(x, reduction_indices=[1]) x = tf.transpose(x) # Reshape to shape (6, 2) n = tf.constant([6] * 2) - prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0) - posterior = gaussian_conjugate_posteriors.known_sigma_posterior( + prior = distributions.Gaussian(mu=mu0, sigma=sigma0) + posterior = distributions.gaussian_conjugates_known_sigma_posterior( prior=prior, sigma=sigma, s=s, n=n) # Smoke test - self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian)) + self.assertTrue(isinstance(posterior, distributions.Gaussian)) # Calculate log_pdf under the 2 models posterior_log_pdf = posterior.log_pdf(x) @@ -96,14 +96,15 @@ class GaussianTest(tf.test.TestCase): x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]) s = tf.reduce_sum(x) n = tf.size(x) - prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0) - predictive = gaussian_conjugate_posteriors.known_sigma_predictive( + prior = distributions.Gaussian(mu=mu0, sigma=sigma0) + predictive = distributions.gaussian_congugates_known_sigma_predictive( prior=prior, sigma=sigma, s=s, n=n) # Smoke test - self.assertTrue(isinstance(predictive, tf.contrib.distributions.Gaussian)) + self.assertTrue(isinstance(predictive, distributions.Gaussian)) predictive_log_pdf = predictive.log_pdf(x).eval() self.assertEqual(predictive_log_pdf.shape, (6,)) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py index 358af118255..29436ef6e38 100644 --- a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py +++ b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The Dirichlet Multinomial distribution class. - -@@DirichletMultinomial -""" +"""The Dirichlet Multinomial distribution class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -61,6 +58,11 @@ def _log_combinations(counts, name='log_combinations'): class DirichletMultinomial(object): """DirichletMultinomial mixture distribution. + This distribution is parameterized by a vector `alpha` of concentration + parameters for `k` classes. + + #### Mathematical details + The Dirichlet Multinomial is a distribution over k-class count data, meaning for each k-tuple of non-negative integer `counts = [c_1,...,c_k]`, we have a probability of these draws being made from the distribution. The distribution @@ -85,7 +87,7 @@ class DirichletMultinomial(object): same shape (if possible). In all cases, the last dimension of alpha/counts represents single Dirichlet Multinomial distributions. - Examples: + #### Examples ```python alpha = [1, 2, 3] diff --git a/tensorflow/contrib/distributions/python/ops/gaussian.py b/tensorflow/contrib/distributions/python/ops/gaussian.py index cbb98624d97..8c7cd80d0fe 100644 --- a/tensorflow/contrib/distributions/python/ops/gaussian.py +++ b/tensorflow/contrib/distributions/python/ops/gaussian.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The Normal (Gaussian) distribution class. - -@@Gaussian -""" +"""The Normal (Gaussian) distribution class.""" from __future__ import absolute_import from __future__ import division @@ -44,10 +41,47 @@ def _assert_all_positive(x): class Gaussian(object): """The scalar Gaussian distribution with mean and stddev parameters mu, sigma. + #### Mathematical details + The PDF of this distribution is: ```f(x) = sqrt(1/(2*pi*sigma^2)) exp(-(x-mu)^2/(2*sigma^2))``` + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + # Define a single scalar Gaussian distribution. + dist = tf.contrib.Gaussian(mu=0, sigma=3) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1) + + # Define a batch of two scalar valued Gaussians. + # The first has mean 1 and standard deviation 11, the second 2 and 22. + dist = tf.contrib.distributions.Gaussian(mu=[1, 2.], sigma=[11, 22.]) + + # Evaluate the pdf of the first distribution on 0, and the second on 1.5, + # returning a length two tensor. + dist.pdf([0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample(3) + ``` + + Arguments are broadcast when possible. + + ```python + # Define a batch of two scalar valued Gaussians. + # Both have mean 1, but different standard deviations. + dist = tf.contrib.distributions.Gaussian(mu=1, sigma=[11, 22.]) + + # Evaluate the pdf of both distributions on the same point, 3.0, + # returning a length 2 tensor. + dist.pdf(3.0) + ``` + """ def __init__(self, mu, sigma, name=None): diff --git a/tensorflow/contrib/distributions/python/ops/gaussian_conjugate_posteriors.py b/tensorflow/contrib/distributions/python/ops/gaussian_conjugate_posteriors.py index f5536d8534a..c0089964152 100644 --- a/tensorflow/contrib/distributions/python/ops/gaussian_conjugate_posteriors.py +++ b/tensorflow/contrib/distributions/python/ops/gaussian_conjugate_posteriors.py @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The Gaussian distribution: conjugate posterior closed form calculations. - -@@known_sigma_posterior -@@known_sigma_predictive -""" +"""The Gaussian distribution: conjugate posterior closed form calculations.""" from __future__ import absolute_import from __future__ import division @@ -27,7 +23,7 @@ from tensorflow.contrib.distributions.python.ops.gaussian import Gaussian # pyl from tensorflow.python.ops import math_ops -def known_sigma_posterior(prior, sigma, s, n): +def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n): """Posterior Gaussian distribution with conjugate prior on the mean. This model assumes that `n` observations (with sum `s`) come from a @@ -86,7 +82,7 @@ def known_sigma_posterior(prior, sigma, s, n): sigma=math_ops.sqrt(sigmap_2)) -def known_sigma_predictive(prior, sigma, s, n): +def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n): """Posterior predictive Gaussian distribution w. conjugate prior on the mean. This model assumes that `n` observations (with sum `s`) come from a diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py index 4ddd577d46b..0faa36df149 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn.py +++ b/tensorflow/contrib/distributions/python/ops/mvn.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The Multivariate Normal distribution class. - -@@MultivariateNormal -""" +"""The Multivariate Normal distribution class.""" from __future__ import absolute_import from __future__ import division @@ -101,6 +98,8 @@ class MultivariateNormal(object): or alternatively mean `mu` and factored covariance (cholesky decomposed `sigma`) called `sigma_chol`. + #### Mathematical details + The PDF of this distribution is: ``` @@ -123,6 +122,34 @@ class MultivariateNormal(object): ``` where `tri_solve()` solves a triangular system of equations. + + #### Examples + + A single multi-variate Gaussian distribution is defined by a vector of means + of length `k`, and a covariance matrix of shape `k x k`. + + Extra leading dimensions, if provided, allow for batches. + + ```python + # Initialize a single 3-variate Gaussian with diagonal covariance. + mu = [1, 2, 3] + sigma = [[1, 0, 0], [0, 3, 0], [0, 0, 2]] + dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + + # Evaluate this on an observation in R^3, returning a scalar. + dist.pdf([-1, 0, 1]) + + # Initialize a batch of two 3-variate Gaussians. + mu = [[1, 2, 3], [11, 22, 33]] + sigma = ... # shape 2 x 3 x 3 + dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + + # Evaluate this on a two observations, each in R^3, returning a length two + # tensor. + x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3. + dist.pdf(x) + ``` + """ def __init__(self, mu, sigma=None, sigma_chol=None, name=None): diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index 2dbc5e656f0..fc7b616fb2d 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -50,6 +50,7 @@ def get_module_to_name(): tf.train: "tf.train", tf.python_io: "tf.python_io", tf.test: "tf.test", + tf.contrib.distributions: "tf.contrib.distributions", tf.contrib.layers: "tf.contrib.layers", tf.contrib.learn: "tf.contrib.learn", tf.contrib.util: "tf.contrib.util", @@ -125,6 +126,8 @@ def all_libraries(module_to_name, members, documented): "RankingExample", "SequenceExample"]), library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT), library("test", "Testing", tf.test), + library("contrib.distributions", "Statistical distributions (contrib)", + tf.contrib.distributions), library("contrib.layers", "Layers (contrib)", tf.contrib.layers), library("contrib.learn", "Learn (contrib)", tf.contrib.learn), library("contrib.util", "Utilities (contrib)", tf.contrib.util),