Gamma, Chi2 and Exponential Distributions for Tensorflow

Change: 122546445
This commit is contained in:
A. Unique TensorFlower 2016-05-17 10:37:02 -08:00 committed by TensorFlower Gardener
parent 43ff0e9172
commit da10ae8699
8 changed files with 647 additions and 1 deletions

View File

@ -27,6 +27,33 @@ cuda_py_tests(
],
)
cuda_py_tests(
name = "gamma_test",
srcs = ["python/kernel_tests/gamma_test.py"],
additional_deps = [
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "chi2_test",
srcs = ["python/kernel_tests/chi2_test.py"],
additional_deps = [
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "exponential_test",
srcs = ["python/kernel_tests/exponential_test.py"],
additional_deps = [
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "gaussian_test",
size = "small",
@ -65,7 +92,6 @@ cuda_py_tests(
srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)

View File

@ -27,6 +27,9 @@ initialized with parameters that define the distributions.
### Univariate (scalar) distributions
@@Chi2
@@Exponential
@@Gamma
@@Gaussian
@@Uniform
@ -50,8 +53,12 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.distributions.python.ops.chi2 import *
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.distribution import *
from tensorflow.contrib.distributions.python.ops.exponential import *
from tensorflow.contrib.distributions.python.ops.gamma import *
from tensorflow.contrib.distributions.python.ops.gaussian import *
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.mvn import *

View File

@ -0,0 +1,85 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for initializers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from scipy import stats
import tensorflow as tf
class Chi2Test(tf.test.TestCase):
def testChi2LogPDF(self):
with tf.Session():
batch_size = 6
df = tf.constant([2.0] * batch_size, dtype=np.float64)
df_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64)
chi2 = tf.contrib.distributions.Chi2(df=df)
expected_log_pdf = stats.chi2.logpdf(x, df_v)
log_pdf = chi2.log_pdf(x)
self.assertEqual(log_pdf.get_shape(), (6,))
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
pdf = chi2.pdf(x)
self.assertEqual(pdf.get_shape(), (6,))
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
def testChi2CDF(self):
with tf.Session():
batch_size = 6
df = tf.constant([2.0] * batch_size, dtype=np.float64)
df_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64)
chi2 = tf.contrib.distributions.Chi2(df=df)
expected_cdf = stats.chi2.cdf(x, df_v)
cdf = chi2.cdf(x)
self.assertEqual(cdf.get_shape(), (6,))
self.assertAllClose(cdf.eval(), expected_cdf)
def testChi2Mean(self):
with tf.Session():
df_v = np.array([1., 3, 5], dtype=np.float64)
expected_mean = stats.chi2.mean(df_v)
chi2 = tf.contrib.distributions.Chi2(df=df_v)
self.assertEqual(chi2.mean.get_shape(), (3,))
self.assertAllClose(chi2.mean.eval(), expected_mean)
def testChi2Variance(self):
with tf.Session():
df_v = np.array([1., 3, 5], np.float64)
expected_variances = stats.chi2.var(df_v)
chi2 = tf.contrib.distributions.Chi2(df=df_v)
self.assertEqual(chi2.variance.get_shape(), (3,))
self.assertAllClose(chi2.variance.eval(), expected_variances)
def testChi2Entropy(self):
with tf.Session():
df_v = np.array([1., 3, 5], dtype=np.float64)
expected_entropy = stats.chi2.entropy(df_v)
chi2 = tf.contrib.distributions.Chi2(df=df_v)
self.assertEqual(chi2.entropy().get_shape(), (3,))
self.assertAllClose(chi2.entropy().eval(), expected_entropy)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,85 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for initializers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from scipy import stats
import tensorflow as tf
class ExponentialTest(tf.test.TestCase):
def testExponentialLogPDF(self):
with tf.Session():
batch_size = 6
lam = tf.constant([2.0] * batch_size)
lam_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
exponential = tf.contrib.distributions.Exponential(lam=lam)
expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
log_pdf = exponential.log_pdf(x)
self.assertEqual(log_pdf.get_shape(), (6,))
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
pdf = exponential.pdf(x)
self.assertEqual(pdf.get_shape(), (6,))
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
def testExponentialCDF(self):
with tf.Session():
batch_size = 6
lam = tf.constant([2.0] * batch_size)
lam_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
exponential = tf.contrib.distributions.Exponential(lam=lam)
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
cdf = exponential.cdf(x)
self.assertEqual(cdf.get_shape(), (6,))
self.assertAllClose(cdf.eval(), expected_cdf)
def testExponentialMean(self):
with tf.Session():
lam_v = np.array([1.0, 4.0, 2.5])
expected_mean = stats.expon.mean(scale=1 / lam_v)
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
self.assertEqual(exponential.mean.get_shape(), (3,))
self.assertAllClose(exponential.mean.eval(), expected_mean)
def testExponentialVariance(self):
with tf.Session():
lam_v = np.array([1.0, 4.0, 2.5])
expected_variance = stats.expon.var(scale=1 / lam_v)
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
self.assertEqual(exponential.variance.get_shape(), (3,))
self.assertAllClose(exponential.variance.eval(), expected_variance)
def testExponentialEntropy(self):
with tf.Session():
lam_v = np.array([1.0, 4.0, 2.5])
expected_entropy = stats.expon.entropy(scale=1 / lam_v)
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
self.assertEqual(exponential.entropy().get_shape(), (3,))
self.assertAllClose(exponential.entropy().eval(), expected_entropy)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,142 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for initializers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from scipy import stats
import tensorflow as tf
class GammaTest(tf.test.TestCase):
def testGammaShape(self):
with tf.Session():
alpha = tf.constant([3.0] * 5)
beta = tf.constant(11.0)
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
self.assertEqual(gamma.batch_shape().eval(), (5,))
self.assertEqual(gamma.get_batch_shape(), tf.TensorShape([5]))
self.assertEqual(gamma.event_shape().eval(), 1)
self.assertEqual(gamma.get_event_shape(), tf.TensorShape([]))
def testGammaLogPDF(self):
with tf.Session():
batch_size = 6
alpha = tf.constant([2.0] * batch_size)
beta = tf.constant([3.0] * batch_size)
alpha_v = 2.0
beta_v = 3.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_pdf(x)
self.assertEqual(log_pdf.get_shape(), (6,))
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
pdf = gamma.pdf(x)
self.assertEqual(pdf.get_shape(), (6,))
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensional(self):
with tf.Session():
batch_size = 6
alpha = tf.constant([[2.0, 4.0]] * batch_size)
beta = tf.constant([[3.0, 4.0]] * batch_size)
alpha_v = np.array([2.0, 4.0])
beta_v = np.array([3.0, 4.0])
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_pdf(x)
log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(log_pdf_values, expected_log_pdf)
pdf = gamma.pdf(x)
pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2))
self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensionalBroadcasting(self):
with tf.Session():
batch_size = 6
alpha = tf.constant([[2.0, 4.0]] * batch_size)
beta = tf.constant(3.0)
alpha_v = np.array([2.0, 4.0])
beta_v = 3.0
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_pdf(x)
log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(log_pdf_values, expected_log_pdf)
pdf = gamma.pdf(x)
pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2))
self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testGammaCDF(self):
with tf.Session():
batch_size = 6
alpha = tf.constant([2.0] * batch_size)
beta = tf.constant([3.0] * batch_size)
alpha_v = 2.0
beta_v = 3.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
cdf = gamma.cdf(x)
self.assertEqual(cdf.get_shape(), (6,))
self.assertAllClose(cdf.eval(), expected_cdf)
def testGammaMean(self):
with tf.Session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.mean.get_shape(), (3,))
self.assertAllClose(gamma.mean.eval(), expected_means)
def testGammaVariance(self):
with tf.Session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.variance.get_shape(), (3,))
self.assertAllClose(gamma.variance.eval(), expected_variances)
def testGammaEntropy(self):
with tf.Session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
self.assertEqual(gamma.entropy().get_shape(), (3,))
self.assertAllClose(gamma.entropy().eval(), expected_entropy)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,46 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Chi2 distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import gamma
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
class Chi2(gamma.Gamma):
"""The Chi2 distribution with degrees of freedom df.
The PDF of this distribution is:
```pdf(x) = (x^(df/2 - 1)e^(-x/2))/(2^(k/2)Gamma(k/2)), x > 0```
Note that the Chi2 distribution is a special case of the Gamma distribution,
with Chi2(df) = Gamma(df/2, 1/2).
"""
def __init__(self, df, name="Chi2"):
with ops.op_scope([df], name, "init"):
df = ops.convert_to_tensor(df)
self._df = df
super(Chi2, self).__init__(alpha=df / 2,
beta=math_ops.cast(0.5, dtype=df.dtype))
@property
def df(self):
return self._df

