BREAKING CHANGE: Use Covariance
where appropriate.
BUGFIX: Fix broadcasting in DirichletMultinomial.mean. Change: 145851311
This commit is contained in:
parent
ca39ed7c07
commit
333b7ce729
@ -207,7 +207,49 @@ class DirichletMultinomialTest(test.TestCase):
|
||||
self.assertAllClose(mean2[class_num], 2 * mean1[class_num])
|
||||
self.assertTupleEqual((3,), mean1.shape)
|
||||
|
||||
def testVariance(self):
|
||||
def testCovarianceFromSampling(self):
|
||||
# We will test mean, cov, var, stddev on a DirichletMultinomial constructed
|
||||
# via broadcast between alpha, n.
|
||||
alpha = np.array([[1., 2, 3],
|
||||
[2.5, 4, 0.01]], dtype=np.float32)
|
||||
# Ideally we'd be able to test broadcasting but, the multinomial sampler
|
||||
# doesn't support different total counts.
|
||||
n = np.float32(5)
|
||||
with self.test_session() as sess:
|
||||
# batch_shape=[2], event_shape=[3]
|
||||
dist = ds.DirichletMultinomial(n, alpha)
|
||||
x = dist.sample(int(250e3), seed=1)
|
||||
sample_mean = math_ops.reduce_mean(x, 0)
|
||||
x_centered = x - sample_mean[None, ...]
|
||||
sample_cov = math_ops.reduce_mean(math_ops.matmul(
|
||||
x_centered[..., None], x_centered[..., None, :]), 0)
|
||||
sample_var = array_ops.matrix_diag_part(sample_cov)
|
||||
sample_stddev = math_ops.sqrt(sample_var)
|
||||
[
|
||||
sample_mean_,
|
||||
sample_cov_,
|
||||
sample_var_,
|
||||
sample_stddev_,
|
||||
analytic_mean,
|
||||
analytic_cov,
|
||||
analytic_var,
|
||||
analytic_stddev,
|
||||
] = sess.run([
|
||||
sample_mean,
|
||||
sample_cov,
|
||||
sample_var,
|
||||
sample_stddev,
|
||||
dist.mean(),
|
||||
dist.covariance(),
|
||||
dist.variance(),
|
||||
dist.stddev(),
|
||||
])
|
||||
self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04)
|
||||
self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.05)
|
||||
self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
|
||||
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
|
||||
|
||||
def testCovariance(self):
|
||||
# Shape [2]
|
||||
alpha = [1., 2]
|
||||
ns = [2., 3., 4., 5.]
|
||||
@ -234,13 +276,13 @@ class DirichletMultinomialTest(test.TestCase):
|
||||
for n in ns:
|
||||
# n is shape [] and alpha is shape [2].
|
||||
dist = ds.DirichletMultinomial(n, alpha)
|
||||
variance = dist.variance()
|
||||
expected_variance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix
|
||||
covariance = dist.covariance()
|
||||
expected_covariance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix
|
||||
|
||||
self.assertEqual((2, 2), variance.get_shape())
|
||||
self.assertAllClose(expected_variance, variance.eval())
|
||||
self.assertEqual((2, 2), covariance.get_shape())
|
||||
self.assertAllClose(expected_covariance, covariance.eval())
|
||||
|
||||
def testVarianceNAlphaBroadcast(self):
|
||||
def testCovarianceNAlphaBroadcast(self):
|
||||
alpha_v = [1., 2, 3]
|
||||
alpha_0 = 6.
|
||||
|
||||
@ -271,14 +313,14 @@ class DirichletMultinomialTest(test.TestCase):
|
||||
with self.test_session():
|
||||
# ns is shape [4, 1], and alpha is shape [4, 3].
|
||||
dist = ds.DirichletMultinomial(ns, alpha)
|
||||
variance = dist.variance()
|
||||
expected_variance = np.expand_dims(ns * (ns + alpha_0) / (1 + alpha_0),
|
||||
-1) * shared_matrix
|
||||
covariance = dist.covariance()
|
||||
expected_covariance = shared_matrix * (
|
||||
ns * (ns + alpha_0) / (1 + alpha_0))[..., None]
|
||||
|
||||
self.assertEqual((4, 3, 3), variance.get_shape())
|
||||
self.assertAllClose(expected_variance, variance.eval())
|
||||
self.assertEqual([4, 3, 3], covariance.get_shape())
|
||||
self.assertAllClose(expected_covariance, covariance.eval())
|
||||
|
||||
def testVarianceMultidimensional(self):
|
||||
def testCovarianceMultidimensional(self):
|
||||
alpha = np.random.rand(3, 5, 4).astype(np.float32)
|
||||
alpha2 = np.random.rand(6, 3, 3).astype(np.float32)
|
||||
|
||||
@ -289,10 +331,10 @@ class DirichletMultinomialTest(test.TestCase):
|
||||
dist = ds.DirichletMultinomial(ns, alpha)
|
||||
dist2 = ds.DirichletMultinomial(ns2, alpha2)
|
||||
|
||||
variance = dist.variance()
|
||||
variance2 = dist2.variance()
|
||||
self.assertEqual((3, 5, 4, 4), variance.get_shape())
|
||||
self.assertEqual((6, 3, 3, 3), variance2.get_shape())
|
||||
covariance = dist.covariance()
|
||||
covariance2 = dist2.covariance()
|
||||
self.assertEqual([3, 5, 4, 4], covariance.get_shape())
|
||||
self.assertEqual([6, 3, 3, 3], covariance2.get_shape())
|
||||
|
||||
def testZeroCountsResultsInPmfEqualToOne(self):
|
||||
# There is only one way for zero items to be selected, and this happens with
|
||||
@ -390,7 +432,7 @@ class DirichletMultinomialTest(test.TestCase):
|
||||
sample_mean,
|
||||
sample_covariance,
|
||||
dist.mean(),
|
||||
dist.variance(),
|
||||
dist.covariance(),
|
||||
])
|
||||
self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
|
||||
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.15)
|
||||
@ -417,7 +459,7 @@ class DirichletMultinomialTest(test.TestCase):
|
||||
sample_mean,
|
||||
sample_covariance,
|
||||
dist.mean(),
|
||||
dist.variance(),
|
||||
dist.covariance(),
|
||||
])
|
||||
self.assertAllEqual([4], sample_mean.get_shape())
|
||||
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.05)
|
||||
|
@ -21,6 +21,8 @@ from scipy import stats
|
||||
from tensorflow.contrib.distributions.python.ops import dirichlet as dirichlet_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -134,16 +136,52 @@ class DirichletTest(test.TestCase):
|
||||
self.assertEqual(dirichlet.mean().get_shape(), (3,))
|
||||
self.assertAllClose(dirichlet.mean().eval(), expected_mean)
|
||||
|
||||
def testDirichletVariance(self):
|
||||
def testCovarianceFromSampling(self):
|
||||
alpha = np.array([[1., 2, 3],
|
||||
[2.5, 4, 0.01]], dtype=np.float32)
|
||||
with self.test_session() as sess:
|
||||
dist = dirichlet_lib.Dirichlet(alpha) # batch_shape=[2], event_shape=[3]
|
||||
x = dist.sample(int(250e3), seed=1)
|
||||
sample_mean = math_ops.reduce_mean(x, 0)
|
||||
x_centered = x - sample_mean[None, ...]
|
||||
sample_cov = math_ops.reduce_mean(math_ops.matmul(
|
||||
x_centered[..., None], x_centered[..., None, :]), 0)
|
||||
sample_var = array_ops.matrix_diag_part(sample_cov)
|
||||
sample_stddev = math_ops.sqrt(sample_var)
|
||||
[
|
||||
sample_mean_,
|
||||
sample_cov_,
|
||||
sample_var_,
|
||||
sample_stddev_,
|
||||
analytic_mean,
|
||||
analytic_cov,
|
||||
analytic_var,
|
||||
analytic_stddev,
|
||||
] = sess.run([
|
||||
sample_mean,
|
||||
sample_cov,
|
||||
sample_var,
|
||||
sample_stddev,
|
||||
dist.mean(),
|
||||
dist.covariance(),
|
||||
dist.variance(),
|
||||
dist.stddev(),
|
||||
])
|
||||
self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04)
|
||||
self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.06)
|
||||
self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
|
||||
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
|
||||
|
||||
def testDirichletCovariance(self):
|
||||
with self.test_session():
|
||||
alpha = [1., 2, 3]
|
||||
denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
|
||||
expected_variance = np.diag(stats.dirichlet.var(alpha))
|
||||
expected_variance += [[0., -2, -3], [-2, 0, -6],
|
||||
[-3, -6, 0]] / denominator
|
||||
expected_covariance = np.diag(stats.dirichlet.var(alpha))
|
||||
expected_covariance += [[0., -2, -3], [-2, 0, -6],
|
||||
[-3, -6, 0]] / denominator
|
||||
dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
|
||||
self.assertEqual(dirichlet.variance().get_shape(), (3, 3))
|
||||
self.assertAllClose(dirichlet.variance().eval(), expected_variance)
|
||||
self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
|
||||
self.assertAllClose(dirichlet.covariance().eval(), expected_covariance)
|
||||
|
||||
def testDirichletMode(self):
|
||||
with self.test_session():
|
||||
|
@ -191,18 +191,18 @@ class MultinomialTest(test.TestCase):
|
||||
self.assertEqual((3,), dist.mean().get_shape())
|
||||
self.assertAllClose(expected_means, dist.mean().eval())
|
||||
|
||||
def testMultinomialVariance(self):
|
||||
def testMultinomialCovariance(self):
|
||||
with self.test_session():
|
||||
n = 5.
|
||||
p = [0.1, 0.2, 0.7]
|
||||
dist = ds.Multinomial(total_count=n, probs=p)
|
||||
expected_variances = [[9. / 20, -1 / 10, -7 / 20],
|
||||
[-1 / 10, 4 / 5, -7 / 10],
|
||||
[-7 / 20, -7 / 10, 21 / 20]]
|
||||
self.assertEqual((3, 3), dist.variance().get_shape())
|
||||
self.assertAllClose(expected_variances, dist.variance().eval())
|
||||
expected_covariances = [[9. / 20, -1 / 10, -7 / 20],
|
||||
[-1 / 10, 4 / 5, -7 / 10],
|
||||
[-7 / 20, -7 / 10, 21 / 20]]
|
||||
self.assertEqual((3, 3), dist.covariance().get_shape())
|
||||
self.assertAllClose(expected_covariances, dist.covariance().eval())
|
||||
|
||||
def testMultinomialVarianceBatch(self):
|
||||
def testMultinomialCovarianceBatch(self):
|
||||
with self.test_session():
|
||||
# Shape [2]
|
||||
n = [5.] * 2
|
||||
@ -212,11 +212,11 @@ class MultinomialTest(test.TestCase):
|
||||
# Shape [2, 2]
|
||||
inner_var = [[9. / 20, -9 / 20], [-9 / 20, 9 / 20]]
|
||||
# Shape [4, 2, 2, 2]
|
||||
expected_variances = [[inner_var, inner_var]] * 4
|
||||
self.assertEqual((4, 2, 2, 2), dist.variance().get_shape())
|
||||
self.assertAllClose(expected_variances, dist.variance().eval())
|
||||
expected_covariances = [[inner_var, inner_var]] * 4
|
||||
self.assertEqual((4, 2, 2, 2), dist.covariance().get_shape())
|
||||
self.assertAllClose(expected_covariances, dist.covariance().eval())
|
||||
|
||||
def testVarianceMultidimensional(self):
|
||||
def testCovarianceMultidimensional(self):
|
||||
# Shape [3, 5, 4]
|
||||
p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
|
||||
# Shape [6, 3, 3]
|
||||
@ -229,10 +229,52 @@ class MultinomialTest(test.TestCase):
|
||||
dist = ds.Multinomial(ns, p)
|
||||
dist2 = ds.Multinomial(ns2, p2)
|
||||
|
||||
variance = dist.variance()
|
||||
variance2 = dist2.variance()
|
||||
self.assertEqual((3, 5, 4, 4), variance.get_shape())
|
||||
self.assertEqual((6, 3, 3, 3), variance2.get_shape())
|
||||
covariance = dist.covariance()
|
||||
covariance2 = dist2.covariance()
|
||||
self.assertEqual((3, 5, 4, 4), covariance.get_shape())
|
||||
self.assertEqual((6, 3, 3, 3), covariance2.get_shape())
|
||||
|
||||
def testCovarianceFromSampling(self):
|
||||
# We will test mean, cov, var, stddev on a DirichletMultinomial constructed
|
||||
# via broadcast between alpha, n.
|
||||
theta = np.array([[1., 2, 3],
|
||||
[2.5, 4, 0.01]], dtype=np.float32)
|
||||
theta /= np.sum(theta, 1)[..., None]
|
||||
# Ideally we'd be able to test broadcasting but, the multinomial sampler
|
||||
# doesn't support different total counts.
|
||||
n = np.float32(5)
|
||||
with self.test_session() as sess:
|
||||
dist = ds.Multinomial(n, theta) # batch_shape=[2], event_shape=[3]
|
||||
x = dist.sample(int(250e3), seed=1)
|
||||
sample_mean = math_ops.reduce_mean(x, 0)
|
||||
x_centered = x - sample_mean[None, ...]
|
||||
sample_cov = math_ops.reduce_mean(math_ops.matmul(
|
||||
x_centered[..., None], x_centered[..., None, :]), 0)
|
||||
sample_var = array_ops.matrix_diag_part(sample_cov)
|
||||
sample_stddev = math_ops.sqrt(sample_var)
|
||||
[
|
||||
sample_mean_,
|
||||
sample_cov_,
|
||||
sample_var_,
|
||||
sample_stddev_,
|
||||
analytic_mean,
|
||||
analytic_cov,
|
||||
analytic_var,
|
||||
analytic_stddev,
|
||||
] = sess.run([
|
||||
sample_mean,
|
||||
sample_cov,
|
||||
sample_var,
|
||||
sample_stddev,
|
||||
dist.mean(),
|
||||
dist.covariance(),
|
||||
dist.variance(),
|
||||
dist.stddev(),
|
||||
])
|
||||
self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.01)
|
||||
self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.01)
|
||||
self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.01)
|
||||
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.01)
|
||||
|
||||
def testSampleUnbiasedNonScalarBatch(self):
|
||||
with self.test_session() as sess:
|
||||
@ -255,7 +297,7 @@ class MultinomialTest(test.TestCase):
|
||||
sample_mean,
|
||||
sample_covariance,
|
||||
dist.mean(),
|
||||
dist.variance(),
|
||||
dist.covariance(),
|
||||
])
|
||||
self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
|
||||
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07)
|
||||
@ -283,7 +325,7 @@ class MultinomialTest(test.TestCase):
|
||||
sample_mean,
|
||||
sample_covariance,
|
||||
dist.mean(),
|
||||
dist.variance(),
|
||||
dist.covariance(),
|
||||
])
|
||||
self.assertAllEqual([4], sample_mean.get_shape())
|
||||
self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07)
|
||||
|
@ -211,18 +211,22 @@ class Dirichlet(distribution.Distribution):
|
||||
return entropy
|
||||
|
||||
def _mean(self):
|
||||
return self.alpha / array_ops.expand_dims(self.alpha_sum, -1)
|
||||
return self.alpha / self.alpha_sum[..., None]
|
||||
|
||||
def _covariance(self):
|
||||
x = self._variance_scale_term() * self._mean()
|
||||
return array_ops.matrix_set_diag(
|
||||
-math_ops.matmul(x[..., None], x[..., None, :]), # outer prod
|
||||
self._variance())
|
||||
|
||||
def _variance(self):
|
||||
scale = self.alpha_sum * math_ops.sqrt(1. + self.alpha_sum)
|
||||
alpha = self.alpha / scale
|
||||
outer_prod = -math_ops.matmul(
|
||||
array_ops.expand_dims(
|
||||
alpha, dim=-1), # column
|
||||
array_ops.expand_dims(
|
||||
alpha, dim=-2)) # row
|
||||
return array_ops.matrix_set_diag(outer_prod,
|
||||
alpha * (self.alpha_sum / scale - alpha))
|
||||
scale = self._variance_scale_term()
|
||||
x = scale * self._mean()
|
||||
return x * (scale - x)
|
||||
|
||||
def _variance_scale_term(self):
|
||||
"""Helper to `_covariance` and `_variance` which computes a shared scale."""
|
||||
return math_ops.rsqrt(1. + self.alpha_sum[..., None])
|
||||
|
||||
@distribution_util.AppendDocstring(
|
||||
"""Note that the mode for the Dirichlet distribution is only defined
|
||||
|
@ -252,11 +252,10 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
return math_ops.exp(self._log_prob(counts))
|
||||
|
||||
def _mean(self):
|
||||
normalized_alpha = self.alpha / array_ops.expand_dims(self.alpha_sum, -1)
|
||||
return self.n[..., None] * normalized_alpha
|
||||
return self.n * (self.alpha / self.alpha_sum[..., None])
|
||||
|
||||
@distribution_util.AppendDocstring(
|
||||
"""The variance for each batch member is defined as the following:
|
||||
"""The covariance for each batch member is defined as the following:
|
||||
|
||||
```
|
||||
Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
|
||||
@ -272,18 +271,23 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
(n + alpha_0) / (1 + alpha_0)
|
||||
```
|
||||
""")
|
||||
def _covariance(self):
|
||||
x = self._variance_scale_term() * self._mean()
|
||||
return array_ops.matrix_set_diag(
|
||||
-math_ops.matmul(x[..., None], x[..., None, :]), # outer prod
|
||||
self._variance())
|
||||
|
||||
def _variance(self):
|
||||
alpha_sum = array_ops.expand_dims(self.alpha_sum, -1)
|
||||
normalized_alpha = self.alpha / alpha_sum
|
||||
variance = -math_ops.matmul(
|
||||
array_ops.expand_dims(normalized_alpha, -1),
|
||||
array_ops.expand_dims(normalized_alpha, -2))
|
||||
variance = array_ops.matrix_set_diag(variance, normalized_alpha *
|
||||
(1. - normalized_alpha))
|
||||
shared_factor = (self.n * (alpha_sum + self.n) /
|
||||
(alpha_sum + 1) * array_ops.ones_like(self.alpha))
|
||||
variance *= array_ops.expand_dims(shared_factor, -1)
|
||||
return variance
|
||||
scale = self._variance_scale_term()
|
||||
x = scale * self._mean()
|
||||
return x * (self.n * scale - x)
|
||||
|
||||
def _variance_scale_term(self):
|
||||
"""Helper to `_covariance` and `_variance` which computes a shared scale."""
|
||||
# We must take care to expand back the last dim whenever we use the
|
||||
# alpha_sum.
|
||||
c0 = self.alpha_sum[..., None]
|
||||
return math_ops.sqrt((1. + c0 / self.n) / (1. + c0))
|
||||
|
||||
def _assert_valid_counts(self, counts):
|
||||
"""Check counts for proper shape, values, then return tensor version."""
|
||||
|
@ -40,7 +40,8 @@ from tensorflow.python.ops import math_ops
|
||||
_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
|
||||
"batch_shape", "get_batch_shape", "event_shape", "get_event_shape",
|
||||
"sample", "log_prob", "prob", "log_cdf", "cdf", "log_survival_function",
|
||||
"survival_function", "entropy", "mean", "variance", "stddev", "mode"]
|
||||
"survival_function", "entropy", "mean", "variance", "stddev", "mode",
|
||||
"covariance"]
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
@ -877,7 +878,24 @@ class Distribution(_BaseDistribution):
|
||||
raise NotImplementedError("variance is not implemented")
|
||||
|
||||
def variance(self, name="variance"):
|
||||
"""Variance."""
|
||||
"""Variance.
|
||||
|
||||
Variance is defined as,
|
||||
|
||||
```none
|
||||
Var = E[(X - E[X])**2]
|
||||
```
|
||||
|
||||
where `X` is the random variable associated with this distribution, `E`
|
||||
denotes expectation, and `Var.shape = batch_shape + event_shape`.
|
||||
|
||||
Args:
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
variance: Floating-point `Tensor` with shape identical to
|
||||
`batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
try:
|
||||
return self._variance()
|
||||
@ -891,7 +909,25 @@ class Distribution(_BaseDistribution):
|
||||
raise NotImplementedError("stddev is not implemented")
|
||||
|
||||
def stddev(self, name="stddev"):
|
||||
"""Standard deviation."""
|
||||
"""Standard deviation.
|
||||
|
||||
Standard deviation is defined as,
|
||||
|
||||
```none
|
||||
stddev = E[(X - E[X])**2]**0.5
|
||||
```
|
||||
|
||||
where `X` is the random variable associated with this distribution, `E`
|
||||
denotes expectation, and `stddev.shape = batch_shape + event_shape`.
|
||||
|
||||
Args:
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
stddev: Floating-point `Tensor` with shape identical to
|
||||
`batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
|
||||
"""
|
||||
|
||||
with self._name_scope(name):
|
||||
try:
|
||||
return self._stddev()
|
||||
@ -901,6 +937,48 @@ class Distribution(_BaseDistribution):
|
||||
except NotImplementedError:
|
||||
raise original_exception
|
||||
|
||||
def _covariance(self):
|
||||
raise NotImplementedError("covariance is not implemented")
|
||||
|
||||
def covariance(self, name="covariance"):
|
||||
"""Covariance.
|
||||
|
||||
Covariance is (possibly) defined only for non-scalar-event distributions.
|
||||
|
||||
For example, for a length-`k`, vector-valued distribution, it is calculated
|
||||
as,
|
||||
|
||||
```none
|
||||
Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]
|
||||
```
|
||||
|
||||
where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E`
|
||||
denotes expectation.
|
||||
|
||||
Alternatively, for non-vector, multivariate distributions (e.g.,
|
||||
matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices
|
||||
under some vectorization of the events, i.e.,
|
||||
|
||||
```none
|
||||
Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]
|
||||
````
|
||||
|
||||
where `Cov` is a (batch of) `k' x k'` matrices,
|
||||
`0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function
|
||||
mapping indices of this distribution's event dimensions to indices of a
|
||||
length-`k'` vector.
|
||||
|
||||
Args:
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']`
|
||||
where the first `n` dimensions are batch coordinates and
|
||||
`k' = reduce_prod(self.event_shape)`.
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
return self._covariance()
|
||||
|
||||
def _mode(self):
|
||||
raise NotImplementedError("mode is not implemented")
|
||||
|
||||
|
@ -195,9 +195,7 @@ class Multinomial(distribution.Distribution):
|
||||
|
||||
@property
|
||||
def probs(self):
|
||||
"""Vector of probabilities summing to one.
|
||||
|
||||
Each element is the probability of drawing that coordinate."""
|
||||
"""Probability of of drawing a `1` in that coordinate."""
|
||||
return self._probs
|
||||
|
||||
def _batch_shape(self):
|
||||
@ -256,11 +254,15 @@ class Multinomial(distribution.Distribution):
|
||||
def _mean(self):
|
||||
return array_ops.identity(self._mean_val)
|
||||
|
||||
def _variance(self):
|
||||
def _covariance(self):
|
||||
p = self.probs * array_ops.ones_like(self.total_count)[..., None]
|
||||
return array_ops.matrix_set_diag(
|
||||
-math_ops.matmul(self._mean_val[..., None], p[..., None, :]),
|
||||
self._mean_val - self._mean_val * p)
|
||||
self._variance())
|
||||
|
||||
def _variance(self):
|
||||
p = self.probs * array_ops.ones_like(self.total_count)[..., None]
|
||||
return self._mean_val - self._mean_val * p
|
||||
|
||||
def _maybe_assert_valid_total_count(self, total_count, validate_args):
|
||||
if not validate_args:
|
||||
|
@ -300,9 +300,12 @@ class _MultivariateNormalOperatorPD(distribution.Distribution):
|
||||
def _mean(self):
|
||||
return array_ops.identity(self._mu)
|
||||
|
||||
def _variance(self):
|
||||
def _covariance(self):
|
||||
return self.sigma
|
||||
|
||||
def _variance(self):
|
||||
return array_ops.matrix_diag_part(self.sigma)
|
||||
|
||||
def _mode(self):
|
||||
return array_ops.identity(self._mu)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user