Rename tf.contrib.distributions.Gaussian -> tf.contrib.distributions.Normal

Change: 123260073
This commit is contained in:
Eugene Brevdo 2016-05-25 14:25:35 -08:00 committed by TensorFlower Gardener
parent 8e9f29598a
commit 1db1272f7d
6 changed files with 98 additions and 98 deletions

View File

@ -55,9 +55,9 @@ cuda_py_tests(
)
cuda_py_tests(
name = "gaussian_test",
name = "normal_test",
size = "small",
srcs = ["python/kernel_tests/gaussian_test.py"],
srcs = ["python/kernel_tests/normal_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
@ -98,9 +98,9 @@ cuda_py_tests(
)
cuda_py_tests(
name = "gaussian_conjugate_posteriors_test",
name = "normal_conjugate_posteriors_test",
size = "small",
srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
srcs = ["python/kernel_tests/normal_conjugate_posteriors_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:platform_test",

View File

@ -30,7 +30,7 @@ initialized with parameters that define the distributions.
@@Chi2
@@Exponential
@@Gamma
@@Gaussian
@@Normal
@@StudentT
@@Uniform
@ -44,10 +44,10 @@ initialized with parameters that define the distributions.
Functions that transform conjugate prior/likelihood pairs to distributions
representing the posterior or posterior predictive.
### Gaussian likelihood with conjugate prior.
### Normal likelihood with conjugate prior.
@@gaussian_conjugates_known_sigma_posterior
@@gaussian_congugates_known_sigma_predictive
@@normal_conjugates_known_sigma_posterior
@@normal_congugates_known_sigma_predictive
"""
from __future__ import absolute_import
from __future__ import division
@ -60,8 +60,8 @@ from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.distribution import *
from tensorflow.contrib.distributions.python.ops.exponential import *
from tensorflow.contrib.distributions.python.ops.gamma import *
from tensorflow.contrib.distributions.python.ops.gaussian import *
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.mvn import *
from tensorflow.contrib.distributions.python.ops.normal import *
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.student_t import *
from tensorflow.contrib.distributions.python.ops.uniform import *

View File

@ -25,9 +25,9 @@ import tensorflow as tf
distributions = tf.contrib.distributions
class GaussianTest(tf.test.TestCase):
class NormalTest(tf.test.TestCase):
def testGaussianConjugateKnownSigmaPosterior(self):
def testNormalConjugateKnownSigmaPosterior(self):
with tf.Session():
mu0 = tf.constant([3.0])
sigma0 = tf.constant([math.sqrt(10.0)])
@ -35,16 +35,16 @@ class GaussianTest(tf.test.TestCase):
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
s = tf.reduce_sum(x)
n = tf.size(x)
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
posterior = distributions.gaussian_conjugates_known_sigma_posterior(
prior = distributions.Normal(mu=mu0, sigma=sigma0)
posterior = distributions.normal_conjugates_known_sigma_posterior(
prior=prior, sigma=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(posterior, distributions.Gaussian))
self.assertTrue(isinstance(posterior, distributions.Normal))
posterior_log_pdf = posterior.log_pdf(x).eval()
self.assertEqual(posterior_log_pdf.shape, (6,))
def testGaussianConjugateKnownSigmaPosteriorND(self):
def testNormalConjugateKnownSigmaPosteriorND(self):
with tf.Session():
batch_size = 6
mu0 = tf.constant([[3.0, -3.0]] * batch_size)
@ -54,16 +54,16 @@ class GaussianTest(tf.test.TestCase):
tf.constant([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=tf.float32))
s = tf.reduce_sum(x)
n = tf.size(x)
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
posterior = distributions.gaussian_conjugates_known_sigma_posterior(
prior = distributions.Normal(mu=mu0, sigma=sigma0)
posterior = distributions.normal_conjugates_known_sigma_posterior(
prior=prior, sigma=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(posterior, distributions.Gaussian))
self.assertTrue(isinstance(posterior, distributions.Normal))
posterior_log_pdf = posterior.log_pdf(x).eval()
self.assertEqual(posterior_log_pdf.shape, (6, 2))
def testGaussianConjugateKnownSigmaNDPosteriorND(self):
def testNormalConjugateKnownSigmaNDPosteriorND(self):
with tf.Session():
batch_size = 6
mu0 = tf.constant([[3.0, -3.0]] * batch_size)
@ -75,19 +75,19 @@ class GaussianTest(tf.test.TestCase):
s = tf.reduce_sum(x, reduction_indices=[1])
x = tf.transpose(x) # Reshape to shape (6, 2)
n = tf.constant([6] * 2)
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
posterior = distributions.gaussian_conjugates_known_sigma_posterior(
prior = distributions.Normal(mu=mu0, sigma=sigma0)
posterior = distributions.normal_conjugates_known_sigma_posterior(
prior=prior, sigma=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(posterior, distributions.Gaussian))
self.assertTrue(isinstance(posterior, distributions.Normal))
# Calculate log_pdf under the 2 models
posterior_log_pdf = posterior.log_pdf(x)
self.assertEqual(posterior_log_pdf.get_shape(), (6, 2))
self.assertEqual(posterior_log_pdf.eval().shape, (6, 2))
def testGaussianConjugateKnownSigmaPredictive(self):
def testNormalConjugateKnownSigmaPredictive(self):
with tf.Session():
batch_size = 6
mu0 = tf.constant([3.0] * batch_size)
@ -96,12 +96,12 @@ class GaussianTest(tf.test.TestCase):
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
s = tf.reduce_sum(x)
n = tf.size(x)
prior = distributions.Gaussian(mu=mu0, sigma=sigma0)
predictive = distributions.gaussian_congugates_known_sigma_predictive(
prior = distributions.Normal(mu=mu0, sigma=sigma0)
predictive = distributions.normal_congugates_known_sigma_predictive(
prior=prior, sigma=sigma, s=s, n=n)
# Smoke test
self.assertTrue(isinstance(predictive, distributions.Gaussian))
self.assertTrue(isinstance(predictive, distributions.Normal))
predictive_log_pdf = predictive.log_pdf(x).eval()
self.assertEqual(predictive_log_pdf.shape, (6,))

View File

@ -24,9 +24,9 @@ import numpy as np
import tensorflow as tf
class GaussianTest(tf.test.TestCase):
class NormalTest(tf.test.TestCase):
def testGaussianLogPDF(self):
def testNormalLogPDF(self):
with tf.Session():
batch_size = 6
mu = tf.constant([3.0] * batch_size)
@ -34,18 +34,18 @@ class GaussianTest(tf.test.TestCase):
mu_v = 3.0
sigma_v = np.sqrt(10.0)
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
expected_log_pdf = np.log(
1 / np.sqrt(2 * np.pi) / sigma_v
* np.exp(-1.0 / (2 * sigma_v**2) * (x - mu_v)**2))
log_pdf = gaussian.log_pdf(x)
log_pdf = normal.log_pdf(x)
self.assertAllClose(expected_log_pdf, log_pdf.eval())
pdf = gaussian.pdf(x)
pdf = normal.pdf(x)
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
def testGaussianLogPDFMultidimensional(self):
def testNormalLogPDFMultidimensional(self):
with tf.Session():
batch_size = 6
mu = tf.constant([[3.0, -3.0]] * batch_size)
@ -53,22 +53,22 @@ class GaussianTest(tf.test.TestCase):
mu_v = np.array([3.0, -3.0])
sigma_v = np.array([np.sqrt(10.0), np.sqrt(15.0)])
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
expected_log_pdf = np.log(
1 / np.sqrt(2 * np.pi) / sigma_v
* np.exp(-1.0 / (2 * sigma_v**2) * (x - mu_v)**2))
log_pdf = gaussian.log_pdf(x)
log_pdf = normal.log_pdf(x)
log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(expected_log_pdf, log_pdf_values)
pdf = gaussian.pdf(x)
pdf = normal.pdf(x)
pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2))
self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testGaussianCDF(self):
def testNormalCDF(self):
with tf.Session():
batch_size = 6
mu = tf.constant([3.0] * batch_size)
@ -77,40 +77,40 @@ class GaussianTest(tf.test.TestCase):
sigma_v = np.sqrt(10.0)
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
erf_fn = np.vectorize(math.erf)
# From Wikipedia
expected_cdf = 0.5 * (1.0 + erf_fn((x - mu_v)/(sigma_v*np.sqrt(2))))
cdf = gaussian.cdf(x)
cdf = normal.cdf(x)
self.assertAllClose(expected_cdf, cdf.eval())
def testGaussianEntropy(self):
def testNormalEntropy(self):
with tf.Session():
mu_v = np.array([1.0, 1.0, 1.0])
sigma_v = np.array([[1.0, 2.0, 3.0]]).T
gaussian = tf.contrib.distributions.Gaussian(mu=mu_v, sigma=sigma_v)
normal = tf.contrib.distributions.Normal(mu=mu_v, sigma=sigma_v)
sigma_broadcast = mu_v * sigma_v
expected_entropy = 0.5 * np.log(2*np.pi*np.exp(1)*sigma_broadcast**2)
self.assertAllClose(expected_entropy, gaussian.entropy().eval())
self.assertAllClose(expected_entropy, normal.entropy().eval())
def testGaussianSample(self):
def testNormalSample(self):
with tf.Session():
mu = tf.constant(3.0)
sigma = tf.constant(math.sqrt(10.0))
mu_v = 3.0
sigma_v = np.sqrt(10.0)
n = tf.constant(100000)
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
samples = gaussian.sample(n, seed=137)
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
samples = normal.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(sample_values.shape, (100000,))
self.assertAllClose(sample_values.mean(), mu_v, atol=1e-2)
self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
def testGaussianSampleMultiDimensional(self):
def testNormalSampleMultiDimensional(self):
with tf.Session():
batch_size = 2
mu = tf.constant([[3.0, -3.0]] * batch_size)
@ -118,8 +118,8 @@ class GaussianTest(tf.test.TestCase):
mu_v = [3.0, -3.0]
sigma_v = [np.sqrt(10.0), np.sqrt(15.0)]
n = tf.constant(100000)
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma)
samples = gaussian.sample(n, seed=137)
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
samples = normal.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-2)
@ -129,13 +129,13 @@ class GaussianTest(tf.test.TestCase):
def testNegativeSigmaFails(self):
with tf.Session():
gaussian = tf.contrib.distributions.Gaussian(
normal = tf.contrib.distributions.Normal(
mu=[1.],
sigma=[-5.],
name='G')
with self.assertRaisesOpError(
r'should contain only positive values'):
gaussian.mean.eval()
normal.mean.eval()
if __name__ == '__main__':
tf.test.main()

View File

@ -38,8 +38,8 @@ def _assert_all_positive(x):
["Tensor %s should contain only positive values: " % x.name, x])
class Gaussian(object):
"""The scalar Gaussian distribution with mean and stddev parameters mu, sigma.
class Normal(object):
"""The scalar Normal distribution with mean and stddev parameters mu, sigma.
#### Mathematical details
@ -52,15 +52,15 @@ class Gaussian(object):
Examples of initialization of one or a batch of distributions.
```python
# Define a single scalar Gaussian distribution.
dist = tf.contrib.distributions.Gaussian(mu=0, sigma=3)
# Define a single scalar Normal distribution.
dist = tf.contrib.distributions.Normal(mu=0, sigma=3)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1)
# Define a batch of two scalar valued Gaussians.
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
dist = tf.contrib.distributions.Gaussian(mu=[1, 2.], sigma=[11, 22.])
dist = tf.contrib.distributions.Normal(mu=[1, 2.], sigma=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@ -73,9 +73,9 @@ class Gaussian(object):
Arguments are broadcast when possible.
```python
# Define a batch of two scalar valued Gaussians.
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
dist = tf.contrib.distributions.Gaussian(mu=1, sigma=[11, 22.])
dist = tf.contrib.distributions.Normal(mu=1, sigma=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
@ -85,7 +85,7 @@ class Gaussian(object):
"""
def __init__(self, mu, sigma, name=None):
"""Construct Gaussian distributions with mean and stddev `mu` and `sigma`.
"""Construct Normal distributions with mean and stddev `mu` and `sigma`.
The parameters `mu` and `sigma` must be shaped in a way that supports
broadcasting (e.g. `mu + sigma` is a valid operation).
@ -99,7 +99,7 @@ class Gaussian(object):
Raises:
TypeError: if mu and sigma are different dtypes.
"""
with ops.op_scope([mu, sigma], name, "Gaussian"):
with ops.op_scope([mu, sigma], name, "Normal"):
mu = ops.convert_to_tensor(mu)
sigma = ops.convert_to_tensor(sigma)
with ops.control_dependencies([_assert_all_positive(sigma)]):
@ -125,7 +125,7 @@ class Gaussian(object):
return self._mu * array_ops.ones_like(self._sigma)
def log_pdf(self, x, name=None):
"""Log pdf of observations in `x` under these Gaussian distribution(s).
"""Log pdf of observations in `x` under these Normal distribution(s).
Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
@ -134,7 +134,7 @@ class Gaussian(object):
Returns:
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
"""
with ops.op_scope([self._mu, self._sigma, x], name, "GaussianLogPdf"):
with ops.op_scope([self._mu, self._sigma, x], name, "NormalLogPdf"):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s"
@ -144,7 +144,7 @@ class Gaussian(object):
-0.5*math_ops.square((x - self._mu) / self._sigma))
def cdf(self, x, name=None):
"""CDF of observations in `x` under these Gaussian distribution(s).
"""CDF of observations in `x` under these Normal distribution(s).
Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
@ -153,7 +153,7 @@ class Gaussian(object):
Returns:
cdf: tensor of dtype `dtype`, the CDFs of `x`.
"""
with ops.op_scope([self._mu, self._sigma, x], name, "GaussianCdf"):
with ops.op_scope([self._mu, self._sigma, x], name, "NormalCdf"):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s"
@ -162,7 +162,7 @@ class Gaussian(object):
1.0/(math.sqrt(2.0) * self._sigma)*(x - self._mu)))
def log_cdf(self, x, name=None):
"""Log CDF of observations `x` under these Gaussian distribution(s).
"""Log CDF of observations `x` under these Normal distribution(s).
Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
@ -171,11 +171,11 @@ class Gaussian(object):
Returns:
log_cdf: tensor of dtype `dtype`, the log-CDFs of `x`.
"""
with ops.op_scope([self._mu, self._sigma, x], name, "GaussianLogCdf"):
with ops.op_scope([self._mu, self._sigma, x], name, "NormalLogCdf"):
return math_ops.log(self.cdf(x))
def pdf(self, x, name=None):
"""The PDF of observations in `x` under these Gaussian distribution(s).
"""The PDF of observations in `x` under these Normal distribution(s).
Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
@ -184,11 +184,11 @@ class Gaussian(object):
Returns:
pdf: tensor of dtype `dtype`, the pdf values of `x`.
"""
with ops.op_scope([self._mu, self._sigma, x], name, "GaussianPdf"):
with ops.op_scope([self._mu, self._sigma, x], name, "NormalPdf"):
return math_ops.exp(self.log_pdf(x))
def entropy(self, name=None):
"""The entropy of Gaussian distribution(s).
"""The entropy of Normal distribution(s).
Args:
name: The name to give this op.
@ -196,7 +196,7 @@ class Gaussian(object):
Returns:
entropy: tensor of dtype `dtype`, the entropy.
"""
with ops.op_scope([self._mu, self._sigma], name, "GaussianEntropy"):
with ops.op_scope([self._mu, self._sigma], name, "NormalEntropy"):
two_pi_e1 = constant_op.constant(
2 * math.pi * math.exp(1), dtype=self.dtype)
# Use broadcasting rules to calculate the full broadcast sigma.
@ -204,7 +204,7 @@ class Gaussian(object):
return 0.5 * math_ops.log(two_pi_e1 * math_ops.square(sigma))
def sample(self, n, seed=None, name=None):
"""Sample `n` observations from the Gaussian Distributions.
"""Sample `n` observations from the Normal Distributions.
Args:
n: `Scalar`, type int32, the number of observations to sample.
@ -215,7 +215,7 @@ class Gaussian(object):
samples: `[n, ...]`, a `Tensor` of `n` samples for each
of the distributions determined by broadcasting the hyperparameters.
"""
with ops.op_scope([self._mu, self._sigma, n], name, "GaussianSample"):
with ops.op_scope([self._mu, self._sigma, n], name, "NormalSample"):
broadcast_shape = (self._mu + self._sigma).get_shape()
n = ops.convert_to_tensor(n)
shape = array_ops.concat(

View File

@ -12,32 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Gaussian distribution: conjugate posterior closed form calculations."""
"""The Normal distribution: conjugate posterior closed form calculations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.gaussian import Gaussian # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops.normal import Normal # pylint: disable=line-too-long
from tensorflow.python.ops import math_ops
def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n):
"""Posterior Gaussian distribution with conjugate prior on the mean.
def normal_conjugates_known_sigma_posterior(prior, sigma, s, n):
"""Posterior Normal distribution with conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a
Gaussian with unknown mean `mu` (described by the Gaussian `prior`)
Normal with unknown mean `mu` (described by the Normal `prior`)
and known variance `sigma^2`. The "known sigma posterior" is
the distribution of the unknown `mu`.
Accepts a prior Gaussian distribution object, having parameters
Accepts a prior Normal distribution object, having parameters
`mu0` and `sigma0`, as well as known `sigma` values of the predictive
distribution(s) (also assumed Gaussian),
distribution(s) (also assumed Normal),
and statistical estimates `s` (the sum(s) of the observations) and
`n` (the number(s) of observations).
Returns a posterior (also Gaussian) distribution object, with parameters
Returns a posterior (also Normal) distribution object, with parameters
`(mu', sigma'^2)`, where:
```
@ -50,7 +50,7 @@ def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n):
will broadcast in the case of multidimensional sets of parameters.
Args:
prior: `Gaussian` object of type `dtype`:
prior: `Normal` object of type `dtype`:
the prior distribution having parameters `(mu0, sigma0)`.
sigma: tensor of type `dtype`, taking values `sigma > 0`.
The known stddev parameter(s).
@ -58,15 +58,15 @@ def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n):
n: Tensor of type `int`. The number(s) of observations.
Returns:
A new Gaussian posterior distribution object for the unknown observation
A new Normal posterior distribution object for the unknown observation
mean `mu`.
Raises:
TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
Gaussian object.
Normal object.
"""
if not isinstance(prior, Gaussian):
raise TypeError("Expected prior to be an instance of type Gaussian")
if not isinstance(prior, Normal):
raise TypeError("Expected prior to be an instance of type Normal")
if s.dtype != prior.dtype:
raise TypeError(
@ -77,27 +77,27 @@ def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n):
sigma0_2 = math_ops.square(prior.sigma)
sigma_2 = math_ops.square(sigma)
sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2)
return Gaussian(
return Normal(
mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2,
sigma=math_ops.sqrt(sigmap_2))
def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n):
"""Posterior predictive Gaussian distribution w. conjugate prior on the mean.
def normal_congugates_known_sigma_predictive(prior, sigma, s, n):
"""Posterior predictive Normal distribution w. conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a
Gaussian with unknown mean `mu` (described by the Gaussian `prior`)
Normal with unknown mean `mu` (described by the Normal `prior`)
and known variance `sigma^2`. The "known sigma predictive"
is the distribution of new observations, conditioned on the existing
observations and our prior.
Accepts a prior Gaussian distribution object, having parameters
Accepts a prior Normal distribution object, having parameters
`mu0` and `sigma0`, as well as known `sigma` values of the predictive
distribution(s) (also assumed Gaussian),
distribution(s) (also assumed Normal),
and statistical estimates `s` (the sum(s) of the observations) and
`n` (the number(s) of observations).
Calculates the Gaussian distribution(s) `p(x | sigma^2)`:
Calculates the Normal distribution(s) `p(x | sigma^2)`:
```
p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu
@ -117,7 +117,7 @@ def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n):
will broadcast in the case of multidimensional sets of parameters.
Args:
prior: `Gaussian` object of type `dtype`:
prior: `Normal` object of type `dtype`:
the prior distribution having parameters `(mu0, sigma0)`.
sigma: tensor of type `dtype`, taking values `sigma > 0`.
The known stddev parameter(s).
@ -125,14 +125,14 @@ def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n):
n: Tensor of type `int`. The number(s) of observations.
Returns:
A new Gaussian predictive distribution object.
A new Normal predictive distribution object.
Raises:
TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
Gaussian object.
Normal object.
"""
if not isinstance(prior, Gaussian):
raise TypeError("Expected prior to be an instance of type Gaussian")
if not isinstance(prior, Normal):
raise TypeError("Expected prior to be an instance of type Normal")
if s.dtype != prior.dtype:
raise TypeError(
@ -143,6 +143,6 @@ def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n):
sigma0_2 = math_ops.square(prior.sigma)
sigma_2 = math_ops.square(sigma)
sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2)
return Gaussian(
return Normal(
mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2,
sigma=math_ops.sqrt(sigmap_2 + sigma_2))