diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 3a8c9f2321c..c9cfc922079 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -27,6 +27,33 @@ cuda_py_tests( ], ) +cuda_py_tests( + name = "gamma_test", + srcs = ["python/kernel_tests/gamma_test.py"], + additional_deps = [ + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "chi2_test", + srcs = ["python/kernel_tests/chi2_test.py"], + additional_deps = [ + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "exponential_test", + srcs = ["python/kernel_tests/exponential_test.py"], + additional_deps = [ + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_tests( name = "gaussian_test", size = "small", @@ -65,7 +92,6 @@ cuda_py_tests( srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"], additional_deps = [ ":distributions_py", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 5b4bbac8270..74cedaa251e 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -27,6 +27,9 @@ initialized with parameters that define the distributions. ### Univariate (scalar) distributions +@@Chi2 +@@Exponential +@@Gamma @@Gaussian @@Uniform @@ -50,8 +53,12 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long + +from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * from tensorflow.contrib.distributions.python.ops.distribution import * +from tensorflow.contrib.distributions.python.ops.exponential import * +from tensorflow.contrib.distributions.python.ops.gamma import * from tensorflow.contrib.distributions.python.ops.gaussian import * from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import * from tensorflow.contrib.distributions.python.ops.mvn import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py b/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py new file mode 100644 index 00000000000..84763735637 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py @@ -0,0 +1,85 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats +import tensorflow as tf + + +class Chi2Test(tf.test.TestCase): + + def testChi2LogPDF(self): + with tf.Session(): + batch_size = 6 + df = tf.constant([2.0] * batch_size, dtype=np.float64) + df_v = 2.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64) + chi2 = tf.contrib.distributions.Chi2(df=df) + expected_log_pdf = stats.chi2.logpdf(x, df_v) + + log_pdf = chi2.log_pdf(x) + self.assertEqual(log_pdf.get_shape(), (6,)) + self.assertAllClose(log_pdf.eval(), expected_log_pdf) + + pdf = chi2.pdf(x) + self.assertEqual(pdf.get_shape(), (6,)) + self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) + + def testChi2CDF(self): + with tf.Session(): + batch_size = 6 + df = tf.constant([2.0] * batch_size, dtype=np.float64) + df_v = 2.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64) + + chi2 = tf.contrib.distributions.Chi2(df=df) + expected_cdf = stats.chi2.cdf(x, df_v) + + cdf = chi2.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + self.assertAllClose(cdf.eval(), expected_cdf) + + def testChi2Mean(self): + with tf.Session(): + df_v = np.array([1., 3, 5], dtype=np.float64) + expected_mean = stats.chi2.mean(df_v) + chi2 = tf.contrib.distributions.Chi2(df=df_v) + self.assertEqual(chi2.mean.get_shape(), (3,)) + self.assertAllClose(chi2.mean.eval(), expected_mean) + + def testChi2Variance(self): + with tf.Session(): + df_v = np.array([1., 3, 5], np.float64) + expected_variances = stats.chi2.var(df_v) + chi2 = tf.contrib.distributions.Chi2(df=df_v) + self.assertEqual(chi2.variance.get_shape(), (3,)) + self.assertAllClose(chi2.variance.eval(), expected_variances) + + def testChi2Entropy(self): + with tf.Session(): + df_v = np.array([1., 3, 5], dtype=np.float64) + expected_entropy = stats.chi2.entropy(df_v) + chi2 = tf.contrib.distributions.Chi2(df=df_v) + self.assertEqual(chi2.entropy().get_shape(), (3,)) + self.assertAllClose(chi2.entropy().eval(), expected_entropy) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py b/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py new file mode 100644 index 00000000000..3113034b985 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py @@ -0,0 +1,85 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats +import tensorflow as tf + + +class ExponentialTest(tf.test.TestCase): + + def testExponentialLogPDF(self): + with tf.Session(): + batch_size = 6 + lam = tf.constant([2.0] * batch_size) + lam_v = 2.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + exponential = tf.contrib.distributions.Exponential(lam=lam) + expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v) + + log_pdf = exponential.log_pdf(x) + self.assertEqual(log_pdf.get_shape(), (6,)) + self.assertAllClose(log_pdf.eval(), expected_log_pdf) + + pdf = exponential.pdf(x) + self.assertEqual(pdf.get_shape(), (6,)) + self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) + + def testExponentialCDF(self): + with tf.Session(): + batch_size = 6 + lam = tf.constant([2.0] * batch_size) + lam_v = 2.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + + exponential = tf.contrib.distributions.Exponential(lam=lam) + expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) + + cdf = exponential.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + self.assertAllClose(cdf.eval(), expected_cdf) + + def testExponentialMean(self): + with tf.Session(): + lam_v = np.array([1.0, 4.0, 2.5]) + expected_mean = stats.expon.mean(scale=1 / lam_v) + exponential = tf.contrib.distributions.Exponential(lam=lam_v) + self.assertEqual(exponential.mean.get_shape(), (3,)) + self.assertAllClose(exponential.mean.eval(), expected_mean) + + def testExponentialVariance(self): + with tf.Session(): + lam_v = np.array([1.0, 4.0, 2.5]) + expected_variance = stats.expon.var(scale=1 / lam_v) + exponential = tf.contrib.distributions.Exponential(lam=lam_v) + self.assertEqual(exponential.variance.get_shape(), (3,)) + self.assertAllClose(exponential.variance.eval(), expected_variance) + + def testExponentialEntropy(self): + with tf.Session(): + lam_v = np.array([1.0, 4.0, 2.5]) + expected_entropy = stats.expon.entropy(scale=1 / lam_v) + exponential = tf.contrib.distributions.Exponential(lam=lam_v) + self.assertEqual(exponential.entropy().get_shape(), (3,)) + self.assertAllClose(exponential.entropy().eval(), expected_entropy) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py new file mode 100644 index 00000000000..22f44aeaf46 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py @@ -0,0 +1,142 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats +import tensorflow as tf + + +class GammaTest(tf.test.TestCase): + + def testGammaShape(self): + with tf.Session(): + alpha = tf.constant([3.0] * 5) + beta = tf.constant(11.0) + gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta) + + self.assertEqual(gamma.batch_shape().eval(), (5,)) + self.assertEqual(gamma.get_batch_shape(), tf.TensorShape([5])) + self.assertEqual(gamma.event_shape().eval(), 1) + self.assertEqual(gamma.get_event_shape(), tf.TensorShape([])) + + def testGammaLogPDF(self): + with tf.Session(): + batch_size = 6 + alpha = tf.constant([2.0] * batch_size) + beta = tf.constant([3.0] * batch_size) + alpha_v = 2.0 + beta_v = 3.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta) + expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) + log_pdf = gamma.log_pdf(x) + self.assertEqual(log_pdf.get_shape(), (6,)) + self.assertAllClose(log_pdf.eval(), expected_log_pdf) + + pdf = gamma.pdf(x) + self.assertEqual(pdf.get_shape(), (6,)) + self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) + + def testGammaLogPDFMultidimensional(self): + with tf.Session(): + batch_size = 6 + alpha = tf.constant([[2.0, 4.0]] * batch_size) + beta = tf.constant([[3.0, 4.0]] * batch_size) + alpha_v = np.array([2.0, 4.0]) + beta_v = np.array([3.0, 4.0]) + x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T + gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta) + expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) + log_pdf = gamma.log_pdf(x) + log_pdf_values = log_pdf.eval() + self.assertEqual(log_pdf.get_shape(), (6, 2)) + self.assertAllClose(log_pdf_values, expected_log_pdf) + + pdf = gamma.pdf(x) + pdf_values = pdf.eval() + self.assertEqual(pdf.get_shape(), (6, 2)) + self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) + + def testGammaLogPDFMultidimensionalBroadcasting(self): + with tf.Session(): + batch_size = 6 + alpha = tf.constant([[2.0, 4.0]] * batch_size) + beta = tf.constant(3.0) + alpha_v = np.array([2.0, 4.0]) + beta_v = 3.0 + x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T + gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta) + expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) + log_pdf = gamma.log_pdf(x) + log_pdf_values = log_pdf.eval() + self.assertEqual(log_pdf.get_shape(), (6, 2)) + self.assertAllClose(log_pdf_values, expected_log_pdf) + + pdf = gamma.pdf(x) + pdf_values = pdf.eval() + self.assertEqual(pdf.get_shape(), (6, 2)) + self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) + + def testGammaCDF(self): + with tf.Session(): + batch_size = 6 + alpha = tf.constant([2.0] * batch_size) + beta = tf.constant([3.0] * batch_size) + alpha_v = 2.0 + beta_v = 3.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + + gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta) + expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) + + cdf = gamma.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + self.assertAllClose(cdf.eval(), expected_cdf) + + def testGammaMean(self): + with tf.Session(): + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v) + expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) + self.assertEqual(gamma.mean.get_shape(), (3,)) + self.assertAllClose(gamma.mean.eval(), expected_means) + + def testGammaVariance(self): + with tf.Session(): + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v) + expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) + self.assertEqual(gamma.variance.get_shape(), (3,)) + self.assertAllClose(gamma.variance.eval(), expected_variances) + + def testGammaEntropy(self): + with tf.Session(): + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) + gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v) + self.assertEqual(gamma.entropy().get_shape(), (3,)) + self.assertAllClose(gamma.entropy().eval(), expected_entropy) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py new file mode 100644 index 00000000000..cdcb5620f20 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -0,0 +1,46 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Chi2 distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import gamma +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops + + +class Chi2(gamma.Gamma): + """The Chi2 distribution with degrees of freedom df. + + The PDF of this distribution is: + + ```pdf(x) = (x^(df/2 - 1)e^(-x/2))/(2^(k/2)Gamma(k/2)), x > 0``` + + Note that the Chi2 distribution is a special case of the Gamma distribution, + with Chi2(df) = Gamma(df/2, 1/2). + """ + + def __init__(self, df, name="Chi2"): + with ops.op_scope([df], name, "init"): + df = ops.convert_to_tensor(df) + self._df = df + super(Chi2, self).__init__(alpha=df / 2, + beta=math_ops.cast(0.5, dtype=df.dtype)) + + @property + def df(self): + return self._df diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py new file mode 100644 index 00000000000..4652e6b3ec7 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/exponential.py @@ -0,0 +1,47 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Exponential distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import gamma +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops + + +class Exponential(gamma.Gamma): + """The Exponential distribution with rate parameter lam. + + The PDF of this distribution is: + + ```pdf(x) = (lam * e^(-lam * x)), x > 0``` + + Note that the Exponential distribution is a special case of the Gamma + distribution, with Exponential(lam) = Gamma(1, lam). + """ + + def __init__(self, lam, name="Exponential"): + with ops.op_scope([lam], name, "init"): + lam = ops.convert_to_tensor(lam) + self._lam = lam + super(Exponential, self).__init__( + alpha=math_ops.cast(1.0, dtype=lam.dtype), + beta=lam) + + @property + def lam(self): + return self._lam diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py new file mode 100644 index 00000000000..2c445a3f12d --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/gamma.py @@ -0,0 +1,208 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Gamma distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long +from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import constant_op +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops + + +class Gamma(ContinuousDistribution): + """The `Gamma` distribution with parameter alpha and beta. + + The parameters are the shape and inverse scale parameters alpha, beta. + + The PDF of this distribution is: + + ```pdf(x) = (beta^alpha)(x^(alpha-1))e^(-x*beta)/Gamma(alpha), x > 0``` + + and the CDF of this distribution is: + + ```cdf(x) = GammaInc(alpha, beta * x) / Gamma(alpha), x > 0``` + + where GammaInc is the incomplete lower Gamma function. + + Examples: + + ```python + dist = Gamma(alpha=3.0, beta=2.0) + dist2 = Gamma(alpha=[3.0, 4.0], beta=[2.0, 3.0]) + ``` + + """ + + def __init__(self, alpha, beta, name="Gamma"): + """Construct Gamma distributions with parameters `alpha` and `beta`. + + The parameters `alpha` and `beta` must be shaped in a way that supports + broadcasting (e.g. `alpha + beta` is a valid operation). + + Args: + alpha: `float` or `double` tensor, the shape params of the + distribution(s). + alpha must contain only positive values. + beta: `float` or `double` tensor, the inverse scale params of the + distribution(s). + beta must contain only positive values. + name: The name to give Ops created by the initializer. + + Raises: + TypeError: if `alpha` and `beta` are different dtypes. + """ + with ops.op_scope([alpha, beta], name): + alpha = ops.convert_to_tensor(alpha, name="alpha_before_dependencies") + beta = ops.convert_to_tensor(beta, name="beta_before_dependencies") + contrib_tensor_util.assert_same_float_dtype((alpha, beta)) + with ops.control_dependencies([ + check_ops.assert_positive(alpha), check_ops.assert_positive(beta) + ]): + self._alpha = alpha + self._beta = beta + self._name = name + + with ops.op_scope([self._alpha, self._beta], name, "mean"): + self._mean = self._alpha / self._beta + self._batch_shape = self._mean.get_shape() + + with ops.op_scope([self._alpha, self._beta], name, "variance"): + self._variance = self._alpha / math_ops.square(self._beta) + + self._event_shape = tensor_shape.TensorShape([]) + + @property + def name(self): + return self._name + + @property + def dtype(self): + return self._alpha.dtype + + @property + def alpha(self): + return self._alpha + + @property + def beta(self): + return self._beta + + def batch_shape(self, name="batch_shape"): + with ops.name_scope(self.name): + return array_ops.shape(self._mean, name=name) + + def get_batch_shape(self): + return self._batch_shape + + def event_shape(self, name="event_shape"): + with ops.name_scope(self.name): + return constant_op.constant(1, name=name) + + def get_event_shape(self): + return self._event_shape + + @property + def mean(self): + return self._mean + + @property + def variance(self): + return self._variance + + def log_pdf(self, x, name="log_pdf"): + """Log pdf of observations in `x` under these Gamma distribution(s). + + Args: + x: tensor of dtype `dtype`, must be broadcastable with `alpha` and `beta`. + name: The name to give this op. + + Returns: + log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`. + Raises: + TypeError: if `x` and `alpha` are different dtypes. + """ + with ops.op_scope([self._alpha, self._beta, x], self.name): + with ops.name_scope(name): + alpha = self._alpha + beta = self._beta + x = ops.convert_to_tensor(x) + x = control_flow_ops.with_dependencies( + [check_ops.assert_positive(x)], x) + contrib_tensor_util.assert_same_float_dtype(tensors=[x,], + dtype=self.dtype) + + return (alpha * math_ops.log(beta) + (alpha - 1) * math_ops.log(x) - + beta * x - math_ops.lgamma(self._alpha)) + + def pdf(self, x, name="pdf"): + with ops.name_scope(name): + return math_ops.exp(self.log_pdf(x, name)) + + def log_cdf(self, x, name="log_cdf"): + """Log CDF of observations `x` under these Gamma distribution(s). + + Args: + x: tensor of dtype `dtype`, must be broadcastable with `alpha` and `beta`. + name: The name to give this op. + + Returns: + log_cdf: tensor of dtype `dtype`, the log-CDFs of `x`. + """ + with ops.op_scope([self._alpha, self._beta, x], self.name): + with ops.name_scope(name): + x = ops.convert_to_tensor(x) + x = control_flow_ops.with_dependencies( + [check_ops.assert_positive(x)], x) + contrib_tensor_util.assert_same_float_dtype(tensors=[x,], + dtype=self.dtype) + # Note that igamma returns the regularized incomplete gamma function, + # which is what we want for the CDF. + return math_ops.log(math_ops.igamma(self._alpha, self._beta * x)) + + def cdf(self, x, name="cdf"): + with ops.op_scope([self._alpha, self._beta, x], self.name): + with ops.name_scope(name): + return math_ops.igamma(self._alpha, self._beta * x) + + def entropy(self, name="entropy"): + """The entropy of Gamma distribution(s). + + This is defined to be + + ```entropy = alpha - log(beta) + log(Gamma(alpha)) + + (1-alpha)digamma(alpha)``` + + where digamma(alpha) is the digamma function. + + Args: + name: The name to give this op. + + Returns: + entropy: tensor of dtype `dtype`, the entropy. + """ + with ops.op_scope([self.alpha, self._beta], self.name): + with ops.name_scope(name): + alpha = self._alpha + beta = self._beta + return (alpha - math_ops.log(beta) + math_ops.lgamma(alpha) + + (1 - alpha) * math_ops.digamma(alpha))