View File

@ -0,0 +1,47 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Exponential distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import gamma
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
class Exponential(gamma.Gamma):
"""The Exponential distribution with rate parameter lam.
The PDF of this distribution is:
```pdf(x) = (lam * e^(-lam * x)), x > 0```
Note that the Exponential distribution is a special case of the Gamma
distribution, with Exponential(lam) = Gamma(1, lam).
"""
def __init__(self, lam, name="Exponential"):
with ops.op_scope([lam], name, "init"):
lam = ops.convert_to_tensor(lam)
self._lam = lam
super(Exponential, self).__init__(
alpha=math_ops.cast(1.0, dtype=lam.dtype),
beta=lam)
@property
def lam(self):
return self._lam

View File

@ -0,0 +1,208 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Gamma distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
class Gamma(ContinuousDistribution):
"""The `Gamma` distribution with parameter alpha and beta.
The parameters are the shape and inverse scale parameters alpha, beta.
The PDF of this distribution is:
```pdf(x) = (beta^alpha)(x^(alpha-1))e^(-x*beta)/Gamma(alpha), x > 0```
and the CDF of this distribution is:
```cdf(x) = GammaInc(alpha, beta * x) / Gamma(alpha), x > 0```
where GammaInc is the incomplete lower Gamma function.
Examples:
```python
dist = Gamma(alpha=3.0, beta=2.0)
dist2 = Gamma(alpha=[3.0, 4.0], beta=[2.0, 3.0])
```
"""
def __init__(self, alpha, beta, name="Gamma"):
"""Construct Gamma distributions with parameters `alpha` and `beta`.
The parameters `alpha` and `beta` must be shaped in a way that supports
broadcasting (e.g. `alpha + beta` is a valid operation).
Args:
alpha: `float` or `double` tensor, the shape params of the
distribution(s).
alpha must contain only positive values.
beta: `float` or `double` tensor, the inverse scale params of the
distribution(s).
beta must contain only positive values.
name: The name to give Ops created by the initializer.
Raises:
TypeError: if `alpha` and `beta` are different dtypes.
"""
with ops.op_scope([alpha, beta], name):
alpha = ops.convert_to_tensor(alpha, name="alpha_before_dependencies")
beta = ops.convert_to_tensor(beta, name="beta_before_dependencies")
contrib_tensor_util.assert_same_float_dtype((alpha, beta))
with ops.control_dependencies([
check_ops.assert_positive(alpha), check_ops.assert_positive(beta)
]):
self._alpha = alpha
self._beta = beta
self._name = name
with ops.op_scope([self._alpha, self._beta], name, "mean"):
self._mean = self._alpha / self._beta
self._batch_shape = self._mean.get_shape()
with ops.op_scope([self._alpha, self._beta], name, "variance"):
self._variance = self._alpha / math_ops.square(self._beta)
self._event_shape = tensor_shape.TensorShape([])
@property
def name(self):
return self._name
@property
def dtype(self):
return self._alpha.dtype
@property
def alpha(self):
return self._alpha
@property
def beta(self):
return self._beta
def batch_shape(self, name="batch_shape"):
with ops.name_scope(self.name):
return array_ops.shape(self._mean, name=name)
def get_batch_shape(self):
return self._batch_shape
def event_shape(self, name="event_shape"):
with ops.name_scope(self.name):
return constant_op.constant(1, name=name)
def get_event_shape(self):
return self._event_shape
@property
def mean(self):
return self._mean
@property
def variance(self):
return self._variance
def log_pdf(self, x, name="log_pdf"):
"""Log pdf of observations in `x` under these Gamma distribution(s).
Args:
x: tensor of dtype `dtype`, must be broadcastable with `alpha` and `beta`.
name: The name to give this op.
Returns:
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
Raises:
TypeError: if `x` and `alpha` are different dtypes.
"""
with ops.op_scope([self._alpha, self._beta, x], self.name):
with ops.name_scope(name):
alpha = self._alpha
beta = self._beta
x = ops.convert_to_tensor(x)
x = control_flow_ops.with_dependencies(
[check_ops.assert_positive(x)], x)
contrib_tensor_util.assert_same_float_dtype(tensors=[x,],
dtype=self.dtype)
return (alpha * math_ops.log(beta) + (alpha - 1) * math_ops.log(x) -
beta * x - math_ops.lgamma(self._alpha))
def pdf(self, x, name="pdf"):
with ops.name_scope(name):
return math_ops.exp(self.log_pdf(x, name))
def log_cdf(self, x, name="log_cdf"):
"""Log CDF of observations `x` under these Gamma distribution(s).
Args:
x: tensor of dtype `dtype`, must be broadcastable with `alpha` and `beta`.
name: The name to give this op.
Returns:
log_cdf: tensor of dtype `dtype`, the log-CDFs of `x`.
"""
with ops.op_scope([self._alpha, self._beta, x], self.name):
with ops.name_scope(name):
x = ops.convert_to_tensor(x)
x = control_flow_ops.with_dependencies(
[check_ops.assert_positive(x)], x)
contrib_tensor_util.assert_same_float_dtype(tensors=[x,],
dtype=self.dtype)
# Note that igamma returns the regularized incomplete gamma function,
# which is what we want for the CDF.
return math_ops.log(math_ops.igamma(self._alpha, self._beta * x))
def cdf(self, x, name="cdf"):
with ops.op_scope([self._alpha, self._beta, x], self.name):
with ops.name_scope(name):
return math_ops.igamma(self._alpha, self._beta * x)
def entropy(self, name="entropy"):
"""The entropy of Gamma distribution(s).
This is defined to be
```entropy = alpha - log(beta) + log(Gamma(alpha))
+ (1-alpha)digamma(alpha)```
where digamma(alpha) is the digamma function.
Args:
name: The name to give this op.
Returns:
entropy: tensor of dtype `dtype`, the entropy.
"""
with ops.op_scope([self.alpha, self._beta], self.name):
with ops.name_scope(name):
alpha = self._alpha
beta = self._beta
return (alpha - math_ops.log(beta) + math_ops.lgamma(alpha) +
(1 - alpha) * math_ops.digamma(alpha))