diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index dda77f76396..fafc6e86289 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -13,6 +13,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/ctc:ctc_py", + "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", "//tensorflow/contrib/testing:testing_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 269d439e882..f9290eda1e2 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import ctc +from tensorflow.contrib import distributions from tensorflow.contrib import layers from tensorflow.contrib import linear_optimizer from tensorflow.contrib import testing diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD new file mode 100644 index 00000000000..a5fde453cad --- /dev/null +++ b/tensorflow/contrib/distributions/BUILD @@ -0,0 +1,49 @@ +# Description: +# Contains ops to train linear models on top of TensorFlow. +# APIs here are meant to evolve over time. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_tests") + +py_library( + name = "distributions_py", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + srcs_version = "PY2AND3", +) + +cuda_py_tests( + name = "gaussian_test", + srcs = ["python/kernel_tests/gaussian_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "gaussian_conjugate_posteriors_test", + srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py new file mode 100644 index 00000000000..46aae254a7a --- /dev/null +++ b/tensorflow/contrib/distributions/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for representing statistical distributions. + +## This package provides classes for statistical distributions. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors +from tensorflow.contrib.distributions.python.ops.gaussian import * +# from tensorflow.contrib.distributions.python.ops.dirichlet import * # pylint: disable=line-too-long +# from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * # pylint: disable=line-too-long diff --git a/tensorflow/contrib/distributions/python/__init__.py b/tensorflow/contrib/distributions/python/__init__.py new file mode 100644 index 00000000000..c9b177d43d5 --- /dev/null +++ b/tensorflow/contrib/distributions/python/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gaussian_conjugate_posteriors_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gaussian_conjugate_posteriors_test.py new file mode 100644 index 00000000000..115f56f339c --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/gaussian_conjugate_posteriors_test.py @@ -0,0 +1,65 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import tensorflow as tf + +gaussian_conjugate_posteriors = tf.contrib.distributions.gaussian_conjugate_posteriors # pylint: disable=line-too-long + + +class GaussianTest(tf.test.TestCase): + + def testGaussianConjugateKnownSigmaPosterior(self): + with tf.Session(): + mu0 = tf.constant(3.0) + sigma0 = tf.constant(math.sqrt(1/0.1)) + sigma = tf.constant(math.sqrt(1/0.5)) + x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]) + s = tf.reduce_sum(x) + n = tf.size(x) + prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0) + posterior = gaussian_conjugate_posteriors.known_sigma_posterior( + prior=prior, sigma=sigma, s=s, n=n) + + # Smoke test + self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian)) + posterior_log_pdf = posterior.log_pdf(x).eval() + self.assertEqual(posterior_log_pdf.shape, (6,)) + + def testGaussianConjugateKnownSigmaPredictive(self): + with tf.Session(): + mu0 = tf.constant(3.0) + sigma0 = tf.constant(math.sqrt(1/0.1)) + sigma = tf.constant(math.sqrt(1/0.5)) + x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]) + s = tf.reduce_sum(x) + n = tf.size(x) + prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0) + predictive = gaussian_conjugate_posteriors.known_sigma_predictive( + prior=prior, sigma=sigma, s=s, n=n) + + # Smoke test + self.assertTrue(isinstance(predictive, tf.contrib.distributions.Gaussian)) + predictive_log_pdf = predictive.log_pdf(x).eval() + self.assertEqual(predictive_log_pdf.shape, (6,)) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gaussian_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gaussian_test.py new file mode 100644 index 00000000000..c20cb6dc4d3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/gaussian_test.py @@ -0,0 +1,77 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np +import tensorflow as tf + + +class GaussianTest(tf.test.TestCase): + + def testGaussianLogLikelihoodPDF(self): + with tf.Session(): + mu = tf.constant(3.0) + sigma = tf.constant(math.sqrt(1/0.1)) + mu_v = 3.0 + sigma_v = np.sqrt(1/0.1) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]) + gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma) + expected_log_pdf = np.log( + 1/np.sqrt(2*np.pi)/sigma_v*np.exp(-1.0/(2*sigma_v**2)*(x-mu_v)**2)) + + log_pdf = gaussian.log_pdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + + pdf = gaussian.pdf(x) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testGaussianCDF(self): + with tf.Session(): + mu = tf.constant(3.0) + sigma = tf.constant(math.sqrt(1/0.1)) + mu_v = 3.0 + sigma_v = np.sqrt(1/0.1) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]) + gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma) + erf_fn = np.vectorize(math.erf) + + # From Wikipedia + expected_cdf = 0.5*(1.0 + erf_fn((x - mu_v)/(sigma_v*np.sqrt(2)))) + + cdf = gaussian.cdf(x) + self.assertAllClose(expected_cdf, cdf.eval()) + + def testGaussianSample(self): + with tf.Session(): + mu = tf.constant(3.0) + sigma = tf.constant(math.sqrt(1/0.1)) + mu_v = 3.0 + sigma_v = np.sqrt(1/0.1) + n = tf.constant(10000) + gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma) + samples = gaussian.sample(n, seed=137) + sample_values = samples.eval() + self.assertEqual(sample_values.shape, (10000,)) + self.assertAllClose(sample_values.mean(), mu_v, atol=1e-2) + self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/gaussian.py b/tensorflow/contrib/distributions/python/ops/gaussian.py new file mode 100644 index 00000000000..d54825c6949 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/gaussian.py @@ -0,0 +1,123 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Normal (Gaussian) distribution class. + +@@Gaussian +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import constant_op +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + + +class Gaussian(object): + """The Normal (Gaussian) distribution with mean mu and stddev sigma. + + The PDF of this distribution is: + f(x) = sqrt(1/(2*pi*sigma^2)) exp(-(x-mu)^2/(2*sigma^2)) + """ + + def __init__(self, mu, sigma): + """Construct a new Gaussian distribution with mean mu and stddev sigma. + + Args: + mu: Scalar tensor, the mean of the distribution. + sigma: Scalar tensor, the precision of the distribution. + + Raises: + TypeError: if mu and sigma are different dtypes. + """ + self._mu = ops.convert_to_tensor(mu) + self._sigma = ops.convert_to_tensor(sigma) + if mu.dtype != sigma.dtype: + raise TypeError("Expected same dtype for mu, sigma but got: %s vs. %s" + % (mu.dtype, sigma.dtype)) + + @property + def dtype(self): + return self._mu.dtype + + @property + def shape(self): + return constant_op.constant([]) # Scalar + + @property + def mu(self): + return self._mu + + @property + def sigma(self): + return self._sigma + + def log_pdf(self, x): + """Log likelihood of observations in x under Gaussian with mu and sigma. + + Args: + x: 1-D, a vector of observations. + + Returns: + log_lik: 1-D, a vector of log likelihoods of `x` under the model. + """ + return (-0.5*math.log(2 * math.pi) - math_ops.log(self._sigma) + -0.5*math_ops.square((x - self._mu) / self._sigma)) + + def cdf(self, x): + """CDF of observations in x under Gaussian with mu and sigma. + + Args: + x: 1-D, a vector of observations. + + Returns: + cdf: 1-D, a vector of CDFs of `x` under the model. + """ + return (0.5 + 0.5*math_ops.erf( + 1.0/(math.sqrt(2.0) * self._sigma)*(x - self._mu))) + + def log_cdf(self, x): + """Log of the CDF of observations x under Gaussian with mu and sigma.""" + return math_ops.log(self.cdf(x)) + + def pdf(self, x): + """The PDF for observations x. + + Args: + x: 1-D, a vector of observations. + + Returns: + pdf: 1-D, a vector of pdf values of `x` under the model. + """ + return math_ops.exp(self.log_pdf(x)) + + def sample(self, n, seed=None): + """Sample `n` observations from this Distribution. + + Args: + n: Scalar int `Tensor`, the number of observations to sample. + seed: Python integer, the random seed. + + Returns: + samples: A vector of samples with shape `[n]`. + """ + return random_ops.random_normal( + shape=array_ops.expand_dims(n, 0), mean=self._mu, + stddev=self._sigma, dtype=self._mu.dtype, seed=seed) diff --git a/tensorflow/contrib/distributions/python/ops/gaussian_conjugate_posteriors.py b/tensorflow/contrib/distributions/python/ops/gaussian_conjugate_posteriors.py new file mode 100644 index 00000000000..cd59a09d6fc --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/gaussian_conjugate_posteriors.py @@ -0,0 +1,126 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Gaussian distribution: conjugate posterior closed form calculations. + +@@known_sigma_posterior +@@known_sigma_predictive +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops.gaussian import Gaussian # pylint: disable=line-too-long + +from tensorflow.python.ops import math_ops + + +def known_sigma_posterior(prior, sigma, s, n): + """Return the conjugate posterior distribution with known sigma. + + Accepts a prior Gaussian distribution, having parameters `mu0` and `sigma0`, + a known `sigma` of the predictive distribution (also assumed Gaussian), + and statistical estimates `s` (the sum of the observations) and + `n` (the number of observations). + + Returns a posterior (also Gaussian) distribution object, with parameters + `(mu', sigma'^2)`, where: + ``` + sigma'^2 = 1/(1/sigma0^2 + n/sigma^2), + mu' = (mu0/sigma0^2 + s/sigma^2) * sigma'^2. + ``` + + Args: + prior: `Normal` object of type `dtype`, the prior distribution having + parameters `(mu0, sigma0)`. + sigma: Scalar of type `dtype`, `sigma > 0`. The known stddev parameter. + s: Scalar, of type `dtype`, the sum of observations. + n: Scalar int, the number of observations. + + Returns: + A new Gaussian posterior distribution. + + Raises: + TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a + Gaussian object. + """ + if not isinstance(prior, Gaussian): + raise TypeError("Expected prior to be an instance of type Gaussian") + + if s.dtype != prior.dtype: + raise TypeError( + "Observation sum s.dtype does not match prior dtype: %s vs. %s" + % (s.dtype, prior.dtype)) + + n = math_ops.cast(n, prior.dtype) + sigma0_2 = math_ops.square(prior.sigma) + sigma_2 = math_ops.square(sigma) + sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2) + return Gaussian( + mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2, + sigma=math_ops.sqrt(sigmap_2)) + + +def known_sigma_predictive(prior, sigma, s, n): + """Return the posterior predictive distribution with known sigma. + + Accepts a prior Gaussian distribution, having parameters `mu0` and `sigma0`, + a known `sigma` of the predictive distribution (also assumed Gaussian), + and statistical estimates `s` (the sum of the observations) and + `n` (the number of observations). + + Calculates the Gaussian distribution p(x | sigma): + ``` + p(x | sigma) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu + = N(x | prior.mu, 1/(sigma^2 + prior.sigma^2)) + ``` + + Returns the predictive posterior distribution object, with parameters + `(mu', sigma'^2)`, where: + ``` + sigma_n^2 = 1/(1/sigma0^2 + n/sigma^2), + mu' = (mu0/sigma0^2 + s/sigma^2) * sigma_n^2. + sigma'^2 = sigma_n^2 + sigma^2, + ``` + + Args: + prior: `Normal` object of type `dtype`, the prior distribution having + parameters `(mu0, sigma0)`. + sigma: Scalar of type `dtype`, `sigma > 0`. The known stddev parameter. + s: Scalar, of type `dtype`, the sum of observations. + n: Scalar int, the number of observations. + + Returns: + A new Gaussian posterior distribution. + + Raises: + TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a + Gaussian object. + """ + if not isinstance(prior, Gaussian): + raise TypeError("Expected prior to be an instance of type Gaussian") + + if s.dtype != prior.dtype: + raise TypeError( + "Observation sum s.dtype does not match prior dtype: %s vs. %s" + % (s.dtype, prior.dtype)) + + n = math_ops.cast(n, prior.dtype) + sigma0_2 = math_ops.square(prior.sigma) + sigma_2 = math_ops.square(sigma) + sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2) + return Gaussian( + mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2, + sigma=math_ops.sqrt(sigmap_2 + sigma_2))