BREAKING CHANGE: Use Covariance
where appropriate.
BUGFIX: Fix broadcasting in DirichletMultinomial.mean. Change: 145851311
This commit is contained in:
parent
ca39ed7c07
commit
333b7ce729
@ -207,7 +207,49 @@ class DirichletMultinomialTest(test.TestCase):
|
|||||||
self.assertAllClose(mean2[class_num], 2 * mean1[class_num])
|
self.assertAllClose(mean2[class_num], 2 * mean1[class_num])
|
||||||
self.assertTupleEqual((3,), mean1.shape)
|
self.assertTupleEqual((3,), mean1.shape)
|
||||||
|
|
||||||
def testVariance(self):
|
def testCovarianceFromSampling(self):
|
||||||
|
# We will test mean, cov, var, stddev on a DirichletMultinomial constructed
|
||||||
|
# via broadcast between alpha, n.
|
||||||
|
alpha = np.array([[1., 2, 3],
|
||||||
|
[2.5, 4, 0.01]], dtype=np.float32)
|
||||||
|
# Ideally we'd be able to test broadcasting but, the multinomial sampler
|
||||||
|
# doesn't support different total counts.
|
||||||
|
n = np.float32(5)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
# batch_shape=[2], event_shape=[3]
|
||||||
|
dist = ds.DirichletMultinomial(n, alpha)
|
||||||
|
x = dist.sample(int(250e3), seed=1)
|
||||||
|
sample_mean = math_ops.reduce_mean(x, 0)
|
||||||
|
x_centered = x - sample_mean[None, ...]
|
||||||
|
sample_cov = math_ops.reduce_mean(math_ops.matmul(
|
||||||
|
x_centered[..., None], x_centered[..., None, :]), 0)
|
||||||
|
sample_var = array_ops.matrix_diag_part(sample_cov)
|
||||||
|
sample_stddev = math_ops.sqrt(sample_var)
|
||||||
|
[
|
||||||
|
sample_mean_,
|
||||||
|
sample_cov_,
|
||||||
|
sample_var_,
|
||||||
|
sample_stddev_,
|
||||||
|
analytic_mean,
|
||||||
|
analytic_cov,
|
||||||
|
analytic_var,
|
||||||
|
analytic_stddev,
|
||||||
|
] = sess.run([
|
||||||
|
sample_mean,
|
||||||
|
sample_cov,
|
||||||
|
sample_var,
|
||||||
|
sample_stddev,
|
||||||
|
dist.mean(),
|
||||||
|
dist.covariance(),
|
||||||
|
dist.variance(),
|
||||||
|
dist.stddev(),
|
||||||
|
])
|
||||||
|
self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04)
|
||||||
|
self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.05)
|
||||||
|
self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
|
||||||
|
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
|
||||||
|
|
||||||
|
def testCovariance(self):
|
||||||
# Shape [2]
|
# Shape [2]
|
||||||
alpha = [1., 2]
|
alpha = [1., 2]
|
||||||
ns = [2., 3., 4., 5.]
|
ns = [2., 3., 4., 5.]
|
||||||
@ -234,13 +276,13 @@ class DirichletMultinomialTest(test.TestCase):
|
|||||||
for n in ns:
|
for n in ns:
|
||||||
# n is shape [] and alpha is shape [2].
|
# n is shape [] and alpha is shape [2].
|
||||||
dist = ds.DirichletMultinomial(n, alpha)
|
dist = ds.DirichletMultinomial(n, alpha)
|
||||||
variance = dist.variance()
|
covariance = dist.covariance()
|
||||||
expected_variance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix
|
expected_covariance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix
|
||||||
|
|
||||||
self.assertEqual((2, 2), variance.get_shape())
|
self.assertEqual((2, 2), covariance.get_shape())
|
||||||
self.assertAllClose(expected_variance, variance.eval())
|
self.assertAllClose(expected_covariance, covariance.eval())
|
||||||
|
|
||||||
def testVarianceNAlphaBroadcast(self):
|
def testCovarianceNAlphaBroadcast(self):
|
||||||
alpha_v = [1., 2, 3]
|
alpha_v = [1., 2, 3]
|
||||||
alpha_0 = 6.
|
alpha_0 = 6.
|
||||||
|
|
||||||
@ -271,14 +313,14 @@ class DirichletMultinomialTest(test.TestCase):
|
|||||||
with self.test_session():
|
with self.test_session():
|
||||||
# ns is shape [4, 1], and alpha is shape [4, 3].
|
# ns is shape [4, 1], and alpha is shape [4, 3].
|
||||||
dist = ds.DirichletMultinomial(ns, alpha)
|
dist = ds.DirichletMultinomial(ns, alpha)
|
||||||
variance = dist.variance()
|
covariance = dist.covariance()
|
||||||
expected_variance = np.expand_dims(ns * (ns + alpha_0) / (1 + alpha_0),
|
expected_covariance = shared_matrix * (
|
||||||
-1) * shared_matrix
|
ns * (ns + alpha_0) / (1 + alpha_0))[..., None]
|
||||||
|
|
||||||
self.assertEqual((4, 3, 3), variance.get_shape())
|
self.assertEqual([4, 3, 3], covariance.get_shape())
|
||||||
self.assertAllClose(expected_variance, variance.eval())
|
self.assertAllClose(expected_covariance, covariance.eval())
|
||||||
|
|
||||||
def testVarianceMultidimensional(self):
|
def testCovarianceMultidimensional(self):
|
||||||
alpha = np.random.rand(3, 5, 4).astype(np.float32)
|
alpha = np.random.rand(3, 5, 4).astype(np.float32)
|
||||||
alpha2 = np.random.rand(6, 3, 3).astype(np.float32)
|
alpha2 = np.random.rand(6, 3, 3).astype(np.float32)
|
||||||
|
|
||||||
@ -289,10 +331,10 @@ class DirichletMultinomialTest(test.TestCase):
|
|||||||
dist = ds.DirichletMultinomial(ns, alpha)
|
dist = ds.DirichletMultinomial(ns, alpha)
|
||||||
dist2 = ds.DirichletMultinomial(ns2, alpha2)
|
dist2 = ds.DirichletMultinomial(ns2, alpha2)
|
||||||
|
|
||||||
variance = dist.variance()
|
covariance = dist.covariance()
|
||||||
variance2 = dist2.variance()
|
covariance2 = dist2.covariance()
|
||||||
self.assertEqual((3, 5, 4, 4), variance.get_shape())
|
self.assertEqual([3, 5, 4, 4], covariance.get_shape())
|
||||||
self.assertEqual((6, 3, 3, 3), variance2.get_shape())
|
self.assertEqual([6, 3, 3, 3], covariance2.get_shape())
|
||||||
|
|
||||||
def testZeroCountsResultsInPmfEqualToOne(self):
|
def testZeroCountsResultsInPmfEqualToOne(self):
|
||||||
# There is only one way for zero items to be selected, and this happens with
|
# There is only one way for zero items to be selected, and this happens with
|
||||||
@ -390,7 +432,7 @@ class DirichletMultinomialTest(test.TestCase):
|
|||||||
sample_mean,
|
sample_mean,
|
||||||
sample_covariance,
|
sample_covariance,
|
||||||
dist.mean(),
|
dist.mean(),
|
||||||
dist.variance(),
|
dist.covariance(),
|
||||||
])
|
])
|
||||||
self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
|
self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
|
||||||
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.15)
|
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.15)
|
||||||
@ -417,7 +459,7 @@ class DirichletMultinomialTest(test.TestCase):
|
|||||||
sample_mean,
|
sample_mean,
|
||||||
sample_covariance,
|
sample_covariance,
|
||||||
dist.mean(),
|
dist.mean(),
|
||||||
dist.variance(),
|
dist.covariance(),
|
||||||
])
|
])
|
||||||
self.assertAllEqual([4], sample_mean.get_shape())
|
self.assertAllEqual([4], sample_mean.get_shape())
|
||||||
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.05)
|
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.05)
|
||||||
|
@ -21,6 +21,8 @@ from scipy import stats
|
|||||||
from tensorflow.contrib.distributions.python.ops import dirichlet as dirichlet_lib
|
from tensorflow.contrib.distributions.python.ops import dirichlet as dirichlet_lib
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -134,16 +136,52 @@ class DirichletTest(test.TestCase):
|
|||||||
self.assertEqual(dirichlet.mean().get_shape(), (3,))
|
self.assertEqual(dirichlet.mean().get_shape(), (3,))
|
||||||
self.assertAllClose(dirichlet.mean().eval(), expected_mean)
|
self.assertAllClose(dirichlet.mean().eval(), expected_mean)
|
||||||
|
|
||||||
def testDirichletVariance(self):
|
def testCovarianceFromSampling(self):
|
||||||
|
alpha = np.array([[1., 2, 3],
|
||||||
|
[2.5, 4, 0.01]], dtype=np.float32)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
dist = dirichlet_lib.Dirichlet(alpha) # batch_shape=[2], event_shape=[3]
|
||||||
|
x = dist.sample(int(250e3), seed=1)
|
||||||
|
sample_mean = math_ops.reduce_mean(x, 0)
|
||||||
|
x_centered = x - sample_mean[None, ...]
|
||||||
|
sample_cov = math_ops.reduce_mean(math_ops.matmul(
|
||||||
|
x_centered[..., None], x_centered[..., None, :]), 0)
|
||||||
|
sample_var = array_ops.matrix_diag_part(sample_cov)
|
||||||
|
sample_stddev = math_ops.sqrt(sample_var)
|
||||||
|
[
|
||||||
|
sample_mean_,
|
||||||
|
sample_cov_,
|
||||||
|
sample_var_,
|
||||||
|
sample_stddev_,
|
||||||
|
analytic_mean,
|
||||||
|
analytic_cov,
|
||||||
|
analytic_var,
|
||||||
|
analytic_stddev,
|
||||||
|
] = sess.run([
|
||||||
|
sample_mean,
|
||||||
|
sample_cov,
|
||||||
|
sample_var,
|
||||||
|
sample_stddev,
|
||||||
|
dist.mean(),
|
||||||
|
dist.covariance(),
|
||||||
|
dist.variance(),
|
||||||
|
dist.stddev(),
|
||||||
|
])
|
||||||
|
self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04)
|
||||||
|
self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.06)
|
||||||
|
self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
|
||||||
|
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
|
||||||
|
|
||||||
|
def testDirichletCovariance(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
alpha = [1., 2, 3]
|
alpha = [1., 2, 3]
|
||||||
denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
|
denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
|
||||||
expected_variance = np.diag(stats.dirichlet.var(alpha))
|
expected_covariance = np.diag(stats.dirichlet.var(alpha))
|
||||||
expected_variance += [[0., -2, -3], [-2, 0, -6],
|
expected_covariance += [[0., -2, -3], [-2, 0, -6],
|
||||||
[-3, -6, 0]] / denominator
|
[-3, -6, 0]] / denominator
|
||||||
dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
|
dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
|
||||||
self.assertEqual(dirichlet.variance().get_shape(), (3, 3))
|
self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
|
||||||
self.assertAllClose(dirichlet.variance().eval(), expected_variance)
|
self.assertAllClose(dirichlet.covariance().eval(), expected_covariance)
|
||||||
|
|
||||||
def testDirichletMode(self):
|
def testDirichletMode(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
@ -191,18 +191,18 @@ class MultinomialTest(test.TestCase):
|
|||||||
self.assertEqual((3,), dist.mean().get_shape())
|
self.assertEqual((3,), dist.mean().get_shape())
|
||||||
self.assertAllClose(expected_means, dist.mean().eval())
|
self.assertAllClose(expected_means, dist.mean().eval())
|
||||||
|
|
||||||
def testMultinomialVariance(self):
|
def testMultinomialCovariance(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
n = 5.
|
n = 5.
|
||||||
p = [0.1, 0.2, 0.7]
|
p = [0.1, 0.2, 0.7]
|
||||||
dist = ds.Multinomial(total_count=n, probs=p)
|
dist = ds.Multinomial(total_count=n, probs=p)
|
||||||
expected_variances = [[9. / 20, -1 / 10, -7 / 20],
|
expected_covariances = [[9. / 20, -1 / 10, -7 / 20],
|
||||||
[-1 / 10, 4 / 5, -7 / 10],
|
[-1 / 10, 4 / 5, -7 / 10],
|
||||||
[-7 / 20, -7 / 10, 21 / 20]]
|
[-7 / 20, -7 / 10, 21 / 20]]
|
||||||
self.assertEqual((3, 3), dist.variance().get_shape())
|
self.assertEqual((3, 3), dist.covariance().get_shape())
|
||||||
self.assertAllClose(expected_variances, dist.variance().eval())
|
self.assertAllClose(expected_covariances, dist.covariance().eval())
|
||||||
|
|
||||||
def testMultinomialVarianceBatch(self):
|
def testMultinomialCovarianceBatch(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
# Shape [2]
|
# Shape [2]
|
||||||
n = [5.] * 2
|
n = [5.] * 2
|
||||||
@ -212,11 +212,11 @@ class MultinomialTest(test.TestCase):
|
|||||||
# Shape [2, 2]
|
# Shape [2, 2]
|
||||||
inner_var = [[9. / 20, -9 / 20], [-9 / 20, 9 / 20]]
|
inner_var = [[9. / 20, -9 / 20], [-9 / 20, 9 / 20]]
|
||||||
# Shape [4, 2, 2, 2]
|
# Shape [4, 2, 2, 2]
|
||||||
expected_variances = [[inner_var, inner_var]] * 4
|
expected_covariances = [[inner_var, inner_var]] * 4
|
||||||
self.assertEqual((4, 2, 2, 2), dist.variance().get_shape())
|
self.assertEqual((4, 2, 2, 2), dist.covariance().get_shape())
|
||||||
self.assertAllClose(expected_variances, dist.variance().eval())
|
self.assertAllClose(expected_covariances, dist.covariance().eval())
|
||||||
|
|
||||||
def testVarianceMultidimensional(self):
|
def testCovarianceMultidimensional(self):
|
||||||
# Shape [3, 5, 4]
|
# Shape [3, 5, 4]
|
||||||
p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
|
p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
|
||||||
# Shape [6, 3, 3]
|
# Shape [6, 3, 3]
|
||||||
@ -229,10 +229,52 @@ class MultinomialTest(test.TestCase):
|
|||||||
dist = ds.Multinomial(ns, p)
|
dist = ds.Multinomial(ns, p)
|
||||||
dist2 = ds.Multinomial(ns2, p2)
|
dist2 = ds.Multinomial(ns2, p2)
|
||||||
|
|
||||||
variance = dist.variance()
|
covariance = dist.covariance()
|
||||||
variance2 = dist2.variance()
|
covariance2 = dist2.covariance()
|
||||||
self.assertEqual((3, 5, 4, 4), variance.get_shape())
|
self.assertEqual((3, 5, 4, 4), covariance.get_shape())
|
||||||
self.assertEqual((6, 3, 3, 3), variance2.get_shape())
|
self.assertEqual((6, 3, 3, 3), covariance2.get_shape())
|
||||||
|
|
||||||
|
def testCovarianceFromSampling(self):
|
||||||
|
# We will test mean, cov, var, stddev on a DirichletMultinomial constructed
|
||||||
|
# via broadcast between alpha, n.
|
||||||
|
theta = np.array([[1., 2, 3],
|
||||||
|
[2.5, 4, 0.01]], dtype=np.float32)
|
||||||
|
theta /= np.sum(theta, 1)[..., None]
|
||||||
|
# Ideally we'd be able to test broadcasting but, the multinomial sampler
|
||||||
|
# doesn't support different total counts.
|
||||||
|
n = np.float32(5)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
dist = ds.Multinomial(n, theta) # batch_shape=[2], event_shape=[3]
|
||||||
|
x = dist.sample(int(250e3), seed=1)
|
||||||
|
sample_mean = math_ops.reduce_mean(x, 0)
|
||||||
|
x_centered = x - sample_mean[None, ...]
|
||||||
|
sample_cov = math_ops.reduce_mean(math_ops.matmul(
|
||||||
|
x_centered[..., None], x_centered[..., None, :]), 0)
|
||||||
|
sample_var = array_ops.matrix_diag_part(sample_cov)
|
||||||
|
sample_stddev = math_ops.sqrt(sample_var)
|
||||||
|
[
|
||||||
|
sample_mean_,
|
||||||
|
sample_cov_,
|
||||||
|
sample_var_,
|
||||||
|
sample_stddev_,
|
||||||
|
analytic_mean,
|
||||||
|
analytic_cov,
|
||||||
|
analytic_var,
|
||||||
|
analytic_stddev,
|
||||||
|
] = sess.run([
|
||||||
|
sample_mean,
|
||||||
|
sample_cov,
|
||||||
|
sample_var,
|
||||||
|
sample_stddev,
|
||||||
|
dist.mean(),
|
||||||
|
dist.covariance(),
|
||||||
|
dist.variance(),
|
||||||
|
dist.stddev(),
|
||||||
|
])
|
||||||
|
self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.01)
|
||||||
|
self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.01)
|
||||||
|
self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.01)
|
||||||
|
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.01)
|
||||||
|
|
||||||
def testSampleUnbiasedNonScalarBatch(self):
|
def testSampleUnbiasedNonScalarBatch(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -255,7 +297,7 @@ class MultinomialTest(test.TestCase):
|
|||||||
sample_mean,
|
sample_mean,
|
||||||
sample_covariance,
|
sample_covariance,
|
||||||
dist.mean(),
|
dist.mean(),
|
||||||
dist.variance(),
|
dist.covariance(),
|
||||||
])
|
])
|
||||||
self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
|
self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
|
||||||
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07)
|
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07)
|
||||||
@ -283,7 +325,7 @@ class MultinomialTest(test.TestCase):
|
|||||||
sample_mean,
|
sample_mean,
|
||||||
sample_covariance,
|
sample_covariance,
|
||||||
dist.mean(),
|
dist.mean(),
|
||||||
dist.variance(),
|
dist.covariance(),
|
||||||
])
|
])
|
||||||
self.assertAllEqual([4], sample_mean.get_shape())
|
self.assertAllEqual([4], sample_mean.get_shape())
|
||||||
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07)
|
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07)
|
||||||
|
@ -211,18 +211,22 @@ class Dirichlet(distribution.Distribution):
|
|||||||
return entropy
|
return entropy
|
||||||
|
|
||||||
def _mean(self):
|
def _mean(self):
|
||||||
return self.alpha / array_ops.expand_dims(self.alpha_sum, -1)
|
return self.alpha / self.alpha_sum[..., None]
|
||||||
|
|
||||||
|
def _covariance(self):
|
||||||
|
x = self._variance_scale_term() * self._mean()
|
||||||
|
return array_ops.matrix_set_diag(
|
||||||
|
-math_ops.matmul(x[..., None], x[..., None, :]), # outer prod
|
||||||
|
self._variance())
|
||||||
|
|
||||||
def _variance(self):
|
def _variance(self):
|
||||||
scale = self.alpha_sum * math_ops.sqrt(1. + self.alpha_sum)
|
scale = self._variance_scale_term()
|
||||||
alpha = self.alpha / scale
|
x = scale * self._mean()
|
||||||
outer_prod = -math_ops.matmul(
|
return x * (scale - x)
|
||||||
array_ops.expand_dims(
|
|
||||||
alpha, dim=-1), # column
|
def _variance_scale_term(self):
|
||||||
array_ops.expand_dims(
|
"""Helper to `_covariance` and `_variance` which computes a shared scale."""
|
||||||
alpha, dim=-2)) # row
|
return math_ops.rsqrt(1. + self.alpha_sum[..., None])
|
||||||
return array_ops.matrix_set_diag(outer_prod,
|
|
||||||
alpha * (self.alpha_sum / scale - alpha))
|
|
||||||
|
|
||||||
@distribution_util.AppendDocstring(
|
@distribution_util.AppendDocstring(
|
||||||
"""Note that the mode for the Dirichlet distribution is only defined
|
"""Note that the mode for the Dirichlet distribution is only defined
|
||||||
|
@ -252,11 +252,10 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
return math_ops.exp(self._log_prob(counts))
|
return math_ops.exp(self._log_prob(counts))
|
||||||
|
|
||||||
def _mean(self):
|
def _mean(self):
|
||||||
normalized_alpha = self.alpha / array_ops.expand_dims(self.alpha_sum, -1)
|
return self.n * (self.alpha / self.alpha_sum[..., None])
|
||||||
return self.n[..., None] * normalized_alpha
|
|
||||||
|
|
||||||
@distribution_util.AppendDocstring(
|
@distribution_util.AppendDocstring(
|
||||||
"""The variance for each batch member is defined as the following:
|
"""The covariance for each batch member is defined as the following:
|
||||||
|
|
||||||
```
|
```
|
||||||
Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
|
Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
|
||||||
@ -272,18 +271,23 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
(n + alpha_0) / (1 + alpha_0)
|
(n + alpha_0) / (1 + alpha_0)
|
||||||
```
|
```
|
||||||
""")
|
""")
|
||||||
|
def _covariance(self):
|
||||||
|
x = self._variance_scale_term() * self._mean()
|
||||||
|
return array_ops.matrix_set_diag(
|
||||||
|
-math_ops.matmul(x[..., None], x[..., None, :]), # outer prod
|
||||||
|
self._variance())
|
||||||
|
|
||||||
def _variance(self):
|
def _variance(self):
|
||||||
alpha_sum = array_ops.expand_dims(self.alpha_sum, -1)
|
scale = self._variance_scale_term()
|
||||||
normalized_alpha = self.alpha / alpha_sum
|
x = scale * self._mean()
|
||||||
variance = -math_ops.matmul(
|
return x * (self.n * scale - x)
|
||||||
array_ops.expand_dims(normalized_alpha, -1),
|
|
||||||
array_ops.expand_dims(normalized_alpha, -2))
|
def _variance_scale_term(self):
|
||||||
variance = array_ops.matrix_set_diag(variance, normalized_alpha *
|
"""Helper to `_covariance` and `_variance` which computes a shared scale."""
|
||||||
(1. - normalized_alpha))
|
# We must take care to expand back the last dim whenever we use the
|
||||||
shared_factor = (self.n * (alpha_sum + self.n) /
|
# alpha_sum.
|
||||||
(alpha_sum + 1) * array_ops.ones_like(self.alpha))
|
c0 = self.alpha_sum[..., None]
|
||||||
variance *= array_ops.expand_dims(shared_factor, -1)
|
return math_ops.sqrt((1. + c0 / self.n) / (1. + c0))
|
||||||
return variance
|
|
||||||
|
|
||||||
def _assert_valid_counts(self, counts):
|
def _assert_valid_counts(self, counts):
|
||||||
"""Check counts for proper shape, values, then return tensor version."""
|
"""Check counts for proper shape, values, then return tensor version."""
|
||||||
|
@ -40,7 +40,8 @@ from tensorflow.python.ops import math_ops
|
|||||||
_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
|
_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
|
||||||
"batch_shape", "get_batch_shape", "event_shape", "get_event_shape",
|
"batch_shape", "get_batch_shape", "event_shape", "get_event_shape",
|
||||||
"sample", "log_prob", "prob", "log_cdf", "cdf", "log_survival_function",
|
"sample", "log_prob", "prob", "log_cdf", "cdf", "log_survival_function",
|
||||||
"survival_function", "entropy", "mean", "variance", "stddev", "mode"]
|
"survival_function", "entropy", "mean", "variance", "stddev", "mode",
|
||||||
|
"covariance"]
|
||||||
|
|
||||||
|
|
||||||
@six.add_metaclass(abc.ABCMeta)
|
@six.add_metaclass(abc.ABCMeta)
|
||||||
@ -877,7 +878,24 @@ class Distribution(_BaseDistribution):
|
|||||||
raise NotImplementedError("variance is not implemented")
|
raise NotImplementedError("variance is not implemented")
|
||||||
|
|
||||||
def variance(self, name="variance"):
|
def variance(self, name="variance"):
|
||||||
"""Variance."""
|
"""Variance.
|
||||||
|
|
||||||
|
Variance is defined as,
|
||||||
|
|
||||||
|
```none
|
||||||
|
Var = E[(X - E[X])**2]
|
||||||
|
```
|
||||||
|
|
||||||
|
where `X` is the random variable associated with this distribution, `E`
|
||||||
|
denotes expectation, and `Var.shape = batch_shape + event_shape`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
variance: Floating-point `Tensor` with shape identical to
|
||||||
|
`batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
|
||||||
|
"""
|
||||||
with self._name_scope(name):
|
with self._name_scope(name):
|
||||||
try:
|
try:
|
||||||
return self._variance()
|
return self._variance()
|
||||||
@ -891,7 +909,25 @@ class Distribution(_BaseDistribution):
|
|||||||
raise NotImplementedError("stddev is not implemented")
|
raise NotImplementedError("stddev is not implemented")
|
||||||
|
|
||||||
def stddev(self, name="stddev"):
|
def stddev(self, name="stddev"):
|
||||||
"""Standard deviation."""
|
"""Standard deviation.
|
||||||
|
|
||||||
|
Standard deviation is defined as,
|
||||||
|
|
||||||
|
```none
|
||||||
|
stddev = E[(X - E[X])**2]**0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
where `X` is the random variable associated with this distribution, `E`
|
||||||
|
denotes expectation, and `stddev.shape = batch_shape + event_shape`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
stddev: Floating-point `Tensor` with shape identical to
|
||||||
|
`batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
|
||||||
|
"""
|
||||||
|
|
||||||
with self._name_scope(name):
|
with self._name_scope(name):
|
||||||
try:
|
try:
|
||||||
return self._stddev()
|
return self._stddev()
|
||||||
@ -901,6 +937,48 @@ class Distribution(_BaseDistribution):
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
|
def _covariance(self):
|
||||||
|
raise NotImplementedError("covariance is not implemented")
|
||||||
|
|
||||||
|
def covariance(self, name="covariance"):
|
||||||
|
"""Covariance.
|
||||||
|
|
||||||
|
Covariance is (possibly) defined only for non-scalar-event distributions.
|
||||||
|
|
||||||
|
For example, for a length-`k`, vector-valued distribution, it is calculated
|
||||||
|
as,
|
||||||
|
|
||||||
|
```none
|
||||||
|
Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]
|
||||||
|
```
|
||||||
|
|
||||||
|
where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E`
|
||||||
|
denotes expectation.
|
||||||
|
|
||||||
|
Alternatively, for non-vector, multivariate distributions (e.g.,
|
||||||
|
matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices
|
||||||
|
under some vectorization of the events, i.e.,
|
||||||
|
|
||||||
|
```none
|
||||||
|
Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]
|
||||||
|
````
|
||||||
|
|
||||||
|
where `Cov` is a (batch of) `k' x k'` matrices,
|
||||||
|
`0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function
|
||||||
|
mapping indices of this distribution's event dimensions to indices of a
|
||||||
|
length-`k'` vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name to give this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']`
|
||||||
|
where the first `n` dimensions are batch coordinates and
|
||||||
|
`k' = reduce_prod(self.event_shape)`.
|
||||||
|
"""
|
||||||
|
with self._name_scope(name):
|
||||||
|
return self._covariance()
|
||||||
|
|
||||||
def _mode(self):
|
def _mode(self):
|
||||||
raise NotImplementedError("mode is not implemented")
|
raise NotImplementedError("mode is not implemented")
|
||||||
|
|
||||||
|
@ -195,9 +195,7 @@ class Multinomial(distribution.Distribution):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def probs(self):
|
def probs(self):
|
||||||
"""Vector of probabilities summing to one.
|
"""Probability of of drawing a `1` in that coordinate."""
|
||||||
|
|
||||||
Each element is the probability of drawing that coordinate."""
|
|
||||||
return self._probs
|
return self._probs
|
||||||
|
|
||||||
def _batch_shape(self):
|
def _batch_shape(self):
|
||||||
@ -256,11 +254,15 @@ class Multinomial(distribution.Distribution):
|
|||||||
def _mean(self):
|
def _mean(self):
|
||||||
return array_ops.identity(self._mean_val)
|
return array_ops.identity(self._mean_val)
|
||||||
|
|
||||||
def _variance(self):
|
def _covariance(self):
|
||||||
p = self.probs * array_ops.ones_like(self.total_count)[..., None]
|
p = self.probs * array_ops.ones_like(self.total_count)[..., None]
|
||||||
return array_ops.matrix_set_diag(
|
return array_ops.matrix_set_diag(
|
||||||
-math_ops.matmul(self._mean_val[..., None], p[..., None, :]),
|
-math_ops.matmul(self._mean_val[..., None], p[..., None, :]),
|
||||||
self._mean_val - self._mean_val * p)
|
self._variance())
|
||||||
|
|
||||||
|
def _variance(self):
|
||||||
|
p = self.probs * array_ops.ones_like(self.total_count)[..., None]
|
||||||
|
return self._mean_val - self._mean_val * p
|
||||||
|
|
||||||
def _maybe_assert_valid_total_count(self, total_count, validate_args):
|
def _maybe_assert_valid_total_count(self, total_count, validate_args):
|
||||||
if not validate_args:
|
if not validate_args:
|
||||||
|
@ -300,9 +300,12 @@ class _MultivariateNormalOperatorPD(distribution.Distribution):
|
|||||||
def _mean(self):
|
def _mean(self):
|
||||||
return array_ops.identity(self._mu)
|
return array_ops.identity(self._mu)
|
||||||
|
|
||||||
def _variance(self):
|
def _covariance(self):
|
||||||
return self.sigma
|
return self.sigma
|
||||||
|
|
||||||
|
def _variance(self):
|
||||||
|
return array_ops.matrix_diag_part(self.sigma)
|
||||||
|
|
||||||
def _mode(self):
|
def _mode(self):
|
||||||
return array_ops.identity(self._mu)
|
return array_ops.identity(self._mu)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user