Initial API for tf.contrib.distributions.
Change: 115725802
This commit is contained in:
parent
52c73af0b0
commit
06b0813ddd
@ -13,6 +13,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/ctc:ctc_py",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
# Add projects here, they will show up under tf.contrib.
|
||||
from tensorflow.contrib import ctc
|
||||
from tensorflow.contrib import distributions
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import linear_optimizer
|
||||
from tensorflow.contrib import testing
|
||||
|
49
tensorflow/contrib/distributions/BUILD
Normal file
49
tensorflow/contrib/distributions/BUILD
Normal file
@ -0,0 +1,49 @@
|
||||
# Description:
|
||||
# Contains ops to train linear models on top of TensorFlow.
|
||||
# APIs here are meant to evolve over time.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
|
||||
py_library(
|
||||
name = "distributions_py",
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "gaussian_test",
|
||||
srcs = ["python/kernel_tests/gaussian_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "gaussian_conjugate_posteriors_test",
|
||||
srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
28
tensorflow/contrib/distributions/__init__.py
Normal file
28
tensorflow/contrib/distributions/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops for representing statistical distributions.
|
||||
|
||||
## This package provides classes for statistical distributions.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||
from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors
|
||||
from tensorflow.contrib.distributions.python.ops.gaussian import *
|
||||
# from tensorflow.contrib.distributions.python.ops.dirichlet import * # pylint: disable=line-too-long
|
||||
# from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * # pylint: disable=line-too-long
|
14
tensorflow/contrib/distributions/python/__init__.py
Normal file
14
tensorflow/contrib/distributions/python/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
@ -0,0 +1,65 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for initializers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
gaussian_conjugate_posteriors = tf.contrib.distributions.gaussian_conjugate_posteriors # pylint: disable=line-too-long
|
||||
|
||||
|
||||
class GaussianTest(tf.test.TestCase):
|
||||
|
||||
def testGaussianConjugateKnownSigmaPosterior(self):
|
||||
with tf.Session():
|
||||
mu0 = tf.constant(3.0)
|
||||
sigma0 = tf.constant(math.sqrt(1/0.1))
|
||||
sigma = tf.constant(math.sqrt(1/0.5))
|
||||
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
|
||||
s = tf.reduce_sum(x)
|
||||
n = tf.size(x)
|
||||
prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
|
||||
posterior = gaussian_conjugate_posteriors.known_sigma_posterior(
|
||||
prior=prior, sigma=sigma, s=s, n=n)
|
||||
|
||||
# Smoke test
|
||||
self.assertTrue(isinstance(posterior, tf.contrib.distributions.Gaussian))
|
||||
posterior_log_pdf = posterior.log_pdf(x).eval()
|
||||
self.assertEqual(posterior_log_pdf.shape, (6,))
|
||||
|
||||
def testGaussianConjugateKnownSigmaPredictive(self):
|
||||
with tf.Session():
|
||||
mu0 = tf.constant(3.0)
|
||||
sigma0 = tf.constant(math.sqrt(1/0.1))
|
||||
sigma = tf.constant(math.sqrt(1/0.5))
|
||||
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
|
||||
s = tf.reduce_sum(x)
|
||||
n = tf.size(x)
|
||||
prior = tf.contrib.distributions.Gaussian(mu=mu0, sigma=sigma0)
|
||||
predictive = gaussian_conjugate_posteriors.known_sigma_predictive(
|
||||
prior=prior, sigma=sigma, s=s, n=n)
|
||||
|
||||
# Smoke test
|
||||
self.assertTrue(isinstance(predictive, tf.contrib.distributions.Gaussian))
|
||||
predictive_log_pdf = predictive.log_pdf(x).eval()
|
||||
self.assertEqual(predictive_log_pdf.shape, (6,))
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -0,0 +1,77 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for initializers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class GaussianTest(tf.test.TestCase):
|
||||
|
||||
def testGaussianLogLikelihoodPDF(self):
|
||||
with tf.Session():
|
||||
mu = tf.constant(3.0)
|
||||
sigma = tf.constant(math.sqrt(1/0.1))
|
||||
mu_v = 3.0
|
||||
sigma_v = np.sqrt(1/0.1)
|
||||
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
|
||||
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
|
||||
expected_log_pdf = np.log(
|
||||
1/np.sqrt(2*np.pi)/sigma_v*np.exp(-1.0/(2*sigma_v**2)*(x-mu_v)**2))
|
||||
|
||||
log_pdf = gaussian.log_pdf(x)
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
|
||||
pdf = gaussian.pdf(x)
|
||||
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
|
||||
|
||||
def testGaussianCDF(self):
|
||||
with tf.Session():
|
||||
mu = tf.constant(3.0)
|
||||
sigma = tf.constant(math.sqrt(1/0.1))
|
||||
mu_v = 3.0
|
||||
sigma_v = np.sqrt(1/0.1)
|
||||
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
|
||||
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
|
||||
erf_fn = np.vectorize(math.erf)
|
||||
|
||||
# From Wikipedia
|
||||
expected_cdf = 0.5*(1.0 + erf_fn((x - mu_v)/(sigma_v*np.sqrt(2))))
|
||||
|
||||
cdf = gaussian.cdf(x)
|
||||
self.assertAllClose(expected_cdf, cdf.eval())
|
||||
|
||||
def testGaussianSample(self):
|
||||
with tf.Session():
|
||||
mu = tf.constant(3.0)
|
||||
sigma = tf.constant(math.sqrt(1/0.1))
|
||||
mu_v = 3.0
|
||||
sigma_v = np.sqrt(1/0.1)
|
||||
n = tf.constant(10000)
|
||||
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
|
||||
samples = gaussian.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
self.assertEqual(sample_values.shape, (10000,))
|
||||
self.assertAllClose(sample_values.mean(), mu_v, atol=1e-2)
|
||||
self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
123
tensorflow/contrib/distributions/python/ops/gaussian.py
Normal file
123
tensorflow/contrib/distributions/python/ops/gaussian.py
Normal file
@ -0,0 +1,123 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Normal (Gaussian) distribution class.
|
||||
|
||||
@@Gaussian
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import constant_op
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
|
||||
|
||||
class Gaussian(object):
|
||||
"""The Normal (Gaussian) distribution with mean mu and stddev sigma.
|
||||
|
||||
The PDF of this distribution is:
|
||||
f(x) = sqrt(1/(2*pi*sigma^2)) exp(-(x-mu)^2/(2*sigma^2))
|
||||
"""
|
||||
|
||||
def __init__(self, mu, sigma):
|
||||
"""Construct a new Gaussian distribution with mean mu and stddev sigma.
|
||||
|
||||
Args:
|
||||
mu: Scalar tensor, the mean of the distribution.
|
||||
sigma: Scalar tensor, the precision of the distribution.
|
||||
|
||||
Raises:
|
||||
TypeError: if mu and sigma are different dtypes.
|
||||
"""
|
||||
self._mu = ops.convert_to_tensor(mu)
|
||||
self._sigma = ops.convert_to_tensor(sigma)
|
||||
if mu.dtype != sigma.dtype:
|
||||
raise TypeError("Expected same dtype for mu, sigma but got: %s vs. %s"
|
||||
% (mu.dtype, sigma.dtype))
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._mu.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return constant_op.constant([]) # Scalar
|
||||
|
||||
@property
|
||||
def mu(self):
|
||||
return self._mu
|
||||
|
||||
@property
|
||||
def sigma(self):
|
||||
return self._sigma
|
||||
|
||||
def log_pdf(self, x):
|
||||
"""Log likelihood of observations in x under Gaussian with mu and sigma.
|
||||
|
||||
Args:
|
||||
x: 1-D, a vector of observations.
|
||||
|
||||
Returns:
|
||||
log_lik: 1-D, a vector of log likelihoods of `x` under the model.
|
||||
"""
|
||||
return (-0.5*math.log(2 * math.pi) - math_ops.log(self._sigma)
|
||||
-0.5*math_ops.square((x - self._mu) / self._sigma))
|
||||
|
||||
def cdf(self, x):
|
||||
"""CDF of observations in x under Gaussian with mu and sigma.
|
||||
|
||||
Args:
|
||||
x: 1-D, a vector of observations.
|
||||
|
||||
Returns:
|
||||
cdf: 1-D, a vector of CDFs of `x` under the model.
|
||||
"""
|
||||
return (0.5 + 0.5*math_ops.erf(
|
||||
1.0/(math.sqrt(2.0) * self._sigma)*(x - self._mu)))
|
||||
|
||||
def log_cdf(self, x):
|
||||
"""Log of the CDF of observations x under Gaussian with mu and sigma."""
|
||||
return math_ops.log(self.cdf(x))
|
||||
|
||||
def pdf(self, x):
|
||||
"""The PDF for observations x.
|
||||
|
||||
Args:
|
||||
x: 1-D, a vector of observations.
|
||||
|
||||
Returns:
|
||||
pdf: 1-D, a vector of pdf values of `x` under the model.
|
||||
"""
|
||||
return math_ops.exp(self.log_pdf(x))
|
||||
|
||||
def sample(self, n, seed=None):
|
||||
"""Sample `n` observations from this Distribution.
|
||||
|
||||
Args:
|
||||
n: Scalar int `Tensor`, the number of observations to sample.
|
||||
seed: Python integer, the random seed.
|
||||
|
||||
Returns:
|
||||
samples: A vector of samples with shape `[n]`.
|
||||
"""
|
||||
return random_ops.random_normal(
|
||||
shape=array_ops.expand_dims(n, 0), mean=self._mu,
|
||||
stddev=self._sigma, dtype=self._mu.dtype, seed=seed)
|
@ -0,0 +1,126 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Gaussian distribution: conjugate posterior closed form calculations.
|
||||
|
||||
@@known_sigma_posterior
|
||||
@@known_sigma_predictive
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops.gaussian import Gaussian # pylint: disable=line-too-long
|
||||
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def known_sigma_posterior(prior, sigma, s, n):
|
||||
"""Return the conjugate posterior distribution with known sigma.
|
||||
|
||||
Accepts a prior Gaussian distribution, having parameters `mu0` and `sigma0`,
|
||||
a known `sigma` of the predictive distribution (also assumed Gaussian),
|
||||
and statistical estimates `s` (the sum of the observations) and
|
||||
`n` (the number of observations).
|
||||
|
||||
Returns a posterior (also Gaussian) distribution object, with parameters
|
||||
`(mu', sigma'^2)`, where:
|
||||
```
|
||||
sigma'^2 = 1/(1/sigma0^2 + n/sigma^2),
|
||||
mu' = (mu0/sigma0^2 + s/sigma^2) * sigma'^2.
|
||||
```
|
||||
|
||||
Args:
|
||||
prior: `Normal` object of type `dtype`, the prior distribution having
|
||||
parameters `(mu0, sigma0)`.
|
||||
sigma: Scalar of type `dtype`, `sigma > 0`. The known stddev parameter.
|
||||
s: Scalar, of type `dtype`, the sum of observations.
|
||||
n: Scalar int, the number of observations.
|
||||
|
||||
Returns:
|
||||
A new Gaussian posterior distribution.
|
||||
|
||||
Raises:
|
||||
TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
|
||||
Gaussian object.
|
||||
"""
|
||||
if not isinstance(prior, Gaussian):
|
||||
raise TypeError("Expected prior to be an instance of type Gaussian")
|
||||
|
||||
if s.dtype != prior.dtype:
|
||||
raise TypeError(
|
||||
"Observation sum s.dtype does not match prior dtype: %s vs. %s"
|
||||
% (s.dtype, prior.dtype))
|
||||
|
||||
n = math_ops.cast(n, prior.dtype)
|
||||
sigma0_2 = math_ops.square(prior.sigma)
|
||||
sigma_2 = math_ops.square(sigma)
|
||||
sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2)
|
||||
return Gaussian(
|
||||
mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2,
|
||||
sigma=math_ops.sqrt(sigmap_2))
|
||||
|
||||
|
||||
def known_sigma_predictive(prior, sigma, s, n):
|
||||
"""Return the posterior predictive distribution with known sigma.
|
||||
|
||||
Accepts a prior Gaussian distribution, having parameters `mu0` and `sigma0`,
|
||||
a known `sigma` of the predictive distribution (also assumed Gaussian),
|
||||
and statistical estimates `s` (the sum of the observations) and
|
||||
`n` (the number of observations).
|
||||
|
||||
Calculates the Gaussian distribution p(x | sigma):
|
||||
```
|
||||
p(x | sigma) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu
|
||||
= N(x | prior.mu, 1/(sigma^2 + prior.sigma^2))
|
||||
```
|
||||
|
||||
Returns the predictive posterior distribution object, with parameters
|
||||
`(mu', sigma'^2)`, where:
|
||||
```
|
||||
sigma_n^2 = 1/(1/sigma0^2 + n/sigma^2),
|
||||
mu' = (mu0/sigma0^2 + s/sigma^2) * sigma_n^2.
|
||||
sigma'^2 = sigma_n^2 + sigma^2,
|
||||
```
|
||||
|
||||
Args:
|
||||
prior: `Normal` object of type `dtype`, the prior distribution having
|
||||
parameters `(mu0, sigma0)`.
|
||||
sigma: Scalar of type `dtype`, `sigma > 0`. The known stddev parameter.
|
||||
s: Scalar, of type `dtype`, the sum of observations.
|
||||
n: Scalar int, the number of observations.
|
||||
|
||||
Returns:
|
||||
A new Gaussian posterior distribution.
|
||||
|
||||
Raises:
|
||||
TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
|
||||
Gaussian object.
|
||||
"""
|
||||
if not isinstance(prior, Gaussian):
|
||||
raise TypeError("Expected prior to be an instance of type Gaussian")
|
||||
|
||||
if s.dtype != prior.dtype:
|
||||
raise TypeError(
|
||||
"Observation sum s.dtype does not match prior dtype: %s vs. %s"
|
||||
% (s.dtype, prior.dtype))
|
||||
|
||||
n = math_ops.cast(n, prior.dtype)
|
||||
sigma0_2 = math_ops.square(prior.sigma)
|
||||
sigma_2 = math_ops.square(sigma)
|
||||
sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2)
|
||||
return Gaussian(
|
||||
mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2,
|
||||
sigma=math_ops.sqrt(sigmap_2 + sigma_2))
|
Loading…
Reference in New Issue
Block a user