diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 5feac79ecb0..bfa31dbe1cd 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -32,6 +32,19 @@ cuda_py_tests( srcs = ["python/kernel_tests/gaussian_test.py"], additional_deps = [ ":distributions_py", + "//third_party/py/scipy", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "mvn_test", + size = "small", + srcs = ["python/kernel_tests/mvn_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/scipy", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], @@ -43,6 +56,7 @@ cuda_py_tests( srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"], additional_deps = [ ":distributions_py", + "//third_party/py/scipy", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 2f9b8fcafb1..54607a7379e 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -21,8 +21,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,wildcard-import, line-too-long +# pylint: disable=unused-import,wildcard-import,line-too-long from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * from tensorflow.contrib.distributions.python.ops.gaussian import * -# from tensorflow.contrib.distributions.python.ops.dirichlet import * # pylint: disable=line-too-long +from tensorflow.contrib.distributions.python.ops.mvn import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py new file mode 100644 index 00000000000..8b249c22362 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py @@ -0,0 +1,252 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MultivariateNormal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats +import tensorflow as tf + + +class MultivariateNormalTest(tf.test.TestCase): + + def testNonmatchingMuSigmaFails(self): + with tf.Session(): + mvn = tf.contrib.distributions.MultivariateNormal( + mu=[1.0, 2.0], + sigma=[[[1.0, 0.0], + [0.0, 1.0]], + [[1.0, 0.0], + [0.0, 1.0]]]) + with self.assertRaisesOpError( + r"Rank of mu should be one less than rank of sigma"): + mvn.mean.eval() + + mvn = tf.contrib.distributions.MultivariateNormal( + mu=[[1.0], [2.0]], + sigma=[[[1.0, 0.0], + [0.0, 1.0]], + [[1.0, 0.0], + [0.0, 1.0]]]) + with self.assertRaisesOpError( + r"mu.shape and sigma.shape\[\:-1\] must match"): + mvn.mean.eval() + + def testNotPositiveDefiniteSigmaFails(self): + with tf.Session(): + mvn = tf.contrib.distributions.MultivariateNormal( + mu=[[1.0, 2.0], [1.0, 2.0]], + sigma=[[[1.0, 0.0], + [0.0, 1.0]], + [[1.0, 1.0], + [1.0, 1.0]]]) + with self.assertRaisesOpError( + r"LLT decomposition was not successful."): + mvn.mean.eval() + mvn = tf.contrib.distributions.MultivariateNormal( + mu=[[1.0, 2.0], [1.0, 2.0]], + sigma=[[[1.0, 0.0], + [0.0, 1.0]], + [[-1.0, 0.0], + [0.0, 1.0]]]) + with self.assertRaisesOpError( + r"LLT decomposition was not successful."): + mvn.mean.eval() + mvn = tf.contrib.distributions.MultivariateNormal( + mu=[[1.0, 2.0], [1.0, 2.0]], + sigma_chol=[[[1.0, 0.0], + [0.0, 1.0]], + [[-1.0, 0.0], + [0.0, 1.0]]]) + with self.assertRaisesOpError( + r"sigma_chol is not positive definite."): + mvn.mean.eval() + + def testLogPDFScalar(self): + with tf.Session(): + mu_v = np.array([-3.0, 3.0], dtype=np.float32) + mu = tf.constant(mu_v) + sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) + sigma = tf.constant(sigma_v) + x = np.array([-2.5, 2.5], dtype=np.float32) + mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + + log_pdf = mvn.log_pdf(x) + + scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) + expected_log_pdf = scipy_mvn.logpdf(x) + expected_pdf = scipy_mvn.pdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + + pdf = mvn.pdf(x) + self.assertAllClose(expected_pdf, pdf.eval()) + + def testLogPDFScalarSigmaHalf(self): + with tf.Session(): + mu_v = np.array([-3.0, 3.0, 1.0], dtype=np.float32) + mu = tf.constant(mu_v) + sigma_v = np.array([[1.0, 0.1, 0.2], + [0.1, 2.0, 0.05], + [0.2, 0.05, 3.0]], dtype=np.float32) + sigma_chol_v = np.linalg.cholesky(sigma_v) + sigma_chol = tf.constant(sigma_chol_v) + x = np.array([-2.5, 2.5, 1.0], dtype=np.float32) + mvn = tf.contrib.distributions.MultivariateNormal( + mu=mu, sigma_chol=sigma_chol) + + log_pdf = mvn.log_pdf(x) + sigma = mvn.sigma + + scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) + expected_log_pdf = scipy_mvn.logpdf(x) + expected_pdf = scipy_mvn.pdf(x) + self.assertEqual(sigma.get_shape(), (3, 3)) + self.assertAllClose(sigma_v, sigma.eval()) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + + pdf = mvn.pdf(x) + self.assertAllClose(expected_pdf, pdf.eval()) + + def testLogPDF(self): + with tf.Session(): + mu_v = np.array([-3.0, 3.0], dtype=np.float32) + mu = tf.constant(mu_v) + sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) + sigma = tf.constant(sigma_v) + x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32) + mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + + log_pdf = mvn.log_pdf(x) + + scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) + expected_log_pdf = scipy_mvn.logpdf(x) + expected_pdf = scipy_mvn.pdf(x) + self.assertEqual(log_pdf.get_shape(), (3,)) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + + pdf = mvn.pdf(x) + self.assertAllClose(expected_pdf, pdf.eval()) + + def testLogPDFMatchingDimension(self): + with tf.Session(): + mu_v = np.array([-3.0, 3.0], dtype=np.float32) + mu = tf.constant(np.vstack(3 * [mu_v])) + sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) + sigma = tf.constant(np.vstack(3 * [sigma_v[np.newaxis, :]])) + x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32) + mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + + log_pdf = mvn.log_pdf(x) + + scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) + expected_log_pdf = scipy_mvn.logpdf(x) + expected_pdf = scipy_mvn.pdf(x) + self.assertEqual(log_pdf.get_shape(), (3,)) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + + pdf = mvn.pdf(x) + self.assertAllClose(expected_pdf, pdf.eval()) + + def testLogPDFMultidimensional(self): + with tf.Session(): + mu_v = np.array([-3.0, 3.0], dtype=np.float32) + mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2)) + sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) + sigma = tf.constant( + np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2)) + x = np.array([-2.5, 2.5], dtype=np.float32) + mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + + log_pdf = mvn.log_pdf(x) + + scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) + expected_log_pdf = np.vstack(15 * [scipy_mvn.logpdf(x)]).reshape(3, 5) + expected_pdf = np.vstack(15 * [scipy_mvn.pdf(x)]).reshape(3, 5) + self.assertEqual(log_pdf.get_shape(), (3, 5)) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + + pdf = mvn.pdf(x) + self.assertAllClose(expected_pdf, pdf.eval()) + + def testEntropy(self): + with tf.Session(): + mu_v = np.array([-3.0, 3.0], dtype=np.float32) + mu = tf.constant(mu_v) + sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) + sigma = tf.constant(sigma_v) + mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + entropy = mvn.entropy() + + scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) + expected_entropy = scipy_mvn.entropy() + + self.assertEqual(entropy.get_shape(), ()) + self.assertAllClose(expected_entropy, entropy.eval()) + + def testEntropyMultidimensional(self): + with tf.Session(): + mu_v = np.array([-3.0, 3.0], dtype=np.float32) + mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2)) + sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) + sigma = tf.constant( + np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2)) + mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + entropy = mvn.entropy() + + scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v) + expected_entropy = np.vstack(15 * [scipy_mvn.entropy()]).reshape(3, 5) + + self.assertEqual(entropy.get_shape(), (3, 5)) + self.assertAllClose(expected_entropy, entropy.eval()) + + def testSample(self): + with tf.Session(): + mu_v = np.array([-3.0, 3.0], dtype=np.float32) + mu = tf.constant(mu_v) + sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) + sigma = tf.constant(sigma_v) + n = tf.constant(100000) + mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + samples = mvn.sample(n, seed=137) + sample_values = samples.eval() + self.assertEqual(samples.get_shape(), (100000, 2)) + self.assertAllClose(sample_values.mean(axis=0), mu_v, atol=1e-2) + self.assertAllClose(np.cov(sample_values, rowvar=0), sigma_v, atol=1e-1) + + def testSampleMultiDimensional(self): + with tf.Session(): + mu_v = np.array([-3.0, 3.0], dtype=np.float32) + mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2)) + sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32) + sigma = tf.constant( + np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2)) + n = tf.constant(100000) + mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma) + samples = mvn.sample(n, seed=137) + sample_values = samples.eval() + self.assertEqual(samples.get_shape(), (100000, 3, 5, 2)) + sample_values = sample_values.reshape(100000, 15, 2) + for i in range(15): + self.assertAllClose( + sample_values[:, i, :].mean(axis=0), mu_v, atol=1e-2) + self.assertAllClose( + np.cov(sample_values[:, i, :], rowvar=0), sigma_v, atol=1e-1) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/gaussian.py b/tensorflow/contrib/distributions/python/ops/gaussian.py index b9dad502983..cbb98624d97 100644 --- a/tensorflow/contrib/distributions/python/ops/gaussian.py +++ b/tensorflow/contrib/distributions/python/ops/gaussian.py @@ -88,7 +88,7 @@ class Gaussian(object): @property def mean(self): - return self._mu + return self._mu * array_ops.ones_like(self._sigma) def log_pdf(self, x, name=None): """Log pdf of observations in `x` under these Gaussian distribution(s). @@ -170,7 +170,7 @@ class Gaussian(object): return 0.5 * math_ops.log(two_pi_e1 * math_ops.square(sigma)) def sample(self, n, seed=None, name=None): - """Sample `n` observations the Gaussian Distributions. + """Sample `n` observations from the Gaussian Distributions. Args: n: `Scalar`, type int32, the number of observations to sample. @@ -185,7 +185,7 @@ class Gaussian(object): broadcast_shape = (self._mu + self._sigma).get_shape() n = ops.convert_to_tensor(n) shape = array_ops.concat( - 0, [array_ops.pack([n]), array_ops.shape(self._mu)]) + 0, [array_ops.pack([n]), array_ops.shape(self.mean)]) sampled = random_ops.random_normal( shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed) diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py new file mode 100644 index 00000000000..4ddd577d46b --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/mvn.py @@ -0,0 +1,429 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Multivariate Normal distribution class. + +@@MultivariateNormal +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import constant_op +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + + +def _assert_compatible_shapes(mu, sigma): + r_mu = array_ops.rank(mu) + r_sigma = array_ops.rank(sigma) + sigma_shape = array_ops.shape(sigma) + sigma_rank = array_ops.rank(sigma) + mu_shape = array_ops.shape(mu) + return control_flow_ops.group( + logging_ops.Assert( + math_ops.equal(r_mu + 1, r_sigma), + ["Rank of mu should be one less than rank of sigma, but saw: ", + r_mu, " vs. ", r_sigma]), + logging_ops.Assert( + math_ops.equal( + array_ops.gather(sigma_shape, sigma_rank - 2), + array_ops.gather(sigma_shape, sigma_rank - 1)), + ["Last two dimensions of sigma (%s) must be equal: " % sigma.name, + sigma_shape]), + logging_ops.Assert( + math_ops.reduce_all(math_ops.equal( + mu_shape, + array_ops.slice( + sigma_shape, [0], array_ops.pack([sigma_rank - 1])))), + ["mu.shape and sigma.shape[:-1] must match, but saw: ", + mu_shape, " vs. ", sigma_shape])) + + +def _assert_batch_positive_definite(sigma_chol): + """Add assertions checking that the sigmas are all Positive Definite. + + Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that + `all(diag(sigma_chol) > 0)`. This is because to check that a matrix is PD, + it is sufficient that its cholesky factorization is PD, and to check that a + triangular matrix is PD, it is sufficient to check that its diagonal + entries are positive. + + Args: + sigma_chol: N-D. The lower triangular cholesky decomposition of `sigma`. + + Returns: + An assertion op to use with `control_dependencies`, verifying that + `sigma_chol` is positive definite. + """ + sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol) + return logging_ops.Assert( + math_ops.reduce_all(sigma_batch_diag > 0), + ["sigma_chol is not positive definite. batched diagonals: ", + sigma_batch_diag, " shaped: ", array_ops.shape(sigma_batch_diag)]) + + +def _determinant_from_sigma_chol(sigma_chol): + det_last_dim = array_ops.rank(sigma_chol) - 2 + sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol) + det = math_ops.square(math_ops.reduce_prod( + sigma_batch_diag, reduction_indices=det_last_dim)) + det.set_shape(sigma_chol.get_shape()[:-2]) + return det + + +class MultivariateNormal(object): + """The Multivariate Normal distribution on `R^k`. + + The distribution has mean and covariance parameters mu (1-D), sigma (2-D), + or alternatively mean `mu` and factored covariance (cholesky decomposed + `sigma`) called `sigma_chol`. + + The PDF of this distribution is: + + ``` + f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu)) + ``` + + where `.` denotes the inner product on `R^k` and `^*` denotes transpose. + + Alternatively, if `sigma` is positive definite, it can be represented in terms + of its lower triangular cholesky factorization + + ```sigma = sigma_chol . sigma_chol^*``` + + and the pdf above allows simpler computation: + + ``` + |det(sigma)| = reduce_prod(diag(sigma_chol))^2 + x_whitened = sigma^{-1/2} . (x - mu) = tri_solve(sigma_chol, x - mu) + (x-mu)^* .sigma^{-1} . (x-mu) = x_whitened^* . x_whitened + ``` + + where `tri_solve()` solves a triangular system of equations. + """ + + def __init__(self, mu, sigma=None, sigma_chol=None, name=None): + """Multivariate Normal distributions on `R^k`. + + User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`) + with the last dimension having length `k`. + + User must provide exactly one of `sigma` (the covariance matrices) or + `sigma_chol` (the cholesky decompositions of the covariance matrices). + `sigma` or `sigma_chol` must be of rank `N+2`. The last two dimensions + must both have length `k`. The first `N` dimensions correspond to batch + indices. + + If `sigma_chol` is not provided, the batch cholesky factorization of `sigma` + is calculated for you. + + The shapes of `mu` and `sigma` must match for the first `N` dimensions. + + Regardless of which parameter is provided, the covariance matrices must all + be **positive definite** (an error is raised if one of them is not). + + Args: + mu: (N+1)-D. `float` or `double` tensor, the means of the distributions. + sigma: (N+2)-D. (optional) `float` or `double` tensor, the covariances + of the distribution(s). The first `N+1` dimensions must match + those of `mu`. Must be batch-positive-definite. + sigma_chol: (N+2)-D. (optional) `float` or `double` tensor, a + lower-triangular factorization of `sigma` + (`sigma = sigma_chol . sigma_chol^*`). The first `N+1` dimensions + must match those of `mu`. The tensor itself need not be batch + lower triangular: we ignore the upper triangular part. However, + the batch diagonals must be positive (i.e., sigma_chol must be + batch-positive-definite). + name: The name to give Ops created by the initializer. + + Raises: + ValueError: if neither sigma nor sigma_chol is provided. + TypeError: if mu and sigma (resp. sigma_chol) are different dtypes. + """ + if (sigma is None) == (sigma_chol is None): + raise ValueError("Exactly one of sigma and sigma_chol must be provided") + + with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"): + sigma_or_half = sigma_chol if sigma is None else sigma + + mu = ops.convert_to_tensor(mu) + sigma_or_half = ops.convert_to_tensor(sigma_or_half) + + contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half)) + + with ops.control_dependencies([ + _assert_compatible_shapes(mu, sigma_or_half)]): + mu = array_ops.identity(mu, name="mu") + + # Store the dimensionality of the MVNs + self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1) + + if sigma_chol is not None: + # Ensure we only keep the lower triangular part. + sigma_chol = array_ops.batch_matrix_band_part( + sigma_chol, num_lower=-1, num_upper=0) + sigma_det = _determinant_from_sigma_chol(sigma_chol) + with ops.control_dependencies([ + _assert_batch_positive_definite(sigma_chol)]): + self._sigma = math_ops.batch_matmul( + sigma_chol, sigma_chol, adj_y=True, name="sigma") + self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol") + self._sigma_det = array_ops.identity(sigma_det, "sigma_det") + self._mu = array_ops.identity(mu, "mu") + else: # sigma is not None + sigma_chol = linalg_ops.batch_cholesky(sigma) + sigma_det = _determinant_from_sigma_chol(sigma_chol) + # batch_cholesky checks for PSD; so we can just use it here. + with ops.control_dependencies([sigma_chol]): + self._sigma = array_ops.identity(sigma, "sigma") + self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol") + self._sigma_det = array_ops.identity(sigma_det, "sigma_det") + self._mu = array_ops.identity(mu, "mu") + + @property + def dtype(self): + return self._mu.dtype + + @property + def mu(self): + return self._mu + + @property + def sigma(self): + return self._sigma + + @property + def mean(self): + return self._mu + + @property + def sigma_det(self): + return self._sigma_det + + def log_pdf(self, x, name=None): + """Log pdf of observations `x` given these Multivariate Normals. + + Args: + x: tensor of dtype `dtype`, must be broadcastable with `mu`. + name: The name to give this op. + + Returns: + log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`. + """ + with ops.op_scope( + [self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"): + x = ops.convert_to_tensor(x) + contrib_tensor_util.assert_same_float_dtype((self._mu, x)) + + x_centered = x - self.mu + + x_rank = array_ops.rank(x_centered) + sigma_rank = array_ops.rank(self._sigma_chol) + + x_rank_vec = array_ops.pack([x_rank]) + sigma_rank_vec = array_ops.pack([sigma_rank]) + x_shape = array_ops.shape(x_centered) + + # sigma_chol is shaped [D, E, F, ..., k, k] + # x_centered shape is one of: + # [D, E, F, ..., k], or [F, ..., k], or + # [A, B, C, D, E, F, ..., k] + # and we need to convert x_centered to shape: + # [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist) + # then transpose and reshape x_whitened back to one of the shapes: + # [D, E, F, ..., k], or [1, 1, F, ..., k], or + # [A, B, C, D, E, F, ..., k] + + # This helper handles the case where rank(x_centered) < rank(sigma) + def _broadcast_x_not_higher_rank_than_sigma(): + return array_ops.reshape( + x_centered, + array_ops.concat( + # Reshape to ones(deficient x rank) + x_shape + [1] + 0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]), + dtype=x_rank.dtype), + x_shape, + [1]))) + + # These helpers handle the case where rank(x_centered) >= rank(sigma) + def _broadcast_x_higher_rank_than_sigma(): + x_shape_left = array_ops.slice( + x_shape, [0], sigma_rank_vec - 1) + x_shape_right = array_ops.slice( + x_shape, sigma_rank_vec - 1, x_rank_vec - 1) + x_shape_perm = array_ops.concat( + 0, (math_ops.range(sigma_rank - 1, x_rank), + math_ops.range(0, sigma_rank - 1))) + return array_ops.reshape( + # Convert to [D, E, F, ..., k, B, C] + array_ops.transpose( + x_centered, perm=x_shape_perm), + # Reshape to [D, E, F, ..., k, B*C] + array_ops.concat( + 0, (x_shape_right, + array_ops.pack([ + math_ops.reduce_prod(x_shape_left, 0)])))) + + def _unbroadcast_x_higher_rank_than_sigma(): + x_shape_left = array_ops.slice( + x_shape, [0], sigma_rank_vec - 1) + x_shape_right = array_ops.slice( + x_shape, sigma_rank_vec - 1, x_rank_vec - 1) + x_shape_perm = array_ops.concat( + 0, (math_ops.range(sigma_rank - 1, x_rank), + math_ops.range(0, sigma_rank - 1))) + return array_ops.transpose( + # [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k] + array_ops.reshape( + # convert to [D, E, F, ..., k, B, C] + x_whitened_broadcast, + array_ops.concat(0, (x_shape_right, x_shape_left))), + perm=x_shape_perm) + + # Step 1: reshape x_centered + x_centered_broadcast = control_flow_ops.cond( + # x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1] + # or == [F, ..., k] => [1, 1, F, ..., k, 1] + x_rank <= sigma_rank - 1, + _broadcast_x_not_higher_rank_than_sigma, + # x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C] + _broadcast_x_higher_rank_than_sigma) + + x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve( + self._sigma_chol, x_centered_broadcast) + + # Reshape x_whitened_broadcast back to x_whitened + x_whitened = control_flow_ops.cond( + x_rank <= sigma_rank - 1, + lambda: array_ops.reshape(x_whitened_broadcast, x_shape), + _unbroadcast_x_higher_rank_than_sigma) + + x_whitened = array_ops.expand_dims(x_whitened, -1) + # Reshape x_whitened to contain row vectors + # Returns a batchwise scalar + x_whitened_norm = math_ops.batch_matmul( + x_whitened, x_whitened, adj_x=True) + x_whitened_norm = control_flow_ops.cond( + x_rank <= sigma_rank - 1, + lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]), + lambda: array_ops.squeeze(x_whitened_norm, [-1])) + + log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype) + k = math_ops.cast(self._k, self.dtype) + log_pdf_value = ( + -math_ops.log(self._sigma_det) -k * log_two_pi - x_whitened_norm) / 2 + final_shaped_value = control_flow_ops.cond( + x_rank <= sigma_rank - 1, + lambda: log_pdf_value, + lambda: array_ops.squeeze(log_pdf_value, [-1])) + + output_static_shape = x_centered.get_shape()[:-1] + final_shaped_value.set_shape(output_static_shape) + return final_shaped_value + + def pdf(self, x, name=None): + """The PDF of observations `x` under these Multivariate Normals. + + Args: + x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`. + name: The name to give this op. + + Returns: + pdf: tensor of dtype `dtype`, the pdf values of `x`. + """ + with ops.op_scope( + [self._mu, self._sigma_chol, x], name, "MultivariateNormalPdf"): + return math_ops.exp(self.log_pdf(x)) + + def entropy(self, name=None): + """The entropies of these Multivariate Normals. + + Args: + name: The name to give this op. + + Returns: + entropy: tensor of dtype `dtype`, the entropies. + """ + with ops.op_scope( + [self._mu, self._sigma_chol], name, "MultivariateNormalEntropy"): + one_plus_log_two_pi = constant_op.constant( + 1 + math.log(2 * math.pi), dtype=self.dtype) + + # Use broadcasting rules to calculate the full broadcast sigma. + k = math_ops.cast(self._k, dtype=self.dtype) + entropy_value = ( + k * one_plus_log_two_pi + math_ops.log(self._sigma_det)) / 2 + entropy_value.set_shape(self._sigma_det.get_shape()) + return entropy_value + + def sample(self, n, seed=None, name=None): + """Sample `n` observations from the Multivariate Normal Distributions. + + Args: + n: `Scalar`, type int32, the number of observations to sample. + seed: Python integer, the random seed. + name: The name to give this op. + + Returns: + samples: `[n, ...]`, a `Tensor` of `n` samples for each + of the distributions determined by broadcasting the hyperparameters. + """ + with ops.op_scope( + [self._mu, self._sigma_chol, n], name, "MultivariateNormalSample"): + # TODO(ebrevdo): Is there a better way to get broadcast_shape? + broadcast_shape = self.mu.get_shape() + n = ops.convert_to_tensor(n) + sigma_shape_left = array_ops.slice( + array_ops.shape(self._sigma_chol), + [0], array_ops.pack([array_ops.rank(self._sigma_chol) - 2])) + + k_n = array_ops.pack([self._k, n]) + shape = array_ops.concat(0, [sigma_shape_left, k_n]) + white_samples = random_ops.random_normal( + shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed) + + correlated_samples = math_ops.batch_matmul( + self._sigma_chol, white_samples) + + # Move the last dimension to the front + perm = array_ops.concat( + 0, + (array_ops.pack([array_ops.rank(correlated_samples) - 1]), + math_ops.range(0, array_ops.rank(correlated_samples) - 1))) + + # TODO(ebrevdo): Once we get a proper tensor contraction op, + # perform the inner product using that instead of batch_matmul + # and this slow transpose can go away! + correlated_samples = array_ops.transpose(correlated_samples, perm) + + samples = correlated_samples + self.mu + + # Provide some hints to shape inference + n_val = tensor_util.constant_value(n) + final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape) + samples.set_shape(final_shape) + + return samples diff --git a/tensorflow/core/kernels/batch_matmul_op.cc b/tensorflow/core/kernels/batch_matmul_op.cc index f5a64e1f46e..922e9f63de5 100644 --- a/tensorflow/core/kernels/batch_matmul_op.cc +++ b/tensorflow/core/kernels/batch_matmul_op.cc @@ -234,8 +234,8 @@ class BatchMatMul : public OpKernel { in1.shape().DebugString())); const int ndims = in0.dims(); OP_REQUIRES( - ctx, ndims >= 3, - errors::InvalidArgument("In[0] and In[1] ndims must be >= 3: ", ndims)); + ctx, ndims >= 2, + errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims)); TensorShape out_shape; for (int i = 0; i < ndims - 2; ++i) { OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i), @@ -245,7 +245,7 @@ class BatchMatMul : public OpKernel { in1.shape().DebugString())); out_shape.AddDim(in0.dim_size(i)); } - auto n = out_shape.num_elements(); + auto n = (ndims == 2) ? 1 : out_shape.num_elements(); auto d0 = in0.dim_size(ndims - 2); auto d1 = in0.dim_size(ndims - 1); Tensor in0_reshaped; diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py index c82a4249fc4..199b54512e0 100644 --- a/tensorflow/python/kernel_tests/cholesky_op_test.py +++ b/tensorflow/python/kernel_tests/cholesky_op_test.py @@ -25,19 +25,8 @@ import tensorflow as tf class CholeskyOpTest(tf.test.TestCase): - def _verifyCholesky(self, x): - with self.test_session() as sess: - # Verify that LL^T == x. - if x.ndim == 2: - chol = tf.cholesky(x) - verification = tf.matmul(chol, - chol, - transpose_a=False, - transpose_b=True) - else: - chol = tf.batch_cholesky(x) - verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True) - chol_np, verification_np = sess.run([chol, verification]) + def _verifyCholeskyBase(self, sess, x, chol, verification): + chol_np, verification_np = sess.run([chol, verification]) self.assertAllClose(x, verification_np) self.assertShapeEqual(x, chol) # Check that the cholesky is lower triangular, and has positive diagonal @@ -49,6 +38,20 @@ class CholeskyOpTest(tf.test.TestCase): self.assertAllClose(chol_matrix, np.tril(chol_matrix)) self.assertTrue((np.diag(chol_matrix) > 0.0).all()) + def _verifyCholesky(self, x): + # Verify that LL^T == x. + with self.test_session() as sess: + # Check the batch version, which works for ndim >= 2. + chol = tf.batch_cholesky(x) + verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True) + self._verifyCholeskyBase(sess, x, chol, verification) + + if x.ndim == 2: # Check the simple form of cholesky + chol = tf.cholesky(x) + verification = tf.matmul( + chol, chol, transpose_a=False, transpose_b=True) + self._verifyCholeskyBase(sess, x, chol, verification) + def testBasic(self): self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]])) diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py index 4355da8a05e..779d924ecf9 100644 --- a/tensorflow/python/kernel_tests/determinant_op_test.py +++ b/tensorflow/python/kernel_tests/determinant_op_test.py @@ -24,13 +24,8 @@ import tensorflow as tf class DeterminantOpTest(tf.test.TestCase): - def _compareDeterminant(self, matrix_x): - with self.test_session(): - if matrix_x.ndim == 2: - tf_ans = tf.matrix_determinant(matrix_x) - else: - tf_ans = tf.batch_matrix_determinant(matrix_x) - out = tf_ans.eval() + def _compareDeterminantBase(self, matrix_x, tf_ans): + out = tf_ans.eval() shape = matrix_x.shape if shape[-1] == 0 and shape[-2] == 0: np_ans = np.ones(shape[:-2]).astype(matrix_x.dtype) @@ -39,6 +34,15 @@ class DeterminantOpTest(tf.test.TestCase): self.assertAllClose(np_ans, out) self.assertShapeEqual(np_ans, tf_ans) + def _compareDeterminant(self, matrix_x): + with self.test_session(): + # Check the batch version, which should work for ndim >= 2 + self._compareDeterminantBase( + matrix_x, tf.batch_matrix_determinant(matrix_x)) + if matrix_x.ndim == 2: + # Check the simple version + self._compareDeterminantBase(matrix_x, tf.matrix_determinant(matrix_x)) + def testBasic(self): # 2x2 matrices self._compareDeterminant(np.array([[2., 3.], [3., 4.]]).astype(np.float32)) diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py index 32e49328c16..d04020eac1d 100644 --- a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py @@ -67,11 +67,13 @@ class MatrixSolveLsOpTest(tf.test.TestCase): np_ans, _, _, _ = np.linalg.lstsq(a, b) for fast in [True, False]: with self.test_session(): - tf_ans = tf.matrix_solve_ls(a, b, fast=fast).eval() - self.assertEqual(np_ans.shape, tf_ans.shape) + tf_ans = tf.matrix_solve_ls(a, b, fast=fast) + ans = tf_ans.eval() + self.assertEqual(np_ans.shape, tf_ans.get_shape()) + self.assertEqual(np_ans.shape, ans.shape) # Check residual norm. - tf_r = b - BatchMatMul(a, tf_ans) + tf_r = b - BatchMatMul(a, ans) tf_r_norm = np.sum(tf_r * tf_r) np_r = b - BatchMatMul(a, np_ans) np_r_norm = np.sum(np_r * np_r) @@ -83,7 +85,7 @@ class MatrixSolveLsOpTest(tf.test.TestCase): # slow path, because Eigen does not return a minimum norm solution. # TODO(rmlarsen): Enable this check for all paths if/when we fix # Eigen's solver. - self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5) + self.assertAllClose(np_ans, ans, atol=1e-5, rtol=1e-5) def _verifySolveBatch(self, x, y): # Since numpy.linalg.lsqr does not support batch solves, as opposed @@ -122,20 +124,23 @@ class MatrixSolveLsOpTest(tf.test.TestCase): b = y.astype(np_type) np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer) with self.test_session(): - tf_ans = tf.matrix_solve_ls(a, - b, - l2_regularizer=l2_regularizer, - fast=True).eval() - self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5) + # Test with the batch version of matrix_solve_ls on regular matrices + tf_ans = tf.batch_matrix_solve_ls( + a, b, l2_regularizer=l2_regularizer, fast=True).eval() + self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5) + + # Test with the simple matrix_solve_ls on regular matrices + tf_ans = tf.matrix_solve_ls( + a, b, l2_regularizer=l2_regularizer, fast=True).eval() + self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5) + # Test with a 2x3 batch of matrices. a = np.tile(x.astype(np_type), [2, 3, 1, 1]) b = np.tile(y.astype(np_type), [2, 3, 1, 1]) np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer) with self.test_session(): - tf_ans = tf.batch_matrix_solve_ls(a, - b, - l2_regularizer=l2_regularizer, - fast=True).eval() + tf_ans = tf.batch_matrix_solve_ls( + a, b, l2_regularizer=l2_regularizer, fast=True).eval() self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5) def testSquare(self): diff --git a/tensorflow/python/kernel_tests/matrix_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_op_test.py index cffdf4e6884..a08d0f27501 100644 --- a/tensorflow/python/kernel_tests/matrix_solve_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_solve_op_test.py @@ -37,15 +37,23 @@ class MatrixSolveOpTest(tf.test.TestCase): a = np.tile(a, batch_dims + [1, 1]) a_np = np.tile(a_np, batch_dims + [1, 1]) b = np.tile(b, batch_dims + [1, 1]) - with self.test_session(): - if a.ndim == 2: - tf_ans = tf.matrix_solve(a, b, adjoint=adjoint) - else: - tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint) - out = tf_ans.eval() + np_ans = np.linalg.solve(a_np, b) - self.assertEqual(np_ans.shape, out.shape) - self.assertAllClose(np_ans, out) + with self.test_session(): + # Test the batch version, which works for ndim >= 2 + tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint) + out = tf_ans.eval() + self.assertEqual(tf_ans.get_shape(), out.shape) + self.assertEqual(np_ans.shape, out.shape) + self.assertAllClose(np_ans, out) + + if a.ndim == 2: + # Test the simple version + tf_ans = tf.matrix_solve(a, b, adjoint=adjoint) + out = tf_ans.eval() + self.assertEqual(out.shape, tf_ans.get_shape()) + self.assertEqual(np_ans.shape, out.shape) + self.assertAllClose(np_ans, out) def testSolve(self): # 2x2 matrices, 2x1 right-hand side. diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py index f4637fa628f..fba393d599a 100644 --- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py @@ -51,20 +51,27 @@ class MatrixTriangularSolveOpTest(tf.test.TestCase): a = np.tile(a, batch_dims + [1, 1]) a_np = np.tile(a_np, batch_dims + [1, 1]) b = np.tile(b, batch_dims + [1, 1]) + with self.test_session(): + # Test the batch version, which works for ndim >= 2 + tf_ans = tf.batch_matrix_triangular_solve( + a, b, lower=lower, adjoint=adjoint) + out = tf_ans.eval() + + np_ans = np.linalg.solve(a_np, b) + + self.assertEqual(np_ans.shape, tf_ans.get_shape()) + self.assertEqual(np_ans.shape, out.shape) + self.assertAllClose(np_ans, out) + if a.ndim == 2: - tf_ans = tf.matrix_triangular_solve(a, - b, - lower=lower, - adjoint=adjoint).eval() - else: - tf_ans = tf.batch_matrix_triangular_solve(a, - b, - lower=lower, - adjoint=adjoint).eval() - np_ans = np.linalg.solve(a_np, b) - self.assertEqual(np_ans.shape, tf_ans.shape) - self.assertAllClose(np_ans, tf_ans) + # Test the simple version + tf_ans = tf.matrix_triangular_solve( + a, b, lower=lower, adjoint=adjoint) + out = tf_ans.eval() + self.assertEqual(np_ans.shape, tf_ans.get_shape()) + self.assertEqual(np_ans.shape, out.shape) + self.assertAllClose(np_ans, out) def testSolve(self): # 2x2 matrices, single right-hand side. diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py index e2c385c9dd7..d955ee1ad5e 100644 --- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py +++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py @@ -71,14 +71,28 @@ class SelfAdjointEigOpTest(tf.test.TestCase): for i in xrange(dlist[0]): self._testEigs(x[i], d, tf_out[i]) + def _compareBatchSelfAdjointEigRank2(self, x, use_gpu=False): + with self.test_session() as sess: + tf_eig = tf.batch_self_adjoint_eig(tf.constant(x)) + tf_out = sess.run([tf_eig])[0] + dlist = x.shape + d = dlist[-2] + + self.assertEqual(len(tf_eig.get_shape()), 2) + self.assertEqual([d+1, d], tf_eig.get_shape().dims[-2:]) + self._testEigs(x, d, tf_out) + def testBasic(self): self._compareSelfAdjointEig( np.array([[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]])) def testBatch(self): simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2) + simple_array_2d = simple_array[0] # shape (2, 2) self._compareBatchSelfAdjointEigRank3(simple_array) - self._compareBatchSelfAdjointEigRank3(np.vstack((simple_array, simple_array))) + self._compareBatchSelfAdjointEigRank3( + np.vstack((simple_array, simple_array))) + self._compareBatchSelfAdjointEigRank2(simple_array_2d) odd_sized_array = np.array([[[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]]) self._compareBatchSelfAdjointEigRank3( np.vstack((odd_sized_array, odd_sized_array))) diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 58bddb0b672..31fc2b28768 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -39,7 +39,7 @@ def _UnchangedSquare(op): @ops.RegisterShape("BatchCholesky") @ops.RegisterShape("BatchMatrixInverse") def _BatchUnchangedSquare(op): - input_shape = op.inputs[0].get_shape().with_rank_at_least(3) + input_shape = op.inputs[0].get_shape().with_rank_at_least(2) # The matrices in the batch must be square. input_shape[-1].assert_is_compatible_with(input_shape[-2]) return [input_shape] @@ -61,7 +61,7 @@ def _MatrixDeterminantShape(op): @ops.RegisterShape("BatchMatrixDeterminant") def _BatchMatrixDeterminantShape(op): - input_shape = op.inputs[0].get_shape().with_rank_at_least(3) + input_shape = op.inputs[0].get_shape().with_rank_at_least(2) # The matrices in the batch must be square. input_shape[-1].assert_is_compatible_with(input_shape[-2]) if input_shape.ndims is not None: @@ -82,7 +82,7 @@ def _SelfAdjointEigShape(op): @ops.RegisterShape("BatchSelfAdjointEig") def _BatchSelfAdjointEigShape(op): - input_shape = op.inputs[0].get_shape().with_rank_at_least(3) + input_shape = op.inputs[0].get_shape().with_rank_at_least(2) # The matrices in the batch must be square. input_shape[-1].assert_is_compatible_with(input_shape[-2]) dlist = input_shape.dims @@ -106,8 +106,8 @@ def _SquareMatrixSolveShape(op): @ops.RegisterShape("BatchMatrixSolve") @ops.RegisterShape("BatchMatrixTriangularSolve") def _BatchSquareMatrixSolveShape(op): - lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3) - rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3) + lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2) + rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2) # The matrices must be square. lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2]) # The matrices and right-hand sides in the batch must have the same number of @@ -127,8 +127,8 @@ def _MatrixSolveLsShape(op): @ops.RegisterShape("BatchMatrixSolveLs") def _BatchMatrixSolveLsShape(op): - lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3) - rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3) + lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2) + rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2) # The matrices and right-hand sides in the batch must have the same number of # rows. lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])