Adding covariance, covariance_cholesky, mvn_cov to distributions/
These allow alternative representations of covariance matrices, e.g. diagonal/sparse/functional, to be used in a multivariate-normal. Change: 126226644
This commit is contained in:
parent
f2c3b2e702
commit
94cbf42a07
@ -10,6 +10,39 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
|
||||
cuda_py_tests(
|
||||
name = "operator_pd_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/operator_pd_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "operator_pd_cholesky_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/operator_pd_cholesky_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "operator_pd_full_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/operator_pd_full_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distributions_py",
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||
|
@ -38,9 +38,26 @@ initialized with parameters that define the distributions.
|
||||
|
||||
### Multivariate distributions
|
||||
|
||||
@@MultivariateNormal
|
||||
#### Multivariate normal
|
||||
|
||||
@@MultivariateNormalFull
|
||||
@@MultivariateNormalCholesky
|
||||
|
||||
#### Other multivariate distributions
|
||||
|
||||
@@DirichletMultinomial
|
||||
|
||||
## Operators allowing for matrix-free methods
|
||||
|
||||
### Positive definite operators
|
||||
|
||||
A matrix is positive definite if it is symmetric with all positive eigenvalues.
|
||||
|
||||
@@OperatorPDBase
|
||||
@@OperatorPDFull
|
||||
@@OperatorPDCholesky
|
||||
@@batch_matrix_diag_transform
|
||||
|
||||
## Posterior inference with conjugate priors.
|
||||
|
||||
Functions that transform conjugate prior/likelihood pairs to distributions
|
||||
@ -61,7 +78,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops.bernoulli import *
|
||||
from tensorflow.contrib.distributions.python.ops.categorical import *
|
||||
@ -74,5 +91,8 @@ from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
|
||||
from tensorflow.contrib.distributions.python.ops.mvn import *
|
||||
from tensorflow.contrib.distributions.python.ops.normal import *
|
||||
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
|
||||
from tensorflow.contrib.distributions.python.ops.operator_pd import *
|
||||
from tensorflow.contrib.distributions.python.ops.operator_pd_cholesky import *
|
||||
from tensorflow.contrib.distributions.python.ops.operator_pd_full import *
|
||||
from tensorflow.contrib.distributions.python.ops.student_t import *
|
||||
from tensorflow.contrib.distributions.python.ops.uniform import *
|
||||
|
@ -19,249 +19,153 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow as tf
|
||||
|
||||
distributions = tf.contrib.distributions
|
||||
|
||||
class MultivariateNormalTest(tf.test.TestCase):
|
||||
|
||||
class MultivariateNormalCholeskyTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._rng = np.random.RandomState(42)
|
||||
|
||||
def _random_chol(self, *shape):
|
||||
mat = self._rng.rand(*shape)
|
||||
chol = distributions.batch_matrix_diag_transform(
|
||||
mat, transform=tf.nn.softplus)
|
||||
chol = tf.batch_matrix_band_part(chol, -1, 0)
|
||||
sigma = tf.batch_matmul(chol, chol, adj_y=True)
|
||||
return chol.eval(), sigma.eval()
|
||||
|
||||
def testNonmatchingMuSigmaFails(self):
|
||||
with tf.Session():
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=[1.0, 2.0],
|
||||
sigma=[[[1.0, 0.0],
|
||||
[0.0, 1.0]],
|
||||
[[1.0, 0.0],
|
||||
[0.0, 1.0]]])
|
||||
with self.assertRaisesOpError(
|
||||
r"Rank of mu should be one less than rank of sigma"):
|
||||
mvn.mean.eval()
|
||||
with self.test_session():
|
||||
mu = self._rng.rand(2)
|
||||
chol, _ = self._random_chol(2, 2, 2)
|
||||
mvn = distributions.MultivariateNormalCholesky(mu, chol)
|
||||
with self.assertRaisesOpError("mu should have rank 1 less than cov"):
|
||||
mvn.mean().eval()
|
||||
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=[[1.0], [2.0]],
|
||||
sigma=[[[1.0, 0.0],
|
||||
[0.0, 1.0]],
|
||||
[[1.0, 0.0],
|
||||
[0.0, 1.0]]])
|
||||
with self.assertRaisesOpError(
|
||||
r"mu.shape and sigma.shape\[\:-1\] must match"):
|
||||
mvn.mean.eval()
|
||||
mu = self._rng.rand(2, 1)
|
||||
chol, _ = self._random_chol(2, 2, 2)
|
||||
mvn = distributions.MultivariateNormalCholesky(mu, chol)
|
||||
with self.assertRaisesOpError("mu.shape and cov.shape.*should match"):
|
||||
mvn.mean().eval()
|
||||
|
||||
def testNotPositiveDefiniteSigmaFails(self):
|
||||
with tf.Session():
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=[[1.0, 2.0], [1.0, 2.0]],
|
||||
sigma=[[[1.0, 0.0],
|
||||
[0.0, 1.0]],
|
||||
[[1.0, 1.0],
|
||||
[1.0, 1.0]]])
|
||||
with self.assertRaisesOpError(
|
||||
r"LLT decomposition was not successful."):
|
||||
mvn.mean.eval()
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=[[1.0, 2.0], [1.0, 2.0]],
|
||||
sigma=[[[1.0, 0.0],
|
||||
[0.0, 1.0]],
|
||||
[[-1.0, 0.0],
|
||||
[0.0, 1.0]]])
|
||||
with self.assertRaisesOpError(
|
||||
r"LLT decomposition was not successful."):
|
||||
mvn.mean.eval()
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=[[1.0, 2.0], [1.0, 2.0]],
|
||||
sigma_chol=[[[1.0, 0.0],
|
||||
[0.0, 1.0]],
|
||||
[[-1.0, 0.0],
|
||||
[0.0, 1.0]]])
|
||||
with self.assertRaisesOpError(
|
||||
r"sigma_chol is not positive definite."):
|
||||
mvn.mean.eval()
|
||||
|
||||
def testLogPDFScalar(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(mu_v)
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(sigma_v)
|
||||
x = np.array([-2.5, 2.5], dtype=np.float32)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
def testLogPDFScalarBatch(self):
|
||||
with self.test_session():
|
||||
mu = self._rng.rand(2)
|
||||
chol, sigma = self._random_chol(2, 2)
|
||||
mvn = distributions.MultivariateNormalCholesky(mu, chol)
|
||||
x = self._rng.rand(2)
|
||||
|
||||
log_pdf = mvn.log_pdf(x)
|
||||
pdf = mvn.pdf(x)
|
||||
|
||||
try:
|
||||
from scipy import stats # pylint: disable=g-import-not-at-top
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
except ImportError as e:
|
||||
tf.logging.warn("Cannot test stats functions: %s" % str(e))
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
|
||||
|
||||
def testLogPDFScalarSigmaHalf(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0, 1.0], dtype=np.float32)
|
||||
mu = tf.constant(mu_v)
|
||||
sigma_v = np.array([[1.0, 0.1, 0.2],
|
||||
[0.1, 2.0, 0.05],
|
||||
[0.2, 0.05, 3.0]], dtype=np.float32)
|
||||
sigma_chol_v = np.linalg.cholesky(sigma_v)
|
||||
sigma_chol = tf.constant(sigma_chol_v)
|
||||
x = np.array([-2.5, 2.5, 1.0], dtype=np.float32)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=mu, sigma_chol=sigma_chol)
|
||||
log_pdf = mvn.log_pdf(x)
|
||||
pdf = mvn.pdf(x)
|
||||
sigma = mvn.sigma
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
self.assertEqual((), log_pdf.get_shape())
|
||||
self.assertEqual((), pdf.get_shape())
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
|
||||
try:
|
||||
from scipy import stats # pylint: disable=g-import-not-at-top
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
self.assertEqual(sigma.get_shape(), (3, 3))
|
||||
self.assertAllClose(sigma_v, sigma.eval())
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
except ImportError as e:
|
||||
tf.logging.warn("Cannot test stats functions: %s" % str(e))
|
||||
def testLogPDFXIsHigherRank(self):
|
||||
with self.test_session():
|
||||
mu = self._rng.rand(2)
|
||||
chol, sigma = self._random_chol(2, 2)
|
||||
mvn = distributions.MultivariateNormalCholesky(mu, chol)
|
||||
x = self._rng.rand(3, 2)
|
||||
|
||||
def testLogPDF(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(mu_v)
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(sigma_v)
|
||||
x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
log_pdf = mvn.log_pdf(x)
|
||||
pdf = mvn.pdf(x)
|
||||
|
||||
try:
|
||||
from scipy import stats # pylint: disable=g-import-not-at-top
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
self.assertEqual(log_pdf.get_shape(), (3,))
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
except ImportError as e:
|
||||
tf.logging.warn("Cannot test stats functions: %s" % str(e))
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
|
||||
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
self.assertEqual((3,), log_pdf.get_shape())
|
||||
self.assertEqual((3,), pdf.get_shape())
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
|
||||
def testLogPDFXLowerDimension(self):
|
||||
with self.test_session():
|
||||
mu = self._rng.rand(3, 2)
|
||||
chol, sigma = self._random_chol(3, 2, 2)
|
||||
mvn = distributions.MultivariateNormalCholesky(mu, chol)
|
||||
x = self._rng.rand(2)
|
||||
|
||||
def testLogPDFMatchingDimension(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(np.vstack(3 * [mu_v]))
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(np.vstack(3 * [sigma_v[np.newaxis, :]]))
|
||||
x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
log_pdf = mvn.log_pdf(x)
|
||||
pdf = mvn.pdf(x)
|
||||
|
||||
try:
|
||||
from scipy import stats # pylint: disable=g-import-not-at-top
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
self.assertEqual(log_pdf.get_shape(), (3,))
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
except ImportError as e:
|
||||
tf.logging.warn("Cannot test stats functions: %s" % str(e))
|
||||
self.assertEqual((3,), log_pdf.get_shape())
|
||||
self.assertEqual((3,), pdf.get_shape())
|
||||
|
||||
def testLogPDFMultidimensional(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(
|
||||
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
|
||||
x = np.array([-2.5, 2.5], dtype=np.float32)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
log_pdf = mvn.log_pdf(x)
|
||||
pdf = mvn.pdf(x)
|
||||
# scipy can't do batches, so just test one of them.
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu[1, :], cov=sigma[1, :, :])
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
|
||||
try:
|
||||
from scipy import stats # pylint: disable=g-import-not-at-top
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_log_pdf = np.vstack(15 * [scipy_mvn.logpdf(x)]).reshape(3, 5)
|
||||
expected_pdf = np.vstack(15 * [scipy_mvn.pdf(x)]).reshape(3, 5)
|
||||
self.assertEqual(log_pdf.get_shape(), (3, 5))
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
except ImportError as e:
|
||||
tf.logging.warn("Cannot test stats functions: %s" % str(e))
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval()[1])
|
||||
self.assertAllClose(expected_pdf, pdf.eval()[1])
|
||||
|
||||
def testEntropy(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(mu_v)
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(sigma_v)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
with self.test_session():
|
||||
mu = self._rng.rand(2)
|
||||
chol, sigma = self._random_chol(2, 2)
|
||||
mvn = distributions.MultivariateNormalCholesky(mu, chol)
|
||||
entropy = mvn.entropy()
|
||||
|
||||
try:
|
||||
from scipy import stats # pylint: disable=g-import-not-at-top
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_entropy = scipy_mvn.entropy()
|
||||
self.assertEqual(entropy.get_shape(), ())
|
||||
self.assertAllClose(expected_entropy, entropy.eval())
|
||||
except ImportError as e:
|
||||
tf.logging.warn("Cannot test stats functions: %s" % str(e))
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
|
||||
expected_entropy = scipy_mvn.entropy()
|
||||
self.assertEqual(entropy.get_shape(), ())
|
||||
self.assertAllClose(expected_entropy, entropy.eval())
|
||||
|
||||
def testEntropyMultidimensional(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(
|
||||
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
with self.test_session():
|
||||
mu = self._rng.rand(3, 5, 2)
|
||||
chol, sigma = self._random_chol(3, 5, 2, 2)
|
||||
mvn = distributions.MultivariateNormalCholesky(mu, chol)
|
||||
entropy = mvn.entropy()
|
||||
|
||||
try:
|
||||
from scipy import stats # pylint: disable=g-import-not-at-top
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_entropy = np.vstack(15 * [scipy_mvn.entropy()]).reshape(3, 5)
|
||||
self.assertEqual(entropy.get_shape(), (3, 5))
|
||||
self.assertAllClose(expected_entropy, entropy.eval())
|
||||
except ImportError as e:
|
||||
tf.logging.warn("Cannot test stats functions: %s" % str(e))
|
||||
# Scipy doesn't do batches, so test one of them.
|
||||
expected_entropy = stats.multivariate_normal(
|
||||
mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy()
|
||||
self.assertEqual(entropy.get_shape(), (3, 5))
|
||||
self.assertAllClose(expected_entropy, entropy.eval()[1, 1])
|
||||
|
||||
def testSample(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(mu_v)
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(sigma_v)
|
||||
with self.test_session():
|
||||
mu = self._rng.rand(2)
|
||||
chol, sigma = self._random_chol(2, 2)
|
||||
|
||||
n = tf.constant(100000)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
mvn = distributions.MultivariateNormalCholesky(mu, chol)
|
||||
samples = mvn.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
self.assertEqual(samples.get_shape(), (100000, 2))
|
||||
self.assertAllClose(sample_values.mean(axis=0), mu_v, atol=1e-2)
|
||||
self.assertAllClose(np.cov(sample_values, rowvar=0), sigma_v, atol=1e-1)
|
||||
self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2)
|
||||
self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=1e-1)
|
||||
|
||||
def testSampleMultiDimensional(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(
|
||||
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
|
||||
with self.test_session():
|
||||
mu = self._rng.rand(3, 5, 2)
|
||||
chol, sigma = self._random_chol(3, 5, 2, 2)
|
||||
|
||||
n = tf.constant(100000)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
mvn = distributions.MultivariateNormalCholesky(mu, chol)
|
||||
samples = mvn.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
|
||||
self.assertEqual(samples.get_shape(), (100000, 3, 5, 2))
|
||||
sample_values = sample_values.reshape(100000, 15, 2)
|
||||
for i in range(15):
|
||||
self.assertAllClose(
|
||||
sample_values[:, i, :].mean(axis=0), mu_v, atol=1e-2)
|
||||
self.assertAllClose(
|
||||
np.cov(sample_values[:, i, :], rowvar=0), sigma_v, atol=1e-1)
|
||||
self.assertAllClose(
|
||||
sample_values[:, 1, 1, :].mean(axis=0),
|
||||
mu[1, 1, :], atol=0.05)
|
||||
self.assertAllClose(
|
||||
np.cov(sample_values[:, 1, 1, :], rowvar=0),
|
||||
sigma[1, 1, :, :], atol=1e-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -0,0 +1,327 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
distributions = tf.contrib.distributions
|
||||
|
||||
|
||||
def softplus(x):
|
||||
return np.log(1 + np.exp(x))
|
||||
|
||||
|
||||
class OperatorPDCholeskyTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._rng = np.random.RandomState(42)
|
||||
|
||||
def _random_cholesky_array(self, shape):
|
||||
mat = self._rng.rand(*shape)
|
||||
chol = distributions.batch_matrix_diag_transform(mat,
|
||||
transform=tf.nn.softplus)
|
||||
# Zero the upper triangle because we're using this as a true Cholesky factor
|
||||
# in our tests.
|
||||
return tf.batch_matrix_band_part(chol, -1, 0).eval()
|
||||
|
||||
def _numpy_inv_quadratic_form(self, chol, x):
|
||||
# Numpy works with batches now (calls them "stacks").
|
||||
x_expanded = np.expand_dims(x, -1)
|
||||
whitened = np.linalg.solve(chol, x_expanded)
|
||||
return (whitened**2).sum(axis=-1).sum(axis=-1)
|
||||
|
||||
def test_inv_quadratic_form_x_rank_same_as_broadcast_rank(self):
|
||||
with self.test_session():
|
||||
for batch_shape in [(), (2,)]:
|
||||
for k in [1, 3]:
|
||||
|
||||
x_shape = batch_shape + (k,)
|
||||
x = self._rng.randn(*x_shape)
|
||||
|
||||
chol_shape = batch_shape + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
qf = operator.inv_quadratic_form(x)
|
||||
|
||||
self.assertEqual(batch_shape, qf.get_shape())
|
||||
|
||||
numpy_qf = self._numpy_inv_quadratic_form(chol, x)
|
||||
self.assertAllClose(numpy_qf, qf.eval())
|
||||
|
||||
def test_inv_quadratic_form_x_and_chol_batch_shape_dont_match(self):
|
||||
# In this case, chol will have to be stretched to match x.
|
||||
with self.test_session():
|
||||
k = 3
|
||||
x_shape = (2, k)
|
||||
chol_shape = (1, k, k)
|
||||
broadcast_batch_shape = (2,)
|
||||
|
||||
x = self._rng.randn(*x_shape)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
qf = operator.inv_quadratic_form(x)
|
||||
|
||||
self.assertEqual(broadcast_batch_shape, qf.get_shape())
|
||||
|
||||
numpy_qf = self._numpy_inv_quadratic_form(chol, x)
|
||||
self.assertAllClose(numpy_qf, qf.eval())
|
||||
|
||||
def test_inv_quadratic_form_x_rank_less_than_broadcast_rank(self):
|
||||
with self.test_session():
|
||||
for batch_shape in [(2,), (2, 3)]:
|
||||
for k in [1, 4]:
|
||||
|
||||
# x will not have the leading dimension.
|
||||
x_shape = batch_shape[1:] + (k,)
|
||||
x = self._rng.randn(*x_shape)
|
||||
|
||||
chol_shape = batch_shape + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
qf = operator.inv_quadratic_form(x)
|
||||
|
||||
self.assertEqual(batch_shape, qf.get_shape())
|
||||
|
||||
x_upshaped = x + np.zeros(chol.shape[:-1])
|
||||
numpy_qf = self._numpy_inv_quadratic_form(chol, x_upshaped)
|
||||
numpy_qf = numpy_qf.reshape(batch_shape)
|
||||
self.assertAllClose(numpy_qf, qf.eval())
|
||||
|
||||
def test_inv_quadratic_form_x_rank_greater_than_broadcast_rank(self):
|
||||
with self.test_session():
|
||||
for batch_shape in [(2,), (2, 3)]:
|
||||
for k in [1, 4]:
|
||||
|
||||
x_shape = batch_shape + (k,)
|
||||
x = self._rng.randn(*x_shape)
|
||||
|
||||
# chol will not have the leading dimension.
|
||||
chol_shape = batch_shape[1:] + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
qf = operator.inv_quadratic_form(x)
|
||||
numpy_qf = self._numpy_inv_quadratic_form(chol, x)
|
||||
|
||||
self.assertEqual(batch_shape, qf.get_shape())
|
||||
self.assertAllClose(numpy_qf, qf.eval())
|
||||
|
||||
def test_inv_quadratic_form_x_rank_two_greater_than_broadcast_rank(self):
|
||||
with self.test_session():
|
||||
for batch_shape in [(2, 3), (2, 3, 4), (2, 3, 4, 5)]:
|
||||
for k in [1, 4]:
|
||||
|
||||
x_shape = batch_shape + (k,)
|
||||
x = self._rng.randn(*x_shape)
|
||||
|
||||
# chol will not have the leading two dimensions.
|
||||
chol_shape = batch_shape[2:] + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
qf = operator.inv_quadratic_form(x)
|
||||
numpy_qf = self._numpy_inv_quadratic_form(chol, x)
|
||||
|
||||
self.assertEqual(batch_shape, qf.get_shape())
|
||||
self.assertAllClose(numpy_qf, qf.eval())
|
||||
|
||||
def test_log_det(self):
|
||||
with self.test_session():
|
||||
batch_shape = ()
|
||||
for k in [1, 4]:
|
||||
chol_shape = batch_shape + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
log_det = operator.log_det()
|
||||
expected_log_det = np.log(np.prod(np.diag(chol))**2)
|
||||
|
||||
self.assertEqual(batch_shape, log_det.get_shape())
|
||||
self.assertAllClose(expected_log_det, log_det.eval())
|
||||
|
||||
def test_log_det_batch_matrix(self):
|
||||
with self.test_session():
|
||||
batch_shape = (2, 3)
|
||||
for k in [1, 4]:
|
||||
chol_shape = batch_shape + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
log_det = operator.log_det()
|
||||
|
||||
self.assertEqual(batch_shape, log_det.get_shape())
|
||||
|
||||
# Test the log-determinant of the [1, 1] matrix.
|
||||
chol_11 = chol[1, 1, :, :]
|
||||
expected_log_det = np.log(np.prod(np.diag(chol_11))**2)
|
||||
self.assertAllClose(expected_log_det, log_det.eval()[1, 1])
|
||||
|
||||
def test_sqrt_matmul_single_matrix(self):
|
||||
with self.test_session():
|
||||
batch_shape = ()
|
||||
for k in [1, 4]:
|
||||
x_shape = batch_shape + (k, 3)
|
||||
x = self._rng.rand(*x_shape)
|
||||
chol_shape = batch_shape + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
|
||||
sqrt_operator_times_x = operator.sqrt_matmul(x)
|
||||
expected = tf.batch_matmul(chol, x)
|
||||
|
||||
self.assertEqual(expected.get_shape(),
|
||||
sqrt_operator_times_x.get_shape())
|
||||
self.assertAllClose(expected.eval(), sqrt_operator_times_x.eval())
|
||||
|
||||
def test_sqrt_matmul_batch_matrix(self):
|
||||
with self.test_session():
|
||||
batch_shape = (2, 3)
|
||||
for k in [1, 4]:
|
||||
x_shape = batch_shape + (k, 5)
|
||||
x = self._rng.rand(*x_shape)
|
||||
chol_shape = batch_shape + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
|
||||
sqrt_operator_times_x = operator.sqrt_matmul(x)
|
||||
expected = tf.batch_matmul(chol, x)
|
||||
|
||||
self.assertEqual(expected.get_shape(),
|
||||
sqrt_operator_times_x.get_shape())
|
||||
self.assertAllClose(expected.eval(), sqrt_operator_times_x.eval())
|
||||
|
||||
def test_matmul_batch_matrix(self):
|
||||
with self.test_session():
|
||||
batch_shape = (2, 3)
|
||||
for k in [1, 4]:
|
||||
x_shape = batch_shape + (k, 5)
|
||||
x = self._rng.rand(*x_shape)
|
||||
chol_shape = batch_shape + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
|
||||
chol_times_x = tf.batch_matmul(chol, x, adj_x=True)
|
||||
expected = tf.batch_matmul(chol, chol_times_x)
|
||||
|
||||
self.assertEqual(expected.get_shape(), operator.matmul(x).get_shape())
|
||||
self.assertAllClose(expected.eval(), operator.matmul(x).eval())
|
||||
|
||||
def test_shape(self):
|
||||
# All other shapes are defined by the abstractmethod shape, so we only need
|
||||
# to test this.
|
||||
with self.test_session():
|
||||
for shape in [(3, 3), (2, 3, 3), (1, 2, 3, 3)]:
|
||||
chol = self._random_cholesky_array(shape)
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
self.assertAllEqual(shape, operator.shape().eval())
|
||||
|
||||
def test_to_dense(self):
|
||||
with self.test_session():
|
||||
chol = self._random_cholesky_array((3, 3))
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
self.assertAllClose(chol.dot(chol.T), operator.to_dense().eval())
|
||||
|
||||
def test_to_dense_sqrt(self):
|
||||
with self.test_session():
|
||||
chol = self._random_cholesky_array((2, 3, 3))
|
||||
operator = distributions.OperatorPDCholesky(chol)
|
||||
self.assertAllClose(chol, operator.to_dense_sqrt().eval())
|
||||
|
||||
def test_non_positive_definite_matrix_raises(self):
|
||||
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
|
||||
with self.test_session():
|
||||
lower_mat = [[1.0, 0.0], [2.0, 0.0]]
|
||||
operator = distributions.OperatorPDCholesky(lower_mat)
|
||||
with self.assertRaisesOpError('x > 0 did not hold'):
|
||||
operator.to_dense().eval()
|
||||
|
||||
def test_non_positive_definite_matrix_does_not_raise_if_not_verify_pd(self):
|
||||
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
|
||||
with self.test_session():
|
||||
lower_mat = [[1.0, 0.0], [2.0, 0.0]]
|
||||
operator = distributions.OperatorPDCholesky(lower_mat, verify_pd=False)
|
||||
operator.to_dense().eval() # Should not raise.
|
||||
|
||||
def test_not_having_two_identical_last_dims_raises(self):
|
||||
# Unless the last two dims are equal, this cannot represent a matrix, and it
|
||||
# should raise.
|
||||
with self.test_session():
|
||||
batch_vec = [[1.0], [2.0]] # shape 2 x 1
|
||||
with self.assertRaisesRegexp(ValueError, '.*Dimensions.*'):
|
||||
operator = distributions.OperatorPDCholesky(batch_vec)
|
||||
operator.to_dense().eval()
|
||||
|
||||
|
||||
class BatchMatrixDiagTransformTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._rng = np.random.RandomState(0)
|
||||
|
||||
def check_off_diagonal_same(self, m1, m2):
|
||||
"""Check the lower triangular part, not upper or diag."""
|
||||
self.assertAllClose(np.tril(m1, k=-1), np.tril(m2, k=-1))
|
||||
self.assertAllClose(np.triu(m1, k=1), np.triu(m2, k=1))
|
||||
|
||||
def test_non_batch_matrix_with_transform(self):
|
||||
mat = self._rng.rand(4, 4)
|
||||
with self.test_session():
|
||||
chol = distributions.batch_matrix_diag_transform(mat,
|
||||
transform=tf.nn.softplus)
|
||||
self.assertEqual((4, 4), chol.get_shape())
|
||||
|
||||
self.check_off_diagonal_same(mat, chol.eval())
|
||||
self.assertAllClose(softplus(np.diag(mat)), np.diag(chol.eval()))
|
||||
|
||||
def test_non_batch_matrix_no_transform(self):
|
||||
mat = self._rng.rand(4, 4)
|
||||
with self.test_session():
|
||||
# Default is no transform.
|
||||
chol = distributions.batch_matrix_diag_transform(mat)
|
||||
self.assertEqual((4, 4), chol.get_shape())
|
||||
self.assertAllClose(mat, chol.eval())
|
||||
|
||||
def test_batch_matrix_with_transform(self):
|
||||
mat = self._rng.rand(2, 4, 4)
|
||||
mat_0 = mat[0, :, :]
|
||||
with self.test_session():
|
||||
chol = distributions.batch_matrix_diag_transform(mat,
|
||||
transform=tf.nn.softplus)
|
||||
|
||||
self.assertEqual((2, 4, 4), chol.get_shape())
|
||||
|
||||
chol_0 = chol.eval()[0, :, :]
|
||||
|
||||
self.check_off_diagonal_same(mat_0, chol_0)
|
||||
self.assertAllClose(softplus(np.diag(mat_0)), np.diag(chol_0))
|
||||
|
||||
self.check_off_diagonal_same(mat_0, chol_0)
|
||||
self.assertAllClose(softplus(np.diag(mat_0)), np.diag(chol_0))
|
||||
|
||||
def test_batch_matrix_no_transform(self):
|
||||
mat = self._rng.rand(2, 4, 4)
|
||||
with self.test_session():
|
||||
# Default is no transform.
|
||||
chol = distributions.batch_matrix_diag_transform(mat)
|
||||
|
||||
self.assertEqual((2, 4, 4), chol.get_shape())
|
||||
self.assertAllClose(mat, chol.eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -0,0 +1,62 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
distributions = tf.contrib.distributions
|
||||
|
||||
|
||||
class OperatorPDFullTest(tf.test.TestCase):
|
||||
# The only method needing checked (because it isn't part of the parent class)
|
||||
# is the check for symmetry.
|
||||
|
||||
def setUp(self):
|
||||
self._rng = np.random.RandomState(42)
|
||||
|
||||
def _random_positive_def_array(self, *shape):
|
||||
matrix = self._rng.rand(*shape)
|
||||
return tf.batch_matmul(matrix, matrix, adj_y=True).eval()
|
||||
|
||||
def test_positive_definite_matrix_doesnt_raise(self):
|
||||
with self.test_session():
|
||||
matrix = self._random_positive_def_array(2, 3, 3)
|
||||
operator = distributions.OperatorPDFull(matrix, verify_pd=True)
|
||||
operator.to_dense().eval() # Should not raise
|
||||
|
||||
def test_negative_definite_matrix_raises(self):
|
||||
with self.test_session():
|
||||
matrix = -1 * self._random_positive_def_array(3, 2, 2)
|
||||
operator = distributions.OperatorPDFull(matrix, verify_pd=True)
|
||||
# Could fail inside Cholesky decomposition, or later when we test the
|
||||
# diag.
|
||||
with self.assertRaisesOpError('x > 0|LLT'):
|
||||
operator.to_dense().eval()
|
||||
|
||||
def test_non_symmetric_matrix_raises(self):
|
||||
with self.test_session():
|
||||
matrix = self._random_positive_def_array(3, 2, 2)
|
||||
matrix[0, 0, 1] += 0.001
|
||||
operator = distributions.OperatorPDFull(matrix, verify_pd=True)
|
||||
with self.assertRaisesOpError('x == y'):
|
||||
operator.to_dense().eval()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -0,0 +1,83 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
distributions = tf.contrib.distributions
|
||||
|
||||
|
||||
class BaseCovarianceTest(tf.test.TestCase):
|
||||
|
||||
def test_all_shapes_methods_defined_by_the_one_abstractproperty_shape(self):
|
||||
|
||||
class OperatorShape(distributions.OperatorPDBase):
|
||||
"""Operator implements the ABC method .shape."""
|
||||
|
||||
def __init__(self, shape):
|
||||
self._shape = shape
|
||||
|
||||
@property
|
||||
def verify_pd(self):
|
||||
return True
|
||||
|
||||
def get_shape(self):
|
||||
return tf.TensorShape(self._shape)
|
||||
|
||||
def shape(self, name='shape'):
|
||||
return tf.shape(np.random.rand(*self._shape))
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'OperatorShape'
|
||||
|
||||
def dtype(self):
|
||||
raise tf.int32
|
||||
|
||||
def inv_quadratic_form(
|
||||
self, x, name='inv_quadratic_form'):
|
||||
return x
|
||||
|
||||
def log_det(self, name='log_det'):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return []
|
||||
|
||||
def sqrt_matmul(self, x, name='sqrt_matmul'):
|
||||
return x
|
||||
|
||||
shape = (1, 2, 3, 3)
|
||||
with self.test_session():
|
||||
operator = OperatorShape(shape)
|
||||
|
||||
self.assertAllEqual(shape, operator.shape().eval())
|
||||
self.assertAllEqual(4, operator.rank().eval())
|
||||
self.assertAllEqual((1, 2), operator.batch_shape().eval())
|
||||
self.assertAllEqual((1, 2, 3), operator.vector_shape().eval())
|
||||
self.assertAllEqual(3, operator.vector_space_dimension().eval())
|
||||
|
||||
self.assertEqual(shape, operator.get_shape())
|
||||
self.assertEqual((1, 2), operator.get_batch_shape())
|
||||
self.assertEqual((1, 2, 3), operator.get_vector_shape())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Multivariate Normal distribution class."""
|
||||
"""Multivariate Normal distribution classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -20,83 +20,33 @@ from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.distributions.python.ops import operator_pd_full # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
|
||||
|
||||
def _assert_compatible_shapes(mu, sigma):
|
||||
r_mu = array_ops.rank(mu)
|
||||
r_sigma = array_ops.rank(sigma)
|
||||
sigma_shape = array_ops.shape(sigma)
|
||||
sigma_rank = array_ops.rank(sigma)
|
||||
mu_shape = array_ops.shape(mu)
|
||||
return control_flow_ops.group(
|
||||
logging_ops.Assert(
|
||||
math_ops.equal(r_mu + 1, r_sigma),
|
||||
["Rank of mu should be one less than rank of sigma, but saw: ",
|
||||
r_mu, " vs. ", r_sigma]),
|
||||
logging_ops.Assert(
|
||||
math_ops.equal(
|
||||
array_ops.gather(sigma_shape, sigma_rank - 2),
|
||||
array_ops.gather(sigma_shape, sigma_rank - 1)),
|
||||
["Last two dimensions of sigma (%s) must be equal: " % sigma.name,
|
||||
sigma_shape]),
|
||||
logging_ops.Assert(
|
||||
math_ops.reduce_all(math_ops.equal(
|
||||
mu_shape,
|
||||
array_ops.slice(
|
||||
sigma_shape, [0], array_ops.pack([sigma_rank - 1])))),
|
||||
["mu.shape and sigma.shape[:-1] must match, but saw: ",
|
||||
mu_shape, " vs. ", sigma_shape]))
|
||||
__all__ = [
|
||||
"MultivariateNormalCholesky",
|
||||
"MultivariateNormalFull",
|
||||
]
|
||||
|
||||
|
||||
def _assert_batch_positive_definite(sigma_chol):
|
||||
"""Add assertions checking that the sigmas are all Positive Definite.
|
||||
class MultivariateNormalOperatorPD(distribution.ContinuousDistribution):
|
||||
"""The multivariate normal distribution on `R^k`.
|
||||
|
||||
Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that
|
||||
`all(diag(sigma_chol) > 0)`. This is because to check that a matrix is PD,
|
||||
it is sufficient that its cholesky factorization is PD, and to check that a
|
||||
triangular matrix is PD, it is sufficient to check that its diagonal
|
||||
entries are positive.
|
||||
|
||||
Args:
|
||||
sigma_chol: N-D. The lower triangular cholesky decomposition of `sigma`.
|
||||
|
||||
Returns:
|
||||
An assertion op to use with `control_dependencies`, verifying that
|
||||
`sigma_chol` is positive definite.
|
||||
"""
|
||||
sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
|
||||
return logging_ops.Assert(
|
||||
math_ops.reduce_all(sigma_batch_diag > 0),
|
||||
["sigma_chol is not positive definite. batched diagonals: ",
|
||||
sigma_batch_diag, " shaped: ", array_ops.shape(sigma_batch_diag)])
|
||||
|
||||
|
||||
def _log_determinant_from_sigma_chol(sigma_chol):
|
||||
det_last_dim = array_ops.rank(sigma_chol) - 2
|
||||
sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
|
||||
log_det = 2.0 * math_ops.reduce_sum(
|
||||
math_ops.log(sigma_batch_diag), reduction_indices=det_last_dim)
|
||||
log_det.set_shape(sigma_chol.get_shape()[:-2])
|
||||
return log_det
|
||||
|
||||
|
||||
class MultivariateNormal(object):
|
||||
"""The Multivariate Normal distribution on `R^k`.
|
||||
|
||||
The distribution has mean and covariance parameters mu (1-D), sigma (2-D),
|
||||
or alternatively mean `mu` and factored covariance (cholesky decomposed
|
||||
`sigma`) called `sigma_chol`.
|
||||
This distribution is defined by a 1-D mean `mu` and an instance of
|
||||
`OperatorPDBase`, which provides access to a symmetric positive definite
|
||||
operator, which defines the covariance.
|
||||
|
||||
#### Mathematical details
|
||||
|
||||
@ -108,21 +58,6 @@ class MultivariateNormal(object):
|
||||
|
||||
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
|
||||
|
||||
Alternatively, if `sigma` is positive definite, it can be represented in terms
|
||||
of its lower triangular cholesky factorization
|
||||
|
||||
```sigma = sigma_chol . sigma_chol^*```
|
||||
|
||||
and the pdf above allows simpler computation:
|
||||
|
||||
```
|
||||
|det(sigma)| = reduce_prod(diag(sigma_chol))^2
|
||||
x_whitened = sigma^{-1/2} . (x - mu) = tri_solve(sigma_chol, x - mu)
|
||||
(x-mu)^* .sigma^{-1} . (x-mu) = x_whitened^* . x_whitened
|
||||
```
|
||||
|
||||
where `tri_solve()` solves a triangular system of equations.
|
||||
|
||||
#### Examples
|
||||
|
||||
A single multi-variate Gaussian distribution is defined by a vector of means
|
||||
@ -131,128 +66,177 @@ class MultivariateNormal(object):
|
||||
Extra leading dimensions, if provided, allow for batches.
|
||||
|
||||
```python
|
||||
# Initialize a single 3-variate Gaussian with diagonal covariance.
|
||||
# Initialize a single 3-variate Gaussian.
|
||||
mu = [1, 2, 3]
|
||||
sigma = [[1, 0, 0], [0, 3, 0], [0, 0, 2]]
|
||||
dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
chol = [[1, 0, 0.], [1, 3, 0], [1, 2, 3]]
|
||||
cov = tf.contrib.distributions.OperatorPDCholesky(chol)
|
||||
dist = tf.contrib.distributions.MultivariateNormalOperatorPD(mu, cov)
|
||||
|
||||
# Evaluate this on an observation in R^3, returning a scalar.
|
||||
dist.pdf([-1, 0, 1])
|
||||
dist.pdf([-1, 0, 1.])
|
||||
|
||||
# Initialize a batch of two 3-variate Gaussians.
|
||||
mu = [[1, 2, 3], [11, 22, 33]]
|
||||
sigma = ... # shape 2 x 3 x 3
|
||||
dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
mu = [[1, 2, 3], [11, 22, 33.]]
|
||||
chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal.
|
||||
cov = tf.contrib.distributions.OperatorPDCholesky(chol)
|
||||
dist = tf.contrib.distributions.MultivariateNormalOperatorPD(mu, cov)
|
||||
|
||||
# Evaluate this on a two observations, each in R^3, returning a length two
|
||||
# tensor.
|
||||
x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3.
|
||||
x = [[-1, 0, 1], [-11, 0, 11.]] # Shape 2 x 3.
|
||||
dist.pdf(x)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
|
||||
def __init__(
|
||||
self,
|
||||
mu,
|
||||
cov,
|
||||
allow_nan=False,
|
||||
strict=True,
|
||||
strict_statistics=True,
|
||||
name="MultivariateNormalCov"):
|
||||
"""Multivariate Normal distributions on `R^k`.
|
||||
|
||||
User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`)
|
||||
with the last dimension having length `k`.
|
||||
|
||||
User must provide exactly one of `sigma` (the covariance matrices) or
|
||||
`sigma_chol` (the cholesky decompositions of the covariance matrices).
|
||||
`sigma` or `sigma_chol` must be of rank `N+2`. The last two dimensions
|
||||
must both have length `k`. The first `N` dimensions correspond to batch
|
||||
indices.
|
||||
|
||||
If `sigma_chol` is not provided, the batch cholesky factorization of `sigma`
|
||||
is calculated for you.
|
||||
|
||||
The shapes of `mu` and `sigma` must match for the first `N` dimensions.
|
||||
|
||||
Regardless of which parameter is provided, the covariance matrices must all
|
||||
be **positive definite** (an error is raised if one of them is not).
|
||||
User must provide means `mu`, and an instance of `OperatorPDBase`, `cov`,
|
||||
which determines the covariance.
|
||||
|
||||
Args:
|
||||
mu: (N+1)-D. `float` or `double` tensor, the means of the distributions.
|
||||
sigma: (N+2)-D. (optional) `float` or `double` tensor, the covariances
|
||||
of the distribution(s). The first `N+1` dimensions must match
|
||||
those of `mu`. Must be batch-positive-definite.
|
||||
sigma_chol: (N+2)-D. (optional) `float` or `double` tensor, a
|
||||
lower-triangular factorization of `sigma`
|
||||
(`sigma = sigma_chol . sigma_chol^*`). The first `N+1` dimensions
|
||||
must match those of `mu`. The tensor itself need not be batch
|
||||
lower triangular: we ignore the upper triangular part. However,
|
||||
the batch diagonals must be positive (i.e., sigma_chol must be
|
||||
batch-positive-definite).
|
||||
mu: `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
|
||||
cov: `float` or `double` instance of `OperatorPDBase` with same `dtype`
|
||||
as `mu` and shape `[N1,...,Nb, k, k]`.
|
||||
allow_nan: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
strict: Whether to validate input with asserts. If `strict` is `False`,
|
||||
and the inputs are invalid, correct behavior is not guaranteed.
|
||||
strict_statistics: Boolean, default True. If True, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If False, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
name: The name to give Ops created by the initializer.
|
||||
|
||||
Raises:
|
||||
ValueError: if neither sigma nor sigma_chol is provided.
|
||||
TypeError: if mu and sigma (resp. sigma_chol) are different dtypes.
|
||||
TypeError: If `mu` and `cov` are different dtypes.
|
||||
"""
|
||||
if (sigma is None) == (sigma_chol is None):
|
||||
raise ValueError("Exactly one of sigma and sigma_chol must be provided")
|
||||
self._strict_statistics = strict_statistics
|
||||
self._strict = strict
|
||||
with ops.name_scope(name):
|
||||
with ops.op_scope([mu] + cov.inputs, "init"):
|
||||
self._cov = cov
|
||||
self._mu = self._check_mu(mu)
|
||||
self._name = name
|
||||
|
||||
with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"):
|
||||
sigma_or_half = sigma_chol if sigma is None else sigma
|
||||
def _check_mu(self, mu):
|
||||
"""Return `mu` after validity checks and possibly with assertations."""
|
||||
mu = ops.convert_to_tensor(mu)
|
||||
cov = self._cov
|
||||
|
||||
mu = ops.convert_to_tensor(mu)
|
||||
sigma_or_half = ops.convert_to_tensor(sigma_or_half)
|
||||
if mu.dtype != cov.dtype:
|
||||
raise TypeError(
|
||||
"mu and cov must have the same dtype. Found mu.dtype = %s, "
|
||||
"cov.dtype = %s"
|
||||
% (mu.dtype, cov.dtype))
|
||||
if not self.strict:
|
||||
return mu
|
||||
else:
|
||||
assert_compatible_shapes = control_flow_ops.group(
|
||||
check_ops.assert_equal(
|
||||
array_ops.rank(mu) + 1,
|
||||
cov.rank(),
|
||||
data=["mu should have rank 1 less than cov. Found: rank(mu) = ",
|
||||
array_ops.rank(mu), " rank(cov) = ", cov.rank()],
|
||||
),
|
||||
check_ops.assert_equal(
|
||||
array_ops.shape(mu),
|
||||
cov.vector_shape(),
|
||||
data=["mu.shape and cov.shape[:-1] should match. "
|
||||
"Found: shape(mu) = "
|
||||
, array_ops.shape(mu), " shape(cov) = ", cov.shape()],
|
||||
),
|
||||
)
|
||||
return control_flow_ops.with_dependencies([assert_compatible_shapes], mu)
|
||||
|
||||
contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half))
|
||||
@property
|
||||
def strict(self):
|
||||
"""Boolean describing behavior on invalid input."""
|
||||
return self._strict
|
||||
|
||||
with ops.control_dependencies([
|
||||
_assert_compatible_shapes(mu, sigma_or_half)]):
|
||||
mu = array_ops.identity(mu, name="mu")
|
||||
|
||||
# Store the dimensionality of the MVNs
|
||||
self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1)
|
||||
|
||||
if sigma_chol is not None:
|
||||
# Ensure we only keep the lower triangular part.
|
||||
sigma_chol = array_ops.batch_matrix_band_part(
|
||||
sigma_chol, num_lower=-1, num_upper=0)
|
||||
log_sigma_det = _log_determinant_from_sigma_chol(sigma_chol)
|
||||
with ops.control_dependencies([
|
||||
_assert_batch_positive_definite(sigma_chol)]):
|
||||
self._sigma = math_ops.batch_matmul(
|
||||
sigma_chol, sigma_chol, adj_y=True, name="sigma")
|
||||
self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
|
||||
self._log_sigma_det = array_ops.identity(
|
||||
log_sigma_det, "log_sigma_det")
|
||||
self._mu = array_ops.identity(mu, "mu")
|
||||
else: # sigma is not None
|
||||
sigma_chol = linalg_ops.batch_cholesky(sigma)
|
||||
log_sigma_det = _log_determinant_from_sigma_chol(sigma_chol)
|
||||
# batch_cholesky checks for PSD; so we can just use it here.
|
||||
with ops.control_dependencies([sigma_chol]):
|
||||
self._sigma = array_ops.identity(sigma, "sigma")
|
||||
self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
|
||||
self._log_sigma_det = array_ops.identity(
|
||||
log_sigma_det, "log_sigma_det")
|
||||
self._mu = array_ops.identity(mu, "mu")
|
||||
@property
|
||||
def strict_statistics(self):
|
||||
"""Boolean describing behavior when a stat is undefined for batch member."""
|
||||
return self._strict_statistics
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._mu.dtype
|
||||
|
||||
def get_event_shape(self):
|
||||
"""`TensorShape` available at graph construction time."""
|
||||
# Recall _check_mu ensures mu and self._cov have same batch shape.
|
||||
return self._cov.get_shape()[-1:]
|
||||
|
||||
def event_shape(self, name="event_shape"):
|
||||
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`."""
|
||||
# Recall _check_mu ensures mu and self._cov have same batch shape.
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self._cov.inputs, name):
|
||||
return array_ops.pack([self._cov.vector_space_dimension()])
|
||||
|
||||
def batch_shape(self, name="batch_shape"):
|
||||
"""Batch dimensions of this instance as a 1-D int32 `Tensor`."""
|
||||
# Recall _check_mu ensures mu and self._cov have same batch shape.
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self._cov.inputs, name):
|
||||
return self._cov.batch_shape()
|
||||
|
||||
def get_batch_shape(self):
|
||||
"""`TensorShape` available at graph construction time."""
|
||||
# Recall _check_mu ensures mu and self._cov have same batch shape.
|
||||
return self._cov.get_batch_shape()
|
||||
|
||||
@property
|
||||
def mu(self):
|
||||
return self._mu
|
||||
|
||||
@property
|
||||
def sigma(self):
|
||||
return self._sigma
|
||||
"""Dense (batch) covariance matrix, if available."""
|
||||
with ops.name_scope(self.name):
|
||||
return self._cov.to_dense()
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self._mu
|
||||
def mean(self, name="mean"):
|
||||
"""Mean of each batch member."""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._mu], name):
|
||||
return array_ops.identity(self._mu)
|
||||
|
||||
@property
|
||||
def sigma_det(self):
|
||||
return math_ops.exp(self._log_sigma_det)
|
||||
def mode(self, name="mode"):
|
||||
"""Mode of each batch member."""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._mu], name):
|
||||
return array_ops.identity(self._mu)
|
||||
|
||||
def log_pdf(self, x, name=None):
|
||||
def variance(self, name="variance"):
|
||||
"""Variance of each batch member."""
|
||||
with ops.name_scope(self.name):
|
||||
return self.sigma
|
||||
|
||||
def log_sigma_det(self, name="log_sigma_det"):
|
||||
"""Log of determinant of covariance matrix."""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self._cov.inputs, name):
|
||||
return self._cov.log_det()
|
||||
|
||||
def sigma_det(self, name="sigma_det"):
|
||||
"""Determinant of covariance matrix."""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self._cov.inputs, name):
|
||||
return math_ops.exp(self._cov.log_det())
|
||||
|
||||
def log_pdf(self, x, name="log_pdf"):
|
||||
"""Log pdf of observations `x` given these Multivariate Normals.
|
||||
|
||||
Args:
|
||||
@ -262,120 +246,25 @@ class MultivariateNormal(object):
|
||||
Returns:
|
||||
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
|
||||
"""
|
||||
with ops.op_scope(
|
||||
[self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"):
|
||||
x = ops.convert_to_tensor(x)
|
||||
contrib_tensor_util.assert_same_float_dtype((self._mu, x))
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._mu, x] + self._cov.inputs, name):
|
||||
x = ops.convert_to_tensor(x)
|
||||
contrib_tensor_util.assert_same_float_dtype((self._mu, x))
|
||||
|
||||
x_centered = x - self.mu
|
||||
x_centered = x - self.mu
|
||||
x_whitened_norm = self._cov.inv_quadratic_form(x_centered)
|
||||
log_sigma_det = self.log_sigma_det()
|
||||
|
||||
x_rank = array_ops.rank(x_centered)
|
||||
sigma_rank = array_ops.rank(self._sigma_chol)
|
||||
log_two_pi = constant_op.constant(
|
||||
math.log(2 * math.pi), dtype=self.dtype)
|
||||
k = math_ops.cast(self._cov.vector_space_dimension(), self.dtype)
|
||||
log_pdf_value = -(log_sigma_det + k * log_two_pi + x_whitened_norm) / 2
|
||||
|
||||
x_rank_vec = array_ops.pack([x_rank])
|
||||
sigma_rank_vec = array_ops.pack([sigma_rank])
|
||||
x_shape = array_ops.shape(x_centered)
|
||||
output_static_shape = x_centered.get_shape()[:-1]
|
||||
log_pdf_value.set_shape(output_static_shape)
|
||||
return log_pdf_value
|
||||
|
||||
# sigma_chol is shaped [D, E, F, ..., k, k]
|
||||
# x_centered shape is one of:
|
||||
# [D, E, F, ..., k], or [F, ..., k], or
|
||||
# [A, B, C, D, E, F, ..., k]
|
||||
# and we need to convert x_centered to shape:
|
||||
# [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist)
|
||||
# then transpose and reshape x_whitened back to one of the shapes:
|
||||
# [D, E, F, ..., k], or [1, 1, F, ..., k], or
|
||||
# [A, B, C, D, E, F, ..., k]
|
||||
# Note that if rank(x) <= rank(sigma) - 1, the first dimensions of
|
||||
# x_centered will match sigma exactly because x_centered = x - self.mu.
|
||||
|
||||
# This helper handles the case where rank(x_centered) < rank(sigma)
|
||||
def _broadcast_x_not_higher_rank_than_sigma():
|
||||
return array_ops.reshape(
|
||||
x_centered,
|
||||
array_ops.concat(
|
||||
# Reshape to ones(deficient x rank) + x_shape + [1]
|
||||
0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]),
|
||||
dtype=x_rank.dtype),
|
||||
x_shape,
|
||||
[1])))
|
||||
|
||||
# These helpers handle the case where rank(x_centered) >= rank(sigma)
|
||||
def _broadcast_x_higher_rank_than_sigma():
|
||||
x_shape_left = array_ops.slice(
|
||||
x_shape, [0], sigma_rank_vec - 1)
|
||||
x_shape_right = array_ops.slice(
|
||||
x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
|
||||
x_shape_perm = array_ops.concat(
|
||||
0, (math_ops.range(sigma_rank - 1, x_rank),
|
||||
math_ops.range(0, sigma_rank - 1)))
|
||||
return array_ops.reshape(
|
||||
# Convert to [D, E, F, ..., k, B, C]
|
||||
array_ops.transpose(
|
||||
x_centered, perm=x_shape_perm),
|
||||
# Reshape to [D, E, F, ..., k, B*C]
|
||||
array_ops.concat(
|
||||
0, (x_shape_right,
|
||||
array_ops.pack([
|
||||
math_ops.reduce_prod(x_shape_left, 0)]))))
|
||||
|
||||
def _unbroadcast_x_higher_rank_than_sigma():
|
||||
x_shape_left = array_ops.slice(
|
||||
x_shape, [0], sigma_rank_vec - 1)
|
||||
x_shape_right = array_ops.slice(
|
||||
x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
|
||||
x_shape_perm = array_ops.concat(
|
||||
0, (math_ops.range(sigma_rank - 1, x_rank),
|
||||
math_ops.range(0, sigma_rank - 1)))
|
||||
return array_ops.transpose(
|
||||
# [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k]
|
||||
array_ops.reshape(
|
||||
# convert to [D, E, F, ..., k, B, C]
|
||||
x_whitened_broadcast,
|
||||
array_ops.concat(0, (x_shape_right, x_shape_left))),
|
||||
perm=x_shape_perm)
|
||||
|
||||
# Step 1: reshape x_centered
|
||||
x_centered_broadcast = control_flow_ops.cond(
|
||||
# x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1]
|
||||
# or == [F, ..., k] => [1, 1, F, ..., k, 1]
|
||||
x_rank <= sigma_rank - 1,
|
||||
_broadcast_x_not_higher_rank_than_sigma,
|
||||
# x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C]
|
||||
_broadcast_x_higher_rank_than_sigma)
|
||||
|
||||
x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve(
|
||||
self._sigma_chol, x_centered_broadcast)
|
||||
|
||||
# Reshape x_whitened_broadcast back to x_whitened
|
||||
x_whitened = control_flow_ops.cond(
|
||||
x_rank <= sigma_rank - 1,
|
||||
lambda: array_ops.reshape(x_whitened_broadcast, x_shape),
|
||||
_unbroadcast_x_higher_rank_than_sigma)
|
||||
|
||||
x_whitened = array_ops.expand_dims(x_whitened, -1)
|
||||
# Reshape x_whitened to contain row vectors
|
||||
# Returns a batchwise scalar
|
||||
x_whitened_norm = math_ops.batch_matmul(
|
||||
x_whitened, x_whitened, adj_x=True)
|
||||
x_whitened_norm = control_flow_ops.cond(
|
||||
x_rank <= sigma_rank - 1,
|
||||
lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]),
|
||||
lambda: array_ops.squeeze(x_whitened_norm, [-1]))
|
||||
|
||||
log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype)
|
||||
k = math_ops.cast(self._k, self.dtype)
|
||||
log_pdf_value = (
|
||||
-self._log_sigma_det -(k * log_two_pi) - x_whitened_norm) / 2
|
||||
final_shaped_value = control_flow_ops.cond(
|
||||
x_rank <= sigma_rank - 1,
|
||||
lambda: log_pdf_value,
|
||||
lambda: array_ops.squeeze(log_pdf_value, [-1]))
|
||||
|
||||
output_static_shape = x_centered.get_shape()[:-1]
|
||||
final_shaped_value.set_shape(output_static_shape)
|
||||
return final_shaped_value
|
||||
|
||||
def pdf(self, x, name=None):
|
||||
def pdf(self, x, name="pdf"):
|
||||
"""The PDF of observations `x` under these Multivariate Normals.
|
||||
|
||||
Args:
|
||||
@ -385,11 +274,11 @@ class MultivariateNormal(object):
|
||||
Returns:
|
||||
pdf: tensor of dtype `dtype`, the pdf values of `x`.
|
||||
"""
|
||||
with ops.op_scope(
|
||||
[self._mu, self._sigma_chol, x], name, "MultivariateNormalPdf"):
|
||||
return math_ops.exp(self.log_pdf(x))
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._mu, x] + self._cov.inputs, name):
|
||||
return math_ops.exp(self.log_pdf(x))
|
||||
|
||||
def entropy(self, name=None):
|
||||
def entropy(self, name="entropy"):
|
||||
"""The entropies of these Multivariate Normals.
|
||||
|
||||
Args:
|
||||
@ -398,19 +287,19 @@ class MultivariateNormal(object):
|
||||
Returns:
|
||||
entropy: tensor of dtype `dtype`, the entropies.
|
||||
"""
|
||||
with ops.op_scope(
|
||||
[self._mu, self._sigma_chol], name, "MultivariateNormalEntropy"):
|
||||
one_plus_log_two_pi = constant_op.constant(
|
||||
1 + math.log(2 * math.pi), dtype=self.dtype)
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._mu] + self._cov.inputs, name):
|
||||
log_sigma_det = self.log_sigma_det()
|
||||
one_plus_log_two_pi = constant_op.constant(1 + math.log(2 * math.pi),
|
||||
dtype=self.dtype)
|
||||
|
||||
# Use broadcasting rules to calculate the full broadcast sigma.
|
||||
k = math_ops.cast(self._k, dtype=self.dtype)
|
||||
entropy_value = (
|
||||
k * one_plus_log_two_pi + self._log_sigma_det) / 2
|
||||
entropy_value.set_shape(self._log_sigma_det.get_shape())
|
||||
return entropy_value
|
||||
# Use broadcasting rules to calculate the full broadcast sigma.
|
||||
k = math_ops.cast(self._cov.vector_space_dimension(), dtype=self.dtype)
|
||||
entropy_value = (k * one_plus_log_two_pi + log_sigma_det) / 2
|
||||
entropy_value.set_shape(log_sigma_det.get_shape())
|
||||
return entropy_value
|
||||
|
||||
def sample(self, n, seed=None, name=None):
|
||||
def sample(self, n, seed=None, name="sample"):
|
||||
"""Sample `n` observations from the Multivariate Normal Distributions.
|
||||
|
||||
Args:
|
||||
@ -422,43 +311,203 @@ class MultivariateNormal(object):
|
||||
samples: `[n, ...]`, a `Tensor` of `n` samples for each
|
||||
of the distributions determined by broadcasting the hyperparameters.
|
||||
"""
|
||||
with ops.op_scope(
|
||||
[self._mu, self._sigma_chol, n], name, "MultivariateNormalSample"):
|
||||
# TODO(ebrevdo): Is there a better way to get broadcast_shape?
|
||||
broadcast_shape = self.mu.get_shape()
|
||||
n = ops.convert_to_tensor(n)
|
||||
sigma_shape_left = array_ops.slice(
|
||||
array_ops.shape(self._sigma_chol),
|
||||
[0], array_ops.pack([array_ops.rank(self._sigma_chol) - 2]))
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._mu, n] + self._cov.inputs, name):
|
||||
# Recall _check_mu ensures mu and self._cov have same batch shape.
|
||||
broadcast_shape = self.mu.get_shape()
|
||||
n = ops.convert_to_tensor(n)
|
||||
|
||||
k_n = array_ops.pack([self._k, n])
|
||||
shape = array_ops.concat(0, [sigma_shape_left, k_n])
|
||||
white_samples = random_ops.random_normal(
|
||||
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
|
||||
shape = array_ops.concat(0, [self._cov.vector_shape(), [n]])
|
||||
white_samples = random_ops.random_normal(shape=shape,
|
||||
mean=0,
|
||||
stddev=1,
|
||||
dtype=self.dtype,
|
||||
seed=seed)
|
||||
|
||||
correlated_samples = math_ops.batch_matmul(
|
||||
self._sigma_chol, white_samples)
|
||||
correlated_samples = self._cov.sqrt_matmul(white_samples)
|
||||
|
||||
# Move the last dimension to the front
|
||||
perm = array_ops.concat(
|
||||
0,
|
||||
(array_ops.pack([array_ops.rank(correlated_samples) - 1]),
|
||||
math_ops.range(0, array_ops.rank(correlated_samples) - 1)))
|
||||
# Move the last dimension to the front
|
||||
perm = array_ops.concat(0, (
|
||||
array_ops.pack([array_ops.rank(correlated_samples) - 1]),
|
||||
math_ops.range(0, array_ops.rank(correlated_samples) - 1)))
|
||||
|
||||
# TODO(ebrevdo): Once we get a proper tensor contraction op,
|
||||
# perform the inner product using that instead of batch_matmul
|
||||
# and this slow transpose can go away!
|
||||
correlated_samples = array_ops.transpose(correlated_samples, perm)
|
||||
# TODO(ebrevdo): Once we get a proper tensor contraction op,
|
||||
# perform the inner product using that instead of batch_matmul
|
||||
# and this slow transpose can go away!
|
||||
correlated_samples = array_ops.transpose(correlated_samples, perm)
|
||||
|
||||
samples = correlated_samples + self.mu
|
||||
samples = correlated_samples + self.mu
|
||||
|
||||
# Provide some hints to shape inference
|
||||
n_val = tensor_util.constant_value(n)
|
||||
final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
|
||||
samples.set_shape(final_shape)
|
||||
# Provide some hints to shape inference
|
||||
n_val = tensor_util.constant_value(n)
|
||||
final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
|
||||
samples.set_shape(final_shape)
|
||||
|
||||
return samples
|
||||
return samples
|
||||
|
||||
@property
|
||||
def is_reparameterized(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
|
||||
class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
|
||||
"""The multivariate normal distribution on `R^k`.
|
||||
|
||||
This distribution is defined by a 1-D mean `mu` and a Cholesky factor `chol`.
|
||||
Providing the Cholesky factor allows for `O(k^2)` pdf evaluation and sampling,
|
||||
and requires `O(k^2)` storage.
|
||||
|
||||
#### Mathematical details
|
||||
|
||||
The PDF of this distribution is:
|
||||
|
||||
```
|
||||
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
|
||||
```
|
||||
|
||||
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
|
||||
|
||||
#### Examples
|
||||
|
||||
A single multi-variate Gaussian distribution is defined by a vector of means
|
||||
of length `k`, and a covariance matrix of shape `k x k`.
|
||||
|
||||
Extra leading dimensions, if provided, allow for batches.
|
||||
|
||||
```python
|
||||
# Initialize a single 3-variate Gaussian with diagonal covariance.
|
||||
mu = [1, 2, 3.]
|
||||
chol = [[1, 0, 0], [0, 3, 0], [0, 0, 2]]
|
||||
dist = tf.contrib.distributions.MultivariateNormalCholesky(mu, chol)
|
||||
|
||||
# Evaluate this on an observation in R^3, returning a scalar.
|
||||
dist.pdf([-1, 0, 1])
|
||||
|
||||
# Initialize a batch of two 3-variate Gaussians.
|
||||
mu = [[1, 2, 3], [11, 22, 33]]
|
||||
chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal.
|
||||
dist = tf.contrib.distributions.MultivariateNormalCholesky(mu, chol)
|
||||
|
||||
# Evaluate this on a two observations, each in R^3, returning a length two
|
||||
# tensor.
|
||||
x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3.
|
||||
dist.pdf(x)
|
||||
```
|
||||
|
||||
Trainable (batch) Choesky matrices can be created with
|
||||
`tf.contrib.distributions.batch_matrix_diag_transform()`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mu,
|
||||
chol,
|
||||
strict=True,
|
||||
strict_statistics=True,
|
||||
name="MultivariateNormalCholesky"):
|
||||
"""Multivariate Normal distributions on `R^k`.
|
||||
|
||||
User must provide means `mu` and `chol` which holds the (batch) Cholesky
|
||||
factors `S`, such that the covariance of each batch member is `S S^*`.
|
||||
|
||||
Args:
|
||||
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
||||
`b >= 0`.
|
||||
chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
||||
`[N1,...,Nb, k, k]`.
|
||||
strict: Whether to validate input with asserts. If `strict` is `False`,
|
||||
and the inputs are invalid, correct behavior is not guaranteed.
|
||||
strict_statistics: Boolean, default True. If True, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If False, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
name: The name to give Ops created by the initializer.
|
||||
|
||||
Raises:
|
||||
TypeError: If `mu` and `chol` are different dtypes.
|
||||
"""
|
||||
cov = operator_pd_cholesky.OperatorPDCholesky(chol, verify_pd=strict)
|
||||
super(MultivariateNormalCholesky, self).__init__(
|
||||
mu, cov, strict_statistics=strict_statistics, strict=strict, name=name)
|
||||
|
||||
|
||||
class MultivariateNormalFull(MultivariateNormalOperatorPD):
|
||||
"""The multivariate normal distribution on `R^k`.
|
||||
|
||||
This distribution is defined by a 1-D mean `mu` and covariance matrix `sigma`.
|
||||
Evaluation of the pdf, determinant, and sampling are all `O(k^3)` operations.
|
||||
|
||||
#### Mathematical details
|
||||
|
||||
The PDF of this distribution is:
|
||||
|
||||
```
|
||||
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
|
||||
```
|
||||
|
||||
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
|
||||
|
||||
#### Examples
|
||||
|
||||
A single multi-variate Gaussian distribution is defined by a vector of means
|
||||
of length `k`, and a covariance matrix of shape `k x k`.
|
||||
|
||||
Extra leading dimensions, if provided, allow for batches.
|
||||
|
||||
```python
|
||||
# Initialize a single 3-variate Gaussian with diagonal covariance.
|
||||
mu = [1, 2, 3.]
|
||||
sigma = [[1, 0, 0], [0, 3, 0], [0, 0, 2.]]
|
||||
dist = tf.contrib.distributions.MultivariateNormalFull(mu, chol)
|
||||
|
||||
# Evaluate this on an observation in R^3, returning a scalar.
|
||||
dist.pdf([-1, 0, 1])
|
||||
|
||||
# Initialize a batch of two 3-variate Gaussians.
|
||||
mu = [[1, 2, 3], [11, 22, 33.]]
|
||||
sigma = ... # shape 2 x 3 x 3, positive definite.
|
||||
dist = tf.contrib.distributions.MultivariateNormalFull(mu, sigma)
|
||||
|
||||
# Evaluate this on a two observations, each in R^3, returning a length two
|
||||
# tensor.
|
||||
x = [[-1, 0, 1], [-11, 0, 11.]] # Shape 2 x 3.
|
||||
dist.pdf(x)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mu,
|
||||
sigma,
|
||||
strict=True,
|
||||
strict_statistics=True,
|
||||
name="MultivariateNormalFull"):
|
||||
"""Multivariate Normal distributions on `R^k`.
|
||||
|
||||
User must provide means `mu` and `sigma`, the mean and covariance.
|
||||
|
||||
Args:
|
||||
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
||||
`b >= 0`.
|
||||
sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
||||
`[N1,...,Nb, k, k]`.
|
||||
strict: Whether to validate input with asserts. If `strict` is `False`,
|
||||
and the inputs are invalid, correct behavior is not guaranteed.
|
||||
strict_statistics: Boolean, default True. If True, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If False, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
name: The name to give Ops created by the initializer.
|
||||
|
||||
Raises:
|
||||
TypeError: If `mu` and `sigma` are different dtypes.
|
||||
"""
|
||||
cov = operator_pd_full.OperatorPDFull(sigma, verify_pd=strict)
|
||||
super(MultivariateNormalFull, self).__init__(
|
||||
mu, cov, strict_statistics=strict_statistics, strict=strict, name=name)
|
||||
|
284
tensorflow/contrib/distributions/python/ops/operator_pd.py
Normal file
284
tensorflow/contrib/distributions/python/ops/operator_pd.py
Normal file
@ -0,0 +1,284 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Base class for symmetric positive definite operator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OperatorPDBase',
|
||||
]
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class OperatorPDBase(object):
|
||||
"""Class representing a (batch) of positive definite matrices `A`.
|
||||
|
||||
This class provides access to functions of a (batch) symmetric positive
|
||||
definite (PD) matrix, without the need to materialize them. In other words,
|
||||
this provides means to do "matrix free" computations.
|
||||
|
||||
For example, `my_operator.matmul(x)` computes the result of matrix
|
||||
multiplication, and this class is free to do this computation with or without
|
||||
ever materializing a matrix.
|
||||
|
||||
In practice, this operator represents a (batch) matrix `A` with shape
|
||||
`[N1,...,Nb, k, k]` for some `b >= 0`. The first `b` indices index a
|
||||
batch member. For every batch index `(n1,...,nb)`, `A[n1,...,nb, : :]` is
|
||||
a `k x k` matrix. Again, this matrix `A` may not be materialized, but for
|
||||
purposes of broadcasting this shape will be relevant.
|
||||
|
||||
Since `A` is (batch) positive definite, it has a (or several) square roots `S`
|
||||
such that `A = SS^T`.
|
||||
|
||||
For example, if `MyOperator` inherits from `OperatorPDBase`, the user can do
|
||||
|
||||
```python
|
||||
operator = MyOperator(...) # Initialize with some tensors.
|
||||
operator.log_det()
|
||||
|
||||
# Compute the quadratic form x^T A^{-1} x for vector x.
|
||||
x = ... # some shape [..., k] tensor
|
||||
operator.inv_quadratic_form(x)
|
||||
|
||||
# Matrix multiplication by the square root, S w.
|
||||
# If w is iid normal, S w has covariance A.
|
||||
w = ... # some shape [..., k, L] tensor, L >= 1
|
||||
operator.sqrt_matmul(w)
|
||||
```
|
||||
|
||||
The above three methods, `log_det`, `inv_quadratic_form`, and
|
||||
`sqrt_matmul` provide "all" that is necessary to use a covariance matrix
|
||||
in a multi-variate normal distribution. See the class `MVNOperatorPD`.
|
||||
"""
|
||||
|
||||
@abc.abstractproperty
|
||||
def name(self):
|
||||
"""String name identifying this `Operator`."""
|
||||
# return self._name
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def verify_pd(self):
|
||||
"""Whether to verify that this `Operator` is positive definite."""
|
||||
# return self._verify_pd
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def dtype(self):
|
||||
"""Data type of matrix elements of `A`."""
|
||||
pass
|
||||
|
||||
def inv_quadratic_form(self, x, name='inv_quadratic_form'):
|
||||
"""Compute the quadratic form: x^T A^{-1} x.
|
||||
|
||||
Args:
|
||||
x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
|
||||
as self.
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`Tensor` holding the square of the norm induced by inverse of `A`. For
|
||||
every broadcast batch member.
|
||||
"""
|
||||
# with ops.name_scope(self.name):
|
||||
# with ops.op_scope([x] + self.inputs, name):
|
||||
# # ... your code here
|
||||
pass
|
||||
|
||||
def det(self, name='det'):
|
||||
"""Determinant for every batch member.
|
||||
|
||||
Args:
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
Determinant for every batch member.
|
||||
"""
|
||||
# Derived classes are encouraged to implement log_det() (since it is
|
||||
# usually more stable), and then det() comes for free.
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self.inputs, name):
|
||||
return math_ops.exp(self.log_det())
|
||||
|
||||
def log_det(self, name='log_det'):
|
||||
"""Log of the determinant for every batch member.
|
||||
|
||||
Args:
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
Logarithm of determinant for every batch member.
|
||||
"""
|
||||
# with ops.name_scope(self.name):
|
||||
# with ops.op_scope(self.inputs, name):
|
||||
# # ... your code here
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def inputs(self):
|
||||
"""List of tensors that were provided as initialization inputs."""
|
||||
pass
|
||||
|
||||
def sqrt_matmul(self, x, name='sqrt_matmul'):
|
||||
"""Left (batch) matmul `x` by a sqrt of this matrix: `Sx` where `A = S S^T.
|
||||
|
||||
Args:
|
||||
x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
|
||||
as self.
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`.
|
||||
"""
|
||||
# with ops.name_scope(self.name):
|
||||
# with ops.op_scope([x] + self.inputs, name):
|
||||
# # ... your code here
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_shape(self):
|
||||
"""`TensorShape` giving static shape."""
|
||||
pass
|
||||
|
||||
def get_batch_shape(self):
|
||||
"""`TensorShape` with batch shape."""
|
||||
return self.get_shape()[:-2]
|
||||
|
||||
def get_vector_shape(self):
|
||||
"""`TensorShape` of vectors this operator will work with."""
|
||||
return self.get_shape()[:-1]
|
||||
|
||||
@abc.abstractmethod
|
||||
def shape(self, name='shape'):
|
||||
"""Equivalent to `tf.shape(A).` Equal to `[N1,...,Nb, k, k]`, `b >= 0`.
|
||||
|
||||
Args:
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`int32` `Tensor`
|
||||
"""
|
||||
# with ops.name_scope(self.name):
|
||||
# with ops.op_scope(self.inputs, name):
|
||||
# # ... your code here
|
||||
pass
|
||||
|
||||
def rank(self, name='rank'):
|
||||
"""Tensor rank. Equivalent to `tf.rank(A)`. Will equal `b + 2`.
|
||||
|
||||
If this operator represents the batch matrix `A` with
|
||||
`A.shape = [N1,...,Nb, k, k]`, the `rank` is `b + 2`.
|
||||
|
||||
Args:
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`int32` `Tensor`
|
||||
"""
|
||||
# Derived classes get this "for free" once .shape() is implemented.
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self.inputs, name):
|
||||
return array_ops.shape(self.shape())[0]
|
||||
|
||||
def batch_shape(self, name='batch_shape'):
|
||||
"""Shape of batches associated with this operator.
|
||||
|
||||
If this operator represents the batch matrix `A` with
|
||||
`A.shape = [N1,...,Nb, k, k]`, the `batch_shape` is `[N1,...,Nb]`.
|
||||
|
||||
Args:
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`int32` `Tensor`
|
||||
"""
|
||||
# Derived classes get this "for free" once .shape() is implemented.
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self.inputs, name):
|
||||
end = array_ops.pack([self.rank() - 2])
|
||||
return array_ops.slice(self.shape(), [0], end)
|
||||
|
||||
def vector_shape(self, name='vector_shape'):
|
||||
"""Shape of (batch) vectors that this (batch) matrix will multiply.
|
||||
|
||||
If this operator represents the batch matrix `A` with
|
||||
`A.shape = [N1,...,Nb, k, k]`, the `vector_shape` is `[N1,...,Nb, k]`.
|
||||
|
||||
Args:
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`int32` `Tensor`
|
||||
"""
|
||||
# Derived classes get this "for free" once .shape() is implemented.
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self.inputs, name):
|
||||
return array_ops.slice(self.shape(), [0], [self.rank() - 1])
|
||||
|
||||
def vector_space_dimension(self, name='vector_space_dimension'):
|
||||
"""Dimension of vector space on which this acts. The `k` in `R^k`.
|
||||
|
||||
If this operator represents the batch matrix `A` with
|
||||
`A.shape = [N1,...,Nb, k, k]`, the `vector_space_dimension` is `k`.
|
||||
|
||||
Args:
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`int32` `Tensor`
|
||||
"""
|
||||
# Derived classes get this "for free" once .shape() is implemented.
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self.inputs, name):
|
||||
return array_ops.gather(self.shape(), self.rank() - 1)
|
||||
|
||||
def matmul(self, x, name='matmul'):
|
||||
"""Left multiply `x` by this operator.
|
||||
|
||||
Args:
|
||||
x: Shape `[N1,...,Nb, k, L]` `Tensor` with same `dtype` as this operator
|
||||
name: A name to give this `Op`.
|
||||
|
||||
Returns:
|
||||
A result equivalent to `tf.batch_matmul(self.to_dense(), x)`.
|
||||
"""
|
||||
# with ops.name_scope(self.name):
|
||||
# with ops.op_scope([x] + self.inputs, name):
|
||||
# # ... your code here
|
||||
raise NotImplementedError('This operator has no batch_matmul Op.')
|
||||
|
||||
def to_dense(self, name='to_dense'):
|
||||
"""Return a dense (batch) matrix representing this operator."""
|
||||
# with ops.name_scope(self.name):
|
||||
# with ops.op_scope(self.inputs, name):
|
||||
# # ... your code here
|
||||
raise NotImplementedError('This operator has no dense representation.')
|
||||
|
||||
def to_dense_sqrt(self, name='to_dense_sqrt'):
|
||||
"""Return a dense (batch) matrix representing sqrt of this operator."""
|
||||
# with ops.name_scope(self.name):
|
||||
# with ops.op_scope(self.inputs, name):
|
||||
# # ... your code here
|
||||
raise NotImplementedError('This operator has no dense sqrt representation.')
|
@ -0,0 +1,431 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Symmetric positive definite (PD) Operator defined by a Cholesky factor."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import operator_pd
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
__all__ = ['OperatorPDCholesky', 'batch_matrix_diag_transform']
|
||||
|
||||
|
||||
class OperatorPDCholesky(operator_pd.OperatorPDBase):
|
||||
"""Class representing a (batch) of positive definite matrices `A`.
|
||||
|
||||
This class provides access to functions of a batch of symmetric positive
|
||||
definite (PD) matrices `A` in `R^{k x k}` defined by Cholesky factor(s).
|
||||
Determinants and solves are `O(k^2)`.
|
||||
|
||||
In practice, this operator represents a (batch) matrix `A` with shape
|
||||
`[N1,...,Nb, k, k]` for some `b >= 0`. The first `b` indices designate a
|
||||
batch member. For every batch member `(n1,...,nb)`, `A[n1,...,nb, : :]` is
|
||||
a `k x k` matrix.
|
||||
|
||||
Since `A` is (batch) positive definite, it has a (or several) square roots `S`
|
||||
such that `A = SS^T`.
|
||||
|
||||
For example,
|
||||
|
||||
```python
|
||||
distributions = tf.contrib.distributions
|
||||
chol = [[1.0, 0.0], [1.0, 2.0]]
|
||||
operator = OperatorPDCholesky(chol)
|
||||
operator.log_det()
|
||||
|
||||
# Compute the quadratic form x^T A^{-1} x for vector x.
|
||||
x = [1.0, 2.0]
|
||||
operator.inv_quadratic_form(x)
|
||||
|
||||
# Matrix multiplication by the square root, S w.
|
||||
# If w is iid normal, S w has covariance A.
|
||||
w = [[1.0], [2.0]]
|
||||
operator.sqrt_matmul(w)
|
||||
```
|
||||
|
||||
The above three methods, `log_det`, `inv_quadratic_form`, and
|
||||
`sqrt_matmul` provide "all" that is necessary to use a covariance matrix
|
||||
in a multi-variate normal distribution. See the class `MVNOperatorPD`.
|
||||
"""
|
||||
|
||||
def __init__(self, chol, verify_pd=True, name='OperatorPDCholesky'):
|
||||
"""Initialize an OperatorPDCholesky.
|
||||
|
||||
Args:
|
||||
chol: Shape `[N1,...,Nb, k, k]` tensor with `b >= 0`, `k >= 1`, and
|
||||
positive diagonal elements. The strict upper triangle of `chol` is
|
||||
never used, and the user may set these elements to zero, or ignore them.
|
||||
verify_pd: Whether to check that `chol` has positive diagonal (this is
|
||||
equivalent to it being a Cholesky factor of a symmetric positive
|
||||
definite matrix. If `verify_pd` is `False`, correct behavior is not
|
||||
guaranteed.
|
||||
name: A name to prepend to all ops created by this class.
|
||||
"""
|
||||
self._verify_pd = verify_pd
|
||||
self._name = name
|
||||
with ops.name_scope(name):
|
||||
with ops.op_scope([chol], 'init'):
|
||||
self._diag = array_ops.batch_matrix_diag_part(chol)
|
||||
self._chol = self._check_chol(chol)
|
||||
|
||||
@property
|
||||
def verify_pd(self):
|
||||
"""Whether to verify that this `Operator` is positive definite."""
|
||||
return self._verify_pd
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._chol.dtype
|
||||
|
||||
def inv_quadratic_form(self, x, name='inv_quadratic_form'):
|
||||
"""Compute the induced vector norm (squared): ||x||^2 := x^T A^{-1} x.
|
||||
|
||||
For every batch member, this is done in `O(k^2)` complexity. The efficiency
|
||||
depends on the shape of `x`.
|
||||
* If `x.shape = [M1,...,Mm, N1,...,Nb, k]`, `m >= 0`, and
|
||||
`self.shape = [N1,...,Nb, k, k]`, `x` will be reshaped and the
|
||||
initialization matrix `chol` does not need to be copied.
|
||||
* Otherwise, data will be broadcast and copied.
|
||||
|
||||
Args:
|
||||
x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
|
||||
as self. If the batch dimensions of `x` do not match exactly with those
|
||||
of self, `x` and/or self's Cholesky factor will broadcast to match, and
|
||||
the resultant set of linear systems are solved independently. This may
|
||||
result in inefficient operation.
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`Tensor` holding the square of the norm induced by inverse of `A`. For
|
||||
every broadcast batch member.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([x] + self.inputs, name):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
# Boolean, True iff x.shape = [M1,...,Mm] + chol.shape[:-1].
|
||||
should_flip = self._should_flip(x)
|
||||
x_whitened = control_flow_ops.cond(
|
||||
should_flip,
|
||||
lambda: self._x_whitened_if_should_flip(x),
|
||||
lambda: self._x_whitened_if_no_flip(x))
|
||||
|
||||
# batch version of: || L^{-1} x||^2
|
||||
x_whitened_norm = math_ops.reduce_sum(
|
||||
math_ops.square(x_whitened),
|
||||
reduction_indices=[-1])
|
||||
|
||||
# Set the final shape by making a dummy tensor that will never be
|
||||
# evaluated.
|
||||
chol_without_final_dim = math_ops.reduce_sum(
|
||||
self._chol, reduction_indices=[-1])
|
||||
final_shape = (x + chol_without_final_dim).get_shape()[:-1]
|
||||
x_whitened_norm.set_shape(final_shape)
|
||||
|
||||
return x_whitened_norm
|
||||
|
||||
def _should_flip(self, x):
|
||||
"""Return boolean tensor telling whether `x` should be flipped."""
|
||||
# We "flip" (see self._flip_front_dims_to_back) iff
|
||||
# chol.shape = [N1,...,Nn, k, k]
|
||||
# x.shape = [M1,...,Mm, N1,...,Nn, k]
|
||||
x_shape = array_ops.shape(x)
|
||||
x_rank = array_ops.rank(x)
|
||||
# If m <= 0, we should not flip.
|
||||
m = x_rank + 1 - self.rank()
|
||||
def result_if_m_positive():
|
||||
x_shape_right = array_ops.slice(x_shape, [m], [x_rank - m])
|
||||
return math_ops.reduce_all(
|
||||
math_ops.equal(x_shape_right, self.vector_shape()))
|
||||
return control_flow_ops.cond(
|
||||
m > 0,
|
||||
result_if_m_positive,
|
||||
lambda: ops.convert_to_tensor(False))
|
||||
|
||||
def _x_whitened_if_no_flip(self, x):
|
||||
"""x_whitened in the event of no flip."""
|
||||
# Tensors to use if x and chol have same shape, or a shape that must be
|
||||
# broadcast to match.
|
||||
chol_bcast, x_bcast = self._get_chol_and_x_compatible_shape(x)
|
||||
|
||||
# batch version of: L^{-1} x
|
||||
# Note that here x_bcast has trailing dims of (k, 1), for "1" system of k
|
||||
# linear equations. This is the form used by the solver.
|
||||
x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve(
|
||||
chol_bcast, x_bcast)
|
||||
|
||||
x_whitened = array_ops.squeeze(x_whitened_expanded, squeeze_dims=[-1])
|
||||
return x_whitened
|
||||
|
||||
def _x_whitened_if_should_flip(self, x):
|
||||
# Tensor to use if x.shape = [M1,...,Mm] + chol.shape[:-1],
|
||||
# which is common if x was sampled.
|
||||
x_flipped = self._flip_front_dims_to_back(x)
|
||||
|
||||
# batch version of: L^{-1} x
|
||||
x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve(
|
||||
self._chol, x_flipped)
|
||||
|
||||
return self._unfip_back_dims_to_front(
|
||||
x_whitened_expanded,
|
||||
array_ops.shape(x),
|
||||
x.get_shape())
|
||||
|
||||
def _flip_front_dims_to_back(self, x):
|
||||
"""Flip x to make x.shape = chol.shape[:-1] + [M1*...*Mr]."""
|
||||
# E.g. suppose
|
||||
# chol.shape = [N1,...,Nn, k, k]
|
||||
# x.shape = [M1,...,Mm, N1,...,Nn, k]
|
||||
# Then we want to return x_flipped where
|
||||
# x_flipped.shape = [N1,...,Nn, k, M1*...*Mm].
|
||||
x_shape = array_ops.shape(x)
|
||||
x_rank = array_ops.rank(x)
|
||||
m = x_rank + 1 - self.rank()
|
||||
x_shape_left = array_ops.slice(x_shape, [0], [m])
|
||||
|
||||
# Permutation corresponding to [N1,...,Nn, k, M1,...,Mm]
|
||||
perm = array_ops.concat(
|
||||
0, (math_ops.range(m, x_rank), math_ops.range(0, m)))
|
||||
x_permuted = array_ops.transpose(x, perm=perm)
|
||||
|
||||
# Now that things are ordered correctly, condense the last dimensions.
|
||||
# condensed_shape = [M1*...*Mm]
|
||||
condensed_shape = array_ops.pack([math_ops.reduce_prod(x_shape_left)])
|
||||
new_shape = array_ops.concat(0, (self.vector_shape(), condensed_shape))
|
||||
|
||||
return array_ops.reshape(x_permuted, new_shape)
|
||||
|
||||
def _unfip_back_dims_to_front(self, x_flipped, x_shape, x_get_shape):
|
||||
# E.g. suppose that originally
|
||||
# chol.shape = [N1,...,Nn, k, k]
|
||||
# x.shape = [M1,...,Mm, N1,...,Nn, k]
|
||||
# Then we have flipped the dims so that
|
||||
# x_flipped.shape = [N1,...,Nn, k, M1*...*Mm].
|
||||
# We want to return x with the original shape.
|
||||
rank = array_ops.rank(x_flipped)
|
||||
# Permutation corresponding to [M1*...*Mm, N1,...,Nn, k]
|
||||
perm = array_ops.concat(
|
||||
0, (math_ops.range(rank - 1, rank), math_ops.range(0, rank - 1)))
|
||||
x_with_end_at_beginning = array_ops.transpose(x_flipped, perm=perm)
|
||||
x = array_ops.reshape(x_with_end_at_beginning, x_shape)
|
||||
return x
|
||||
|
||||
def _get_chol_and_x_compatible_shape(self, x):
|
||||
"""Return self.chol and x, (possibly) broadcast to compatible shape."""
|
||||
# x and chol are "compatible" if their shape matches except for the last two
|
||||
# dimensions of chol are [k, k], and the last two of x are [k, 1].
|
||||
# E.g. x.shape = [A, B, k, 1], and chol.shape = [A, B, k, k]
|
||||
# This is required for the batch_triangular_solve, which does not broadcast.
|
||||
|
||||
# TODO(langmore) This broadcast replicates matrices unnecesarily! In the
|
||||
# case where
|
||||
# x.shape = [M1,...,Mr, N1,...,Nb, k], and chol.shape = [N1,...,Nb, k, k]
|
||||
# (which is common if x was sampled), the front dimensions of x can be
|
||||
# "flipped" to the end, making
|
||||
# x_flipped.shape = [N1,...,Nb, k, M1*...*Mr],
|
||||
# and this can be handled by the linear solvers. This is preferred, because
|
||||
# it does not replicate the matrix, or create any new data.
|
||||
|
||||
# We assume x starts without the trailing singleton dimension, e.g.
|
||||
# x.shape = [B, k].
|
||||
chol = self._chol
|
||||
with ops.op_scope([x] + self.inputs, 'get_chol_and_x_compatible_shape'):
|
||||
# If we determine statically that shapes match, we're done.
|
||||
if x.get_shape() == chol.get_shape()[:-1]:
|
||||
x_expanded = array_ops.expand_dims(x, -1)
|
||||
return chol, x_expanded
|
||||
|
||||
# Dynamic check if shapes match or not.
|
||||
vector_shape = self.vector_shape() # Shape of chol minus last dim.
|
||||
are_same_rank = math_ops.equal(
|
||||
array_ops.rank(x), array_ops.rank(vector_shape))
|
||||
|
||||
def shapes_match_if_same_rank():
|
||||
return math_ops.reduce_all(math_ops.equal(
|
||||
array_ops.shape(x), vector_shape))
|
||||
|
||||
shapes_match = control_flow_ops.cond(are_same_rank,
|
||||
shapes_match_if_same_rank,
|
||||
lambda: ops.convert_to_tensor(False))
|
||||
|
||||
# Make tensors (never instantiated) holding the broadcast shape.
|
||||
# matrix_broadcast_dummy is the shape we will broadcast chol to.
|
||||
matrix_bcast_dummy = chol + array_ops.expand_dims(x, -1)
|
||||
# vector_bcast_dummy is the shape we will bcast x to, before we expand it.
|
||||
chol_minus_last_dim = math_ops.reduce_sum(chol, reduction_indices=[-1])
|
||||
vector_bcast_dummy = x + chol_minus_last_dim
|
||||
|
||||
chol_bcast = chol + array_ops.zeros_like(matrix_bcast_dummy)
|
||||
x_bcast = x + array_ops.zeros_like(vector_bcast_dummy)
|
||||
|
||||
chol_result = control_flow_ops.cond(shapes_match, lambda: chol,
|
||||
lambda: chol_bcast)
|
||||
chol_result.set_shape(matrix_bcast_dummy.get_shape())
|
||||
x_result = control_flow_ops.cond(shapes_match, lambda: x, lambda: x_bcast)
|
||||
x_result.set_shape(vector_bcast_dummy.get_shape())
|
||||
|
||||
x_expanded = array_ops.expand_dims(x_result, -1)
|
||||
|
||||
return chol_result, x_expanded
|
||||
|
||||
def log_det(self, name='log_det'):
|
||||
"""Log determinant of every batch member."""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self.inputs, name):
|
||||
det = 2.0 * math_ops.reduce_sum(
|
||||
math_ops.log(self._diag),
|
||||
reduction_indices=[-1])
|
||||
det.set_shape(self._chol.get_shape()[:-2])
|
||||
return det
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
"""List of tensors that were provided as initialization inputs."""
|
||||
return [self._chol]
|
||||
|
||||
def sqrt_matmul(self, x, name='sqrt_matmul'):
|
||||
"""Left (batch) matmul `x` by a sqrt of this matrix: `Sx` where `A = S S^T.
|
||||
|
||||
Args:
|
||||
x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
|
||||
as self.
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([x] + self.inputs, name):
|
||||
chol_lower = array_ops.batch_matrix_band_part(self._chol, -1, 0)
|
||||
return math_ops.batch_matmul(chol_lower, x)
|
||||
|
||||
def get_shape(self):
|
||||
"""`TensorShape` giving static shape."""
|
||||
return self._chol.get_shape()
|
||||
|
||||
def shape(self, name='shape'):
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self.inputs, name):
|
||||
return array_ops.shape(self._chol)
|
||||
|
||||
def _check_chol(self, chol):
|
||||
"""Verify that `chol` is proper."""
|
||||
chol = ops.convert_to_tensor(chol, name='chol')
|
||||
if not self.verify_pd:
|
||||
return chol
|
||||
|
||||
shape = array_ops.shape(chol)
|
||||
rank = array_ops.rank(chol)
|
||||
|
||||
is_matrix = check_ops.assert_rank_at_least(chol, 2)
|
||||
is_square = check_ops.assert_equal(
|
||||
array_ops.gather(shape, rank - 2), array_ops.gather(shape, rank - 1))
|
||||
|
||||
deps = [is_matrix, is_square]
|
||||
deps.append(check_ops.assert_positive(self._diag))
|
||||
|
||||
return control_flow_ops.with_dependencies(deps, chol)
|
||||
|
||||
def matmul(self, x, name='matmul'):
|
||||
"""Left (batch) matrix multiplication of `x` by this operator."""
|
||||
chol = self._chol
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self.inputs, name):
|
||||
a_times_x = math_ops.batch_matmul(chol, x, adj_x=True)
|
||||
return math_ops.batch_matmul(chol, a_times_x)
|
||||
|
||||
def to_dense_sqrt(self, name='to_dense_sqrt'):
|
||||
"""Return a dense (batch) matrix representing sqrt of this covariance."""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self.inputs, name):
|
||||
return array_ops.identity(self._chol)
|
||||
|
||||
def to_dense(self, name='to_dense'):
|
||||
"""Return a dense (batch) matrix representing this covariance."""
|
||||
chol = self._chol
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope(self.inputs, name):
|
||||
return math_ops.batch_matmul(chol, chol, adj_y=True)
|
||||
|
||||
|
||||
def batch_matrix_diag_transform(matrix, transform=None, name=None):
|
||||
"""Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.
|
||||
|
||||
Create a trainable covariance defined by a Cholesky factor:
|
||||
|
||||
```python
|
||||
# Transform network layer into 2 x 2 array.
|
||||
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
|
||||
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
|
||||
|
||||
# Make the diagonal positive. If the upper triangle was zero, this would be a
|
||||
# valid Cholesky factor.
|
||||
chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus)
|
||||
|
||||
# OperatorPDCholesky ignores the upper triangle.
|
||||
operator = OperatorPDCholesky(chol)
|
||||
```
|
||||
|
||||
Example of heteroskedastic 2-D linear regression.
|
||||
|
||||
```python
|
||||
# Get a trainable Cholesky factor.
|
||||
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
|
||||
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
|
||||
chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus)
|
||||
|
||||
# Get a trainable mean.
|
||||
mu = tf.contrib.layers.fully_connected(activations, 2)
|
||||
|
||||
# This is a fully trainable multivariate normal!
|
||||
dist = tf.contrib.distributions.MVNCholesky(mu, chol)
|
||||
|
||||
# Standard log loss. Minimizing this will "train" mu and chol, and then dist
|
||||
# will be a distribution predicting labels as multivariate Gaussians.
|
||||
loss = -1 * tf.reduce_mean(dist.log_pdf(labels))
|
||||
```
|
||||
|
||||
Args:
|
||||
matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
|
||||
equal.
|
||||
transform: Element-wise function mapping `Tensors` to `Tensors`. To
|
||||
be applied to the diagonal of `matrix`. If `None`, `matrix` is returned
|
||||
unchanged. Defaults to `None`.
|
||||
name: A name to give created ops.
|
||||
Defaults to "batch_matrix_diag_transform".
|
||||
|
||||
Returns:
|
||||
A `Tensor` with same shape and `dtype` as `matrix`.
|
||||
"""
|
||||
with ops.op_scope([matrix], name, 'batch_matrix_diag_transform'):
|
||||
matrix = ops.convert_to_tensor(matrix, name='matrix')
|
||||
if transform is None:
|
||||
return matrix
|
||||
# Replace the diag with transformed diag.
|
||||
diag = array_ops.batch_matrix_diag_part(matrix)
|
||||
transformed_diag = transform(diag)
|
||||
matrix += array_ops.batch_matrix_diag(transformed_diag - diag)
|
||||
|
||||
return matrix
|
106
tensorflow/contrib/distributions/python/ops/operator_pd_full.py
Normal file
106
tensorflow/contrib/distributions/python/ops/operator_pd_full.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Symmetric positive definite (PD) Operator defined by a full matrix."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OperatorPDFull',
|
||||
]
|
||||
|
||||
|
||||
class OperatorPDFull(operator_pd_cholesky.OperatorPDCholesky):
|
||||
"""Class representing a (batch) of positive definite matrices `A`.
|
||||
|
||||
This class provides access to functions of a batch of symmetric positive
|
||||
definite (PD) matrices `A` in `R^{k x k}` defined by dense matrices.
|
||||
Determinants and solves are `O(k^3)`.
|
||||
|
||||
In practice, this operator represents a (batch) matrix `A` with shape
|
||||
`[N1,...,Nb, k, k]` for some `b >= 0`. The first `b` indices designate a
|
||||
batch member. For every batch member `(n1,...,nb)`, `A[n1,...,nb, : :]` is
|
||||
a `k x k` matrix.
|
||||
|
||||
Since `A` is (batch) positive definite, it has a (or several) square roots `S`
|
||||
such that `A = SS^T`.
|
||||
|
||||
For example,
|
||||
|
||||
```python
|
||||
distributions = tf.contrib.distributions
|
||||
matrix = [[1.0, 0.5], [1.0, 2.0]]
|
||||
operator = OperatorPDFull(matrix)
|
||||
operator.log_det()
|
||||
|
||||
# Compute the quadratic form x^T A^{-1} x for vector x.
|
||||
x = [1.0, 2.0]
|
||||
operator.inv_quadratic_form(x)
|
||||
|
||||
# Matrix multiplication by the square root, S w.
|
||||
# If w is iid normal, S w has covariance A.
|
||||
w = [[1.0], [2.0]]
|
||||
operator.sqrt_matmul(w)
|
||||
```
|
||||
|
||||
The above three methods, `log_det`, `inv_quadratic_form`, and
|
||||
`sqrt_matmul` provide "all" that is necessary to use a covariance matrix
|
||||
in a multi-variate normal distribution. See the class `MVNOperatorPD`.
|
||||
"""
|
||||
|
||||
def __init__(self, matrix, verify_pd=True, name='OperatorPDFull'):
|
||||
"""Initialize an OperatorPDFull.
|
||||
|
||||
Args:
|
||||
matrix: Shape `[N1,...,Nb, k, k]` tensor with `b >= 0`, `k >= 1`. The
|
||||
last two dimensions should be `k x k` symmetric positive definite
|
||||
matrices.
|
||||
verify_pd: Whether to check that `matrix` is symmetric positive definite.
|
||||
If `verify_pd` is `False`, correct behavior is not guaranteed.
|
||||
name: A name to prepend to all ops created by this class.
|
||||
"""
|
||||
with ops.name_scope(name):
|
||||
with ops.op_scope([matrix], 'init'):
|
||||
matrix = ops.convert_to_tensor(matrix)
|
||||
# Check symmetric here. Positivity will be verified by checking the
|
||||
# diagonal of the Cholesky factor inside the parent class. The Cholesky
|
||||
# factorization .batch_cholesky() does not always fail for non PSD
|
||||
# matrices, so don't rely on that.
|
||||
if verify_pd:
|
||||
matrix = _check_symmetric(matrix)
|
||||
chol = linalg_ops.batch_cholesky(matrix)
|
||||
super(OperatorPDFull, self).__init__(chol, verify_pd=verify_pd)
|
||||
|
||||
|
||||
def _check_symmetric(matrix):
|
||||
rank = array_ops.rank(matrix)
|
||||
# Create permutation to permute last two dimensions
|
||||
first_dims = math_ops.range(0, rank - 2)
|
||||
flipped_last_dims = array_ops.pack([rank - 1, rank - 2])
|
||||
perm = array_ops.concat(0, (first_dims, flipped_last_dims))
|
||||
matrix_t = array_ops.transpose(matrix, perm=perm)
|
||||
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_equal(matrix, matrix_t)], matrix)
|
Loading…
Reference in New Issue
Block a user