From 94cbf42a074128f9c46d980f9fddb47fba3602a6 Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Wed, 29 Jun 2016 12:31:47 -0800 Subject: [PATCH] Adding covariance, covariance_cholesky, mvn_cov to distributions/ These allow alternative representations of covariance matrices, e.g. diagonal/sparse/functional, to be used in a multivariate-normal. Change: 126226644 --- tensorflow/contrib/distributions/BUILD | 33 + tensorflow/contrib/distributions/__init__.py | 24 +- .../python/kernel_tests/mvn_test.py | 304 +++----- .../kernel_tests/operator_pd_cholesky_test.py | 327 +++++++++ .../kernel_tests/operator_pd_full_test.py | 62 ++ .../python/kernel_tests/operator_pd_test.py | 83 +++ .../contrib/distributions/python/ops/mvn.py | 681 ++++++++++-------- .../distributions/python/ops/operator_pd.py | 284 ++++++++ .../python/ops/operator_pd_cholesky.py | 431 +++++++++++ .../python/ops/operator_pd_full.py | 106 +++ 10 files changed, 1817 insertions(+), 518 deletions(-) create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/operator_pd_cholesky_test.py create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/operator_pd_full_test.py create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/operator_pd_test.py create mode 100644 tensorflow/contrib/distributions/python/ops/operator_pd.py create mode 100644 tensorflow/contrib/distributions/python/ops/operator_pd_cholesky.py create mode 100644 tensorflow/contrib/distributions/python/ops/operator_pd_full.py diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 5cdf4b92e43..8f1e5b860a4 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -10,6 +10,39 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//tensorflow:tensorflow.bzl", "cuda_py_tests") +cuda_py_tests( + name = "operator_pd_test", + size = "small", + srcs = ["python/kernel_tests/operator_pd_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "operator_pd_cholesky_test", + size = "small", + srcs = ["python/kernel_tests/operator_pd_cholesky_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "operator_pd_full_test", + size = "small", + srcs = ["python/kernel_tests/operator_pd_full_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + py_library( name = "distributions_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index edc8c78e099..d04693ce983 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -38,9 +38,26 @@ initialized with parameters that define the distributions. ### Multivariate distributions -@@MultivariateNormal +#### Multivariate normal + +@@MultivariateNormalFull +@@MultivariateNormalCholesky + +#### Other multivariate distributions + @@DirichletMultinomial +## Operators allowing for matrix-free methods + +### Positive definite operators + +A matrix is positive definite if it is symmetric with all positive eigenvalues. + +@@OperatorPDBase +@@OperatorPDFull +@@OperatorPDCholesky +@@batch_matrix_diag_transform + ## Posterior inference with conjugate priors. Functions that transform conjugate prior/likelihood pairs to distributions @@ -61,7 +78,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,wildcard-import,line-too-long +# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.contrib.distributions.python.ops.bernoulli import * from tensorflow.contrib.distributions.python.ops.categorical import * @@ -74,5 +91,8 @@ from tensorflow.contrib.distributions.python.ops.kullback_leibler import * from tensorflow.contrib.distributions.python.ops.mvn import * from tensorflow.contrib.distributions.python.ops.normal import * from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * +from tensorflow.contrib.distributions.python.ops.operator_pd import * +from tensorflow.contrib.distributions.python.ops.operator_pd_cholesky import * +from tensorflow.contrib.distributions.python.ops.operator_pd_full import * from tensorflow.contrib.distributions.python.ops.student_t import * from tensorflow.contrib.distributions.python.ops.uniform import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py index 2645e792426..7e5f8a5b6f4 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py @@ -19,249 +19,153 @@ from __future__ import division from __future__ import print_function import numpy as np +from scipy import stats import tensorflow as tf +distributions = tf.contrib.distributions -class MultivariateNormalTest(tf.test.TestCase): + +class MultivariateNormalCholeskyTest(tf.test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(42) + + def _random_chol(self, *shape): + mat = self._rng.rand(*shape) + chol = distributions.batch_matrix_diag_transform( + mat, transform=tf.nn.softplus) + chol = tf.batch_matrix_band_part(chol, -1, 0) + sigma = tf.batch_matmul(chol, chol, adj_y=True) + return chol.eval(), sigma.eval() def testNonmatchingMuSigmaFails(self): - with tf.Session(): - mvn = tf.contrib.distributions.MultivariateNormal( - mu=[1.0, 2.0], - sigma=[[[1.0, 0.0], - [0.0, 1.0]], - [[1.0, 0.0], - [0.0, 1.0]]]) - with self.assertRaisesOpError( - r"Rank of mu should be one less than rank of sigma"): - mvn.mean.eval() + with self.test_session(): + mu = self._rng.rand(2) + chol, _ = self._random_chol(2, 2, 2) + mvn = distributions.MultivariateNormalCholesky(mu, chol) + with self.assertRaisesOpError("mu should have rank 1 less than cov"): + mvn.mean().eval() - mvn = tf.contrib.distributions.MultivariateNormal( - mu=[[1.0], [2.0]], - sigma=[[[1.0, 0.0], - [0.0, 1.0]], - [[1.0, 0.0], - [0.0, 1.0]]]) - with self.assertRaisesOpError( - r"mu.shape and sigma.shape\[\:-1\] must match"): - mvn.mean.eval() + mu = self._rng.rand(2, 1) + chol, _ = self._random_chol(2, 2, 2) + mvn = distributions.MultivariateNormalCholesky(mu, chol) + with self.assertRaisesOpError("mu.shape and cov.shape.*should match"): + mvn.mean().eval() - def testNotPositiveDefiniteSigmaFails(self): - with tf.Session(): - mvn = tf.contrib.distributions.MultivariateNormal( - mu=[[1.0, 2.0], [1.0, 2.0]], - sigma=[[[1.0, 0.0], - [0.0, 1.0]], - [[1.0, 1.0], - [1.0, 1.0]]]) - with self.assertRaisesOpError( - r"LLT decomposition was not successful."): - mvn.mean.eval() - mvn = tf.contrib.distributions.MultivariateNormal( - mu=[[1.0, 2.0], [1.0, 2.0]], - sigma=[[[1.0, 0.0], - [0.0, 1.0]], - [[-1.0, 0.0], - [0.0, 1.0]]]) - with self.assertRaisesOpError( - r"LLT decomposition was not successful."): - mvn.mean.eval() - mvn = tf.contrib.distributions.MultivariateNormal( - mu=[[1.0, 2.0], [1.0, 2.0]], - sigma_chol=[[[1.0, 0.0], - [0.0, 1.0]], - [[-1.0, 0.0], - [0.0, 1.0]]]) - with self.assertRaisesOpError( - r"sigma_chol is not positive definite."): - mvn.mean.eval() - - def testLogPDFScalar(self): - with tf.Session(): - mu_v = np.array([-3.0, 3.0], dtype=np.float32) - mu = tf.constant(mu_v) - sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) - sigma = tf.constant(sigma_v) - x = np.array([-2.5, 2.5], dtype=np.float32) - mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + def testLogPDFScalarBatch(self): + with self.test_session(): + mu = self._rng.rand(2) + chol, sigma = self._random_chol(2, 2) + mvn = distributions.MultivariateNormalCholesky(mu, chol) + x = self._rng.rand(2) log_pdf = mvn.log_pdf(x) pdf = mvn.pdf(x) - try: - from scipy import stats # pylint: disable=g-import-not-at-top - scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) - expected_log_pdf = scipy_mvn.logpdf(x) - expected_pdf = scipy_mvn.pdf(x) - self.assertAllClose(expected_log_pdf, log_pdf.eval()) - self.assertAllClose(expected_pdf, pdf.eval()) - except ImportError as e: - tf.logging.warn("Cannot test stats functions: %s" % str(e)) + scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) - def testLogPDFScalarSigmaHalf(self): - with tf.Session(): - mu_v = np.array([-3.0, 3.0, 1.0], dtype=np.float32) - mu = tf.constant(mu_v) - sigma_v = np.array([[1.0, 0.1, 0.2], - [0.1, 2.0, 0.05], - [0.2, 0.05, 3.0]], dtype=np.float32) - sigma_chol_v = np.linalg.cholesky(sigma_v) - sigma_chol = tf.constant(sigma_chol_v) - x = np.array([-2.5, 2.5, 1.0], dtype=np.float32) - mvn = tf.contrib.distributions.MultivariateNormal( - mu=mu, sigma_chol=sigma_chol) - log_pdf = mvn.log_pdf(x) - pdf = mvn.pdf(x) - sigma = mvn.sigma + expected_log_pdf = scipy_mvn.logpdf(x) + expected_pdf = scipy_mvn.pdf(x) + self.assertEqual((), log_pdf.get_shape()) + self.assertEqual((), pdf.get_shape()) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(expected_pdf, pdf.eval()) - try: - from scipy import stats # pylint: disable=g-import-not-at-top - scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) - expected_log_pdf = scipy_mvn.logpdf(x) - expected_pdf = scipy_mvn.pdf(x) - self.assertEqual(sigma.get_shape(), (3, 3)) - self.assertAllClose(sigma_v, sigma.eval()) - self.assertAllClose(expected_log_pdf, log_pdf.eval()) - self.assertAllClose(expected_pdf, pdf.eval()) - except ImportError as e: - tf.logging.warn("Cannot test stats functions: %s" % str(e)) + def testLogPDFXIsHigherRank(self): + with self.test_session(): + mu = self._rng.rand(2) + chol, sigma = self._random_chol(2, 2) + mvn = distributions.MultivariateNormalCholesky(mu, chol) + x = self._rng.rand(3, 2) - def testLogPDF(self): - with tf.Session(): - mu_v = np.array([-3.0, 3.0], dtype=np.float32) - mu = tf.constant(mu_v) - sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) - sigma = tf.constant(sigma_v) - x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32) - mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) log_pdf = mvn.log_pdf(x) pdf = mvn.pdf(x) - try: - from scipy import stats # pylint: disable=g-import-not-at-top - scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) - expected_log_pdf = scipy_mvn.logpdf(x) - expected_pdf = scipy_mvn.pdf(x) - self.assertEqual(log_pdf.get_shape(), (3,)) - self.assertAllClose(expected_log_pdf, log_pdf.eval()) - self.assertAllClose(expected_pdf, pdf.eval()) - except ImportError as e: - tf.logging.warn("Cannot test stats functions: %s" % str(e)) + scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) + + expected_log_pdf = scipy_mvn.logpdf(x) + expected_pdf = scipy_mvn.pdf(x) + self.assertEqual((3,), log_pdf.get_shape()) + self.assertEqual((3,), pdf.get_shape()) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(expected_pdf, pdf.eval()) + + def testLogPDFXLowerDimension(self): + with self.test_session(): + mu = self._rng.rand(3, 2) + chol, sigma = self._random_chol(3, 2, 2) + mvn = distributions.MultivariateNormalCholesky(mu, chol) + x = self._rng.rand(2) - def testLogPDFMatchingDimension(self): - with tf.Session(): - mu_v = np.array([-3.0, 3.0], dtype=np.float32) - mu = tf.constant(np.vstack(3 * [mu_v])) - sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) - sigma = tf.constant(np.vstack(3 * [sigma_v[np.newaxis, :]])) - x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32) - mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) log_pdf = mvn.log_pdf(x) pdf = mvn.pdf(x) - try: - from scipy import stats # pylint: disable=g-import-not-at-top - scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) - expected_log_pdf = scipy_mvn.logpdf(x) - expected_pdf = scipy_mvn.pdf(x) - self.assertEqual(log_pdf.get_shape(), (3,)) - self.assertAllClose(expected_log_pdf, log_pdf.eval()) - self.assertAllClose(expected_pdf, pdf.eval()) - except ImportError as e: - tf.logging.warn("Cannot test stats functions: %s" % str(e)) + self.assertEqual((3,), log_pdf.get_shape()) + self.assertEqual((3,), pdf.get_shape()) - def testLogPDFMultidimensional(self): - with tf.Session(): - mu_v = np.array([-3.0, 3.0], dtype=np.float32) - mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2)) - sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) - sigma = tf.constant( - np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2)) - x = np.array([-2.5, 2.5], dtype=np.float32) - mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) - log_pdf = mvn.log_pdf(x) - pdf = mvn.pdf(x) + # scipy can't do batches, so just test one of them. + scipy_mvn = stats.multivariate_normal(mean=mu[1, :], cov=sigma[1, :, :]) + expected_log_pdf = scipy_mvn.logpdf(x) + expected_pdf = scipy_mvn.pdf(x) - try: - from scipy import stats # pylint: disable=g-import-not-at-top - scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) - expected_log_pdf = np.vstack(15 * [scipy_mvn.logpdf(x)]).reshape(3, 5) - expected_pdf = np.vstack(15 * [scipy_mvn.pdf(x)]).reshape(3, 5) - self.assertEqual(log_pdf.get_shape(), (3, 5)) - self.assertAllClose(expected_log_pdf, log_pdf.eval()) - self.assertAllClose(expected_pdf, pdf.eval()) - except ImportError as e: - tf.logging.warn("Cannot test stats functions: %s" % str(e)) + self.assertAllClose(expected_log_pdf, log_pdf.eval()[1]) + self.assertAllClose(expected_pdf, pdf.eval()[1]) def testEntropy(self): - with tf.Session(): - mu_v = np.array([-3.0, 3.0], dtype=np.float32) - mu = tf.constant(mu_v) - sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) - sigma = tf.constant(sigma_v) - mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + with self.test_session(): + mu = self._rng.rand(2) + chol, sigma = self._random_chol(2, 2) + mvn = distributions.MultivariateNormalCholesky(mu, chol) entropy = mvn.entropy() - try: - from scipy import stats # pylint: disable=g-import-not-at-top - scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) - expected_entropy = scipy_mvn.entropy() - self.assertEqual(entropy.get_shape(), ()) - self.assertAllClose(expected_entropy, entropy.eval()) - except ImportError as e: - tf.logging.warn("Cannot test stats functions: %s" % str(e)) + scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) + expected_entropy = scipy_mvn.entropy() + self.assertEqual(entropy.get_shape(), ()) + self.assertAllClose(expected_entropy, entropy.eval()) def testEntropyMultidimensional(self): - with tf.Session(): - mu_v = np.array([-3.0, 3.0], dtype=np.float32) - mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2)) - sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) - sigma = tf.constant( - np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2)) - mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + with self.test_session(): + mu = self._rng.rand(3, 5, 2) + chol, sigma = self._random_chol(3, 5, 2, 2) + mvn = distributions.MultivariateNormalCholesky(mu, chol) entropy = mvn.entropy() - try: - from scipy import stats # pylint: disable=g-import-not-at-top - scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) - expected_entropy = np.vstack(15 * [scipy_mvn.entropy()]).reshape(3, 5) - self.assertEqual(entropy.get_shape(), (3, 5)) - self.assertAllClose(expected_entropy, entropy.eval()) - except ImportError as e: - tf.logging.warn("Cannot test stats functions: %s" % str(e)) + # Scipy doesn't do batches, so test one of them. + expected_entropy = stats.multivariate_normal( + mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy() + self.assertEqual(entropy.get_shape(), (3, 5)) + self.assertAllClose(expected_entropy, entropy.eval()[1, 1]) def testSample(self): - with tf.Session(): - mu_v = np.array([-3.0, 3.0], dtype=np.float32) - mu = tf.constant(mu_v) - sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) - sigma = tf.constant(sigma_v) + with self.test_session(): + mu = self._rng.rand(2) + chol, sigma = self._random_chol(2, 2) + n = tf.constant(100000) - mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + mvn = distributions.MultivariateNormalCholesky(mu, chol) samples = mvn.sample(n, seed=137) sample_values = samples.eval() self.assertEqual(samples.get_shape(), (100000, 2)) - self.assertAllClose(sample_values.mean(axis=0), mu_v, atol=1e-2) - self.assertAllClose(np.cov(sample_values, rowvar=0), sigma_v, atol=1e-1) + self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2) + self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=1e-1) def testSampleMultiDimensional(self): - with tf.Session(): - mu_v = np.array([-3.0, 3.0], dtype=np.float32) - mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2)) - sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) - sigma = tf.constant( - np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2)) + with self.test_session(): + mu = self._rng.rand(3, 5, 2) + chol, sigma = self._random_chol(3, 5, 2, 2) + n = tf.constant(100000) - mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + mvn = distributions.MultivariateNormalCholesky(mu, chol) samples = mvn.sample(n, seed=137) sample_values = samples.eval() + self.assertEqual(samples.get_shape(), (100000, 3, 5, 2)) - sample_values = sample_values.reshape(100000, 15, 2) - for i in range(15): - self.assertAllClose( - sample_values[:, i, :].mean(axis=0), mu_v, atol=1e-2) - self.assertAllClose( - np.cov(sample_values[:, i, :], rowvar=0), sigma_v, atol=1e-1) + self.assertAllClose( + sample_values[:, 1, 1, :].mean(axis=0), + mu[1, 1, :], atol=0.05) + self.assertAllClose( + np.cov(sample_values[:, 1, 1, :], rowvar=0), + sigma[1, 1, :, :], atol=1e-1) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_cholesky_test.py b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_cholesky_test.py new file mode 100644 index 00000000000..4ad76a826fe --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_cholesky_test.py @@ -0,0 +1,327 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +distributions = tf.contrib.distributions + + +def softplus(x): + return np.log(1 + np.exp(x)) + + +class OperatorPDCholeskyTest(tf.test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(42) + + def _random_cholesky_array(self, shape): + mat = self._rng.rand(*shape) + chol = distributions.batch_matrix_diag_transform(mat, + transform=tf.nn.softplus) + # Zero the upper triangle because we're using this as a true Cholesky factor + # in our tests. + return tf.batch_matrix_band_part(chol, -1, 0).eval() + + def _numpy_inv_quadratic_form(self, chol, x): + # Numpy works with batches now (calls them "stacks"). + x_expanded = np.expand_dims(x, -1) + whitened = np.linalg.solve(chol, x_expanded) + return (whitened**2).sum(axis=-1).sum(axis=-1) + + def test_inv_quadratic_form_x_rank_same_as_broadcast_rank(self): + with self.test_session(): + for batch_shape in [(), (2,)]: + for k in [1, 3]: + + x_shape = batch_shape + (k,) + x = self._rng.randn(*x_shape) + + chol_shape = batch_shape + (k, k) + chol = self._random_cholesky_array(chol_shape) + operator = distributions.OperatorPDCholesky(chol) + qf = operator.inv_quadratic_form(x) + + self.assertEqual(batch_shape, qf.get_shape()) + + numpy_qf = self._numpy_inv_quadratic_form(chol, x) + self.assertAllClose(numpy_qf, qf.eval()) + + def test_inv_quadratic_form_x_and_chol_batch_shape_dont_match(self): + # In this case, chol will have to be stretched to match x. + with self.test_session(): + k = 3 + x_shape = (2, k) + chol_shape = (1, k, k) + broadcast_batch_shape = (2,) + + x = self._rng.randn(*x_shape) + chol = self._random_cholesky_array(chol_shape) + + operator = distributions.OperatorPDCholesky(chol) + qf = operator.inv_quadratic_form(x) + + self.assertEqual(broadcast_batch_shape, qf.get_shape()) + + numpy_qf = self._numpy_inv_quadratic_form(chol, x) + self.assertAllClose(numpy_qf, qf.eval()) + + def test_inv_quadratic_form_x_rank_less_than_broadcast_rank(self): + with self.test_session(): + for batch_shape in [(2,), (2, 3)]: + for k in [1, 4]: + + # x will not have the leading dimension. + x_shape = batch_shape[1:] + (k,) + x = self._rng.randn(*x_shape) + + chol_shape = batch_shape + (k, k) + chol = self._random_cholesky_array(chol_shape) + operator = distributions.OperatorPDCholesky(chol) + qf = operator.inv_quadratic_form(x) + + self.assertEqual(batch_shape, qf.get_shape()) + + x_upshaped = x + np.zeros(chol.shape[:-1]) + numpy_qf = self._numpy_inv_quadratic_form(chol, x_upshaped) + numpy_qf = numpy_qf.reshape(batch_shape) + self.assertAllClose(numpy_qf, qf.eval()) + + def test_inv_quadratic_form_x_rank_greater_than_broadcast_rank(self): + with self.test_session(): + for batch_shape in [(2,), (2, 3)]: + for k in [1, 4]: + + x_shape = batch_shape + (k,) + x = self._rng.randn(*x_shape) + + # chol will not have the leading dimension. + chol_shape = batch_shape[1:] + (k, k) + chol = self._random_cholesky_array(chol_shape) + operator = distributions.OperatorPDCholesky(chol) + qf = operator.inv_quadratic_form(x) + numpy_qf = self._numpy_inv_quadratic_form(chol, x) + + self.assertEqual(batch_shape, qf.get_shape()) + self.assertAllClose(numpy_qf, qf.eval()) + + def test_inv_quadratic_form_x_rank_two_greater_than_broadcast_rank(self): + with self.test_session(): + for batch_shape in [(2, 3), (2, 3, 4), (2, 3, 4, 5)]: + for k in [1, 4]: + + x_shape = batch_shape + (k,) + x = self._rng.randn(*x_shape) + + # chol will not have the leading two dimensions. + chol_shape = batch_shape[2:] + (k, k) + chol = self._random_cholesky_array(chol_shape) + operator = distributions.OperatorPDCholesky(chol) + qf = operator.inv_quadratic_form(x) + numpy_qf = self._numpy_inv_quadratic_form(chol, x) + + self.assertEqual(batch_shape, qf.get_shape()) + self.assertAllClose(numpy_qf, qf.eval()) + + def test_log_det(self): + with self.test_session(): + batch_shape = () + for k in [1, 4]: + chol_shape = batch_shape + (k, k) + chol = self._random_cholesky_array(chol_shape) + operator = distributions.OperatorPDCholesky(chol) + log_det = operator.log_det() + expected_log_det = np.log(np.prod(np.diag(chol))**2) + + self.assertEqual(batch_shape, log_det.get_shape()) + self.assertAllClose(expected_log_det, log_det.eval()) + + def test_log_det_batch_matrix(self): + with self.test_session(): + batch_shape = (2, 3) + for k in [1, 4]: + chol_shape = batch_shape + (k, k) + chol = self._random_cholesky_array(chol_shape) + operator = distributions.OperatorPDCholesky(chol) + log_det = operator.log_det() + + self.assertEqual(batch_shape, log_det.get_shape()) + + # Test the log-determinant of the [1, 1] matrix. + chol_11 = chol[1, 1, :, :] + expected_log_det = np.log(np.prod(np.diag(chol_11))**2) + self.assertAllClose(expected_log_det, log_det.eval()[1, 1]) + + def test_sqrt_matmul_single_matrix(self): + with self.test_session(): + batch_shape = () + for k in [1, 4]: + x_shape = batch_shape + (k, 3) + x = self._rng.rand(*x_shape) + chol_shape = batch_shape + (k, k) + chol = self._random_cholesky_array(chol_shape) + + operator = distributions.OperatorPDCholesky(chol) + + sqrt_operator_times_x = operator.sqrt_matmul(x) + expected = tf.batch_matmul(chol, x) + + self.assertEqual(expected.get_shape(), + sqrt_operator_times_x.get_shape()) + self.assertAllClose(expected.eval(), sqrt_operator_times_x.eval()) + + def test_sqrt_matmul_batch_matrix(self): + with self.test_session(): + batch_shape = (2, 3) + for k in [1, 4]: + x_shape = batch_shape + (k, 5) + x = self._rng.rand(*x_shape) + chol_shape = batch_shape + (k, k) + chol = self._random_cholesky_array(chol_shape) + + operator = distributions.OperatorPDCholesky(chol) + + sqrt_operator_times_x = operator.sqrt_matmul(x) + expected = tf.batch_matmul(chol, x) + + self.assertEqual(expected.get_shape(), + sqrt_operator_times_x.get_shape()) + self.assertAllClose(expected.eval(), sqrt_operator_times_x.eval()) + + def test_matmul_batch_matrix(self): + with self.test_session(): + batch_shape = (2, 3) + for k in [1, 4]: + x_shape = batch_shape + (k, 5) + x = self._rng.rand(*x_shape) + chol_shape = batch_shape + (k, k) + chol = self._random_cholesky_array(chol_shape) + + operator = distributions.OperatorPDCholesky(chol) + + chol_times_x = tf.batch_matmul(chol, x, adj_x=True) + expected = tf.batch_matmul(chol, chol_times_x) + + self.assertEqual(expected.get_shape(), operator.matmul(x).get_shape()) + self.assertAllClose(expected.eval(), operator.matmul(x).eval()) + + def test_shape(self): + # All other shapes are defined by the abstractmethod shape, so we only need + # to test this. + with self.test_session(): + for shape in [(3, 3), (2, 3, 3), (1, 2, 3, 3)]: + chol = self._random_cholesky_array(shape) + operator = distributions.OperatorPDCholesky(chol) + self.assertAllEqual(shape, operator.shape().eval()) + + def test_to_dense(self): + with self.test_session(): + chol = self._random_cholesky_array((3, 3)) + operator = distributions.OperatorPDCholesky(chol) + self.assertAllClose(chol.dot(chol.T), operator.to_dense().eval()) + + def test_to_dense_sqrt(self): + with self.test_session(): + chol = self._random_cholesky_array((2, 3, 3)) + operator = distributions.OperatorPDCholesky(chol) + self.assertAllClose(chol, operator.to_dense_sqrt().eval()) + + def test_non_positive_definite_matrix_raises(self): + # Singlular matrix with one positive eigenvalue and one zero eigenvalue. + with self.test_session(): + lower_mat = [[1.0, 0.0], [2.0, 0.0]] + operator = distributions.OperatorPDCholesky(lower_mat) + with self.assertRaisesOpError('x > 0 did not hold'): + operator.to_dense().eval() + + def test_non_positive_definite_matrix_does_not_raise_if_not_verify_pd(self): + # Singlular matrix with one positive eigenvalue and one zero eigenvalue. + with self.test_session(): + lower_mat = [[1.0, 0.0], [2.0, 0.0]] + operator = distributions.OperatorPDCholesky(lower_mat, verify_pd=False) + operator.to_dense().eval() # Should not raise. + + def test_not_having_two_identical_last_dims_raises(self): + # Unless the last two dims are equal, this cannot represent a matrix, and it + # should raise. + with self.test_session(): + batch_vec = [[1.0], [2.0]] # shape 2 x 1 + with self.assertRaisesRegexp(ValueError, '.*Dimensions.*'): + operator = distributions.OperatorPDCholesky(batch_vec) + operator.to_dense().eval() + + +class BatchMatrixDiagTransformTest(tf.test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(0) + + def check_off_diagonal_same(self, m1, m2): + """Check the lower triangular part, not upper or diag.""" + self.assertAllClose(np.tril(m1, k=-1), np.tril(m2, k=-1)) + self.assertAllClose(np.triu(m1, k=1), np.triu(m2, k=1)) + + def test_non_batch_matrix_with_transform(self): + mat = self._rng.rand(4, 4) + with self.test_session(): + chol = distributions.batch_matrix_diag_transform(mat, + transform=tf.nn.softplus) + self.assertEqual((4, 4), chol.get_shape()) + + self.check_off_diagonal_same(mat, chol.eval()) + self.assertAllClose(softplus(np.diag(mat)), np.diag(chol.eval())) + + def test_non_batch_matrix_no_transform(self): + mat = self._rng.rand(4, 4) + with self.test_session(): + # Default is no transform. + chol = distributions.batch_matrix_diag_transform(mat) + self.assertEqual((4, 4), chol.get_shape()) + self.assertAllClose(mat, chol.eval()) + + def test_batch_matrix_with_transform(self): + mat = self._rng.rand(2, 4, 4) + mat_0 = mat[0, :, :] + with self.test_session(): + chol = distributions.batch_matrix_diag_transform(mat, + transform=tf.nn.softplus) + + self.assertEqual((2, 4, 4), chol.get_shape()) + + chol_0 = chol.eval()[0, :, :] + + self.check_off_diagonal_same(mat_0, chol_0) + self.assertAllClose(softplus(np.diag(mat_0)), np.diag(chol_0)) + + self.check_off_diagonal_same(mat_0, chol_0) + self.assertAllClose(softplus(np.diag(mat_0)), np.diag(chol_0)) + + def test_batch_matrix_no_transform(self): + mat = self._rng.rand(2, 4, 4) + with self.test_session(): + # Default is no transform. + chol = distributions.batch_matrix_diag_transform(mat) + + self.assertEqual((2, 4, 4), chol.get_shape()) + self.assertAllClose(mat, chol.eval()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_full_test.py b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_full_test.py new file mode 100644 index 00000000000..bcf43e80fd3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_full_test.py @@ -0,0 +1,62 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +distributions = tf.contrib.distributions + + +class OperatorPDFullTest(tf.test.TestCase): + # The only method needing checked (because it isn't part of the parent class) + # is the check for symmetry. + + def setUp(self): + self._rng = np.random.RandomState(42) + + def _random_positive_def_array(self, *shape): + matrix = self._rng.rand(*shape) + return tf.batch_matmul(matrix, matrix, adj_y=True).eval() + + def test_positive_definite_matrix_doesnt_raise(self): + with self.test_session(): + matrix = self._random_positive_def_array(2, 3, 3) + operator = distributions.OperatorPDFull(matrix, verify_pd=True) + operator.to_dense().eval() # Should not raise + + def test_negative_definite_matrix_raises(self): + with self.test_session(): + matrix = -1 * self._random_positive_def_array(3, 2, 2) + operator = distributions.OperatorPDFull(matrix, verify_pd=True) + # Could fail inside Cholesky decomposition, or later when we test the + # diag. + with self.assertRaisesOpError('x > 0|LLT'): + operator.to_dense().eval() + + def test_non_symmetric_matrix_raises(self): + with self.test_session(): + matrix = self._random_positive_def_array(3, 2, 2) + matrix[0, 0, 1] += 0.001 + operator = distributions.OperatorPDFull(matrix, verify_pd=True) + with self.assertRaisesOpError('x == y'): + operator.to_dense().eval() + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_test.py b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_test.py new file mode 100644 index 00000000000..c99d16fe4a5 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_test.py @@ -0,0 +1,83 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +distributions = tf.contrib.distributions + + +class BaseCovarianceTest(tf.test.TestCase): + + def test_all_shapes_methods_defined_by_the_one_abstractproperty_shape(self): + + class OperatorShape(distributions.OperatorPDBase): + """Operator implements the ABC method .shape.""" + + def __init__(self, shape): + self._shape = shape + + @property + def verify_pd(self): + return True + + def get_shape(self): + return tf.TensorShape(self._shape) + + def shape(self, name='shape'): + return tf.shape(np.random.rand(*self._shape)) + + @property + def name(self): + return 'OperatorShape' + + def dtype(self): + raise tf.int32 + + def inv_quadratic_form( + self, x, name='inv_quadratic_form'): + return x + + def log_det(self, name='log_det'): + raise NotImplementedError() + + @property + def inputs(self): + return [] + + def sqrt_matmul(self, x, name='sqrt_matmul'): + return x + + shape = (1, 2, 3, 3) + with self.test_session(): + operator = OperatorShape(shape) + + self.assertAllEqual(shape, operator.shape().eval()) + self.assertAllEqual(4, operator.rank().eval()) + self.assertAllEqual((1, 2), operator.batch_shape().eval()) + self.assertAllEqual((1, 2, 3), operator.vector_shape().eval()) + self.assertAllEqual(3, operator.vector_space_dimension().eval()) + + self.assertEqual(shape, operator.get_shape()) + self.assertEqual((1, 2), operator.get_batch_shape()) + self.assertEqual((1, 2, 3), operator.get_vector_shape()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py index 2e5d93a0fbf..c93c953933f 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn.py +++ b/tensorflow/contrib/distributions/python/ops/mvn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The Multivariate Normal distribution class.""" +"""Multivariate Normal distribution classes.""" from __future__ import absolute_import from __future__ import division @@ -20,83 +20,33 @@ from __future__ import print_function import math +from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long +from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky # pylint: disable=line-too-long +from tensorflow.contrib.distributions.python.ops import operator_pd_full # pylint: disable=line-too-long from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -def _assert_compatible_shapes(mu, sigma): - r_mu = array_ops.rank(mu) - r_sigma = array_ops.rank(sigma) - sigma_shape = array_ops.shape(sigma) - sigma_rank = array_ops.rank(sigma) - mu_shape = array_ops.shape(mu) - return control_flow_ops.group( - logging_ops.Assert( - math_ops.equal(r_mu + 1, r_sigma), - ["Rank of mu should be one less than rank of sigma, but saw: ", - r_mu, " vs. ", r_sigma]), - logging_ops.Assert( - math_ops.equal( - array_ops.gather(sigma_shape, sigma_rank - 2), - array_ops.gather(sigma_shape, sigma_rank - 1)), - ["Last two dimensions of sigma (%s) must be equal: " % sigma.name, - sigma_shape]), - logging_ops.Assert( - math_ops.reduce_all(math_ops.equal( - mu_shape, - array_ops.slice( - sigma_shape, [0], array_ops.pack([sigma_rank - 1])))), - ["mu.shape and sigma.shape[:-1] must match, but saw: ", - mu_shape, " vs. ", sigma_shape])) +__all__ = [ + "MultivariateNormalCholesky", + "MultivariateNormalFull", +] -def _assert_batch_positive_definite(sigma_chol): - """Add assertions checking that the sigmas are all Positive Definite. +class MultivariateNormalOperatorPD(distribution.ContinuousDistribution): + """The multivariate normal distribution on `R^k`. - Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that - `all(diag(sigma_chol) > 0)`. This is because to check that a matrix is PD, - it is sufficient that its cholesky factorization is PD, and to check that a - triangular matrix is PD, it is sufficient to check that its diagonal - entries are positive. - - Args: - sigma_chol: N-D. The lower triangular cholesky decomposition of `sigma`. - - Returns: - An assertion op to use with `control_dependencies`, verifying that - `sigma_chol` is positive definite. - """ - sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol) - return logging_ops.Assert( - math_ops.reduce_all(sigma_batch_diag > 0), - ["sigma_chol is not positive definite. batched diagonals: ", - sigma_batch_diag, " shaped: ", array_ops.shape(sigma_batch_diag)]) - - -def _log_determinant_from_sigma_chol(sigma_chol): - det_last_dim = array_ops.rank(sigma_chol) - 2 - sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol) - log_det = 2.0 * math_ops.reduce_sum( - math_ops.log(sigma_batch_diag), reduction_indices=det_last_dim) - log_det.set_shape(sigma_chol.get_shape()[:-2]) - return log_det - - -class MultivariateNormal(object): - """The Multivariate Normal distribution on `R^k`. - - The distribution has mean and covariance parameters mu (1-D), sigma (2-D), - or alternatively mean `mu` and factored covariance (cholesky decomposed - `sigma`) called `sigma_chol`. + This distribution is defined by a 1-D mean `mu` and an instance of + `OperatorPDBase`, which provides access to a symmetric positive definite + operator, which defines the covariance. #### Mathematical details @@ -108,21 +58,6 @@ class MultivariateNormal(object): where `.` denotes the inner product on `R^k` and `^*` denotes transpose. - Alternatively, if `sigma` is positive definite, it can be represented in terms - of its lower triangular cholesky factorization - - ```sigma = sigma_chol . sigma_chol^*``` - - and the pdf above allows simpler computation: - - ``` - |det(sigma)| = reduce_prod(diag(sigma_chol))^2 - x_whitened = sigma^{-1/2} . (x - mu) = tri_solve(sigma_chol, x - mu) - (x-mu)^* .sigma^{-1} . (x-mu) = x_whitened^* . x_whitened - ``` - - where `tri_solve()` solves a triangular system of equations. - #### Examples A single multi-variate Gaussian distribution is defined by a vector of means @@ -131,128 +66,177 @@ class MultivariateNormal(object): Extra leading dimensions, if provided, allow for batches. ```python - # Initialize a single 3-variate Gaussian with diagonal covariance. + # Initialize a single 3-variate Gaussian. mu = [1, 2, 3] - sigma = [[1, 0, 0], [0, 3, 0], [0, 0, 2]] - dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + chol = [[1, 0, 0.], [1, 3, 0], [1, 2, 3]] + cov = tf.contrib.distributions.OperatorPDCholesky(chol) + dist = tf.contrib.distributions.MultivariateNormalOperatorPD(mu, cov) # Evaluate this on an observation in R^3, returning a scalar. - dist.pdf([-1, 0, 1]) + dist.pdf([-1, 0, 1.]) # Initialize a batch of two 3-variate Gaussians. - mu = [[1, 2, 3], [11, 22, 33]] - sigma = ... # shape 2 x 3 x 3 - dist = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + mu = [[1, 2, 3], [11, 22, 33.]] + chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal. + cov = tf.contrib.distributions.OperatorPDCholesky(chol) + dist = tf.contrib.distributions.MultivariateNormalOperatorPD(mu, cov) # Evaluate this on a two observations, each in R^3, returning a length two # tensor. - x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3. + x = [[-1, 0, 1], [-11, 0, 11.]] # Shape 2 x 3. dist.pdf(x) ``` """ - def __init__(self, mu, sigma=None, sigma_chol=None, name=None): + def __init__( + self, + mu, + cov, + allow_nan=False, + strict=True, + strict_statistics=True, + name="MultivariateNormalCov"): """Multivariate Normal distributions on `R^k`. - User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`) - with the last dimension having length `k`. - - User must provide exactly one of `sigma` (the covariance matrices) or - `sigma_chol` (the cholesky decompositions of the covariance matrices). - `sigma` or `sigma_chol` must be of rank `N+2`. The last two dimensions - must both have length `k`. The first `N` dimensions correspond to batch - indices. - - If `sigma_chol` is not provided, the batch cholesky factorization of `sigma` - is calculated for you. - - The shapes of `mu` and `sigma` must match for the first `N` dimensions. - - Regardless of which parameter is provided, the covariance matrices must all - be **positive definite** (an error is raised if one of them is not). + User must provide means `mu`, and an instance of `OperatorPDBase`, `cov`, + which determines the covariance. Args: - mu: (N+1)-D. `float` or `double` tensor, the means of the distributions. - sigma: (N+2)-D. (optional) `float` or `double` tensor, the covariances - of the distribution(s). The first `N+1` dimensions must match - those of `mu`. Must be batch-positive-definite. - sigma_chol: (N+2)-D. (optional) `float` or `double` tensor, a - lower-triangular factorization of `sigma` - (`sigma = sigma_chol . sigma_chol^*`). The first `N+1` dimensions - must match those of `mu`. The tensor itself need not be batch - lower triangular: we ignore the upper triangular part. However, - the batch diagonals must be positive (i.e., sigma_chol must be - batch-positive-definite). + mu: `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`. + cov: `float` or `double` instance of `OperatorPDBase` with same `dtype` + as `mu` and shape `[N1,...,Nb, k, k]`. + allow_nan: Boolean, default False. If False, raise an exception if + a statistic (e.g. mean/mode/etc...) is undefined for any batch member. + If True, batch members with valid parameters leading to undefined + statistics will return NaN for this statistic. + strict: Whether to validate input with asserts. If `strict` is `False`, + and the inputs are invalid, correct behavior is not guaranteed. + strict_statistics: Boolean, default True. If True, raise an exception if + a statistic (e.g. mean/mode/etc...) is undefined for any batch member. + If False, batch members with valid parameters leading to undefined + statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: - ValueError: if neither sigma nor sigma_chol is provided. - TypeError: if mu and sigma (resp. sigma_chol) are different dtypes. + TypeError: If `mu` and `cov` are different dtypes. """ - if (sigma is None) == (sigma_chol is None): - raise ValueError("Exactly one of sigma and sigma_chol must be provided") + self._strict_statistics = strict_statistics + self._strict = strict + with ops.name_scope(name): + with ops.op_scope([mu] + cov.inputs, "init"): + self._cov = cov + self._mu = self._check_mu(mu) + self._name = name - with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"): - sigma_or_half = sigma_chol if sigma is None else sigma + def _check_mu(self, mu): + """Return `mu` after validity checks and possibly with assertations.""" + mu = ops.convert_to_tensor(mu) + cov = self._cov - mu = ops.convert_to_tensor(mu) - sigma_or_half = ops.convert_to_tensor(sigma_or_half) + if mu.dtype != cov.dtype: + raise TypeError( + "mu and cov must have the same dtype. Found mu.dtype = %s, " + "cov.dtype = %s" + % (mu.dtype, cov.dtype)) + if not self.strict: + return mu + else: + assert_compatible_shapes = control_flow_ops.group( + check_ops.assert_equal( + array_ops.rank(mu) + 1, + cov.rank(), + data=["mu should have rank 1 less than cov. Found: rank(mu) = ", + array_ops.rank(mu), " rank(cov) = ", cov.rank()], + ), + check_ops.assert_equal( + array_ops.shape(mu), + cov.vector_shape(), + data=["mu.shape and cov.shape[:-1] should match. " + "Found: shape(mu) = " + , array_ops.shape(mu), " shape(cov) = ", cov.shape()], + ), + ) + return control_flow_ops.with_dependencies([assert_compatible_shapes], mu) - contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half)) + @property + def strict(self): + """Boolean describing behavior on invalid input.""" + return self._strict - with ops.control_dependencies([ - _assert_compatible_shapes(mu, sigma_or_half)]): - mu = array_ops.identity(mu, name="mu") - - # Store the dimensionality of the MVNs - self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1) - - if sigma_chol is not None: - # Ensure we only keep the lower triangular part. - sigma_chol = array_ops.batch_matrix_band_part( - sigma_chol, num_lower=-1, num_upper=0) - log_sigma_det = _log_determinant_from_sigma_chol(sigma_chol) - with ops.control_dependencies([ - _assert_batch_positive_definite(sigma_chol)]): - self._sigma = math_ops.batch_matmul( - sigma_chol, sigma_chol, adj_y=True, name="sigma") - self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol") - self._log_sigma_det = array_ops.identity( - log_sigma_det, "log_sigma_det") - self._mu = array_ops.identity(mu, "mu") - else: # sigma is not None - sigma_chol = linalg_ops.batch_cholesky(sigma) - log_sigma_det = _log_determinant_from_sigma_chol(sigma_chol) - # batch_cholesky checks for PSD; so we can just use it here. - with ops.control_dependencies([sigma_chol]): - self._sigma = array_ops.identity(sigma, "sigma") - self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol") - self._log_sigma_det = array_ops.identity( - log_sigma_det, "log_sigma_det") - self._mu = array_ops.identity(mu, "mu") + @property + def strict_statistics(self): + """Boolean describing behavior when a stat is undefined for batch member.""" + return self._strict_statistics @property def dtype(self): return self._mu.dtype + def get_event_shape(self): + """`TensorShape` available at graph construction time.""" + # Recall _check_mu ensures mu and self._cov have same batch shape. + return self._cov.get_shape()[-1:] + + def event_shape(self, name="event_shape"): + """Shape of a sample from a single distribution as a 1-D int32 `Tensor`.""" + # Recall _check_mu ensures mu and self._cov have same batch shape. + with ops.name_scope(self.name): + with ops.op_scope(self._cov.inputs, name): + return array_ops.pack([self._cov.vector_space_dimension()]) + + def batch_shape(self, name="batch_shape"): + """Batch dimensions of this instance as a 1-D int32 `Tensor`.""" + # Recall _check_mu ensures mu and self._cov have same batch shape. + with ops.name_scope(self.name): + with ops.op_scope(self._cov.inputs, name): + return self._cov.batch_shape() + + def get_batch_shape(self): + """`TensorShape` available at graph construction time.""" + # Recall _check_mu ensures mu and self._cov have same batch shape. + return self._cov.get_batch_shape() + @property def mu(self): return self._mu @property def sigma(self): - return self._sigma + """Dense (batch) covariance matrix, if available.""" + with ops.name_scope(self.name): + return self._cov.to_dense() - @property - def mean(self): - return self._mu + def mean(self, name="mean"): + """Mean of each batch member.""" + with ops.name_scope(self.name): + with ops.op_scope([self._mu], name): + return array_ops.identity(self._mu) - @property - def sigma_det(self): - return math_ops.exp(self._log_sigma_det) + def mode(self, name="mode"): + """Mode of each batch member.""" + with ops.name_scope(self.name): + with ops.op_scope([self._mu], name): + return array_ops.identity(self._mu) - def log_pdf(self, x, name=None): + def variance(self, name="variance"): + """Variance of each batch member.""" + with ops.name_scope(self.name): + return self.sigma + + def log_sigma_det(self, name="log_sigma_det"): + """Log of determinant of covariance matrix.""" + with ops.name_scope(self.name): + with ops.op_scope(self._cov.inputs, name): + return self._cov.log_det() + + def sigma_det(self, name="sigma_det"): + """Determinant of covariance matrix.""" + with ops.name_scope(self.name): + with ops.op_scope(self._cov.inputs, name): + return math_ops.exp(self._cov.log_det()) + + def log_pdf(self, x, name="log_pdf"): """Log pdf of observations `x` given these Multivariate Normals. Args: @@ -262,120 +246,25 @@ class MultivariateNormal(object): Returns: log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`. """ - with ops.op_scope( - [self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"): - x = ops.convert_to_tensor(x) - contrib_tensor_util.assert_same_float_dtype((self._mu, x)) + with ops.name_scope(self.name): + with ops.op_scope([self._mu, x] + self._cov.inputs, name): + x = ops.convert_to_tensor(x) + contrib_tensor_util.assert_same_float_dtype((self._mu, x)) - x_centered = x - self.mu + x_centered = x - self.mu + x_whitened_norm = self._cov.inv_quadratic_form(x_centered) + log_sigma_det = self.log_sigma_det() - x_rank = array_ops.rank(x_centered) - sigma_rank = array_ops.rank(self._sigma_chol) + log_two_pi = constant_op.constant( + math.log(2 * math.pi), dtype=self.dtype) + k = math_ops.cast(self._cov.vector_space_dimension(), self.dtype) + log_pdf_value = -(log_sigma_det + k * log_two_pi + x_whitened_norm) / 2 - x_rank_vec = array_ops.pack([x_rank]) - sigma_rank_vec = array_ops.pack([sigma_rank]) - x_shape = array_ops.shape(x_centered) + output_static_shape = x_centered.get_shape()[:-1] + log_pdf_value.set_shape(output_static_shape) + return log_pdf_value - # sigma_chol is shaped [D, E, F, ..., k, k] - # x_centered shape is one of: - # [D, E, F, ..., k], or [F, ..., k], or - # [A, B, C, D, E, F, ..., k] - # and we need to convert x_centered to shape: - # [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist) - # then transpose and reshape x_whitened back to one of the shapes: - # [D, E, F, ..., k], or [1, 1, F, ..., k], or - # [A, B, C, D, E, F, ..., k] - # Note that if rank(x) <= rank(sigma) - 1, the first dimensions of - # x_centered will match sigma exactly because x_centered = x - self.mu. - - # This helper handles the case where rank(x_centered) < rank(sigma) - def _broadcast_x_not_higher_rank_than_sigma(): - return array_ops.reshape( - x_centered, - array_ops.concat( - # Reshape to ones(deficient x rank) + x_shape + [1] - 0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]), - dtype=x_rank.dtype), - x_shape, - [1]))) - - # These helpers handle the case where rank(x_centered) >= rank(sigma) - def _broadcast_x_higher_rank_than_sigma(): - x_shape_left = array_ops.slice( - x_shape, [0], sigma_rank_vec - 1) - x_shape_right = array_ops.slice( - x_shape, sigma_rank_vec - 1, x_rank_vec - 1) - x_shape_perm = array_ops.concat( - 0, (math_ops.range(sigma_rank - 1, x_rank), - math_ops.range(0, sigma_rank - 1))) - return array_ops.reshape( - # Convert to [D, E, F, ..., k, B, C] - array_ops.transpose( - x_centered, perm=x_shape_perm), - # Reshape to [D, E, F, ..., k, B*C] - array_ops.concat( - 0, (x_shape_right, - array_ops.pack([ - math_ops.reduce_prod(x_shape_left, 0)])))) - - def _unbroadcast_x_higher_rank_than_sigma(): - x_shape_left = array_ops.slice( - x_shape, [0], sigma_rank_vec - 1) - x_shape_right = array_ops.slice( - x_shape, sigma_rank_vec - 1, x_rank_vec - 1) - x_shape_perm = array_ops.concat( - 0, (math_ops.range(sigma_rank - 1, x_rank), - math_ops.range(0, sigma_rank - 1))) - return array_ops.transpose( - # [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k] - array_ops.reshape( - # convert to [D, E, F, ..., k, B, C] - x_whitened_broadcast, - array_ops.concat(0, (x_shape_right, x_shape_left))), - perm=x_shape_perm) - - # Step 1: reshape x_centered - x_centered_broadcast = control_flow_ops.cond( - # x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1] - # or == [F, ..., k] => [1, 1, F, ..., k, 1] - x_rank <= sigma_rank - 1, - _broadcast_x_not_higher_rank_than_sigma, - # x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C] - _broadcast_x_higher_rank_than_sigma) - - x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve( - self._sigma_chol, x_centered_broadcast) - - # Reshape x_whitened_broadcast back to x_whitened - x_whitened = control_flow_ops.cond( - x_rank <= sigma_rank - 1, - lambda: array_ops.reshape(x_whitened_broadcast, x_shape), - _unbroadcast_x_higher_rank_than_sigma) - - x_whitened = array_ops.expand_dims(x_whitened, -1) - # Reshape x_whitened to contain row vectors - # Returns a batchwise scalar - x_whitened_norm = math_ops.batch_matmul( - x_whitened, x_whitened, adj_x=True) - x_whitened_norm = control_flow_ops.cond( - x_rank <= sigma_rank - 1, - lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]), - lambda: array_ops.squeeze(x_whitened_norm, [-1])) - - log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype) - k = math_ops.cast(self._k, self.dtype) - log_pdf_value = ( - -self._log_sigma_det -(k * log_two_pi) - x_whitened_norm) / 2 - final_shaped_value = control_flow_ops.cond( - x_rank <= sigma_rank - 1, - lambda: log_pdf_value, - lambda: array_ops.squeeze(log_pdf_value, [-1])) - - output_static_shape = x_centered.get_shape()[:-1] - final_shaped_value.set_shape(output_static_shape) - return final_shaped_value - - def pdf(self, x, name=None): + def pdf(self, x, name="pdf"): """The PDF of observations `x` under these Multivariate Normals. Args: @@ -385,11 +274,11 @@ class MultivariateNormal(object): Returns: pdf: tensor of dtype `dtype`, the pdf values of `x`. """ - with ops.op_scope( - [self._mu, self._sigma_chol, x], name, "MultivariateNormalPdf"): - return math_ops.exp(self.log_pdf(x)) + with ops.name_scope(self.name): + with ops.op_scope([self._mu, x] + self._cov.inputs, name): + return math_ops.exp(self.log_pdf(x)) - def entropy(self, name=None): + def entropy(self, name="entropy"): """The entropies of these Multivariate Normals. Args: @@ -398,19 +287,19 @@ class MultivariateNormal(object): Returns: entropy: tensor of dtype `dtype`, the entropies. """ - with ops.op_scope( - [self._mu, self._sigma_chol], name, "MultivariateNormalEntropy"): - one_plus_log_two_pi = constant_op.constant( - 1 + math.log(2 * math.pi), dtype=self.dtype) + with ops.name_scope(self.name): + with ops.op_scope([self._mu] + self._cov.inputs, name): + log_sigma_det = self.log_sigma_det() + one_plus_log_two_pi = constant_op.constant(1 + math.log(2 * math.pi), + dtype=self.dtype) - # Use broadcasting rules to calculate the full broadcast sigma. - k = math_ops.cast(self._k, dtype=self.dtype) - entropy_value = ( - k * one_plus_log_two_pi + self._log_sigma_det) / 2 - entropy_value.set_shape(self._log_sigma_det.get_shape()) - return entropy_value + # Use broadcasting rules to calculate the full broadcast sigma. + k = math_ops.cast(self._cov.vector_space_dimension(), dtype=self.dtype) + entropy_value = (k * one_plus_log_two_pi + log_sigma_det) / 2 + entropy_value.set_shape(log_sigma_det.get_shape()) + return entropy_value - def sample(self, n, seed=None, name=None): + def sample(self, n, seed=None, name="sample"): """Sample `n` observations from the Multivariate Normal Distributions. Args: @@ -422,43 +311,203 @@ class MultivariateNormal(object): samples: `[n, ...]`, a `Tensor` of `n` samples for each of the distributions determined by broadcasting the hyperparameters. """ - with ops.op_scope( - [self._mu, self._sigma_chol, n], name, "MultivariateNormalSample"): - # TODO(ebrevdo): Is there a better way to get broadcast_shape? - broadcast_shape = self.mu.get_shape() - n = ops.convert_to_tensor(n) - sigma_shape_left = array_ops.slice( - array_ops.shape(self._sigma_chol), - [0], array_ops.pack([array_ops.rank(self._sigma_chol) - 2])) + with ops.name_scope(self.name): + with ops.op_scope([self._mu, n] + self._cov.inputs, name): + # Recall _check_mu ensures mu and self._cov have same batch shape. + broadcast_shape = self.mu.get_shape() + n = ops.convert_to_tensor(n) - k_n = array_ops.pack([self._k, n]) - shape = array_ops.concat(0, [sigma_shape_left, k_n]) - white_samples = random_ops.random_normal( - shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed) + shape = array_ops.concat(0, [self._cov.vector_shape(), [n]]) + white_samples = random_ops.random_normal(shape=shape, + mean=0, + stddev=1, + dtype=self.dtype, + seed=seed) - correlated_samples = math_ops.batch_matmul( - self._sigma_chol, white_samples) + correlated_samples = self._cov.sqrt_matmul(white_samples) - # Move the last dimension to the front - perm = array_ops.concat( - 0, - (array_ops.pack([array_ops.rank(correlated_samples) - 1]), - math_ops.range(0, array_ops.rank(correlated_samples) - 1))) + # Move the last dimension to the front + perm = array_ops.concat(0, ( + array_ops.pack([array_ops.rank(correlated_samples) - 1]), + math_ops.range(0, array_ops.rank(correlated_samples) - 1))) - # TODO(ebrevdo): Once we get a proper tensor contraction op, - # perform the inner product using that instead of batch_matmul - # and this slow transpose can go away! - correlated_samples = array_ops.transpose(correlated_samples, perm) + # TODO(ebrevdo): Once we get a proper tensor contraction op, + # perform the inner product using that instead of batch_matmul + # and this slow transpose can go away! + correlated_samples = array_ops.transpose(correlated_samples, perm) - samples = correlated_samples + self.mu + samples = correlated_samples + self.mu - # Provide some hints to shape inference - n_val = tensor_util.constant_value(n) - final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape) - samples.set_shape(final_shape) + # Provide some hints to shape inference + n_val = tensor_util.constant_value(n) + final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape) + samples.set_shape(final_shape) - return samples + return samples @property def is_reparameterized(self): return True + + @property + def name(self): + return self._name + + +class MultivariateNormalCholesky(MultivariateNormalOperatorPD): + """The multivariate normal distribution on `R^k`. + + This distribution is defined by a 1-D mean `mu` and a Cholesky factor `chol`. + Providing the Cholesky factor allows for `O(k^2)` pdf evaluation and sampling, + and requires `O(k^2)` storage. + + #### Mathematical details + + The PDF of this distribution is: + + ``` + f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu)) + ``` + + where `.` denotes the inner product on `R^k` and `^*` denotes transpose. + + #### Examples + + A single multi-variate Gaussian distribution is defined by a vector of means + of length `k`, and a covariance matrix of shape `k x k`. + + Extra leading dimensions, if provided, allow for batches. + + ```python + # Initialize a single 3-variate Gaussian with diagonal covariance. + mu = [1, 2, 3.] + chol = [[1, 0, 0], [0, 3, 0], [0, 0, 2]] + dist = tf.contrib.distributions.MultivariateNormalCholesky(mu, chol) + + # Evaluate this on an observation in R^3, returning a scalar. + dist.pdf([-1, 0, 1]) + + # Initialize a batch of two 3-variate Gaussians. + mu = [[1, 2, 3], [11, 22, 33]] + chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal. + dist = tf.contrib.distributions.MultivariateNormalCholesky(mu, chol) + + # Evaluate this on a two observations, each in R^3, returning a length two + # tensor. + x = [[-1, 0, 1], [-11, 0, 11]] # Shape 2 x 3. + dist.pdf(x) + ``` + + Trainable (batch) Choesky matrices can be created with + `tf.contrib.distributions.batch_matrix_diag_transform()` + + """ + + def __init__( + self, + mu, + chol, + strict=True, + strict_statistics=True, + name="MultivariateNormalCholesky"): + """Multivariate Normal distributions on `R^k`. + + User must provide means `mu` and `chol` which holds the (batch) Cholesky + factors `S`, such that the covariance of each batch member is `S S^*`. + + Args: + mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`, + `b >= 0`. + chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape + `[N1,...,Nb, k, k]`. + strict: Whether to validate input with asserts. If `strict` is `False`, + and the inputs are invalid, correct behavior is not guaranteed. + strict_statistics: Boolean, default True. If True, raise an exception if + a statistic (e.g. mean/mode/etc...) is undefined for any batch member. + If False, batch members with valid parameters leading to undefined + statistics will return NaN for this statistic. + name: The name to give Ops created by the initializer. + + Raises: + TypeError: If `mu` and `chol` are different dtypes. + """ + cov = operator_pd_cholesky.OperatorPDCholesky(chol, verify_pd=strict) + super(MultivariateNormalCholesky, self).__init__( + mu, cov, strict_statistics=strict_statistics, strict=strict, name=name) + + +class MultivariateNormalFull(MultivariateNormalOperatorPD): + """The multivariate normal distribution on `R^k`. + + This distribution is defined by a 1-D mean `mu` and covariance matrix `sigma`. + Evaluation of the pdf, determinant, and sampling are all `O(k^3)` operations. + + #### Mathematical details + + The PDF of this distribution is: + + ``` + f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu)) + ``` + + where `.` denotes the inner product on `R^k` and `^*` denotes transpose. + + #### Examples + + A single multi-variate Gaussian distribution is defined by a vector of means + of length `k`, and a covariance matrix of shape `k x k`. + + Extra leading dimensions, if provided, allow for batches. + + ```python + # Initialize a single 3-variate Gaussian with diagonal covariance. + mu = [1, 2, 3.] + sigma = [[1, 0, 0], [0, 3, 0], [0, 0, 2.]] + dist = tf.contrib.distributions.MultivariateNormalFull(mu, chol) + + # Evaluate this on an observation in R^3, returning a scalar. + dist.pdf([-1, 0, 1]) + + # Initialize a batch of two 3-variate Gaussians. + mu = [[1, 2, 3], [11, 22, 33.]] + sigma = ... # shape 2 x 3 x 3, positive definite. + dist = tf.contrib.distributions.MultivariateNormalFull(mu, sigma) + + # Evaluate this on a two observations, each in R^3, returning a length two + # tensor. + x = [[-1, 0, 1], [-11, 0, 11.]] # Shape 2 x 3. + dist.pdf(x) + ``` + + """ + + def __init__( + self, + mu, + sigma, + strict=True, + strict_statistics=True, + name="MultivariateNormalFull"): + """Multivariate Normal distributions on `R^k`. + + User must provide means `mu` and `sigma`, the mean and covariance. + + Args: + mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`, + `b >= 0`. + sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape + `[N1,...,Nb, k, k]`. + strict: Whether to validate input with asserts. If `strict` is `False`, + and the inputs are invalid, correct behavior is not guaranteed. + strict_statistics: Boolean, default True. If True, raise an exception if + a statistic (e.g. mean/mode/etc...) is undefined for any batch member. + If False, batch members with valid parameters leading to undefined + statistics will return NaN for this statistic. + name: The name to give Ops created by the initializer. + + Raises: + TypeError: If `mu` and `sigma` are different dtypes. + """ + cov = operator_pd_full.OperatorPDFull(sigma, verify_pd=strict) + super(MultivariateNormalFull, self).__init__( + mu, cov, strict_statistics=strict_statistics, strict=strict, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/operator_pd.py b/tensorflow/contrib/distributions/python/ops/operator_pd.py new file mode 100644 index 00000000000..d8fb809da4f --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/operator_pd.py @@ -0,0 +1,284 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for symmetric positive definite operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import six + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +__all__ = [ + 'OperatorPDBase', +] + + +@six.add_metaclass(abc.ABCMeta) +class OperatorPDBase(object): + """Class representing a (batch) of positive definite matrices `A`. + + This class provides access to functions of a (batch) symmetric positive + definite (PD) matrix, without the need to materialize them. In other words, + this provides means to do "matrix free" computations. + + For example, `my_operator.matmul(x)` computes the result of matrix + multiplication, and this class is free to do this computation with or without + ever materializing a matrix. + + In practice, this operator represents a (batch) matrix `A` with shape + `[N1,...,Nb, k, k]` for some `b >= 0`. The first `b` indices index a + batch member. For every batch index `(n1,...,nb)`, `A[n1,...,nb, : :]` is + a `k x k` matrix. Again, this matrix `A` may not be materialized, but for + purposes of broadcasting this shape will be relevant. + + Since `A` is (batch) positive definite, it has a (or several) square roots `S` + such that `A = SS^T`. + + For example, if `MyOperator` inherits from `OperatorPDBase`, the user can do + + ```python + operator = MyOperator(...) # Initialize with some tensors. + operator.log_det() + + # Compute the quadratic form x^T A^{-1} x for vector x. + x = ... # some shape [..., k] tensor + operator.inv_quadratic_form(x) + + # Matrix multiplication by the square root, S w. + # If w is iid normal, S w has covariance A. + w = ... # some shape [..., k, L] tensor, L >= 1 + operator.sqrt_matmul(w) + ``` + + The above three methods, `log_det`, `inv_quadratic_form`, and + `sqrt_matmul` provide "all" that is necessary to use a covariance matrix + in a multi-variate normal distribution. See the class `MVNOperatorPD`. + """ + + @abc.abstractproperty + def name(self): + """String name identifying this `Operator`.""" + # return self._name + pass + + @abc.abstractproperty + def verify_pd(self): + """Whether to verify that this `Operator` is positive definite.""" + # return self._verify_pd + pass + + @abc.abstractproperty + def dtype(self): + """Data type of matrix elements of `A`.""" + pass + + def inv_quadratic_form(self, x, name='inv_quadratic_form'): + """Compute the quadratic form: x^T A^{-1} x. + + Args: + x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype` + as self. + name: A name scope to use for ops added by this method. + + Returns: + `Tensor` holding the square of the norm induced by inverse of `A`. For + every broadcast batch member. + """ + # with ops.name_scope(self.name): + # with ops.op_scope([x] + self.inputs, name): + # # ... your code here + pass + + def det(self, name='det'): + """Determinant for every batch member. + + Args: + name: A name scope to use for ops added by this method. + + Returns: + Determinant for every batch member. + """ + # Derived classes are encouraged to implement log_det() (since it is + # usually more stable), and then det() comes for free. + with ops.name_scope(self.name): + with ops.op_scope(self.inputs, name): + return math_ops.exp(self.log_det()) + + def log_det(self, name='log_det'): + """Log of the determinant for every batch member. + + Args: + name: A name scope to use for ops added by this method. + + Returns: + Logarithm of determinant for every batch member. + """ + # with ops.name_scope(self.name): + # with ops.op_scope(self.inputs, name): + # # ... your code here + pass + + @abc.abstractproperty + def inputs(self): + """List of tensors that were provided as initialization inputs.""" + pass + + def sqrt_matmul(self, x, name='sqrt_matmul'): + """Left (batch) matmul `x` by a sqrt of this matrix: `Sx` where `A = S S^T. + + Args: + x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype` + as self. + name: A name scope to use for ops added by this method. + + Returns: + Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`. + """ + # with ops.name_scope(self.name): + # with ops.op_scope([x] + self.inputs, name): + # # ... your code here + pass + + @abc.abstractmethod + def get_shape(self): + """`TensorShape` giving static shape.""" + pass + + def get_batch_shape(self): + """`TensorShape` with batch shape.""" + return self.get_shape()[:-2] + + def get_vector_shape(self): + """`TensorShape` of vectors this operator will work with.""" + return self.get_shape()[:-1] + + @abc.abstractmethod + def shape(self, name='shape'): + """Equivalent to `tf.shape(A).` Equal to `[N1,...,Nb, k, k]`, `b >= 0`. + + Args: + name: A name scope to use for ops added by this method. + + Returns: + `int32` `Tensor` + """ + # with ops.name_scope(self.name): + # with ops.op_scope(self.inputs, name): + # # ... your code here + pass + + def rank(self, name='rank'): + """Tensor rank. Equivalent to `tf.rank(A)`. Will equal `b + 2`. + + If this operator represents the batch matrix `A` with + `A.shape = [N1,...,Nb, k, k]`, the `rank` is `b + 2`. + + Args: + name: A name scope to use for ops added by this method. + + Returns: + `int32` `Tensor` + """ + # Derived classes get this "for free" once .shape() is implemented. + with ops.name_scope(self.name): + with ops.op_scope(self.inputs, name): + return array_ops.shape(self.shape())[0] + + def batch_shape(self, name='batch_shape'): + """Shape of batches associated with this operator. + + If this operator represents the batch matrix `A` with + `A.shape = [N1,...,Nb, k, k]`, the `batch_shape` is `[N1,...,Nb]`. + + Args: + name: A name scope to use for ops added by this method. + + Returns: + `int32` `Tensor` + """ + # Derived classes get this "for free" once .shape() is implemented. + with ops.name_scope(self.name): + with ops.op_scope(self.inputs, name): + end = array_ops.pack([self.rank() - 2]) + return array_ops.slice(self.shape(), [0], end) + + def vector_shape(self, name='vector_shape'): + """Shape of (batch) vectors that this (batch) matrix will multiply. + + If this operator represents the batch matrix `A` with + `A.shape = [N1,...,Nb, k, k]`, the `vector_shape` is `[N1,...,Nb, k]`. + + Args: + name: A name scope to use for ops added by this method. + + Returns: + `int32` `Tensor` + """ + # Derived classes get this "for free" once .shape() is implemented. + with ops.name_scope(self.name): + with ops.op_scope(self.inputs, name): + return array_ops.slice(self.shape(), [0], [self.rank() - 1]) + + def vector_space_dimension(self, name='vector_space_dimension'): + """Dimension of vector space on which this acts. The `k` in `R^k`. + + If this operator represents the batch matrix `A` with + `A.shape = [N1,...,Nb, k, k]`, the `vector_space_dimension` is `k`. + + Args: + name: A name scope to use for ops added by this method. + + Returns: + `int32` `Tensor` + """ + # Derived classes get this "for free" once .shape() is implemented. + with ops.name_scope(self.name): + with ops.op_scope(self.inputs, name): + return array_ops.gather(self.shape(), self.rank() - 1) + + def matmul(self, x, name='matmul'): + """Left multiply `x` by this operator. + + Args: + x: Shape `[N1,...,Nb, k, L]` `Tensor` with same `dtype` as this operator + name: A name to give this `Op`. + + Returns: + A result equivalent to `tf.batch_matmul(self.to_dense(), x)`. + """ + # with ops.name_scope(self.name): + # with ops.op_scope([x] + self.inputs, name): + # # ... your code here + raise NotImplementedError('This operator has no batch_matmul Op.') + + def to_dense(self, name='to_dense'): + """Return a dense (batch) matrix representing this operator.""" + # with ops.name_scope(self.name): + # with ops.op_scope(self.inputs, name): + # # ... your code here + raise NotImplementedError('This operator has no dense representation.') + + def to_dense_sqrt(self, name='to_dense_sqrt'): + """Return a dense (batch) matrix representing sqrt of this operator.""" + # with ops.name_scope(self.name): + # with ops.op_scope(self.inputs, name): + # # ... your code here + raise NotImplementedError('This operator has no dense sqrt representation.') diff --git a/tensorflow/contrib/distributions/python/ops/operator_pd_cholesky.py b/tensorflow/contrib/distributions/python/ops/operator_pd_cholesky.py new file mode 100644 index 00000000000..17370904d0d --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/operator_pd_cholesky.py @@ -0,0 +1,431 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Symmetric positive definite (PD) Operator defined by a Cholesky factor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import operator_pd +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + +__all__ = ['OperatorPDCholesky', 'batch_matrix_diag_transform'] + + +class OperatorPDCholesky(operator_pd.OperatorPDBase): + """Class representing a (batch) of positive definite matrices `A`. + + This class provides access to functions of a batch of symmetric positive + definite (PD) matrices `A` in `R^{k x k}` defined by Cholesky factor(s). + Determinants and solves are `O(k^2)`. + + In practice, this operator represents a (batch) matrix `A` with shape + `[N1,...,Nb, k, k]` for some `b >= 0`. The first `b` indices designate a + batch member. For every batch member `(n1,...,nb)`, `A[n1,...,nb, : :]` is + a `k x k` matrix. + + Since `A` is (batch) positive definite, it has a (or several) square roots `S` + such that `A = SS^T`. + + For example, + + ```python + distributions = tf.contrib.distributions + chol = [[1.0, 0.0], [1.0, 2.0]] + operator = OperatorPDCholesky(chol) + operator.log_det() + + # Compute the quadratic form x^T A^{-1} x for vector x. + x = [1.0, 2.0] + operator.inv_quadratic_form(x) + + # Matrix multiplication by the square root, S w. + # If w is iid normal, S w has covariance A. + w = [[1.0], [2.0]] + operator.sqrt_matmul(w) + ``` + + The above three methods, `log_det`, `inv_quadratic_form`, and + `sqrt_matmul` provide "all" that is necessary to use a covariance matrix + in a multi-variate normal distribution. See the class `MVNOperatorPD`. + """ + + def __init__(self, chol, verify_pd=True, name='OperatorPDCholesky'): + """Initialize an OperatorPDCholesky. + + Args: + chol: Shape `[N1,...,Nb, k, k]` tensor with `b >= 0`, `k >= 1`, and + positive diagonal elements. The strict upper triangle of `chol` is + never used, and the user may set these elements to zero, or ignore them. + verify_pd: Whether to check that `chol` has positive diagonal (this is + equivalent to it being a Cholesky factor of a symmetric positive + definite matrix. If `verify_pd` is `False`, correct behavior is not + guaranteed. + name: A name to prepend to all ops created by this class. + """ + self._verify_pd = verify_pd + self._name = name + with ops.name_scope(name): + with ops.op_scope([chol], 'init'): + self._diag = array_ops.batch_matrix_diag_part(chol) + self._chol = self._check_chol(chol) + + @property + def verify_pd(self): + """Whether to verify that this `Operator` is positive definite.""" + return self._verify_pd + + @property + def name(self): + return self._name + + @property + def dtype(self): + return self._chol.dtype + + def inv_quadratic_form(self, x, name='inv_quadratic_form'): + """Compute the induced vector norm (squared): ||x||^2 := x^T A^{-1} x. + + For every batch member, this is done in `O(k^2)` complexity. The efficiency + depends on the shape of `x`. + * If `x.shape = [M1,...,Mm, N1,...,Nb, k]`, `m >= 0`, and + `self.shape = [N1,...,Nb, k, k]`, `x` will be reshaped and the + initialization matrix `chol` does not need to be copied. + * Otherwise, data will be broadcast and copied. + + Args: + x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype` + as self. If the batch dimensions of `x` do not match exactly with those + of self, `x` and/or self's Cholesky factor will broadcast to match, and + the resultant set of linear systems are solved independently. This may + result in inefficient operation. + name: A name scope to use for ops added by this method. + + Returns: + `Tensor` holding the square of the norm induced by inverse of `A`. For + every broadcast batch member. + """ + with ops.name_scope(self.name): + with ops.op_scope([x] + self.inputs, name): + x = ops.convert_to_tensor(x, name='x') + # Boolean, True iff x.shape = [M1,...,Mm] + chol.shape[:-1]. + should_flip = self._should_flip(x) + x_whitened = control_flow_ops.cond( + should_flip, + lambda: self._x_whitened_if_should_flip(x), + lambda: self._x_whitened_if_no_flip(x)) + + # batch version of: || L^{-1} x||^2 + x_whitened_norm = math_ops.reduce_sum( + math_ops.square(x_whitened), + reduction_indices=[-1]) + + # Set the final shape by making a dummy tensor that will never be + # evaluated. + chol_without_final_dim = math_ops.reduce_sum( + self._chol, reduction_indices=[-1]) + final_shape = (x + chol_without_final_dim).get_shape()[:-1] + x_whitened_norm.set_shape(final_shape) + + return x_whitened_norm + + def _should_flip(self, x): + """Return boolean tensor telling whether `x` should be flipped.""" + # We "flip" (see self._flip_front_dims_to_back) iff + # chol.shape = [N1,...,Nn, k, k] + # x.shape = [M1,...,Mm, N1,...,Nn, k] + x_shape = array_ops.shape(x) + x_rank = array_ops.rank(x) + # If m <= 0, we should not flip. + m = x_rank + 1 - self.rank() + def result_if_m_positive(): + x_shape_right = array_ops.slice(x_shape, [m], [x_rank - m]) + return math_ops.reduce_all( + math_ops.equal(x_shape_right, self.vector_shape())) + return control_flow_ops.cond( + m > 0, + result_if_m_positive, + lambda: ops.convert_to_tensor(False)) + + def _x_whitened_if_no_flip(self, x): + """x_whitened in the event of no flip.""" + # Tensors to use if x and chol have same shape, or a shape that must be + # broadcast to match. + chol_bcast, x_bcast = self._get_chol_and_x_compatible_shape(x) + + # batch version of: L^{-1} x + # Note that here x_bcast has trailing dims of (k, 1), for "1" system of k + # linear equations. This is the form used by the solver. + x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve( + chol_bcast, x_bcast) + + x_whitened = array_ops.squeeze(x_whitened_expanded, squeeze_dims=[-1]) + return x_whitened + + def _x_whitened_if_should_flip(self, x): + # Tensor to use if x.shape = [M1,...,Mm] + chol.shape[:-1], + # which is common if x was sampled. + x_flipped = self._flip_front_dims_to_back(x) + + # batch version of: L^{-1} x + x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve( + self._chol, x_flipped) + + return self._unfip_back_dims_to_front( + x_whitened_expanded, + array_ops.shape(x), + x.get_shape()) + + def _flip_front_dims_to_back(self, x): + """Flip x to make x.shape = chol.shape[:-1] + [M1*...*Mr].""" + # E.g. suppose + # chol.shape = [N1,...,Nn, k, k] + # x.shape = [M1,...,Mm, N1,...,Nn, k] + # Then we want to return x_flipped where + # x_flipped.shape = [N1,...,Nn, k, M1*...*Mm]. + x_shape = array_ops.shape(x) + x_rank = array_ops.rank(x) + m = x_rank + 1 - self.rank() + x_shape_left = array_ops.slice(x_shape, [0], [m]) + + # Permutation corresponding to [N1,...,Nn, k, M1,...,Mm] + perm = array_ops.concat( + 0, (math_ops.range(m, x_rank), math_ops.range(0, m))) + x_permuted = array_ops.transpose(x, perm=perm) + + # Now that things are ordered correctly, condense the last dimensions. + # condensed_shape = [M1*...*Mm] + condensed_shape = array_ops.pack([math_ops.reduce_prod(x_shape_left)]) + new_shape = array_ops.concat(0, (self.vector_shape(), condensed_shape)) + + return array_ops.reshape(x_permuted, new_shape) + + def _unfip_back_dims_to_front(self, x_flipped, x_shape, x_get_shape): + # E.g. suppose that originally + # chol.shape = [N1,...,Nn, k, k] + # x.shape = [M1,...,Mm, N1,...,Nn, k] + # Then we have flipped the dims so that + # x_flipped.shape = [N1,...,Nn, k, M1*...*Mm]. + # We want to return x with the original shape. + rank = array_ops.rank(x_flipped) + # Permutation corresponding to [M1*...*Mm, N1,...,Nn, k] + perm = array_ops.concat( + 0, (math_ops.range(rank - 1, rank), math_ops.range(0, rank - 1))) + x_with_end_at_beginning = array_ops.transpose(x_flipped, perm=perm) + x = array_ops.reshape(x_with_end_at_beginning, x_shape) + return x + + def _get_chol_and_x_compatible_shape(self, x): + """Return self.chol and x, (possibly) broadcast to compatible shape.""" + # x and chol are "compatible" if their shape matches except for the last two + # dimensions of chol are [k, k], and the last two of x are [k, 1]. + # E.g. x.shape = [A, B, k, 1], and chol.shape = [A, B, k, k] + # This is required for the batch_triangular_solve, which does not broadcast. + + # TODO(langmore) This broadcast replicates matrices unnecesarily! In the + # case where + # x.shape = [M1,...,Mr, N1,...,Nb, k], and chol.shape = [N1,...,Nb, k, k] + # (which is common if x was sampled), the front dimensions of x can be + # "flipped" to the end, making + # x_flipped.shape = [N1,...,Nb, k, M1*...*Mr], + # and this can be handled by the linear solvers. This is preferred, because + # it does not replicate the matrix, or create any new data. + + # We assume x starts without the trailing singleton dimension, e.g. + # x.shape = [B, k]. + chol = self._chol + with ops.op_scope([x] + self.inputs, 'get_chol_and_x_compatible_shape'): + # If we determine statically that shapes match, we're done. + if x.get_shape() == chol.get_shape()[:-1]: + x_expanded = array_ops.expand_dims(x, -1) + return chol, x_expanded + + # Dynamic check if shapes match or not. + vector_shape = self.vector_shape() # Shape of chol minus last dim. + are_same_rank = math_ops.equal( + array_ops.rank(x), array_ops.rank(vector_shape)) + + def shapes_match_if_same_rank(): + return math_ops.reduce_all(math_ops.equal( + array_ops.shape(x), vector_shape)) + + shapes_match = control_flow_ops.cond(are_same_rank, + shapes_match_if_same_rank, + lambda: ops.convert_to_tensor(False)) + + # Make tensors (never instantiated) holding the broadcast shape. + # matrix_broadcast_dummy is the shape we will broadcast chol to. + matrix_bcast_dummy = chol + array_ops.expand_dims(x, -1) + # vector_bcast_dummy is the shape we will bcast x to, before we expand it. + chol_minus_last_dim = math_ops.reduce_sum(chol, reduction_indices=[-1]) + vector_bcast_dummy = x + chol_minus_last_dim + + chol_bcast = chol + array_ops.zeros_like(matrix_bcast_dummy) + x_bcast = x + array_ops.zeros_like(vector_bcast_dummy) + + chol_result = control_flow_ops.cond(shapes_match, lambda: chol, + lambda: chol_bcast) + chol_result.set_shape(matrix_bcast_dummy.get_shape()) + x_result = control_flow_ops.cond(shapes_match, lambda: x, lambda: x_bcast) + x_result.set_shape(vector_bcast_dummy.get_shape()) + + x_expanded = array_ops.expand_dims(x_result, -1) + + return chol_result, x_expanded + + def log_det(self, name='log_det'): + """Log determinant of every batch member.""" + with ops.name_scope(self.name): + with ops.op_scope(self.inputs, name): + det = 2.0 * math_ops.reduce_sum( + math_ops.log(self._diag), + reduction_indices=[-1]) + det.set_shape(self._chol.get_shape()[:-2]) + return det + + @property + def inputs(self): + """List of tensors that were provided as initialization inputs.""" + return [self._chol] + + def sqrt_matmul(self, x, name='sqrt_matmul'): + """Left (batch) matmul `x` by a sqrt of this matrix: `Sx` where `A = S S^T. + + Args: + x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype` + as self. + name: A name scope to use for ops added by this method. + + Returns: + Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`. + """ + with ops.name_scope(self.name): + with ops.op_scope([x] + self.inputs, name): + chol_lower = array_ops.batch_matrix_band_part(self._chol, -1, 0) + return math_ops.batch_matmul(chol_lower, x) + + def get_shape(self): + """`TensorShape` giving static shape.""" + return self._chol.get_shape() + + def shape(self, name='shape'): + with ops.name_scope(self.name): + with ops.op_scope(self.inputs, name): + return array_ops.shape(self._chol) + + def _check_chol(self, chol): + """Verify that `chol` is proper.""" + chol = ops.convert_to_tensor(chol, name='chol') + if not self.verify_pd: + return chol + + shape = array_ops.shape(chol) + rank = array_ops.rank(chol) + + is_matrix = check_ops.assert_rank_at_least(chol, 2) + is_square = check_ops.assert_equal( + array_ops.gather(shape, rank - 2), array_ops.gather(shape, rank - 1)) + + deps = [is_matrix, is_square] + deps.append(check_ops.assert_positive(self._diag)) + + return control_flow_ops.with_dependencies(deps, chol) + + def matmul(self, x, name='matmul'): + """Left (batch) matrix multiplication of `x` by this operator.""" + chol = self._chol + with ops.name_scope(self.name): + with ops.op_scope(self.inputs, name): + a_times_x = math_ops.batch_matmul(chol, x, adj_x=True) + return math_ops.batch_matmul(chol, a_times_x) + + def to_dense_sqrt(self, name='to_dense_sqrt'): + """Return a dense (batch) matrix representing sqrt of this covariance.""" + with ops.name_scope(self.name): + with ops.op_scope(self.inputs, name): + return array_ops.identity(self._chol) + + def to_dense(self, name='to_dense'): + """Return a dense (batch) matrix representing this covariance.""" + chol = self._chol + with ops.name_scope(self.name): + with ops.op_scope(self.inputs, name): + return math_ops.batch_matmul(chol, chol, adj_y=True) + + +def batch_matrix_diag_transform(matrix, transform=None, name=None): + """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. + + Create a trainable covariance defined by a Cholesky factor: + + ```python + # Transform network layer into 2 x 2 array. + matrix_values = tf.contrib.layers.fully_connected(activations, 4) + matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) + + # Make the diagonal positive. If the upper triangle was zero, this would be a + # valid Cholesky factor. + chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus) + + # OperatorPDCholesky ignores the upper triangle. + operator = OperatorPDCholesky(chol) + ``` + + Example of heteroskedastic 2-D linear regression. + + ```python + # Get a trainable Cholesky factor. + matrix_values = tf.contrib.layers.fully_connected(activations, 4) + matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) + chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus) + + # Get a trainable mean. + mu = tf.contrib.layers.fully_connected(activations, 2) + + # This is a fully trainable multivariate normal! + dist = tf.contrib.distributions.MVNCholesky(mu, chol) + + # Standard log loss. Minimizing this will "train" mu and chol, and then dist + # will be a distribution predicting labels as multivariate Gaussians. + loss = -1 * tf.reduce_mean(dist.log_pdf(labels)) + ``` + + Args: + matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are + equal. + transform: Element-wise function mapping `Tensors` to `Tensors`. To + be applied to the diagonal of `matrix`. If `None`, `matrix` is returned + unchanged. Defaults to `None`. + name: A name to give created ops. + Defaults to "batch_matrix_diag_transform". + + Returns: + A `Tensor` with same shape and `dtype` as `matrix`. + """ + with ops.op_scope([matrix], name, 'batch_matrix_diag_transform'): + matrix = ops.convert_to_tensor(matrix, name='matrix') + if transform is None: + return matrix + # Replace the diag with transformed diag. + diag = array_ops.batch_matrix_diag_part(matrix) + transformed_diag = transform(diag) + matrix += array_ops.batch_matrix_diag(transformed_diag - diag) + + return matrix diff --git a/tensorflow/contrib/distributions/python/ops/operator_pd_full.py b/tensorflow/contrib/distributions/python/ops/operator_pd_full.py new file mode 100644 index 00000000000..e01cc7cc79c --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/operator_pd_full.py @@ -0,0 +1,106 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Symmetric positive definite (PD) Operator defined by a full matrix.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + + +__all__ = [ + 'OperatorPDFull', +] + + +class OperatorPDFull(operator_pd_cholesky.OperatorPDCholesky): + """Class representing a (batch) of positive definite matrices `A`. + + This class provides access to functions of a batch of symmetric positive + definite (PD) matrices `A` in `R^{k x k}` defined by dense matrices. + Determinants and solves are `O(k^3)`. + + In practice, this operator represents a (batch) matrix `A` with shape + `[N1,...,Nb, k, k]` for some `b >= 0`. The first `b` indices designate a + batch member. For every batch member `(n1,...,nb)`, `A[n1,...,nb, : :]` is + a `k x k` matrix. + + Since `A` is (batch) positive definite, it has a (or several) square roots `S` + such that `A = SS^T`. + + For example, + + ```python + distributions = tf.contrib.distributions + matrix = [[1.0, 0.5], [1.0, 2.0]] + operator = OperatorPDFull(matrix) + operator.log_det() + + # Compute the quadratic form x^T A^{-1} x for vector x. + x = [1.0, 2.0] + operator.inv_quadratic_form(x) + + # Matrix multiplication by the square root, S w. + # If w is iid normal, S w has covariance A. + w = [[1.0], [2.0]] + operator.sqrt_matmul(w) + ``` + + The above three methods, `log_det`, `inv_quadratic_form`, and + `sqrt_matmul` provide "all" that is necessary to use a covariance matrix + in a multi-variate normal distribution. See the class `MVNOperatorPD`. + """ + + def __init__(self, matrix, verify_pd=True, name='OperatorPDFull'): + """Initialize an OperatorPDFull. + + Args: + matrix: Shape `[N1,...,Nb, k, k]` tensor with `b >= 0`, `k >= 1`. The + last two dimensions should be `k x k` symmetric positive definite + matrices. + verify_pd: Whether to check that `matrix` is symmetric positive definite. + If `verify_pd` is `False`, correct behavior is not guaranteed. + name: A name to prepend to all ops created by this class. + """ + with ops.name_scope(name): + with ops.op_scope([matrix], 'init'): + matrix = ops.convert_to_tensor(matrix) + # Check symmetric here. Positivity will be verified by checking the + # diagonal of the Cholesky factor inside the parent class. The Cholesky + # factorization .batch_cholesky() does not always fail for non PSD + # matrices, so don't rely on that. + if verify_pd: + matrix = _check_symmetric(matrix) + chol = linalg_ops.batch_cholesky(matrix) + super(OperatorPDFull, self).__init__(chol, verify_pd=verify_pd) + + +def _check_symmetric(matrix): + rank = array_ops.rank(matrix) + # Create permutation to permute last two dimensions + first_dims = math_ops.range(0, rank - 2) + flipped_last_dims = array_ops.pack([rank - 1, rank - 2]) + perm = array_ops.concat(0, (first_dims, flipped_last_dims)) + matrix_t = array_ops.transpose(matrix, perm=perm) + + return control_flow_ops.with_dependencies( + [check_ops.assert_equal(matrix, matrix_t)], matrix)