Adding covariance, covariance_cholesky, mvn_cov to distributions/

These allow alternative representations of covariance matrices, e.g. diagonal/sparse/functional, to be used in a multivariate-normal.
Change: 126226644
This commit is contained in:
Ian Langmore 2016-06-29 12:31:47 -08:00 committed by TensorFlower Gardener
parent f2c3b2e702
commit 94cbf42a07
10 changed files with 1817 additions and 518 deletions

View File

@ -10,6 +10,39 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
cuda_py_tests(
name = "operator_pd_test",
size = "small",
srcs = ["python/kernel_tests/operator_pd_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "operator_pd_cholesky_test",
size = "small",
srcs = ["python/kernel_tests/operator_pd_cholesky_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "operator_pd_full_test",
size = "small",
srcs = ["python/kernel_tests/operator_pd_full_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
py_library(
name = "distributions_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),

View File

@ -38,9 +38,26 @@ initialized with parameters that define the distributions.
### Multivariate distributions
@@MultivariateNormal
#### Multivariate normal
@@MultivariateNormalFull
@@MultivariateNormalCholesky
#### Other multivariate distributions
@@DirichletMultinomial
## Operators allowing for matrix-free methods
### Positive definite operators
A matrix is positive definite if it is symmetric with all positive eigenvalues.
@@OperatorPDBase
@@OperatorPDFull
@@OperatorPDCholesky
@@batch_matrix_diag_transform
## Posterior inference with conjugate priors.
Functions that transform conjugate prior/likelihood pairs to distributions
@ -61,7 +78,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
from tensorflow.contrib.distributions.python.ops.bernoulli import *
from tensorflow.contrib.distributions.python.ops.categorical import *
@ -74,5 +91,8 @@ from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
from tensorflow.contrib.distributions.python.ops.mvn import *
from tensorflow.contrib.distributions.python.ops.normal import *
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.operator_pd import *
from tensorflow.contrib.distributions.python.ops.operator_pd_cholesky import *
from tensorflow.contrib.distributions.python.ops.operator_pd_full import *
from tensorflow.contrib.distributions.python.ops.student_t import *
from tensorflow.contrib.distributions.python.ops.uniform import *

View File

@ -19,249 +19,153 @@ from __future__ import division
from __future__ import print_function
import numpy as np
from scipy import stats
import tensorflow as tf
distributions = tf.contrib.distributions
class MultivariateNormalTest(tf.test.TestCase):
class MultivariateNormalCholeskyTest(tf.test.TestCase):
def setUp(self):
self._rng = np.random.RandomState(42)
def _random_chol(self, *shape):
mat = self._rng.rand(*shape)
chol = distributions.batch_matrix_diag_transform(
mat, transform=tf.nn.softplus)
chol = tf.batch_matrix_band_part(chol, -1, 0)
sigma = tf.batch_matmul(chol, chol, adj_y=True)
return chol.eval(), sigma.eval()
def testNonmatchingMuSigmaFails(self):
with tf.Session():
mvn = tf.contrib.distributions.MultivariateNormal(
mu=[1.0, 2.0],
sigma=[[[1.0, 0.0],
[0.0, 1.0]],
[[1.0, 0.0],
[0.0, 1.0]]])
with self.assertRaisesOpError(
r"Rank of mu should be one less than rank of sigma"):
mvn.mean.eval()
with self.test_session():
mu = self._rng.rand(2)
chol, _ = self._random_chol(2, 2, 2)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
with self.assertRaisesOpError("mu should have rank 1 less than cov"):
mvn.mean().eval()
mvn = tf.contrib.distributions.MultivariateNormal(
mu=[[1.0], [2.0]],
sigma=[[[1.0, 0.0],
[0.0, 1.0]],
[[1.0, 0.0],
[0.0, 1.0]]])
with self.assertRaisesOpError(
r"mu.shape and sigma.shape\[\:-1\] must match"):
mvn.mean.eval()
mu = self._rng.rand(2, 1)
chol, _ = self._random_chol(2, 2, 2)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
with self.assertRaisesOpError("mu.shape and cov.shape.*should match"):
mvn.mean().eval()
def testNotPositiveDefiniteSigmaFails(self):
with tf.Session():
mvn = tf.contrib.distributions.MultivariateNormal(
mu=[[1.0, 2.0], [1.0, 2.0]],
sigma=[[[1.0, 0.0],
[0.0, 1.0]],
[[1.0, 1.0],
[1.0, 1.0]]])
with self.assertRaisesOpError(
r"LLT decomposition was not successful."):
mvn.mean.eval()
mvn = tf.contrib.distributions.MultivariateNormal(
mu=[[1.0, 2.0], [1.0, 2.0]],
sigma=[[[1.0, 0.0],
[0.0, 1.0]],
[[-1.0, 0.0],
[0.0, 1.0]]])
with self.assertRaisesOpError(
r"LLT decomposition was not successful."):
mvn.mean.eval()
mvn = tf.contrib.distributions.MultivariateNormal(
mu=[[1.0, 2.0], [1.0, 2.0]],
sigma_chol=[[[1.0, 0.0],
[0.0, 1.0]],
[[-1.0, 0.0],
[0.0, 1.0]]])
with self.assertRaisesOpError(
r"sigma_chol is not positive definite."):
mvn.mean.eval()
def testLogPDFScalar(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(mu_v)
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(sigma_v)
x = np.array([-2.5, 2.5], dtype=np.float32)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
def testLogPDFScalarBatch(self):
with self.test_session():
mu = self._rng.rand(2)
chol, sigma = self._random_chol(2, 2)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
x = self._rng.rand(2)
log_pdf = mvn.log_pdf(x)
pdf = mvn.pdf(x)
try:
from scipy import stats # pylint: disable=g-import-not-at-top
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertAllClose(expected_log_pdf, log_pdf.eval())
self.assertAllClose(expected_pdf, pdf.eval())
except ImportError as e:
tf.logging.warn("Cannot test stats functions: %s" % str(e))
scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
def testLogPDFScalarSigmaHalf(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0, 1.0], dtype=np.float32)
mu = tf.constant(mu_v)
sigma_v = np.array([[1.0, 0.1, 0.2],
[0.1, 2.0, 0.05],
[0.2, 0.05, 3.0]], dtype=np.float32)
sigma_chol_v = np.linalg.cholesky(sigma_v)
sigma_chol = tf.constant(sigma_chol_v)
x = np.array([-2.5, 2.5, 1.0], dtype=np.float32)
mvn = tf.contrib.distributions.MultivariateNormal(
mu=mu, sigma_chol=sigma_chol)
log_pdf = mvn.log_pdf(x)
pdf = mvn.pdf(x)
sigma = mvn.sigma
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertEqual((), log_pdf.get_shape())
self.assertEqual((), pdf.get_shape())
self.assertAllClose(expected_log_pdf, log_pdf.eval())
self.assertAllClose(expected_pdf, pdf.eval())
try:
from scipy import stats # pylint: disable=g-import-not-at-top
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertEqual(sigma.get_shape(), (3, 3))
self.assertAllClose(sigma_v, sigma.eval())
self.assertAllClose(expected_log_pdf, log_pdf.eval())
self.assertAllClose(expected_pdf, pdf.eval())
except ImportError as e:
tf.logging.warn("Cannot test stats functions: %s" % str(e))
def testLogPDFXIsHigherRank(self):
with self.test_session():
mu = self._rng.rand(2)
chol, sigma = self._random_chol(2, 2)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
x = self._rng.rand(3, 2)
def testLogPDF(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(mu_v)
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(sigma_v)
x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
log_pdf = mvn.log_pdf(x)
pdf = mvn.pdf(x)
try:
from scipy import stats # pylint: disable=g-import-not-at-top
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertEqual(log_pdf.get_shape(), (3,))
self.assertAllClose(expected_log_pdf, log_pdf.eval())
self.assertAllClose(expected_pdf, pdf.eval())
except ImportError as e:
tf.logging.warn("Cannot test stats functions: %s" % str(e))
scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertEqual((3,), log_pdf.get_shape())
self.assertEqual((3,), pdf.get_shape())
self.assertAllClose(expected_log_pdf, log_pdf.eval())
self.assertAllClose(expected_pdf, pdf.eval())
def testLogPDFXLowerDimension(self):
with self.test_session():
mu = self._rng.rand(3, 2)
chol, sigma = self._random_chol(3, 2, 2)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
x = self._rng.rand(2)
def testLogPDFMatchingDimension(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(np.vstack(3 * [mu_v]))
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(np.vstack(3 * [sigma_v[np.newaxis, :]]))
x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
log_pdf = mvn.log_pdf(x)
pdf = mvn.pdf(x)
try:
from scipy import stats # pylint: disable=g-import-not-at-top
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertEqual(log_pdf.get_shape(), (3,))
self.assertAllClose(expected_log_pdf, log_pdf.eval())
self.assertAllClose(expected_pdf, pdf.eval())
except ImportError as e:
tf.logging.warn("Cannot test stats functions: %s" % str(e))
self.assertEqual((3,), log_pdf.get_shape())
self.assertEqual((3,), pdf.get_shape())
def testLogPDFMultidimensional(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
x = np.array([-2.5, 2.5], dtype=np.float32)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
log_pdf = mvn.log_pdf(x)
pdf = mvn.pdf(x)
# scipy can't do batches, so just test one of them.
scipy_mvn = stats.multivariate_normal(mean=mu[1, :], cov=sigma[1, :, :])
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
try:
from scipy import stats # pylint: disable=g-import-not-at-top
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_log_pdf = np.vstack(15 * [scipy_mvn.logpdf(x)]).reshape(3, 5)
expected_pdf = np.vstack(15 * [scipy_mvn.pdf(x)]).reshape(3, 5)
self.assertEqual(log_pdf.get_shape(), (3, 5))
self.assertAllClose(expected_log_pdf, log_pdf.eval())
self.assertAllClose(expected_pdf, pdf.eval())
except ImportError as e:
tf.logging.warn("Cannot test stats functions: %s" % str(e))
self.assertAllClose(expected_log_pdf, log_pdf.eval()[1])
self.assertAllClose(expected_pdf, pdf.eval()[1])
def testEntropy(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(mu_v)
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(sigma_v)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
with self.test_session():
mu = self._rng.rand(2)
chol, sigma = self._random_chol(2, 2)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
entropy = mvn.entropy()
try:
from scipy import stats # pylint: disable=g-import-not-at-top
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_entropy = scipy_mvn.entropy()
self.assertEqual(entropy.get_shape(), ())
self.assertAllClose(expected_entropy, entropy.eval())
except ImportError as e:
tf.logging.warn("Cannot test stats functions: %s" % str(e))
scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
expected_entropy = scipy_mvn.entropy()
self.assertEqual(entropy.get_shape(), ())
self.assertAllClose(expected_entropy, entropy.eval())
def testEntropyMultidimensional(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
with self.test_session():
mu = self._rng.rand(3, 5, 2)
chol, sigma = self._random_chol(3, 5, 2, 2)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
entropy = mvn.entropy()
try:
from scipy import stats # pylint: disable=g-import-not-at-top
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_entropy = np.vstack(15 * [scipy_mvn.entropy()]).reshape(3, 5)
self.assertEqual(entropy.get_shape(), (3, 5))
self.assertAllClose(expected_entropy, entropy.eval())
except ImportError as e:
tf.logging.warn("Cannot test stats functions: %s" % str(e))
# Scipy doesn't do batches, so test one of them.
expected_entropy = stats.multivariate_normal(
mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy()
self.assertEqual(entropy.get_shape(), (3, 5))
self.assertAllClose(expected_entropy, entropy.eval()[1, 1])
def testSample(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(mu_v)
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(sigma_v)
with self.test_session():
mu = self._rng.rand(2)
chol, sigma = self._random_chol(2, 2)
n = tf.constant(100000)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
samples = mvn.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (100000, 2))
self.assertAllClose(sample_values.mean(axis=0), mu_v, atol=1e-2)
self.assertAllClose(np.cov(sample_values, rowvar=0), sigma_v, atol=1e-1)
self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2)
self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=1e-1)
def testSampleMultiDimensional(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
with self.test_session():
mu = self._rng.rand(3, 5, 2)
chol, sigma = self._random_chol(3, 5, 2, 2)
n = tf.constant(100000)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
samples = mvn.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (100000, 3, 5, 2))
sample_values = sample_values.reshape(100000, 15, 2)
for i in range(15):
self.assertAllClose(
sample_values[:, i, :].mean(axis=0), mu_v, atol=1e-2)
self.assertAllClose(
np.cov(sample_values[:, i, :], rowvar=0), sigma_v, atol=1e-1)
self.assertAllClose(
sample_values[:, 1, 1, :].mean(axis=0),
mu[1, 1, :], atol=0.05)
self.assertAllClose(
np.cov(sample_values[:, 1, 1, :], rowvar=0),
sigma[1, 1, :, :], atol=1e-1)
if __name__ == "__main__":

View File

@ -0,0 +1,327 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
distributions = tf.contrib.distributions
def softplus(x):
return np.log(1 + np.exp(x))
class OperatorPDCholeskyTest(tf.test.TestCase):
def setUp(self):
self._rng = np.random.RandomState(42)
def _random_cholesky_array(self, shape):
mat = self._rng.rand(*shape)
chol = distributions.batch_matrix_diag_transform(mat,
transform=tf.nn.softplus)
# Zero the upper triangle because we're using this as a true Cholesky factor
# in our tests.
return tf.batch_matrix_band_part(chol, -1, 0).eval()
def _numpy_inv_quadratic_form(self, chol, x):
# Numpy works with batches now (calls them "stacks").
x_expanded = np.expand_dims(x, -1)
whitened = np.linalg.solve(chol, x_expanded)
return (whitened**2).sum(axis=-1).sum(axis=-1)
def test_inv_quadratic_form_x_rank_same_as_broadcast_rank(self):
with self.test_session():
for batch_shape in [(), (2,)]:
for k in [1, 3]:
x_shape = batch_shape + (k,)
x = self._rng.randn(*x_shape)
chol_shape = batch_shape + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = distributions.OperatorPDCholesky(chol)
qf = operator.inv_quadratic_form(x)
self.assertEqual(batch_shape, qf.get_shape())
numpy_qf = self._numpy_inv_quadratic_form(chol, x)
self.assertAllClose(numpy_qf, qf.eval())
def test_inv_quadratic_form_x_and_chol_batch_shape_dont_match(self):
# In this case, chol will have to be stretched to match x.
with self.test_session():
k = 3
x_shape = (2, k)
chol_shape = (1, k, k)
broadcast_batch_shape = (2,)
x = self._rng.randn(*x_shape)
chol = self._random_cholesky_array(chol_shape)
operator = distributions.OperatorPDCholesky(chol)
qf = operator.inv_quadratic_form(x)
self.assertEqual(broadcast_batch_shape, qf.get_shape())
numpy_qf = self._numpy_inv_quadratic_form(chol, x)
self.assertAllClose(numpy_qf, qf.eval())
def test_inv_quadratic_form_x_rank_less_than_broadcast_rank(self):
with self.test_session():
for batch_shape in [(2,), (2, 3)]:
for k in [1, 4]:
# x will not have the leading dimension.
x_shape = batch_shape[1:] + (k,)
x = self._rng.randn(*x_shape)
chol_shape = batch_shape + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = distributions.OperatorPDCholesky(chol)
qf = operator.inv_quadratic_form(x)
self.assertEqual(batch_shape, qf.get_shape())
x_upshaped = x + np.zeros(chol.shape[:-1])
numpy_qf = self._numpy_inv_quadratic_form(chol, x_upshaped)
numpy_qf = numpy_qf.reshape(batch_shape)
self.assertAllClose(numpy_qf, qf.eval())
def test_inv_quadratic_form_x_rank_greater_than_broadcast_rank(self):
with self.test_session():
for batch_shape in [(2,), (2, 3)]:
for k in [1, 4]:
x_shape = batch_shape + (k,)
x = self._rng.randn(*x_shape)
# chol will not have the leading dimension.
chol_shape = batch_shape[1:] + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = distributions.OperatorPDCholesky(chol)
qf = operator.inv_quadratic_form(x)
numpy_qf = self._numpy_inv_quadratic_form(chol, x)
self.assertEqual(batch_shape, qf.get_shape())
self.assertAllClose(numpy_qf, qf.eval())
def test_inv_quadratic_form_x_rank_two_greater_than_broadcast_rank(self):
with self.test_session():
for batch_shape in [(2, 3), (2, 3, 4), (2, 3, 4, 5)]:
for k in [1, 4]:
x_shape = batch_shape + (k,)
x = self._rng.randn(*x_shape)
# chol will not have the leading two dimensions.
chol_shape = batch_shape[2:] + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = distributions.OperatorPDCholesky(chol)
qf = operator.inv_quadratic_form(x)
numpy_qf = self._numpy_inv_quadratic_form(chol, x)
self.assertEqual(batch_shape, qf.get_shape())
self.assertAllClose(numpy_qf, qf.eval())
def test_log_det(self):
with self.test_session():
batch_shape = ()
for k in [1, 4]:
chol_shape = batch_shape + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = distributions.OperatorPDCholesky(chol)
log_det = operator.log_det()
expected_log_det = np.log(np.prod(np.diag(chol))**2)
self.assertEqual(batch_shape, log_det.get_shape())
self.assertAllClose(expected_log_det, log_det.eval())
def test_log_det_batch_matrix(self):
with self.test_session():
batch_shape = (2, 3)
for k in [1, 4]:
chol_shape = batch_shape + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = distributions.OperatorPDCholesky(chol)
log_det = operator.log_det()
self.assertEqual(batch_shape, log_det.get_shape())
# Test the log-determinant of the [1, 1] matrix.
chol_11 = chol[1, 1, :, :]
expected_log_det = np.log(np.prod(np.diag(chol_11))**2)
self.assertAllClose(expected_log_det, log_det.eval()[1, 1])
def test_sqrt_matmul_single_matrix(self):
with self.test_session():
batch_shape = ()
for k in [1, 4]:
x_shape = batch_shape + (k, 3)
x = self._rng.rand(*x_shape)
chol_shape = batch_shape + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = distributions.OperatorPDCholesky(chol)
sqrt_operator_times_x = operator.sqrt_matmul(x)
expected = tf.batch_matmul(chol, x)
self.assertEqual(expected.get_shape(),
sqrt_operator_times_x.get_shape())
self.assertAllClose(expected.eval(), sqrt_operator_times_x.eval())
def test_sqrt_matmul_batch_matrix(self):
with self.test_session():
batch_shape = (2, 3)
for k in [1, 4]:
x_shape = batch_shape + (k, 5)
x = self._rng.rand(*x_shape)
chol_shape = batch_shape + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = distributions.OperatorPDCholesky(chol)
sqrt_operator_times_x = operator.sqrt_matmul(x)
expected = tf.batch_matmul(chol, x)
self.assertEqual(expected.get_shape(),
sqrt_operator_times_x.get_shape())
self.assertAllClose(expected.eval(), sqrt_operator_times_x.eval())
def test_matmul_batch_matrix(self):
with self.test_session():
batch_shape = (2, 3)
for k in [1, 4]:
x_shape = batch_shape + (k, 5)
x = self._rng.rand(*x_shape)
chol_shape = batch_shape + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = distributions.OperatorPDCholesky(chol)
chol_times_x = tf.batch_matmul(chol, x, adj_x=True)
expected = tf.batch_matmul(chol, chol_times_x)
self.assertEqual(expected.get_shape(), operator.matmul(x).get_shape())
self.assertAllClose(expected.eval(), operator.matmul(x).eval())
def test_shape(self):
# All other shapes are defined by the abstractmethod shape, so we only need
# to test this.
with self.test_session():
for shape in [(3, 3), (2, 3, 3), (1, 2, 3, 3)]:
chol = self._random_cholesky_array(shape)
operator = distributions.OperatorPDCholesky(chol)
self.assertAllEqual(shape, operator.shape().eval())
def test_to_dense(self):
with self.test_session():
chol = self._random_cholesky_array((3, 3))
operator = distributions.OperatorPDCholesky(chol)
self.assertAllClose(chol.dot(chol.T), operator.to_dense().eval())
def test_to_dense_sqrt(self):
with self.test_session():
chol = self._random_cholesky_array((2, 3, 3))
operator = distributions.OperatorPDCholesky(chol)
self.assertAllClose(chol, operator.to_dense_sqrt().eval())
def test_non_positive_definite_matrix_raises(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
with self.test_session():
lower_mat = [[1.0, 0.0], [2.0, 0.0]]
operator = distributions.OperatorPDCholesky(lower_mat)
with self.assertRaisesOpError('x > 0 did not hold'):
operator.to_dense().eval()
def test_non_positive_definite_matrix_does_not_raise_if_not_verify_pd(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
with self.test_session():
lower_mat = [[1.0, 0.0], [2.0, 0.0]]
operator = distributions.OperatorPDCholesky(lower_mat, verify_pd=False)
operator.to_dense().eval() # Should not raise.
def test_not_having_two_identical_last_dims_raises(self):
# Unless the last two dims are equal, this cannot represent a matrix, and it
# should raise.
with self.test_session():
batch_vec = [[1.0], [2.0]] # shape 2 x 1
with self.assertRaisesRegexp(ValueError, '.*Dimensions.*'):
operator = distributions.OperatorPDCholesky(batch_vec)
operator.to_dense().eval()
class BatchMatrixDiagTransformTest(tf.test.TestCase):
def setUp(self):
self._rng = np.random.RandomState(0)
def check_off_diagonal_same(self, m1, m2):
"""Check the lower triangular part, not upper or diag."""
self.assertAllClose(np.tril(m1, k=-1), np.tril(m2, k=-1))
self.assertAllClose(np.triu(m1, k=1), np.triu(m2, k=1))
def test_non_batch_matrix_with_transform(self):
mat = self._rng.rand(4, 4)
with self.test_session():
chol = distributions.batch_matrix_diag_transform(mat,
transform=tf.nn.softplus)
self.assertEqual((4, 4), chol.get_shape())
self.check_off_diagonal_same(mat, chol.eval())
self.assertAllClose(softplus(np.diag(mat)), np.diag(chol.eval()))
def test_non_batch_matrix_no_transform(self):
mat = self._rng.rand(4, 4)
with self.test_session():
# Default is no transform.
chol = distributions.batch_matrix_diag_transform(mat)
self.assertEqual((4, 4), chol.get_shape())
self.assertAllClose(mat, chol.eval())
def test_batch_matrix_with_transform(self):
mat = self._rng.rand(2, 4, 4)
mat_0 = mat[0, :, :]
with self.test_session():
chol = distributions.batch_matrix_diag_transform(mat,
transform=tf.nn.softplus)
self.assertEqual((2, 4, 4), chol.get_shape())
chol_0 = chol.eval()[0, :, :]
self.check_off_diagonal_same(mat_0, chol_0)
self.assertAllClose(softplus(np.diag(mat_0)), np.diag(chol_0))
self.check_off_diagonal_same(mat_0, chol_0)
self.assertAllClose(softplus(np.diag(mat_0)), np.diag(chol_0))
def test_batch_matrix_no_transform(self):
mat = self._rng.rand(2, 4, 4)
with self.test_session():
# Default is no transform.
chol = distributions.batch_matrix_diag_transform(mat)
self.assertEqual((2, 4, 4), chol.get_shape())
self.assertAllClose(mat, chol.eval())
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,62 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
distributions = tf.contrib.distributions
class OperatorPDFullTest(tf.test.TestCase):
# The only method needing checked (because it isn't part of the parent class)
# is the check for symmetry.
def setUp(self):
self._rng = np.random.RandomState(42)
def _random_positive_def_array(self, *shape):
matrix = self._rng.rand(*shape)
return tf.batch_matmul(matrix, matrix, adj_y=True).eval()
def test_positive_definite_matrix_doesnt_raise(self):
with self.test_session():
matrix = self._random_positive_def_array(2, 3, 3)
operator = distributions.OperatorPDFull(matrix, verify_pd=True)
operator.to_dense().eval() # Should not raise
def test_negative_definite_matrix_raises(self):
with self.test_session():
matrix = -1 * self._random_positive_def_array(3, 2, 2)
operator = distributions.OperatorPDFull(matrix, verify_pd=True)
# Could fail inside Cholesky decomposition, or later when we test the
# diag.
with self.assertRaisesOpError('x > 0|LLT'):
operator.to_dense().eval()
def test_non_symmetric_matrix_raises(self):
with self.test_session():
matrix = self._random_positive_def_array(3, 2, 2)
matrix[0, 0, 1] += 0.001
operator = distributions.OperatorPDFull(matrix, verify_pd=True)
with self.assertRaisesOpError('x == y'):
operator.to_dense().eval()
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,83 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
distributions = tf.contrib.distributions
class BaseCovarianceTest(tf.test.TestCase):
def test_all_shapes_methods_defined_by_the_one_abstractproperty_shape(self):
class OperatorShape(distributions.OperatorPDBase):
"""Operator implements the ABC method .shape."""
def __init__(self, shape):
self._shape = shape
@property
def verify_pd(self):
return True
def get_shape(self):
return tf.TensorShape(self._shape)
def shape(self, name='shape'):
return tf.shape(np.random.rand(*self._shape))
@property
def name(self):
return 'OperatorShape'
def dtype(self):
raise tf.int32
def inv_quadratic_form(
self, x, name='inv_quadratic_form'):
return x
def log_det(self, name='log_det'):
raise NotImplementedError()
@property
def inputs(self):
return []
def sqrt_matmul(self, x, name='sqrt_matmul'):
return x
shape = (1, 2, 3, 3)
with self.test_session():
operator = OperatorShape(shape)
self.assertAllEqual(shape, operator.shape().eval())
self.assertAllEqual(4, operator.rank().eval())
self.assertAllEqual((1, 2), operator.batch_shape().eval())
self.assertAllEqual((1, 2, 3), operator.vector_shape().eval())
self.assertAllEqual(3, operator.vector_space_dimension().eval())
self.assertEqual(shape, operator.get_shape())
self.assertEqual((1, 2), operator.get_batch_shape())
self.assertEqual((1, 2, 3), operator.get_vector_shape())
if __name__ == '__main__':
tf.test.main()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Multivariate Normal distribution class."""
"""Multivariate Normal distribution classes."""
from __future__ import absolute_import
from __future__ import division
@ -20,83 +20,33 @@ from __future__ import print_function
import math
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import operator_pd_full # pylint: disable=line-too-long
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
def _assert_compatible_shapes(mu, sigma):
r_mu = array_ops.rank(mu)
r_sigma = array_ops.rank(sigma)
sigma_shape = array_ops.shape(sigma)
sigma_rank = array_ops.rank(sigma)
mu_shape = array_ops.shape(mu)
return control_flow_ops.group(
logging_ops.Assert(
math_ops.equal(r_mu + 1, r_sigma),
["Rank of mu should be one less than rank of sigma, but saw: ",
r_mu, " vs. ", r_sigma]),
logging_ops.Assert(
math_ops.equal(
array_ops.gather(sigma_shape, sigma_rank - 2),
array_ops.gather(sigma_shape, sigma_rank - 1)),
["Last two dimensions of sigma (%s) must be equal: " % sigma.name,
sigma_shape]),
logging_ops.Assert(
math_ops.reduce_all(math_ops.equal(
mu_shape,
array_ops.slice(
sigma_shape, [0], array_ops.pack([sigma_rank - 1])))),
["mu.shape and sigma.shape[:-1] must match, but saw: ",
mu_shape, " vs. ", sigma_shape]))
__all__ = [
"MultivariateNormalCholesky",
"MultivariateNormalFull",
]
def _assert_batch_positive_definite(sigma_chol):
"""Add assertions checking that the sigmas are all Positive Definite.
class MultivariateNormalOperatorPD(distribution.ContinuousDistribution):
"""The multivariate normal distribution on `R^k`.
Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that
`all(diag(sigma_chol) > 0)`. This is because to check that a matrix is PD,
it is sufficient that its cholesky factorization is PD, and to check that a
triangular matrix is PD, it is sufficient to check that its diagonal
entries are positive.
Args:
sigma_chol: N-D. The lower triangular cholesky decomposition of `sigma`.
Returns:
An assertion op to use with `control_dependencies`, verifying that
`sigma_chol` is positive definite.
"""
sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
return logging_ops.Assert(
math_ops.reduce_all(sigma_batch_diag > 0),
["sigma_chol is not positive definite. batched diagonals: ",
sigma_batch_diag, " shaped: ", array_ops.shape(sigma_batch_diag)])
def _log_determinant_from_sigma_chol(sigma_chol):
det_last_dim = array_ops.rank(sigma_chol) - 2
sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
log_det = 2.0 * math_ops.reduce_sum(
math_ops.log(sigma_batch_diag), reduction_indices=det_last_dim)
log_det.set_shape(sigma_chol.get_shape()[:-2])
return log_det
class MultivariateNormal(object):
"""The Multivariate Normal distribution on `R^k`.
The distribution has mean and covariance parameters mu (1-D), sigma (2-D),
or alternatively mean `mu` and factored covariance (cholesky decomposed
`sigma`) called `sigma_chol`.
This distribution is defined by a 1-D mean `mu` and an instance of
`OperatorPDBase`, which provides access to a symmetric positive definite
operator, which defines the covariance.
#### Mathematical details
@ -108,21 +58,6 @@ class MultivariateNormal(object):
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
Alternatively, if `sigma` is positive definite, it can be represented in terms
of its lower triangular cholesky factorization
```sigma = sigma_chol . sigma_chol^*```
and the pdf above allows simpler computation:
```
|det(sigma)| = reduce_prod(diag(sigma_chol))^2
x_whitened = sigma^{-1/2} . (x - mu) = tri_solve(sigma_chol, x - mu)
(x-mu)^* .sigma^{-1} . (x-mu) = x_whitened^* . x_whitened
```
where `tri_solve()` solves a triangular system of equations.
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
@ -131,128 +66,177 @@ class MultivariateNormal(object):
Extra leading dimensions, if provided, allow for batches.
```python
# Initialize a single 3-variate Gaussian with diagonal covariance.
# Initialize a single 3-variate Gaussian.
mu = [1, 2, 3]
sigma = [[1, 0, 0], [0, 3, 0], [0, 0, 2]]
dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
chol = [[1, 0, 0.], [1, 3, 0], [1, 2, 3]]
cov = tf.contrib.distributions.OperatorPDCholesky(chol)
dist = tf.contrib.distributions.MultivariateNormalOperatorPD(mu, cov)
# Evaluate this on an observation in R^3, returning a scalar.
dist.pdf([-1, 0, 1])
dist.pdf([-1, 0, 1.])
# Initialize a batch of two 3-variate Gaussians.
mu = [[1, 2, 3], [11, 22, 33]]
sigma = ... # shape 2 x 3 x 3
dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
mu = [[1, 2, 3], [11, 22, 33.]]
chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal.
cov = tf.contrib.distributions.OperatorPDCholesky(chol)
dist = tf.contrib.distributions.MultivariateNormalOperatorPD(mu, cov)
# Evaluate this on a two observations, each in R^3, returning a length two
# tensor.
x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3.
x = [[-1, 0, 1], [-11, 0, 11.]] # Shape 2 x 3.
dist.pdf(x)
```
"""
def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
def __init__(
self,
mu,
cov,
allow_nan=False,
strict=True,
strict_statistics=True,
name="MultivariateNormalCov"):
"""Multivariate Normal distributions on `R^k`.
User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`)
with the last dimension having length `k`.
User must provide exactly one of `sigma` (the covariance matrices) or
`sigma_chol` (the cholesky decompositions of the covariance matrices).
`sigma` or `sigma_chol` must be of rank `N+2`. The last two dimensions
must both have length `k`. The first `N` dimensions correspond to batch
indices.
If `sigma_chol` is not provided, the batch cholesky factorization of `sigma`
is calculated for you.
The shapes of `mu` and `sigma` must match for the first `N` dimensions.
Regardless of which parameter is provided, the covariance matrices must all
be **positive definite** (an error is raised if one of them is not).
User must provide means `mu`, and an instance of `OperatorPDBase`, `cov`,
which determines the covariance.
Args:
mu: (N+1)-D. `float` or `double` tensor, the means of the distributions.
sigma: (N+2)-D. (optional) `float` or `double` tensor, the covariances
of the distribution(s). The first `N+1` dimensions must match
those of `mu`. Must be batch-positive-definite.
sigma_chol: (N+2)-D. (optional) `float` or `double` tensor, a
lower-triangular factorization of `sigma`
(`sigma = sigma_chol . sigma_chol^*`). The first `N+1` dimensions
must match those of `mu`. The tensor itself need not be batch
lower triangular: we ignore the upper triangular part. However,
the batch diagonals must be positive (i.e., sigma_chol must be
batch-positive-definite).
mu: `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
cov: `float` or `double` instance of `OperatorPDBase` with same `dtype`
as `mu` and shape `[N1,...,Nb, k, k]`.
allow_nan: Boolean, default False. If False, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If True, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
strict: Whether to validate input with asserts. If `strict` is `False`,
and the inputs are invalid, correct behavior is not guaranteed.
strict_statistics: Boolean, default True. If True, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If False, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
ValueError: if neither sigma nor sigma_chol is provided.
TypeError: if mu and sigma (resp. sigma_chol) are different dtypes.
TypeError: If `mu` and `cov` are different dtypes.
"""
if (sigma is None) == (sigma_chol is None):
raise ValueError("Exactly one of sigma and sigma_chol must be provided")
self._strict_statistics = strict_statistics
self._strict = strict
with ops.name_scope(name):
with ops.op_scope([mu] + cov.inputs, "init"):
self._cov = cov
self._mu = self._check_mu(mu)
self._name = name
with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"):
sigma_or_half = sigma_chol if sigma is None else sigma
def _check_mu(self, mu):
"""Return `mu` after validity checks and possibly with assertations."""
mu = ops.convert_to_tensor(mu)
cov = self._cov
mu = ops.convert_to_tensor(mu)
sigma_or_half = ops.convert_to_tensor(sigma_or_half)
if mu.dtype != cov.dtype:
raise TypeError(
"mu and cov must have the same dtype. Found mu.dtype = %s, "
"cov.dtype = %s"
% (mu.dtype, cov.dtype))
if not self.strict:
return mu
else:
assert_compatible_shapes = control_flow_ops.group(
check_ops.assert_equal(
array_ops.rank(mu) + 1,
cov.rank(),
data=["mu should have rank 1 less than cov. Found: rank(mu) = ",
array_ops.rank(mu), " rank(cov) = ", cov.rank()],
),
check_ops.assert_equal(
array_ops.shape(mu),
cov.vector_shape(),
data=["mu.shape and cov.shape[:-1] should match. "
"Found: shape(mu) = "
, array_ops.shape(mu), " shape(cov) = ", cov.shape()],
),
)
return control_flow_ops.with_dependencies([assert_compatible_shapes], mu)
contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half))
@property
def strict(self):
"""Boolean describing behavior on invalid input."""
return self._strict
with ops.control_dependencies([
_assert_compatible_shapes(mu, sigma_or_half)]):
mu = array_ops.identity(mu, name="mu")
# Store the dimensionality of the MVNs
self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1)
if sigma_chol is not None:
# Ensure we only keep the lower triangular part.
sigma_chol = array_ops.batch_matrix_band_part(
sigma_chol, num_lower=-1, num_upper=0)
log_sigma_det = _log_determinant_from_sigma_chol(sigma_chol)
with ops.control_dependencies([
_assert_batch_positive_definite(sigma_chol)]):
self._sigma = math_ops.batch_matmul(
sigma_chol, sigma_chol, adj_y=True, name="sigma")
self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
self._log_sigma_det = array_ops.identity(
log_sigma_det, "log_sigma_det")
self._mu = array_ops.identity(mu, "mu")
else: # sigma is not None
sigma_chol = linalg_ops.batch_cholesky(sigma)
log_sigma_det = _log_determinant_from_sigma_chol(sigma_chol)
# batch_cholesky checks for PSD; so we can just use it here.
with ops.control_dependencies([sigma_chol]):
self._sigma = array_ops.identity(sigma, "sigma")
self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
self._log_sigma_det = array_ops.identity(
log_sigma_det, "log_sigma_det")
self._mu = array_ops.identity(mu, "mu")
@property
def strict_statistics(self):
"""Boolean describing behavior when a stat is undefined for batch member."""
return self._strict_statistics
@property
def dtype(self):
return self._mu.dtype
def get_event_shape(self):
"""`TensorShape` available at graph construction time."""
# Recall _check_mu ensures mu and self._cov have same batch shape.
return self._cov.get_shape()[-1:]
def event_shape(self, name="event_shape"):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`."""
# Recall _check_mu ensures mu and self._cov have same batch shape.
with ops.name_scope(self.name):
with ops.op_scope(self._cov.inputs, name):
return array_ops.pack([self._cov.vector_space_dimension()])
def batch_shape(self, name="batch_shape"):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`."""
# Recall _check_mu ensures mu and self._cov have same batch shape.
with ops.name_scope(self.name):
with ops.op_scope(self._cov.inputs, name):
return self._cov.batch_shape()
def get_batch_shape(self):
"""`TensorShape` available at graph construction time."""
# Recall _check_mu ensures mu and self._cov have same batch shape.
return self._cov.get_batch_shape()
@property
def mu(self):
return self._mu
@property
def sigma(self):
return self._sigma
"""Dense (batch) covariance matrix, if available."""
with ops.name_scope(self.name):
return self._cov.to_dense()
@property
def mean(self):
return self._mu
def mean(self, name="mean"):
"""Mean of each batch member."""
with ops.name_scope(self.name):
with ops.op_scope([self._mu], name):
return array_ops.identity(self._mu)
@property
def sigma_det(self):
return math_ops.exp(self._log_sigma_det)
def mode(self, name="mode"):
"""Mode of each batch member."""
with ops.name_scope(self.name):
with ops.op_scope([self._mu], name):
return array_ops.identity(self._mu)
def log_pdf(self, x, name=None):
def variance(self, name="variance"):
"""Variance of each batch member."""
with ops.name_scope(self.name):
return self.sigma
def log_sigma_det(self, name="log_sigma_det"):
"""Log of determinant of covariance matrix."""
with ops.name_scope(self.name):
with ops.op_scope(self._cov.inputs, name):
return self._cov.log_det()
def sigma_det(self, name="sigma_det"):
"""Determinant of covariance matrix."""
with ops.name_scope(self.name):
with ops.op_scope(self._cov.inputs, name):
return math_ops.exp(self._cov.log_det())
def log_pdf(self, x, name="log_pdf"):
"""Log pdf of observations `x` given these Multivariate Normals.
Args:
@ -262,120 +246,25 @@ class MultivariateNormal(object):
Returns:
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
"""
with ops.op_scope(
[self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"):
x = ops.convert_to_tensor(x)
contrib_tensor_util.assert_same_float_dtype((self._mu, x))
with ops.name_scope(self.name):
with ops.op_scope([self._mu, x] + self._cov.inputs, name):
x = ops.convert_to_tensor(x)
contrib_tensor_util.assert_same_float_dtype((self._mu, x))
x_centered = x - self.mu
x_centered = x - self.mu
x_whitened_norm = self._cov.inv_quadratic_form(x_centered)
log_sigma_det = self.log_sigma_det()
x_rank = array_ops.rank(x_centered)
sigma_rank = array_ops.rank(self._sigma_chol)
log_two_pi = constant_op.constant(
math.log(2 * math.pi), dtype=self.dtype)
k = math_ops.cast(self._cov.vector_space_dimension(), self.dtype)
log_pdf_value = -(log_sigma_det + k * log_two_pi + x_whitened_norm) / 2
x_rank_vec = array_ops.pack([x_rank])
sigma_rank_vec = array_ops.pack([sigma_rank])
x_shape = array_ops.shape(x_centered)
output_static_shape = x_centered.get_shape()[:-1]
log_pdf_value.set_shape(output_static_shape)
return log_pdf_value
# sigma_chol is shaped [D, E, F, ..., k, k]
# x_centered shape is one of:
# [D, E, F, ..., k], or [F, ..., k], or
# [A, B, C, D, E, F, ..., k]
# and we need to convert x_centered to shape:
# [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist)
# then transpose and reshape x_whitened back to one of the shapes:
# [D, E, F, ..., k], or [1, 1, F, ..., k], or
# [A, B, C, D, E, F, ..., k]
# Note that if rank(x) <= rank(sigma) - 1, the first dimensions of
# x_centered will match sigma exactly because x_centered = x - self.mu.
# This helper handles the case where rank(x_centered) < rank(sigma)
def _broadcast_x_not_higher_rank_than_sigma():
return array_ops.reshape(
x_centered,
array_ops.concat(
# Reshape to ones(deficient x rank) + x_shape + [1]
0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]),
dtype=x_rank.dtype),
x_shape,
[1])))
# These helpers handle the case where rank(x_centered) >= rank(sigma)
def _broadcast_x_higher_rank_than_sigma():
x_shape_left = array_ops.slice(
x_shape, [0], sigma_rank_vec - 1)
x_shape_right = array_ops.slice(
x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
x_shape_perm = array_ops.concat(
0, (math_ops.range(sigma_rank - 1, x_rank),
math_ops.range(0, sigma_rank - 1)))
return array_ops.reshape(
# Convert to [D, E, F, ..., k, B, C]
array_ops.transpose(
x_centered, perm=x_shape_perm),
# Reshape to [D, E, F, ..., k, B*C]
array_ops.concat(
0, (x_shape_right,
array_ops.pack([
math_ops.reduce_prod(x_shape_left, 0)]))))
def _unbroadcast_x_higher_rank_than_sigma():
x_shape_left = array_ops.slice(
x_shape, [0], sigma_rank_vec - 1)
x_shape_right = array_ops.slice(
x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
x_shape_perm = array_ops.concat(
0, (math_ops.range(sigma_rank - 1, x_rank),
math_ops.range(0, sigma_rank - 1)))
return array_ops.transpose(
# [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k]
array_ops.reshape(
# convert to [D, E, F, ..., k, B, C]
x_whitened_broadcast,
array_ops.concat(0, (x_shape_right, x_shape_left))),
perm=x_shape_perm)
# Step 1: reshape x_centered
x_centered_broadcast = control_flow_ops.cond(
# x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1]
# or == [F, ..., k] => [1, 1, F, ..., k, 1]
x_rank <= sigma_rank - 1,
_broadcast_x_not_higher_rank_than_sigma,
# x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C]
_broadcast_x_higher_rank_than_sigma)
x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve(
self._sigma_chol, x_centered_broadcast)
# Reshape x_whitened_broadcast back to x_whitened
x_whitened = control_flow_ops.cond(
x_rank <= sigma_rank - 1,
lambda: array_ops.reshape(x_whitened_broadcast, x_shape),
_unbroadcast_x_higher_rank_than_sigma)
x_whitened = array_ops.expand_dims(x_whitened, -1)
# Reshape x_whitened to contain row vectors
# Returns a batchwise scalar
x_whitened_norm = math_ops.batch_matmul(
x_whitened, x_whitened, adj_x=True)
x_whitened_norm = control_flow_ops.cond(
x_rank <= sigma_rank - 1,
lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]),
lambda: array_ops.squeeze(x_whitened_norm, [-1]))
log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype)
k = math_ops.cast(self._k, self.dtype)
log_pdf_value = (
-self._log_sigma_det -(k * log_two_pi) - x_whitened_norm) / 2
final_shaped_value = control_flow_ops.cond(
x_rank <= sigma_rank - 1,
lambda: log_pdf_value,
lambda: array_ops.squeeze(log_pdf_value, [-1]))
output_static_shape = x_centered.get_shape()[:-1]
final_shaped_value.set_shape(output_static_shape)
return final_shaped_value
def pdf(self, x, name=None):
def pdf(self, x, name="pdf"):
"""The PDF of observations `x` under these Multivariate Normals.
Args:
@ -385,11 +274,11 @@ class MultivariateNormal(object):
Returns:
pdf: tensor of dtype `dtype`, the pdf values of `x`.
"""
with ops.op_scope(
[self._mu, self._sigma_chol, x], name, "MultivariateNormalPdf"):
return math_ops.exp(self.log_pdf(x))
with ops.name_scope(self.name):
with ops.op_scope([self._mu, x] + self._cov.inputs, name):
return math_ops.exp(self.log_pdf(x))
def entropy(self, name=None):
def entropy(self, name="entropy"):
"""The entropies of these Multivariate Normals.
Args:
@ -398,19 +287,19 @@ class MultivariateNormal(object):
Returns:
entropy: tensor of dtype `dtype`, the entropies.
"""
with ops.op_scope(
[self._mu, self._sigma_chol], name, "MultivariateNormalEntropy"):
one_plus_log_two_pi = constant_op.constant(
1 + math.log(2 * math.pi), dtype=self.dtype)
with ops.name_scope(self.name):
with ops.op_scope([self._mu] + self._cov.inputs, name):
log_sigma_det = self.log_sigma_det()
one_plus_log_two_pi = constant_op.constant(1 + math.log(2 * math.pi),
dtype=self.dtype)
# Use broadcasting rules to calculate the full broadcast sigma.
k = math_ops.cast(self._k, dtype=self.dtype)
entropy_value = (
k * one_plus_log_two_pi + self._log_sigma_det) / 2
entropy_value.set_shape(self._log_sigma_det.get_shape())
return entropy_value
# Use broadcasting rules to calculate the full broadcast sigma.
k = math_ops.cast(self._cov.vector_space_dimension(), dtype=self.dtype)
entropy_value = (k * one_plus_log_two_pi + log_sigma_det) / 2
entropy_value.set_shape(log_sigma_det.get_shape())
return entropy_value
def sample(self, n, seed=None, name=None):
def sample(self, n, seed=None, name="sample"):
"""Sample `n` observations from the Multivariate Normal Distributions.
Args:
@ -422,43 +311,203 @@ class MultivariateNormal(object):
samples: `[n, ...]`, a `Tensor` of `n` samples for each
of the distributions determined by broadcasting the hyperparameters.
"""
with ops.op_scope(
[self._mu, self._sigma_chol, n], name, "MultivariateNormalSample"):
# TODO(ebrevdo): Is there a better way to get broadcast_shape?
broadcast_shape = self.mu.get_shape()
n = ops.convert_to_tensor(n)
sigma_shape_left = array_ops.slice(
array_ops.shape(self._sigma_chol),
[0], array_ops.pack([array_ops.rank(self._sigma_chol) - 2]))
with ops.name_scope(self.name):
with ops.op_scope([self._mu, n] + self._cov.inputs, name):
# Recall _check_mu ensures mu and self._cov have same batch shape.
broadcast_shape = self.mu.get_shape()
n = ops.convert_to_tensor(n)
k_n = array_ops.pack([self._k, n])
shape = array_ops.concat(0, [sigma_shape_left, k_n])
white_samples = random_ops.random_normal(
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
shape = array_ops.concat(0, [self._cov.vector_shape(), [n]])
white_samples = random_ops.random_normal(shape=shape,
mean=0,
stddev=1,
dtype=self.dtype,
seed=seed)
correlated_samples = math_ops.batch_matmul(
self._sigma_chol, white_samples)
correlated_samples = self._cov.sqrt_matmul(white_samples)
# Move the last dimension to the front
perm = array_ops.concat(
0,
(array_ops.pack([array_ops.rank(correlated_samples) - 1]),
math_ops.range(0, array_ops.rank(correlated_samples) - 1)))
# Move the last dimension to the front
perm = array_ops.concat(0, (
array_ops.pack([array_ops.rank(correlated_samples) - 1]),
math_ops.range(0, array_ops.rank(correlated_samples) - 1)))
# TODO(ebrevdo): Once we get a proper tensor contraction op,
# perform the inner product using that instead of batch_matmul
# and this slow transpose can go away!
correlated_samples = array_ops.transpose(correlated_samples, perm)
# TODO(ebrevdo): Once we get a proper tensor contraction op,
# perform the inner product using that instead of batch_matmul
# and this slow transpose can go away!
correlated_samples = array_ops.transpose(correlated_samples, perm)
samples = correlated_samples + self.mu
samples = correlated_samples + self.mu
# Provide some hints to shape inference
n_val = tensor_util.constant_value(n)
final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
samples.set_shape(final_shape)
# Provide some hints to shape inference
n_val = tensor_util.constant_value(n)
final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
samples.set_shape(final_shape)
return samples
return samples
@property
def is_reparameterized(self):
return True
@property
def name(self):
return self._name
class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
"""The multivariate normal distribution on `R^k`.
This distribution is defined by a 1-D mean `mu` and a Cholesky factor `chol`.
Providing the Cholesky factor allows for `O(k^2)` pdf evaluation and sampling,
and requires `O(k^2)` storage.
#### Mathematical details
The PDF of this distribution is:
```
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
```
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
of length `k`, and a covariance matrix of shape `k x k`.
Extra leading dimensions, if provided, allow for batches.
```python
# Initialize a single 3-variate Gaussian with diagonal covariance.
mu = [1, 2, 3.]
chol = [[1, 0, 0], [0, 3, 0], [0, 0, 2]]
dist = tf.contrib.distributions.MultivariateNormalCholesky(mu, chol)
# Evaluate this on an observation in R^3, returning a scalar.
dist.pdf([-1, 0, 1])
# Initialize a batch of two 3-variate Gaussians.
mu = [[1, 2, 3], [11, 22, 33]]
chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal.
dist = tf.contrib.distributions.MultivariateNormalCholesky(mu, chol)
# Evaluate this on a two observations, each in R^3, returning a length two
# tensor.
x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3.
dist.pdf(x)
```
Trainable (batch) Choesky matrices can be created with
`tf.contrib.distributions.batch_matrix_diag_transform()`
"""
def __init__(
self,
mu,
chol,
strict=True,
strict_statistics=True,
name="MultivariateNormalCholesky"):
"""Multivariate Normal distributions on `R^k`.
User must provide means `mu` and `chol` which holds the (batch) Cholesky
factors `S`, such that the covariance of each batch member is `S S^*`.
Args:
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
`b >= 0`.
chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
`[N1,...,Nb, k, k]`.
strict: Whether to validate input with asserts. If `strict` is `False`,
and the inputs are invalid, correct behavior is not guaranteed.
strict_statistics: Boolean, default True. If True, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If False, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
TypeError: If `mu` and `chol` are different dtypes.
"""
cov = operator_pd_cholesky.OperatorPDCholesky(chol, verify_pd=strict)
super(MultivariateNormalCholesky, self).__init__(
mu, cov, strict_statistics=strict_statistics, strict=strict, name=name)
class MultivariateNormalFull(MultivariateNormalOperatorPD):
"""The multivariate normal distribution on `R^k`.
This distribution is defined by a 1-D mean `mu` and covariance matrix `sigma`.
Evaluation of the pdf, determinant, and sampling are all `O(k^3)` operations.
#### Mathematical details
The PDF of this distribution is:
```
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
```
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
#### Examples
A single multi-variate Gaussian distribution is defined by a vector of means
of length `k`, and a covariance matrix of shape `k x k`.
Extra leading dimensions, if provided, allow for batches.
```python
# Initialize a single 3-variate Gaussian with diagonal covariance.
mu = [1, 2, 3.]
sigma = [[1, 0, 0], [0, 3, 0], [0, 0, 2.]]
dist = tf.contrib.distributions.MultivariateNormalFull(mu, chol)
# Evaluate this on an observation in R^3, returning a scalar.
dist.pdf([-1, 0, 1])
# Initialize a batch of two 3-variate Gaussians.
mu = [[1, 2, 3], [11, 22, 33.]]
sigma = ... # shape 2 x 3 x 3, positive definite.
dist = tf.contrib.distributions.MultivariateNormalFull(mu, sigma)
# Evaluate this on a two observations, each in R^3, returning a length two
# tensor.
x = [[-1, 0, 1], [-11, 0, 11.]] # Shape 2 x 3.
dist.pdf(x)
```
"""
def __init__(
self,
mu,
sigma,
strict=True,
strict_statistics=True,
name="MultivariateNormalFull"):
"""Multivariate Normal distributions on `R^k`.
User must provide means `mu` and `sigma`, the mean and covariance.
Args:
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
`b >= 0`.
sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
`[N1,...,Nb, k, k]`.
strict: Whether to validate input with asserts. If `strict` is `False`,
and the inputs are invalid, correct behavior is not guaranteed.
strict_statistics: Boolean, default True. If True, raise an exception if
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
If False, batch members with valid parameters leading to undefined
statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
Raises:
TypeError: If `mu` and `sigma` are different dtypes.
"""
cov = operator_pd_full.OperatorPDFull(sigma, verify_pd=strict)
super(MultivariateNormalFull, self).__init__(
mu, cov, strict_statistics=strict_statistics, strict=strict, name=name)

View File

@ -0,0 +1,284 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for symmetric positive definite operator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
__all__ = [
'OperatorPDBase',
]
@six.add_metaclass(abc.ABCMeta)
class OperatorPDBase(object):
"""Class representing a (batch) of positive definite matrices `A`.
This class provides access to functions of a (batch) symmetric positive
definite (PD) matrix, without the need to materialize them. In other words,
this provides means to do "matrix free" computations.
For example, `my_operator.matmul(x)` computes the result of matrix
multiplication, and this class is free to do this computation with or without
ever materializing a matrix.
In practice, this operator represents a (batch) matrix `A` with shape
`[N1,...,Nb, k, k]` for some `b >= 0`. The first `b` indices index a
batch member. For every batch index `(n1,...,nb)`, `A[n1,...,nb, : :]` is
a `k x k` matrix. Again, this matrix `A` may not be materialized, but for
purposes of broadcasting this shape will be relevant.
Since `A` is (batch) positive definite, it has a (or several) square roots `S`
such that `A = SS^T`.
For example, if `MyOperator` inherits from `OperatorPDBase`, the user can do
```python
operator = MyOperator(...) # Initialize with some tensors.
operator.log_det()
# Compute the quadratic form x^T A^{-1} x for vector x.
x = ... # some shape [..., k] tensor
operator.inv_quadratic_form(x)
# Matrix multiplication by the square root, S w.
# If w is iid normal, S w has covariance A.
w = ... # some shape [..., k, L] tensor, L >= 1
operator.sqrt_matmul(w)
```
The above three methods, `log_det`, `inv_quadratic_form`, and
`sqrt_matmul` provide "all" that is necessary to use a covariance matrix
in a multi-variate normal distribution. See the class `MVNOperatorPD`.
"""
@abc.abstractproperty
def name(self):
"""String name identifying this `Operator`."""
# return self._name
pass
@abc.abstractproperty
def verify_pd(self):
"""Whether to verify that this `Operator` is positive definite."""
# return self._verify_pd
pass
@abc.abstractproperty
def dtype(self):
"""Data type of matrix elements of `A`."""
pass
def inv_quadratic_form(self, x, name='inv_quadratic_form'):
"""Compute the quadratic form: x^T A^{-1} x.
Args:
x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
as self.
name: A name scope to use for ops added by this method.
Returns:
`Tensor` holding the square of the norm induced by inverse of `A`. For
every broadcast batch member.
"""
# with ops.name_scope(self.name):
# with ops.op_scope([x] + self.inputs, name):
# # ... your code here
pass
def det(self, name='det'):
"""Determinant for every batch member.
Args:
name: A name scope to use for ops added by this method.
Returns:
Determinant for every batch member.
"""
# Derived classes are encouraged to implement log_det() (since it is
# usually more stable), and then det() comes for free.
with ops.name_scope(self.name):
with ops.op_scope(self.inputs, name):
return math_ops.exp(self.log_det())
def log_det(self, name='log_det'):
"""Log of the determinant for every batch member.
Args:
name: A name scope to use for ops added by this method.
Returns:
Logarithm of determinant for every batch member.
"""
# with ops.name_scope(self.name):
# with ops.op_scope(self.inputs, name):
# # ... your code here
pass
@abc.abstractproperty
def inputs(self):
"""List of tensors that were provided as initialization inputs."""
pass
def sqrt_matmul(self, x, name='sqrt_matmul'):
"""Left (batch) matmul `x` by a sqrt of this matrix: `Sx` where `A = S S^T.
Args:
x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
as self.
name: A name scope to use for ops added by this method.
Returns:
Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`.
"""
# with ops.name_scope(self.name):
# with ops.op_scope([x] + self.inputs, name):
# # ... your code here
pass
@abc.abstractmethod
def get_shape(self):
"""`TensorShape` giving static shape."""
pass
def get_batch_shape(self):
"""`TensorShape` with batch shape."""
return self.get_shape()[:-2]
def get_vector_shape(self):
"""`TensorShape` of vectors this operator will work with."""
return self.get_shape()[:-1]
@abc.abstractmethod
def shape(self, name='shape'):
"""Equivalent to `tf.shape(A).` Equal to `[N1,...,Nb, k, k]`, `b >= 0`.
Args:
name: A name scope to use for ops added by this method.
Returns:
`int32` `Tensor`
"""
# with ops.name_scope(self.name):
# with ops.op_scope(self.inputs, name):
# # ... your code here
pass
def rank(self, name='rank'):
"""Tensor rank. Equivalent to `tf.rank(A)`. Will equal `b + 2`.
If this operator represents the batch matrix `A` with
`A.shape = [N1,...,Nb, k, k]`, the `rank` is `b + 2`.
Args:
name: A name scope to use for ops added by this method.
Returns:
`int32` `Tensor`
"""
# Derived classes get this "for free" once .shape() is implemented.
with ops.name_scope(self.name):
with ops.op_scope(self.inputs, name):
return array_ops.shape(self.shape())[0]
def batch_shape(self, name='batch_shape'):
"""Shape of batches associated with this operator.
If this operator represents the batch matrix `A` with
`A.shape = [N1,...,Nb, k, k]`, the `batch_shape` is `[N1,...,Nb]`.
Args:
name: A name scope to use for ops added by this method.
Returns:
`int32` `Tensor`
"""
# Derived classes get this "for free" once .shape() is implemented.
with ops.name_scope(self.name):
with ops.op_scope(self.inputs, name):
end = array_ops.pack([self.rank() - 2])
return array_ops.slice(self.shape(), [0], end)
def vector_shape(self, name='vector_shape'):
"""Shape of (batch) vectors that this (batch) matrix will multiply.
If this operator represents the batch matrix `A` with
`A.shape = [N1,...,Nb, k, k]`, the `vector_shape` is `[N1,...,Nb, k]`.
Args:
name: A name scope to use for ops added by this method.
Returns:
`int32` `Tensor`
"""
# Derived classes get this "for free" once .shape() is implemented.
with ops.name_scope(self.name):
with ops.op_scope(self.inputs, name):
return array_ops.slice(self.shape(), [0], [self.rank() - 1])
def vector_space_dimension(self, name='vector_space_dimension'):
"""Dimension of vector space on which this acts. The `k` in `R^k`.
If this operator represents the batch matrix `A` with
`A.shape = [N1,...,Nb, k, k]`, the `vector_space_dimension` is `k`.
Args:
name: A name scope to use for ops added by this method.
Returns:
`int32` `Tensor`
"""
# Derived classes get this "for free" once .shape() is implemented.
with ops.name_scope(self.name):
with ops.op_scope(self.inputs, name):
return array_ops.gather(self.shape(), self.rank() - 1)
def matmul(self, x, name='matmul'):
"""Left multiply `x` by this operator.
Args:
x: Shape `[N1,...,Nb, k, L]` `Tensor` with same `dtype` as this operator
name: A name to give this `Op`.
Returns:
A result equivalent to `tf.batch_matmul(self.to_dense(), x)`.
"""
# with ops.name_scope(self.name):
# with ops.op_scope([x] + self.inputs, name):
# # ... your code here
raise NotImplementedError('This operator has no batch_matmul Op.')
def to_dense(self, name='to_dense'):
"""Return a dense (batch) matrix representing this operator."""
# with ops.name_scope(self.name):
# with ops.op_scope(self.inputs, name):
# # ... your code here
raise NotImplementedError('This operator has no dense representation.')
def to_dense_sqrt(self, name='to_dense_sqrt'):
"""Return a dense (batch) matrix representing sqrt of this operator."""
# with ops.name_scope(self.name):
# with ops.op_scope(self.inputs, name):
# # ... your code here
raise NotImplementedError('This operator has no dense sqrt representation.')

View File

@ -0,0 +1,431 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Symmetric positive definite (PD) Operator defined by a Cholesky factor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import operator_pd
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
__all__ = ['OperatorPDCholesky', 'batch_matrix_diag_transform']
class OperatorPDCholesky(operator_pd.OperatorPDBase):
"""Class representing a (batch) of positive definite matrices `A`.
This class provides access to functions of a batch of symmetric positive
definite (PD) matrices `A` in `R^{k x k}` defined by Cholesky factor(s).
Determinants and solves are `O(k^2)`.
In practice, this operator represents a (batch) matrix `A` with shape
`[N1,...,Nb, k, k]` for some `b >= 0`. The first `b` indices designate a
batch member. For every batch member `(n1,...,nb)`, `A[n1,...,nb, : :]` is
a `k x k` matrix.
Since `A` is (batch) positive definite, it has a (or several) square roots `S`
such that `A = SS^T`.
For example,
```python
distributions = tf.contrib.distributions
chol = [[1.0, 0.0], [1.0, 2.0]]
operator = OperatorPDCholesky(chol)
operator.log_det()
# Compute the quadratic form x^T A^{-1} x for vector x.
x = [1.0, 2.0]
operator.inv_quadratic_form(x)
# Matrix multiplication by the square root, S w.
# If w is iid normal, S w has covariance A.
w = [[1.0], [2.0]]
operator.sqrt_matmul(w)
```
The above three methods, `log_det`, `inv_quadratic_form`, and
`sqrt_matmul` provide "all" that is necessary to use a covariance matrix
in a multi-variate normal distribution. See the class `MVNOperatorPD`.
"""
def __init__(self, chol, verify_pd=True, name='OperatorPDCholesky'):
"""Initialize an OperatorPDCholesky.
Args:
chol: Shape `[N1,...,Nb, k, k]` tensor with `b >= 0`, `k >= 1`, and
positive diagonal elements. The strict upper triangle of `chol` is
never used, and the user may set these elements to zero, or ignore them.
verify_pd: Whether to check that `chol` has positive diagonal (this is
equivalent to it being a Cholesky factor of a symmetric positive
definite matrix. If `verify_pd` is `False`, correct behavior is not
guaranteed.
name: A name to prepend to all ops created by this class.
"""
self._verify_pd = verify_pd
self._name = name
with ops.name_scope(name):
with ops.op_scope([chol], 'init'):
self._diag = array_ops.batch_matrix_diag_part(chol)
self._chol = self._check_chol(chol)
@property
def verify_pd(self):
"""Whether to verify that this `Operator` is positive definite."""
return self._verify_pd
@property
def name(self):
return self._name
@property
def dtype(self):
return self._chol.dtype
def inv_quadratic_form(self, x, name='inv_quadratic_form'):
"""Compute the induced vector norm (squared): ||x||^2 := x^T A^{-1} x.
For every batch member, this is done in `O(k^2)` complexity. The efficiency
depends on the shape of `x`.
* If `x.shape = [M1,...,Mm, N1,...,Nb, k]`, `m >= 0`, and
`self.shape = [N1,...,Nb, k, k]`, `x` will be reshaped and the
initialization matrix `chol` does not need to be copied.
* Otherwise, data will be broadcast and copied.
Args:
x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
as self. If the batch dimensions of `x` do not match exactly with those
of self, `x` and/or self's Cholesky factor will broadcast to match, and
the resultant set of linear systems are solved independently. This may
result in inefficient operation.
name: A name scope to use for ops added by this method.
Returns:
`Tensor` holding the square of the norm induced by inverse of `A`. For
every broadcast batch member.
"""
with ops.name_scope(self.name):
with ops.op_scope([x] + self.inputs, name):
x = ops.convert_to_tensor(x, name='x')
# Boolean, True iff x.shape = [M1,...,Mm] + chol.shape[:-1].
should_flip = self._should_flip(x)
x_whitened = control_flow_ops.cond(
should_flip,
lambda: self._x_whitened_if_should_flip(x),
lambda: self._x_whitened_if_no_flip(x))
# batch version of: || L^{-1} x||^2
x_whitened_norm = math_ops.reduce_sum(
math_ops.square(x_whitened),
reduction_indices=[-1])
# Set the final shape by making a dummy tensor that will never be
# evaluated.
chol_without_final_dim = math_ops.reduce_sum(
self._chol, reduction_indices=[-1])
final_shape = (x + chol_without_final_dim).get_shape()[:-1]
x_whitened_norm.set_shape(final_shape)
return x_whitened_norm
def _should_flip(self, x):
"""Return boolean tensor telling whether `x` should be flipped."""
# We "flip" (see self._flip_front_dims_to_back) iff
# chol.shape = [N1,...,Nn, k, k]
# x.shape = [M1,...,Mm, N1,...,Nn, k]
x_shape = array_ops.shape(x)
x_rank = array_ops.rank(x)
# If m <= 0, we should not flip.
m = x_rank + 1 - self.rank()
def result_if_m_positive():
x_shape_right = array_ops.slice(x_shape, [m], [x_rank - m])
return math_ops.reduce_all(
math_ops.equal(x_shape_right, self.vector_shape()))
return control_flow_ops.cond(
m > 0,
result_if_m_positive,
lambda: ops.convert_to_tensor(False))
def _x_whitened_if_no_flip(self, x):
"""x_whitened in the event of no flip."""
# Tensors to use if x and chol have same shape, or a shape that must be
# broadcast to match.
chol_bcast, x_bcast = self._get_chol_and_x_compatible_shape(x)
# batch version of: L^{-1} x
# Note that here x_bcast has trailing dims of (k, 1), for "1" system of k
# linear equations. This is the form used by the solver.
x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve(
chol_bcast, x_bcast)
x_whitened = array_ops.squeeze(x_whitened_expanded, squeeze_dims=[-1])
return x_whitened
def _x_whitened_if_should_flip(self, x):
# Tensor to use if x.shape = [M1,...,Mm] + chol.shape[:-1],
# which is common if x was sampled.
x_flipped = self._flip_front_dims_to_back(x)
# batch version of: L^{-1} x
x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve(
self._chol, x_flipped)
return self._unfip_back_dims_to_front(
x_whitened_expanded,
array_ops.shape(x),
x.get_shape())
def _flip_front_dims_to_back(self, x):
"""Flip x to make x.shape = chol.shape[:-1] + [M1*...*Mr]."""
# E.g. suppose
# chol.shape = [N1,...,Nn, k, k]
# x.shape = [M1,...,Mm, N1,...,Nn, k]
# Then we want to return x_flipped where
# x_flipped.shape = [N1,...,Nn, k, M1*...*Mm].
x_shape = array_ops.shape(x)
x_rank = array_ops.rank(x)
m = x_rank + 1 - self.rank()
x_shape_left = array_ops.slice(x_shape, [0], [m])
# Permutation corresponding to [N1,...,Nn, k, M1,...,Mm]
perm = array_ops.concat(
0, (math_ops.range(m, x_rank), math_ops.range(0, m)))
x_permuted = array_ops.transpose(x, perm=perm)
# Now that things are ordered correctly, condense the last dimensions.
# condensed_shape = [M1*...*Mm]
condensed_shape = array_ops.pack([math_ops.reduce_prod(x_shape_left)])
new_shape = array_ops.concat(0, (self.vector_shape(), condensed_shape))
return array_ops.reshape(x_permuted, new_shape)
def _unfip_back_dims_to_front(self, x_flipped, x_shape, x_get_shape):
# E.g. suppose that originally
# chol.shape = [N1,...,Nn, k, k]
# x.shape = [M1,...,Mm, N1,...,Nn, k]
# Then we have flipped the dims so that
# x_flipped.shape = [N1,...,Nn, k, M1*...*Mm].
# We want to return x with the original shape.
rank = array_ops.rank(x_flipped)
# Permutation corresponding to [M1*...*Mm, N1,...,Nn, k]
perm = array_ops.concat(
0, (math_ops.range(rank - 1, rank), math_ops.range(0, rank - 1)))
x_with_end_at_beginning = array_ops.transpose(x_flipped, perm=perm)
x = array_ops.reshape(x_with_end_at_beginning, x_shape)
return x
def _get_chol_and_x_compatible_shape(self, x):
"""Return self.chol and x, (possibly) broadcast to compatible shape."""
# x and chol are "compatible" if their shape matches except for the last two
# dimensions of chol are [k, k], and the last two of x are [k, 1].
# E.g. x.shape = [A, B, k, 1], and chol.shape = [A, B, k, k]
# This is required for the batch_triangular_solve, which does not broadcast.
# TODO(langmore) This broadcast replicates matrices unnecesarily! In the
# case where
# x.shape = [M1,...,Mr, N1,...,Nb, k], and chol.shape = [N1,...,Nb, k, k]
# (which is common if x was sampled), the front dimensions of x can be
# "flipped" to the end, making
# x_flipped.shape = [N1,...,Nb, k, M1*...*Mr],
# and this can be handled by the linear solvers. This is preferred, because
# it does not replicate the matrix, or create any new data.
# We assume x starts without the trailing singleton dimension, e.g.
# x.shape = [B, k].
chol = self._chol
with ops.op_scope([x] + self.inputs, 'get_chol_and_x_compatible_shape'):
# If we determine statically that shapes match, we're done.
if x.get_shape() == chol.get_shape()[:-1]:
x_expanded = array_ops.expand_dims(x, -1)
return chol, x_expanded
# Dynamic check if shapes match or not.
vector_shape = self.vector_shape() # Shape of chol minus last dim.
are_same_rank = math_ops.equal(
array_ops.rank(x), array_ops.rank(vector_shape))
def shapes_match_if_same_rank():
return math_ops.reduce_all(math_ops.equal(
array_ops.shape(x), vector_shape))
shapes_match = control_flow_ops.cond(are_same_rank,
shapes_match_if_same_rank,
lambda: ops.convert_to_tensor(False))
# Make tensors (never instantiated) holding the broadcast shape.
# matrix_broadcast_dummy is the shape we will broadcast chol to.
matrix_bcast_dummy = chol + array_ops.expand_dims(x, -1)
# vector_bcast_dummy is the shape we will bcast x to, before we expand it.
chol_minus_last_dim = math_ops.reduce_sum(chol, reduction_indices=[-1])
vector_bcast_dummy = x + chol_minus_last_dim
chol_bcast = chol + array_ops.zeros_like(matrix_bcast_dummy)
x_bcast = x + array_ops.zeros_like(vector_bcast_dummy)
chol_result = control_flow_ops.cond(shapes_match, lambda: chol,
lambda: chol_bcast)
chol_result.set_shape(matrix_bcast_dummy.get_shape())
x_result = control_flow_ops.cond(shapes_match, lambda: x, lambda: x_bcast)
x_result.set_shape(vector_bcast_dummy.get_shape())
x_expanded = array_ops.expand_dims(x_result, -1)
return chol_result, x_expanded
def log_det(self, name='log_det'):
"""Log determinant of every batch member."""
with ops.name_scope(self.name):
with ops.op_scope(self.inputs, name):
det = 2.0 * math_ops.reduce_sum(
math_ops.log(self._diag),
reduction_indices=[-1])
det.set_shape(self._chol.get_shape()[:-2])
return det
@property
def inputs(self):
"""List of tensors that were provided as initialization inputs."""
return [self._chol]
def sqrt_matmul(self, x, name='sqrt_matmul'):
"""Left (batch) matmul `x` by a sqrt of this matrix: `Sx` where `A = S S^T.
Args:
x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
as self.
name: A name scope to use for ops added by this method.
Returns:
Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`.
"""
with ops.name_scope(self.name):
with ops.op_scope([x] + self.inputs, name):
chol_lower = array_ops.batch_matrix_band_part(self._chol, -1, 0)
return math_ops.batch_matmul(chol_lower, x)
def get_shape(self):
"""`TensorShape` giving static shape."""
return self._chol.get_shape()
def shape(self, name='shape'):
with ops.name_scope(self.name):
with ops.op_scope(self.inputs, name):
return array_ops.shape(self._chol)
def _check_chol(self, chol):
"""Verify that `chol` is proper."""
chol = ops.convert_to_tensor(chol, name='chol')
if not self.verify_pd:
return chol
shape = array_ops.shape(chol)
rank = array_ops.rank(chol)
is_matrix = check_ops.assert_rank_at_least(chol, 2)
is_square = check_ops.assert_equal(
array_ops.gather(shape, rank - 2), array_ops.gather(shape, rank - 1))
deps = [is_matrix, is_square]
deps.append(check_ops.assert_positive(self._diag))
return control_flow_ops.with_dependencies(deps, chol)
def matmul(self, x, name='matmul'):
"""Left (batch) matrix multiplication of `x` by this operator."""
chol = self._chol
with ops.name_scope(self.name):
with ops.op_scope(self.inputs, name):
a_times_x = math_ops.batch_matmul(chol, x, adj_x=True)
return math_ops.batch_matmul(chol, a_times_x)
def to_dense_sqrt(self, name='to_dense_sqrt'):
"""Return a dense (batch) matrix representing sqrt of this covariance."""
with ops.name_scope(self.name):
with ops.op_scope(self.inputs, name):
return array_ops.identity(self._chol)
def to_dense(self, name='to_dense'):
"""Return a dense (batch) matrix representing this covariance."""
chol = self._chol
with ops.name_scope(self.name):
with ops.op_scope(self.inputs, name):
return math_ops.batch_matmul(chol, chol, adj_y=True)
def batch_matrix_diag_transform(matrix, transform=None, name=None):
"""Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.
Create a trainable covariance defined by a Cholesky factor:
```python
# Transform network layer into 2 x 2 array.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
# Make the diagonal positive. If the upper triangle was zero, this would be a
# valid Cholesky factor.
chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus)
# OperatorPDCholesky ignores the upper triangle.
operator = OperatorPDCholesky(chol)
```
Example of heteroskedastic 2-D linear regression.
```python
# Get a trainable Cholesky factor.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus)
# Get a trainable mean.
mu = tf.contrib.layers.fully_connected(activations, 2)
# This is a fully trainable multivariate normal!
dist = tf.contrib.distributions.MVNCholesky(mu, chol)
# Standard log loss. Minimizing this will "train" mu and chol, and then dist
# will be a distribution predicting labels as multivariate Gaussians.
loss = -1 * tf.reduce_mean(dist.log_pdf(labels))
```
Args:
matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
equal.
transform: Element-wise function mapping `Tensors` to `Tensors`. To
be applied to the diagonal of `matrix`. If `None`, `matrix` is returned
unchanged. Defaults to `None`.
name: A name to give created ops.
Defaults to "batch_matrix_diag_transform".
Returns:
A `Tensor` with same shape and `dtype` as `matrix`.
"""
with ops.op_scope([matrix], name, 'batch_matrix_diag_transform'):
matrix = ops.convert_to_tensor(matrix, name='matrix')
if transform is None:
return matrix
# Replace the diag with transformed diag.
diag = array_ops.batch_matrix_diag_part(matrix)
transformed_diag = transform(diag)
matrix += array_ops.batch_matrix_diag(transformed_diag - diag)
return matrix

View File

@ -0,0 +1,106 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Symmetric positive definite (PD) Operator defined by a full matrix."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
__all__ = [
'OperatorPDFull',
]
class OperatorPDFull(operator_pd_cholesky.OperatorPDCholesky):
"""Class representing a (batch) of positive definite matrices `A`.
This class provides access to functions of a batch of symmetric positive
definite (PD) matrices `A` in `R^{k x k}` defined by dense matrices.
Determinants and solves are `O(k^3)`.
In practice, this operator represents a (batch) matrix `A` with shape
`[N1,...,Nb, k, k]` for some `b >= 0`. The first `b` indices designate a
batch member. For every batch member `(n1,...,nb)`, `A[n1,...,nb, : :]` is
a `k x k` matrix.
Since `A` is (batch) positive definite, it has a (or several) square roots `S`
such that `A = SS^T`.
For example,
```python
distributions = tf.contrib.distributions
matrix = [[1.0, 0.5], [1.0, 2.0]]
operator = OperatorPDFull(matrix)
operator.log_det()
# Compute the quadratic form x^T A^{-1} x for vector x.
x = [1.0, 2.0]
operator.inv_quadratic_form(x)
# Matrix multiplication by the square root, S w.
# If w is iid normal, S w has covariance A.
w = [[1.0], [2.0]]
operator.sqrt_matmul(w)
```
The above three methods, `log_det`, `inv_quadratic_form`, and
`sqrt_matmul` provide "all" that is necessary to use a covariance matrix
in a multi-variate normal distribution. See the class `MVNOperatorPD`.
"""
def __init__(self, matrix, verify_pd=True, name='OperatorPDFull'):
"""Initialize an OperatorPDFull.
Args:
matrix: Shape `[N1,...,Nb, k, k]` tensor with `b >= 0`, `k >= 1`. The
last two dimensions should be `k x k` symmetric positive definite
matrices.
verify_pd: Whether to check that `matrix` is symmetric positive definite.
If `verify_pd` is `False`, correct behavior is not guaranteed.
name: A name to prepend to all ops created by this class.
"""
with ops.name_scope(name):
with ops.op_scope([matrix], 'init'):
matrix = ops.convert_to_tensor(matrix)
# Check symmetric here. Positivity will be verified by checking the
# diagonal of the Cholesky factor inside the parent class. The Cholesky
# factorization .batch_cholesky() does not always fail for non PSD
# matrices, so don't rely on that.
if verify_pd:
matrix = _check_symmetric(matrix)
chol = linalg_ops.batch_cholesky(matrix)
super(OperatorPDFull, self).__init__(chol, verify_pd=verify_pd)
def _check_symmetric(matrix):
rank = array_ops.rank(matrix)
# Create permutation to permute last two dimensions
first_dims = math_ops.range(0, rank - 2)
flipped_last_dims = array_ops.pack([rank - 1, rank - 2])
perm = array_ops.concat(0, (first_dims, flipped_last_dims))
matrix_t = array_ops.transpose(matrix, perm=perm)
return control_flow_ops.with_dependencies(
[check_ops.assert_equal(matrix, matrix_t)], matrix)