Gamma, Chi2 and Exponential Distributions for Tensorflow
Change: 122546445
This commit is contained in:
parent
43ff0e9172
commit
da10ae8699
@ -27,6 +27,33 @@ cuda_py_tests(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "gamma_test",
|
||||
srcs = ["python/kernel_tests/gamma_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "chi2_test",
|
||||
srcs = ["python/kernel_tests/chi2_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "exponential_test",
|
||||
srcs = ["python/kernel_tests/exponential_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "gaussian_test",
|
||||
size = "small",
|
||||
@ -65,7 +92,6 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
@ -27,6 +27,9 @@ initialized with parameters that define the distributions.
|
||||
|
||||
### Univariate (scalar) distributions
|
||||
|
||||
@@Chi2
|
||||
@@Exponential
|
||||
@@Gamma
|
||||
@@Gaussian
|
||||
@@Uniform
|
||||
|
||||
@ -50,8 +53,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops.chi2 import *
|
||||
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
|
||||
from tensorflow.contrib.distributions.python.ops.distribution import *
|
||||
from tensorflow.contrib.distributions.python.ops.exponential import *
|
||||
from tensorflow.contrib.distributions.python.ops.gamma import *
|
||||
from tensorflow.contrib.distributions.python.ops.gaussian import *
|
||||
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
|
||||
from tensorflow.contrib.distributions.python.ops.mvn import *
|
||||
|
@ -0,0 +1,85 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for initializers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class Chi2Test(tf.test.TestCase):
|
||||
|
||||
def testChi2LogPDF(self):
|
||||
with tf.Session():
|
||||
batch_size = 6
|
||||
df = tf.constant([2.0] * batch_size, dtype=np.float64)
|
||||
df_v = 2.0
|
||||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64)
|
||||
chi2 = tf.contrib.distributions.Chi2(df=df)
|
||||
expected_log_pdf = stats.chi2.logpdf(x, df_v)
|
||||
|
||||
log_pdf = chi2.log_pdf(x)
|
||||
self.assertEqual(log_pdf.get_shape(), (6,))
|
||||
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
|
||||
|
||||
pdf = chi2.pdf(x)
|
||||
self.assertEqual(pdf.get_shape(), (6,))
|
||||
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
|
||||
|
||||
def testChi2CDF(self):
|
||||
with tf.Session():
|
||||
batch_size = 6
|
||||
df = tf.constant([2.0] * batch_size, dtype=np.float64)
|
||||
df_v = 2.0
|
||||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64)
|
||||
|
||||
chi2 = tf.contrib.distributions.Chi2(df=df)
|
||||
expected_cdf = stats.chi2.cdf(x, df_v)
|
||||
|
||||
cdf = chi2.cdf(x)
|
||||
self.assertEqual(cdf.get_shape(), (6,))
|
||||
self.assertAllClose(cdf.eval(), expected_cdf)
|
||||
|
||||
def testChi2Mean(self):
|
||||
with tf.Session():
|
||||
df_v = np.array([1., 3, 5], dtype=np.float64)
|
||||
expected_mean = stats.chi2.mean(df_v)
|
||||
chi2 = tf.contrib.distributions.Chi2(df=df_v)
|
||||
self.assertEqual(chi2.mean.get_shape(), (3,))
|
||||
self.assertAllClose(chi2.mean.eval(), expected_mean)
|
||||
|
||||
def testChi2Variance(self):
|
||||
with tf.Session():
|
||||
df_v = np.array([1., 3, 5], np.float64)
|
||||
expected_variances = stats.chi2.var(df_v)
|
||||
chi2 = tf.contrib.distributions.Chi2(df=df_v)
|
||||
self.assertEqual(chi2.variance.get_shape(), (3,))
|
||||
self.assertAllClose(chi2.variance.eval(), expected_variances)
|
||||
|
||||
def testChi2Entropy(self):
|
||||
with tf.Session():
|
||||
df_v = np.array([1., 3, 5], dtype=np.float64)
|
||||
expected_entropy = stats.chi2.entropy(df_v)
|
||||
chi2 = tf.contrib.distributions.Chi2(df=df_v)
|
||||
self.assertEqual(chi2.entropy().get_shape(), (3,))
|
||||
self.assertAllClose(chi2.entropy().eval(), expected_entropy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -0,0 +1,85 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for initializers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class ExponentialTest(tf.test.TestCase):
|
||||
|
||||
def testExponentialLogPDF(self):
|
||||
with tf.Session():
|
||||
batch_size = 6
|
||||
lam = tf.constant([2.0] * batch_size)
|
||||
lam_v = 2.0
|
||||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
|
||||
exponential = tf.contrib.distributions.Exponential(lam=lam)
|
||||
expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
|
||||
|
||||
log_pdf = exponential.log_pdf(x)
|
||||
self.assertEqual(log_pdf.get_shape(), (6,))
|
||||
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
|
||||
|
||||
pdf = exponential.pdf(x)
|
||||
self.assertEqual(pdf.get_shape(), (6,))
|
||||
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
|
||||
|
||||
def testExponentialCDF(self):
|
||||
with tf.Session():
|
||||
batch_size = 6
|
||||
lam = tf.constant([2.0] * batch_size)
|
||||
lam_v = 2.0
|
||||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
|
||||
|
||||
exponential = tf.contrib.distributions.Exponential(lam=lam)
|
||||
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
|
||||
|
||||
cdf = exponential.cdf(x)
|
||||
self.assertEqual(cdf.get_shape(), (6,))
|
||||
self.assertAllClose(cdf.eval(), expected_cdf)
|
||||
|
||||
def testExponentialMean(self):
|
||||
with tf.Session():
|
||||
lam_v = np.array([1.0, 4.0, 2.5])
|
||||
expected_mean = stats.expon.mean(scale=1 / lam_v)
|
||||
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
|
||||
self.assertEqual(exponential.mean.get_shape(), (3,))
|
||||
self.assertAllClose(exponential.mean.eval(), expected_mean)
|
||||
|
||||
def testExponentialVariance(self):
|
||||
with tf.Session():
|
||||
lam_v = np.array([1.0, 4.0, 2.5])
|
||||
expected_variance = stats.expon.var(scale=1 / lam_v)
|
||||
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
|
||||
self.assertEqual(exponential.variance.get_shape(), (3,))
|
||||
self.assertAllClose(exponential.variance.eval(), expected_variance)
|
||||
|
||||
def testExponentialEntropy(self):
|
||||
with tf.Session():
|
||||
lam_v = np.array([1.0, 4.0, 2.5])
|
||||
expected_entropy = stats.expon.entropy(scale=1 / lam_v)
|
||||
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
|
||||
self.assertEqual(exponential.entropy().get_shape(), (3,))
|
||||
self.assertAllClose(exponential.entropy().eval(), expected_entropy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -0,0 +1,142 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for initializers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class GammaTest(tf.test.TestCase):
|
||||
|
||||
def testGammaShape(self):
|
||||
with tf.Session():
|
||||
alpha = tf.constant([3.0] * 5)
|
||||
beta = tf.constant(11.0)
|
||||
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
|
||||
|
||||
self.assertEqual(gamma.batch_shape().eval(), (5,))
|
||||
self.assertEqual(gamma.get_batch_shape(), tf.TensorShape([5]))
|
||||
self.assertEqual(gamma.event_shape().eval(), 1)
|
||||
self.assertEqual(gamma.get_event_shape(), tf.TensorShape([]))
|
||||
|
||||
def testGammaLogPDF(self):
|
||||
with tf.Session():
|
||||
batch_size = 6
|
||||
alpha = tf.constant([2.0] * batch_size)
|
||||
beta = tf.constant([3.0] * batch_size)
|
||||
alpha_v = 2.0
|
||||
beta_v = 3.0
|
||||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
|
||||
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
|
||||
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
|
||||
log_pdf = gamma.log_pdf(x)
|
||||
self.assertEqual(log_pdf.get_shape(), (6,))
|
||||
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
|
||||
|
||||
pdf = gamma.pdf(x)
|
||||
self.assertEqual(pdf.get_shape(), (6,))
|
||||
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
|
||||
|
||||
def testGammaLogPDFMultidimensional(self):
|
||||
with tf.Session():
|
||||
batch_size = 6
|
||||
alpha = tf.constant([[2.0, 4.0]] * batch_size)
|
||||
beta = tf.constant([[3.0, 4.0]] * batch_size)
|
||||
alpha_v = np.array([2.0, 4.0])
|
||||
beta_v = np.array([3.0, 4.0])
|
||||
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
|
||||
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
|
||||
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
|
||||
log_pdf = gamma.log_pdf(x)
|
||||
log_pdf_values = log_pdf.eval()
|
||||
self.assertEqual(log_pdf.get_shape(), (6, 2))
|
||||
self.assertAllClose(log_pdf_values, expected_log_pdf)
|
||||
|
||||
pdf = gamma.pdf(x)
|
||||
pdf_values = pdf.eval()
|
||||
self.assertEqual(pdf.get_shape(), (6, 2))
|
||||
self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
|
||||
|
||||
def testGammaLogPDFMultidimensionalBroadcasting(self):
|
||||
with tf.Session():
|
||||
batch_size = 6
|
||||
alpha = tf.constant([[2.0, 4.0]] * batch_size)
|
||||
beta = tf.constant(3.0)
|
||||
alpha_v = np.array([2.0, 4.0])
|
||||
beta_v = 3.0
|
||||
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
|
||||
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
|
||||
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
|
||||
log_pdf = gamma.log_pdf(x)
|
||||
log_pdf_values = log_pdf.eval()
|
||||
self.assertEqual(log_pdf.get_shape(), (6, 2))
|
||||
self.assertAllClose(log_pdf_values, expected_log_pdf)
|
||||
|
||||
pdf = gamma.pdf(x)
|
||||
pdf_values = pdf.eval()
|
||||
self.assertEqual(pdf.get_shape(), (6, 2))
|
||||
self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
|
||||
|
||||
def testGammaCDF(self):
|
||||
with tf.Session():
|
||||
batch_size = 6
|
||||
alpha = tf.constant([2.0] * batch_size)
|
||||
beta = tf.constant([3.0] * batch_size)
|
||||
alpha_v = 2.0
|
||||
beta_v = 3.0
|
||||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
|
||||
|
||||
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
|
||||
expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
|
||||
|
||||
cdf = gamma.cdf(x)
|
||||
self.assertEqual(cdf.get_shape(), (6,))
|
||||
self.assertAllClose(cdf.eval(), expected_cdf)
|
||||
|
||||
def testGammaMean(self):
|
||||
with tf.Session():
|
||||
alpha_v = np.array([1.0, 3.0, 2.5])
|
||||
beta_v = np.array([1.0, 4.0, 5.0])
|
||||
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
|
||||
expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
|
||||
self.assertEqual(gamma.mean.get_shape(), (3,))
|
||||
self.assertAllClose(gamma.mean.eval(), expected_means)
|
||||
|
||||
def testGammaVariance(self):
|
||||
with tf.Session():
|
||||
alpha_v = np.array([1.0, 3.0, 2.5])
|
||||
beta_v = np.array([1.0, 4.0, 5.0])
|
||||
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
|
||||
expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
|
||||
self.assertEqual(gamma.variance.get_shape(), (3,))
|
||||
self.assertAllClose(gamma.variance.eval(), expected_variances)
|
||||
|
||||
def testGammaEntropy(self):
|
||||
with tf.Session():
|
||||
alpha_v = np.array([1.0, 3.0, 2.5])
|
||||
beta_v = np.array([1.0, 4.0, 5.0])
|
||||
expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
|
||||
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
|
||||
self.assertEqual(gamma.entropy().get_shape(), (3,))
|
||||
self.assertAllClose(gamma.entropy().eval(), expected_entropy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
46
tensorflow/contrib/distributions/python/ops/chi2.py
Normal file
46
tensorflow/contrib/distributions/python/ops/chi2.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Chi2 distribution class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import gamma
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class Chi2(gamma.Gamma):
|
||||
"""The Chi2 distribution with degrees of freedom df.
|
||||
|
||||
The PDF of this distribution is:
|
||||
|
||||
```pdf(x) = (x^(df/2 - 1)e^(-x/2))/(2^(k/2)Gamma(k/2)), x > 0```
|
||||
|
||||
Note that the Chi2 distribution is a special case of the Gamma distribution,
|
||||
with Chi2(df) = Gamma(df/2, 1/2).
|
||||
"""
|
||||
|
||||
def __init__(self, df, name="Chi2"):
|
||||
with ops.op_scope([df], name, "init"):
|
||||
df = ops.convert_to_tensor(df)
|
||||
self._df = df
|
||||
super(Chi2, self).__init__(alpha=df / 2,
|
||||
beta=math_ops.cast(0.5, dtype=df.dtype))
|
||||
|
||||
@property
|
||||
def df(self):
|
||||
return self._df
|
47
tensorflow/contrib/distributions/python/ops/exponential.py
Normal file
47
tensorflow/contrib/distributions/python/ops/exponential.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Exponential distribution class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import gamma
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class Exponential(gamma.Gamma):
|
||||
"""The Exponential distribution with rate parameter lam.
|
||||
|
||||
The PDF of this distribution is:
|
||||
|
||||
```pdf(x) = (lam * e^(-lam * x)), x > 0```
|
||||
|
||||
Note that the Exponential distribution is a special case of the Gamma
|
||||
distribution, with Exponential(lam) = Gamma(1, lam).
|
||||
"""
|
||||
|
||||
def __init__(self, lam, name="Exponential"):
|
||||
with ops.op_scope([lam], name, "init"):
|
||||
lam = ops.convert_to_tensor(lam)
|
||||
self._lam = lam
|
||||
super(Exponential, self).__init__(
|
||||
alpha=math_ops.cast(1.0, dtype=lam.dtype),
|
||||
beta=lam)
|
||||
|
||||
@property
|
||||
def lam(self):
|
||||
return self._lam
|
208
tensorflow/contrib/distributions/python/ops/gamma.py
Normal file
208
tensorflow/contrib/distributions/python/ops/gamma.py
Normal file
@ -0,0 +1,208 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Gamma distribution class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class Gamma(ContinuousDistribution):
|
||||
"""The `Gamma` distribution with parameter alpha and beta.
|
||||
|
||||
The parameters are the shape and inverse scale parameters alpha, beta.
|
||||
|
||||
The PDF of this distribution is:
|
||||
|
||||
```pdf(x) = (beta^alpha)(x^(alpha-1))e^(-x*beta)/Gamma(alpha), x > 0```
|
||||
|
||||
and the CDF of this distribution is:
|
||||
|
||||
```cdf(x) = GammaInc(alpha, beta * x) / Gamma(alpha), x > 0```
|
||||
|
||||
where GammaInc is the incomplete lower Gamma function.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
dist = Gamma(alpha=3.0, beta=2.0)
|
||||
dist2 = Gamma(alpha=[3.0, 4.0], beta=[2.0, 3.0])
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, beta, name="Gamma"):
|
||||
"""Construct Gamma distributions with parameters `alpha` and `beta`.
|
||||
|
||||
The parameters `alpha` and `beta` must be shaped in a way that supports
|
||||
broadcasting (e.g. `alpha + beta` is a valid operation).
|
||||
|
||||
Args:
|
||||
alpha: `float` or `double` tensor, the shape params of the
|
||||
distribution(s).
|
||||
alpha must contain only positive values.
|
||||
beta: `float` or `double` tensor, the inverse scale params of the
|
||||
distribution(s).
|
||||
beta must contain only positive values.
|
||||
name: The name to give Ops created by the initializer.
|
||||
|
||||
Raises:
|
||||
TypeError: if `alpha` and `beta` are different dtypes.
|
||||
"""
|
||||
with ops.op_scope([alpha, beta], name):
|
||||
alpha = ops.convert_to_tensor(alpha, name="alpha_before_dependencies")
|
||||
beta = ops.convert_to_tensor(beta, name="beta_before_dependencies")
|
||||
contrib_tensor_util.assert_same_float_dtype((alpha, beta))
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_positive(alpha), check_ops.assert_positive(beta)
|
||||
]):
|
||||
self._alpha = alpha
|
||||
self._beta = beta
|
||||
self._name = name
|
||||
|
||||
with ops.op_scope([self._alpha, self._beta], name, "mean"):
|
||||
self._mean = self._alpha / self._beta
|
||||
self._batch_shape = self._mean.get_shape()
|
||||
|
||||
with ops.op_scope([self._alpha, self._beta], name, "variance"):
|
||||
self._variance = self._alpha / math_ops.square(self._beta)
|
||||
|
||||
self._event_shape = tensor_shape.TensorShape([])
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._alpha.dtype
|
||||
|
||||
@property
|
||||
def alpha(self):
|
||||
return self._alpha
|
||||
|
||||
@property
|
||||
def beta(self):
|
||||
return self._beta
|
||||
|
||||
def batch_shape(self, name="batch_shape"):
|
||||
with ops.name_scope(self.name):
|
||||
return array_ops.shape(self._mean, name=name)
|
||||
|
||||
def get_batch_shape(self):
|
||||
return self._batch_shape
|
||||
|
||||
def event_shape(self, name="event_shape"):
|
||||
with ops.name_scope(self.name):
|
||||
return constant_op.constant(1, name=name)
|
||||
|
||||
def get_event_shape(self):
|
||||
return self._event_shape
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self._mean
|
||||
|
||||
@property
|
||||
def variance(self):
|
||||
return self._variance
|
||||
|
||||
def log_pdf(self, x, name="log_pdf"):
|
||||
"""Log pdf of observations in `x` under these Gamma distribution(s).
|
||||
|
||||
Args:
|
||||
x: tensor of dtype `dtype`, must be broadcastable with `alpha` and `beta`.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
|
||||
Raises:
|
||||
TypeError: if `x` and `alpha` are different dtypes.
|
||||
"""
|
||||
with ops.op_scope([self._alpha, self._beta, x], self.name):
|
||||
with ops.name_scope(name):
|
||||
alpha = self._alpha
|
||||
beta = self._beta
|
||||
x = ops.convert_to_tensor(x)
|
||||
x = control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_positive(x)], x)
|
||||
contrib_tensor_util.assert_same_float_dtype(tensors=[x,],
|
||||
dtype=self.dtype)
|
||||
|
||||
return (alpha * math_ops.log(beta) + (alpha - 1) * math_ops.log(x) -
|
||||
beta * x - math_ops.lgamma(self._alpha))
|
||||
|
||||
def pdf(self, x, name="pdf"):
|
||||
with ops.name_scope(name):
|
||||
return math_ops.exp(self.log_pdf(x, name))
|
||||
|
||||
def log_cdf(self, x, name="log_cdf"):
|
||||
"""Log CDF of observations `x` under these Gamma distribution(s).
|
||||
|
||||
Args:
|
||||
x: tensor of dtype `dtype`, must be broadcastable with `alpha` and `beta`.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
log_cdf: tensor of dtype `dtype`, the log-CDFs of `x`.
|
||||
"""
|
||||
with ops.op_scope([self._alpha, self._beta, x], self.name):
|
||||
with ops.name_scope(name):
|
||||
x = ops.convert_to_tensor(x)
|
||||
x = control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_positive(x)], x)
|
||||
contrib_tensor_util.assert_same_float_dtype(tensors=[x,],
|
||||
dtype=self.dtype)
|
||||
# Note that igamma returns the regularized incomplete gamma function,
|
||||
# which is what we want for the CDF.
|
||||
return math_ops.log(math_ops.igamma(self._alpha, self._beta * x))
|
||||
|
||||
def cdf(self, x, name="cdf"):
|
||||
with ops.op_scope([self._alpha, self._beta, x], self.name):
|
||||
with ops.name_scope(name):
|
||||
return math_ops.igamma(self._alpha, self._beta * x)
|
||||
|
||||
def entropy(self, name="entropy"):
|
||||
"""The entropy of Gamma distribution(s).
|
||||
|
||||
This is defined to be
|
||||
|
||||
```entropy = alpha - log(beta) + log(Gamma(alpha))
|
||||
+ (1-alpha)digamma(alpha)```
|
||||
|
||||
where digamma(alpha) is the digamma function.
|
||||
|
||||
Args:
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
entropy: tensor of dtype `dtype`, the entropy.
|
||||
"""
|
||||
with ops.op_scope([self.alpha, self._beta], self.name):
|
||||
with ops.name_scope(name):
|
||||
alpha = self._alpha
|
||||
beta = self._beta
|
||||
return (alpha - math_ops.log(beta) + math_ops.lgamma(alpha) +
|
||||
(1 - alpha) * math_ops.digamma(alpha))
|
Loading…
Reference in New Issue
Block a user