Add MultivariateNormal to tf.contrib.distributions.

Also fix overly stringent constraints on batchwise linalg ops & batch_matmul.
Change: 120279428
This commit is contained in:
Eugene Brevdo 2016-04-19 14:48:31 -08:00 committed by TensorFlower Gardener
parent 3198d7ef30
commit 912ab39d93
13 changed files with 805 additions and 69 deletions

View File

@ -32,6 +32,19 @@ cuda_py_tests(
srcs = ["python/kernel_tests/gaussian_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/scipy",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "mvn_test",
size = "small",
srcs = ["python/kernel_tests/mvn_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/scipy",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
@ -43,6 +56,7 @@ cuda_py_tests(
srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/scipy",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],

View File

@ -21,8 +21,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import, line-too-long
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.gaussian import *
# from tensorflow.contrib.distributions.python.ops.dirichlet import * # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops.mvn import *

View File

@ -0,0 +1,252 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MultivariateNormal."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from scipy import stats
import tensorflow as tf
class MultivariateNormalTest(tf.test.TestCase):
def testNonmatchingMuSigmaFails(self):
with tf.Session():
mvn = tf.contrib.distributions.MultivariateNormal(
mu=[1.0, 2.0],
sigma=[[[1.0, 0.0],
[0.0, 1.0]],
[[1.0, 0.0],
[0.0, 1.0]]])
with self.assertRaisesOpError(
r"Rank of mu should be one less than rank of sigma"):
mvn.mean.eval()
mvn = tf.contrib.distributions.MultivariateNormal(
mu=[[1.0], [2.0]],
sigma=[[[1.0, 0.0],
[0.0, 1.0]],
[[1.0, 0.0],
[0.0, 1.0]]])
with self.assertRaisesOpError(
r"mu.shape and sigma.shape\[\:-1\] must match"):
mvn.mean.eval()
def testNotPositiveDefiniteSigmaFails(self):
with tf.Session():
mvn = tf.contrib.distributions.MultivariateNormal(
mu=[[1.0, 2.0], [1.0, 2.0]],
sigma=[[[1.0, 0.0],
[0.0, 1.0]],
[[1.0, 1.0],
[1.0, 1.0]]])
with self.assertRaisesOpError(
r"LLT decomposition was not successful."):
mvn.mean.eval()
mvn = tf.contrib.distributions.MultivariateNormal(
mu=[[1.0, 2.0], [1.0, 2.0]],
sigma=[[[1.0, 0.0],
[0.0, 1.0]],
[[-1.0, 0.0],
[0.0, 1.0]]])
with self.assertRaisesOpError(
r"LLT decomposition was not successful."):
mvn.mean.eval()
mvn = tf.contrib.distributions.MultivariateNormal(
mu=[[1.0, 2.0], [1.0, 2.0]],
sigma_chol=[[[1.0, 0.0],
[0.0, 1.0]],
[[-1.0, 0.0],
[0.0, 1.0]]])
with self.assertRaisesOpError(
r"sigma_chol is not positive definite."):
mvn.mean.eval()
def testLogPDFScalar(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(mu_v)
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(sigma_v)
x = np.array([-2.5, 2.5], dtype=np.float32)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
log_pdf = mvn.log_pdf(x)
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertAllClose(expected_log_pdf, log_pdf.eval())
pdf = mvn.pdf(x)
self.assertAllClose(expected_pdf, pdf.eval())
def testLogPDFScalarSigmaHalf(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0, 1.0], dtype=np.float32)
mu = tf.constant(mu_v)
sigma_v = np.array([[1.0, 0.1, 0.2],
[0.1, 2.0, 0.05],
[0.2, 0.05, 3.0]], dtype=np.float32)
sigma_chol_v = np.linalg.cholesky(sigma_v)
sigma_chol = tf.constant(sigma_chol_v)
x = np.array([-2.5, 2.5, 1.0], dtype=np.float32)
mvn = tf.contrib.distributions.MultivariateNormal(
mu=mu, sigma_chol=sigma_chol)
log_pdf = mvn.log_pdf(x)
sigma = mvn.sigma
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertEqual(sigma.get_shape(), (3, 3))
self.assertAllClose(sigma_v, sigma.eval())
self.assertAllClose(expected_log_pdf, log_pdf.eval())
pdf = mvn.pdf(x)
self.assertAllClose(expected_pdf, pdf.eval())
def testLogPDF(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(mu_v)
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(sigma_v)
x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
log_pdf = mvn.log_pdf(x)
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertEqual(log_pdf.get_shape(), (3,))
self.assertAllClose(expected_log_pdf, log_pdf.eval())
pdf = mvn.pdf(x)
self.assertAllClose(expected_pdf, pdf.eval())
def testLogPDFMatchingDimension(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(np.vstack(3 * [mu_v]))
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(np.vstack(3 * [sigma_v[np.newaxis, :]]))
x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
log_pdf = mvn.log_pdf(x)
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertEqual(log_pdf.get_shape(), (3,))
self.assertAllClose(expected_log_pdf, log_pdf.eval())
pdf = mvn.pdf(x)
self.assertAllClose(expected_pdf, pdf.eval())
def testLogPDFMultidimensional(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
x = np.array([-2.5, 2.5], dtype=np.float32)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
log_pdf = mvn.log_pdf(x)
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_log_pdf = np.vstack(15 * [scipy_mvn.logpdf(x)]).reshape(3, 5)
expected_pdf = np.vstack(15 * [scipy_mvn.pdf(x)]).reshape(3, 5)
self.assertEqual(log_pdf.get_shape(), (3, 5))
self.assertAllClose(expected_log_pdf, log_pdf.eval())
pdf = mvn.pdf(x)
self.assertAllClose(expected_pdf, pdf.eval())
def testEntropy(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(mu_v)
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(sigma_v)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
entropy = mvn.entropy()
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_entropy = scipy_mvn.entropy()
self.assertEqual(entropy.get_shape(), ())
self.assertAllClose(expected_entropy, entropy.eval())
def testEntropyMultidimensional(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
entropy = mvn.entropy()
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
expected_entropy = np.vstack(15 * [scipy_mvn.entropy()]).reshape(3, 5)
self.assertEqual(entropy.get_shape(), (3, 5))
self.assertAllClose(expected_entropy, entropy.eval())
def testSample(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(mu_v)
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(sigma_v)
n = tf.constant(100000)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
samples = mvn.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (100000, 2))
self.assertAllClose(sample_values.mean(axis=0), mu_v, atol=1e-2)
self.assertAllClose(np.cov(sample_values, rowvar=0), sigma_v, atol=1e-1)
def testSampleMultiDimensional(self):
with tf.Session():
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
sigma = tf.constant(
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
n = tf.constant(100000)
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
samples = mvn.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (100000, 3, 5, 2))
sample_values = sample_values.reshape(100000, 15, 2)
for i in range(15):
self.assertAllClose(
sample_values[:, i, :].mean(axis=0), mu_v, atol=1e-2)
self.assertAllClose(
np.cov(sample_values[:, i, :], rowvar=0), sigma_v, atol=1e-1)
if __name__ == "__main__":
tf.test.main()

View File

@ -88,7 +88,7 @@ class Gaussian(object):
@property
def mean(self):
return self._mu
return self._mu * array_ops.ones_like(self._sigma)
def log_pdf(self, x, name=None):
"""Log pdf of observations in `x` under these Gaussian distribution(s).
@ -170,7 +170,7 @@ class Gaussian(object):
return 0.5 * math_ops.log(two_pi_e1 * math_ops.square(sigma))
def sample(self, n, seed=None, name=None):
"""Sample `n` observations the Gaussian Distributions.
"""Sample `n` observations from the Gaussian Distributions.
Args:
n: `Scalar`, type int32, the number of observations to sample.
@ -185,7 +185,7 @@ class Gaussian(object):
broadcast_shape = (self._mu + self._sigma).get_shape()
n = ops.convert_to_tensor(n)
shape = array_ops.concat(
0, [array_ops.pack([n]), array_ops.shape(self._mu)])
0, [array_ops.pack([n]), array_ops.shape(self.mean)])
sampled = random_ops.random_normal(
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)

View File

@ -0,0 +1,429 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Multivariate Normal distribution class.
@@MultivariateNormal
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
def _assert_compatible_shapes(mu, sigma):
r_mu = array_ops.rank(mu)
r_sigma = array_ops.rank(sigma)
sigma_shape = array_ops.shape(sigma)
sigma_rank = array_ops.rank(sigma)
mu_shape = array_ops.shape(mu)
return control_flow_ops.group(
logging_ops.Assert(
math_ops.equal(r_mu + 1, r_sigma),
["Rank of mu should be one less than rank of sigma, but saw: ",
r_mu, " vs. ", r_sigma]),
logging_ops.Assert(
math_ops.equal(
array_ops.gather(sigma_shape, sigma_rank - 2),
array_ops.gather(sigma_shape, sigma_rank - 1)),
["Last two dimensions of sigma (%s) must be equal: " % sigma.name,
sigma_shape]),
logging_ops.Assert(
math_ops.reduce_all(math_ops.equal(
mu_shape,
array_ops.slice(
sigma_shape, [0], array_ops.pack([sigma_rank - 1])))),
["mu.shape and sigma.shape[:-1] must match, but saw: ",
mu_shape, " vs. ", sigma_shape]))
def _assert_batch_positive_definite(sigma_chol):
"""Add assertions checking that the sigmas are all Positive Definite.
Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that
`all(diag(sigma_chol) > 0)`. This is because to check that a matrix is PD,
it is sufficient that its cholesky factorization is PD, and to check that a
triangular matrix is PD, it is sufficient to check that its diagonal
entries are positive.
Args:
sigma_chol: N-D. The lower triangular cholesky decomposition of `sigma`.
Returns:
An assertion op to use with `control_dependencies`, verifying that
`sigma_chol` is positive definite.
"""
sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
return logging_ops.Assert(
math_ops.reduce_all(sigma_batch_diag > 0),
["sigma_chol is not positive definite. batched diagonals: ",
sigma_batch_diag, " shaped: ", array_ops.shape(sigma_batch_diag)])
def _determinant_from_sigma_chol(sigma_chol):
det_last_dim = array_ops.rank(sigma_chol) - 2
sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
det = math_ops.square(math_ops.reduce_prod(
sigma_batch_diag, reduction_indices=det_last_dim))
det.set_shape(sigma_chol.get_shape()[:-2])
return det
class MultivariateNormal(object):
"""The Multivariate Normal distribution on `R^k`.
The distribution has mean and covariance parameters mu (1-D), sigma (2-D),
or alternatively mean `mu` and factored covariance (cholesky decomposed
`sigma`) called `sigma_chol`.
The PDF of this distribution is:
```
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
```
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
Alternatively, if `sigma` is positive definite, it can be represented in terms
of its lower triangular cholesky factorization
```sigma = sigma_chol . sigma_chol^*```
and the pdf above allows simpler computation:
```
|det(sigma)| = reduce_prod(diag(sigma_chol))^2
x_whitened = sigma^{-1/2} . (x - mu) = tri_solve(sigma_chol, x - mu)
(x-mu)^* .sigma^{-1} . (x-mu) = x_whitened^* . x_whitened
```
where `tri_solve()` solves a triangular system of equations.
"""
def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
"""Multivariate Normal distributions on `R^k`.
User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`)
with the last dimension having length `k`.
User must provide exactly one of `sigma` (the covariance matrices) or
`sigma_chol` (the cholesky decompositions of the covariance matrices).
`sigma` or `sigma_chol` must be of rank `N+2`. The last two dimensions
must both have length `k`. The first `N` dimensions correspond to batch
indices.
If `sigma_chol` is not provided, the batch cholesky factorization of `sigma`
is calculated for you.
The shapes of `mu` and `sigma` must match for the first `N` dimensions.
Regardless of which parameter is provided, the covariance matrices must all
be **positive definite** (an error is raised if one of them is not).
Args:
mu: (N+1)-D. `float` or `double` tensor, the means of the distributions.
sigma: (N+2)-D. (optional) `float` or `double` tensor, the covariances
of the distribution(s). The first `N+1` dimensions must match
those of `mu`. Must be batch-positive-definite.
sigma_chol: (N+2)-D. (optional) `float` or `double` tensor, a
lower-triangular factorization of `sigma`
(`sigma = sigma_chol . sigma_chol^*`). The first `N+1` dimensions
must match those of `mu`. The tensor itself need not be batch
lower triangular: we ignore the upper triangular part. However,
the batch diagonals must be positive (i.e., sigma_chol must be
batch-positive-definite).
name: The name to give Ops created by the initializer.
Raises:
ValueError: if neither sigma nor sigma_chol is provided.
TypeError: if mu and sigma (resp. sigma_chol) are different dtypes.
"""
if (sigma is None) == (sigma_chol is None):
raise ValueError("Exactly one of sigma and sigma_chol must be provided")
with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"):
sigma_or_half = sigma_chol if sigma is None else sigma
mu = ops.convert_to_tensor(mu)
sigma_or_half = ops.convert_to_tensor(sigma_or_half)
contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half))
with ops.control_dependencies([
_assert_compatible_shapes(mu, sigma_or_half)]):
mu = array_ops.identity(mu, name="mu")
# Store the dimensionality of the MVNs
self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1)
if sigma_chol is not None:
# Ensure we only keep the lower triangular part.
sigma_chol = array_ops.batch_matrix_band_part(
sigma_chol, num_lower=-1, num_upper=0)
sigma_det = _determinant_from_sigma_chol(sigma_chol)
with ops.control_dependencies([
_assert_batch_positive_definite(sigma_chol)]):
self._sigma = math_ops.batch_matmul(
sigma_chol, sigma_chol, adj_y=True, name="sigma")
self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
self._sigma_det = array_ops.identity(sigma_det, "sigma_det")
self._mu = array_ops.identity(mu, "mu")
else: # sigma is not None
sigma_chol = linalg_ops.batch_cholesky(sigma)
sigma_det = _determinant_from_sigma_chol(sigma_chol)
# batch_cholesky checks for PSD; so we can just use it here.
with ops.control_dependencies([sigma_chol]):
self._sigma = array_ops.identity(sigma, "sigma")
self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
self._sigma_det = array_ops.identity(sigma_det, "sigma_det")
self._mu = array_ops.identity(mu, "mu")
@property
def dtype(self):
return self._mu.dtype
@property
def mu(self):
return self._mu
@property
def sigma(self):
return self._sigma
@property
def mean(self):
return self._mu
@property
def sigma_det(self):
return self._sigma_det
def log_pdf(self, x, name=None):
"""Log pdf of observations `x` given these Multivariate Normals.
Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu`.
name: The name to give this op.
Returns:
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
"""
with ops.op_scope(
[self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"):
x = ops.convert_to_tensor(x)
contrib_tensor_util.assert_same_float_dtype((self._mu, x))
x_centered = x - self.mu
x_rank = array_ops.rank(x_centered)
sigma_rank = array_ops.rank(self._sigma_chol)
x_rank_vec = array_ops.pack([x_rank])
sigma_rank_vec = array_ops.pack([sigma_rank])
x_shape = array_ops.shape(x_centered)
# sigma_chol is shaped [D, E, F, ..., k, k]
# x_centered shape is one of:
# [D, E, F, ..., k], or [F, ..., k], or
# [A, B, C, D, E, F, ..., k]
# and we need to convert x_centered to shape:
# [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist)
# then transpose and reshape x_whitened back to one of the shapes:
# [D, E, F, ..., k], or [1, 1, F, ..., k], or
# [A, B, C, D, E, F, ..., k]
# This helper handles the case where rank(x_centered) < rank(sigma)
def _broadcast_x_not_higher_rank_than_sigma():
return array_ops.reshape(
x_centered,
array_ops.concat(
# Reshape to ones(deficient x rank) + x_shape + [1]
0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]),
dtype=x_rank.dtype),
x_shape,
[1])))
# These helpers handle the case where rank(x_centered) >= rank(sigma)
def _broadcast_x_higher_rank_than_sigma():
x_shape_left = array_ops.slice(
x_shape, [0], sigma_rank_vec - 1)
x_shape_right = array_ops.slice(
x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
x_shape_perm = array_ops.concat(
0, (math_ops.range(sigma_rank - 1, x_rank),
math_ops.range(0, sigma_rank - 1)))
return array_ops.reshape(
# Convert to [D, E, F, ..., k, B, C]
array_ops.transpose(
x_centered, perm=x_shape_perm),
# Reshape to [D, E, F, ..., k, B*C]
array_ops.concat(
0, (x_shape_right,
array_ops.pack([
math_ops.reduce_prod(x_shape_left, 0)]))))
def _unbroadcast_x_higher_rank_than_sigma():
x_shape_left = array_ops.slice(
x_shape, [0], sigma_rank_vec - 1)
x_shape_right = array_ops.slice(
x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
x_shape_perm = array_ops.concat(
0, (math_ops.range(sigma_rank - 1, x_rank),
math_ops.range(0, sigma_rank - 1)))
return array_ops.transpose(
# [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k]
array_ops.reshape(
# convert to [D, E, F, ..., k, B, C]
x_whitened_broadcast,
array_ops.concat(0, (x_shape_right, x_shape_left))),
perm=x_shape_perm)
# Step 1: reshape x_centered
x_centered_broadcast = control_flow_ops.cond(
# x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1]
# or == [F, ..., k] => [1, 1, F, ..., k, 1]
x_rank <= sigma_rank - 1,
_broadcast_x_not_higher_rank_than_sigma,
# x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C]
_broadcast_x_higher_rank_than_sigma)
x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve(
self._sigma_chol, x_centered_broadcast)
# Reshape x_whitened_broadcast back to x_whitened
x_whitened = control_flow_ops.cond(
x_rank <= sigma_rank - 1,
lambda: array_ops.reshape(x_whitened_broadcast, x_shape),
_unbroadcast_x_higher_rank_than_sigma)
x_whitened = array_ops.expand_dims(x_whitened, -1)
# Reshape x_whitened to contain row vectors
# Returns a batchwise scalar
x_whitened_norm = math_ops.batch_matmul(
x_whitened, x_whitened, adj_x=True)
x_whitened_norm = control_flow_ops.cond(
x_rank <= sigma_rank - 1,
lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]),
lambda: array_ops.squeeze(x_whitened_norm, [-1]))
log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype)
k = math_ops.cast(self._k, self.dtype)
log_pdf_value = (
-math_ops.log(self._sigma_det) -k * log_two_pi - x_whitened_norm) / 2
final_shaped_value = control_flow_ops.cond(
x_rank <= sigma_rank - 1,
lambda: log_pdf_value,
lambda: array_ops.squeeze(log_pdf_value, [-1]))
output_static_shape = x_centered.get_shape()[:-1]
final_shaped_value.set_shape(output_static_shape)
return final_shaped_value
def pdf(self, x, name=None):
"""The PDF of observations `x` under these Multivariate Normals.
Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
name: The name to give this op.
Returns:
pdf: tensor of dtype `dtype`, the pdf values of `x`.
"""
with ops.op_scope(
[self._mu, self._sigma_chol, x], name, "MultivariateNormalPdf"):
return math_ops.exp(self.log_pdf(x))
def entropy(self, name=None):
"""The entropies of these Multivariate Normals.
Args:
name: The name to give this op.
Returns:
entropy: tensor of dtype `dtype`, the entropies.
"""
with ops.op_scope(
[self._mu, self._sigma_chol], name, "MultivariateNormalEntropy"):
one_plus_log_two_pi = constant_op.constant(
1 + math.log(2 * math.pi), dtype=self.dtype)
# Use broadcasting rules to calculate the full broadcast sigma.
k = math_ops.cast(self._k, dtype=self.dtype)
entropy_value = (
k * one_plus_log_two_pi + math_ops.log(self._sigma_det)) / 2
entropy_value.set_shape(self._sigma_det.get_shape())
return entropy_value
def sample(self, n, seed=None, name=None):
"""Sample `n` observations from the Multivariate Normal Distributions.
Args:
n: `Scalar`, type int32, the number of observations to sample.
seed: Python integer, the random seed.
name: The name to give this op.
Returns:
samples: `[n, ...]`, a `Tensor` of `n` samples for each
of the distributions determined by broadcasting the hyperparameters.
"""
with ops.op_scope(
[self._mu, self._sigma_chol, n], name, "MultivariateNormalSample"):
# TODO(ebrevdo): Is there a better way to get broadcast_shape?
broadcast_shape = self.mu.get_shape()
n = ops.convert_to_tensor(n)
sigma_shape_left = array_ops.slice(
array_ops.shape(self._sigma_chol),
[0], array_ops.pack([array_ops.rank(self._sigma_chol) - 2]))
k_n = array_ops.pack([self._k, n])
shape = array_ops.concat(0, [sigma_shape_left, k_n])
white_samples = random_ops.random_normal(
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
correlated_samples = math_ops.batch_matmul(
self._sigma_chol, white_samples)
# Move the last dimension to the front
perm = array_ops.concat(
0,
(array_ops.pack([array_ops.rank(correlated_samples) - 1]),
math_ops.range(0, array_ops.rank(correlated_samples) - 1)))
# TODO(ebrevdo): Once we get a proper tensor contraction op,
# perform the inner product using that instead of batch_matmul
# and this slow transpose can go away!
correlated_samples = array_ops.transpose(correlated_samples, perm)
samples = correlated_samples + self.mu
# Provide some hints to shape inference
n_val = tensor_util.constant_value(n)
final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
samples.set_shape(final_shape)
return samples

View File

@ -234,8 +234,8 @@ class BatchMatMul : public OpKernel {
in1.shape().DebugString()));
const int ndims = in0.dims();
OP_REQUIRES(
ctx, ndims >= 3,
errors::InvalidArgument("In[0] and In[1] ndims must be >= 3: ", ndims));
ctx, ndims >= 2,
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
TensorShape out_shape;
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
@ -245,7 +245,7 @@ class BatchMatMul : public OpKernel {
in1.shape().DebugString()));
out_shape.AddDim(in0.dim_size(i));
}
auto n = out_shape.num_elements();
auto n = (ndims == 2) ? 1 : out_shape.num_elements();
auto d0 = in0.dim_size(ndims - 2);
auto d1 = in0.dim_size(ndims - 1);
Tensor in0_reshaped;

View File

@ -25,19 +25,8 @@ import tensorflow as tf
class CholeskyOpTest(tf.test.TestCase):
def _verifyCholesky(self, x):
with self.test_session() as sess:
# Verify that LL^T == x.
if x.ndim == 2:
chol = tf.cholesky(x)
verification = tf.matmul(chol,
chol,
transpose_a=False,
transpose_b=True)
else:
chol = tf.batch_cholesky(x)
verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
chol_np, verification_np = sess.run([chol, verification])
def _verifyCholeskyBase(self, sess, x, chol, verification):
chol_np, verification_np = sess.run([chol, verification])
self.assertAllClose(x, verification_np)
self.assertShapeEqual(x, chol)
# Check that the cholesky is lower triangular, and has positive diagonal
@ -49,6 +38,20 @@ class CholeskyOpTest(tf.test.TestCase):
self.assertAllClose(chol_matrix, np.tril(chol_matrix))
self.assertTrue((np.diag(chol_matrix) > 0.0).all())
def _verifyCholesky(self, x):
# Verify that LL^T == x.
with self.test_session() as sess:
# Check the batch version, which works for ndim >= 2.
chol = tf.batch_cholesky(x)
verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
self._verifyCholeskyBase(sess, x, chol, verification)
if x.ndim == 2: # Check the simple form of cholesky
chol = tf.cholesky(x)
verification = tf.matmul(
chol, chol, transpose_a=False, transpose_b=True)
self._verifyCholeskyBase(sess, x, chol, verification)
def testBasic(self):
self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]))

View File

@ -24,13 +24,8 @@ import tensorflow as tf
class DeterminantOpTest(tf.test.TestCase):
def _compareDeterminant(self, matrix_x):
with self.test_session():
if matrix_x.ndim == 2:
tf_ans = tf.matrix_determinant(matrix_x)
else:
tf_ans = tf.batch_matrix_determinant(matrix_x)
out = tf_ans.eval()
def _compareDeterminantBase(self, matrix_x, tf_ans):
out = tf_ans.eval()
shape = matrix_x.shape
if shape[-1] == 0 and shape[-2] == 0:
np_ans = np.ones(shape[:-2]).astype(matrix_x.dtype)
@ -39,6 +34,15 @@ class DeterminantOpTest(tf.test.TestCase):
self.assertAllClose(np_ans, out)
self.assertShapeEqual(np_ans, tf_ans)
def _compareDeterminant(self, matrix_x):
with self.test_session():
# Check the batch version, which should work for ndim >= 2
self._compareDeterminantBase(
matrix_x, tf.batch_matrix_determinant(matrix_x))
if matrix_x.ndim == 2:
# Check the simple version
self._compareDeterminantBase(matrix_x, tf.matrix_determinant(matrix_x))
def testBasic(self):
# 2x2 matrices
self._compareDeterminant(np.array([[2., 3.], [3., 4.]]).astype(np.float32))

View File

@ -67,11 +67,13 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
np_ans, _, _, _ = np.linalg.lstsq(a, b)
for fast in [True, False]:
with self.test_session():
tf_ans = tf.matrix_solve_ls(a, b, fast=fast).eval()
self.assertEqual(np_ans.shape, tf_ans.shape)
tf_ans = tf.matrix_solve_ls(a, b, fast=fast)
ans = tf_ans.eval()
self.assertEqual(np_ans.shape, tf_ans.get_shape())
self.assertEqual(np_ans.shape, ans.shape)
# Check residual norm.
tf_r = b - BatchMatMul(a, tf_ans)
tf_r = b - BatchMatMul(a, ans)
tf_r_norm = np.sum(tf_r * tf_r)
np_r = b - BatchMatMul(a, np_ans)
np_r_norm = np.sum(np_r * np_r)
@ -83,7 +85,7 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
# slow path, because Eigen does not return a minimum norm solution.
# TODO(rmlarsen): Enable this check for all paths if/when we fix
# Eigen's solver.
self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
self.assertAllClose(np_ans, ans, atol=1e-5, rtol=1e-5)
def _verifySolveBatch(self, x, y):
# Since numpy.linalg.lsqr does not support batch solves, as opposed
@ -122,20 +124,23 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
b = y.astype(np_type)
np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
with self.test_session():
tf_ans = tf.matrix_solve_ls(a,
b,
l2_regularizer=l2_regularizer,
fast=True).eval()
self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
# Test with the batch version of matrix_solve_ls on regular matrices
tf_ans = tf.batch_matrix_solve_ls(
a, b, l2_regularizer=l2_regularizer, fast=True).eval()
self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
# Test with the simple matrix_solve_ls on regular matrices
tf_ans = tf.matrix_solve_ls(
a, b, l2_regularizer=l2_regularizer, fast=True).eval()
self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
# Test with a 2x3 batch of matrices.
a = np.tile(x.astype(np_type), [2, 3, 1, 1])
b = np.tile(y.astype(np_type), [2, 3, 1, 1])
np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
with self.test_session():
tf_ans = tf.batch_matrix_solve_ls(a,
b,
l2_regularizer=l2_regularizer,
fast=True).eval()
tf_ans = tf.batch_matrix_solve_ls(
a, b, l2_regularizer=l2_regularizer, fast=True).eval()
self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
def testSquare(self):

View File

@ -37,15 +37,23 @@ class MatrixSolveOpTest(tf.test.TestCase):
a = np.tile(a, batch_dims + [1, 1])
a_np = np.tile(a_np, batch_dims + [1, 1])
b = np.tile(b, batch_dims + [1, 1])
with self.test_session():
if a.ndim == 2:
tf_ans = tf.matrix_solve(a, b, adjoint=adjoint)
else:
tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint)
out = tf_ans.eval()
np_ans = np.linalg.solve(a_np, b)
self.assertEqual(np_ans.shape, out.shape)
self.assertAllClose(np_ans, out)
with self.test_session():
# Test the batch version, which works for ndim >= 2
tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint)
out = tf_ans.eval()
self.assertEqual(tf_ans.get_shape(), out.shape)
self.assertEqual(np_ans.shape, out.shape)
self.assertAllClose(np_ans, out)
if a.ndim == 2:
# Test the simple version
tf_ans = tf.matrix_solve(a, b, adjoint=adjoint)
out = tf_ans.eval()
self.assertEqual(out.shape, tf_ans.get_shape())
self.assertEqual(np_ans.shape, out.shape)
self.assertAllClose(np_ans, out)
def testSolve(self):
# 2x2 matrices, 2x1 right-hand side.

View File

@ -51,20 +51,27 @@ class MatrixTriangularSolveOpTest(tf.test.TestCase):
a = np.tile(a, batch_dims + [1, 1])
a_np = np.tile(a_np, batch_dims + [1, 1])
b = np.tile(b, batch_dims + [1, 1])
with self.test_session():
# Test the batch version, which works for ndim >= 2
tf_ans = tf.batch_matrix_triangular_solve(
a, b, lower=lower, adjoint=adjoint)
out = tf_ans.eval()
np_ans = np.linalg.solve(a_np, b)
self.assertEqual(np_ans.shape, tf_ans.get_shape())
self.assertEqual(np_ans.shape, out.shape)
self.assertAllClose(np_ans, out)
if a.ndim == 2:
tf_ans = tf.matrix_triangular_solve(a,
b,
lower=lower,
adjoint=adjoint).eval()
else:
tf_ans = tf.batch_matrix_triangular_solve(a,
b,
lower=lower,
adjoint=adjoint).eval()
np_ans = np.linalg.solve(a_np, b)
self.assertEqual(np_ans.shape, tf_ans.shape)
self.assertAllClose(np_ans, tf_ans)
# Test the simple version
tf_ans = tf.matrix_triangular_solve(
a, b, lower=lower, adjoint=adjoint)
out = tf_ans.eval()
self.assertEqual(np_ans.shape, tf_ans.get_shape())
self.assertEqual(np_ans.shape, out.shape)
self.assertAllClose(np_ans, out)
def testSolve(self):
# 2x2 matrices, single right-hand side.

View File

@ -71,14 +71,28 @@ class SelfAdjointEigOpTest(tf.test.TestCase):
for i in xrange(dlist[0]):
self._testEigs(x[i], d, tf_out[i])
def _compareBatchSelfAdjointEigRank2(self, x, use_gpu=False):
with self.test_session() as sess:
tf_eig = tf.batch_self_adjoint_eig(tf.constant(x))
tf_out = sess.run([tf_eig])[0]
dlist = x.shape
d = dlist[-2]
self.assertEqual(len(tf_eig.get_shape()), 2)
self.assertEqual([d+1, d], tf_eig.get_shape().dims[-2:])
self._testEigs(x, d, tf_out)
def testBasic(self):
self._compareSelfAdjointEig(
np.array([[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]))
def testBatch(self):
simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2)
simple_array_2d = simple_array[0] # shape (2, 2)
self._compareBatchSelfAdjointEigRank3(simple_array)
self._compareBatchSelfAdjointEigRank3(np.vstack((simple_array, simple_array)))
self._compareBatchSelfAdjointEigRank3(
np.vstack((simple_array, simple_array)))
self._compareBatchSelfAdjointEigRank2(simple_array_2d)
odd_sized_array = np.array([[[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]])
self._compareBatchSelfAdjointEigRank3(
np.vstack((odd_sized_array, odd_sized_array)))

View File

@ -39,7 +39,7 @@ def _UnchangedSquare(op):
@ops.RegisterShape("BatchCholesky")
@ops.RegisterShape("BatchMatrixInverse")
def _BatchUnchangedSquare(op):
input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
# The matrices in the batch must be square.
input_shape[-1].assert_is_compatible_with(input_shape[-2])
return [input_shape]
@ -61,7 +61,7 @@ def _MatrixDeterminantShape(op):
@ops.RegisterShape("BatchMatrixDeterminant")
def _BatchMatrixDeterminantShape(op):
input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
# The matrices in the batch must be square.
input_shape[-1].assert_is_compatible_with(input_shape[-2])
if input_shape.ndims is not None:
@ -82,7 +82,7 @@ def _SelfAdjointEigShape(op):
@ops.RegisterShape("BatchSelfAdjointEig")
def _BatchSelfAdjointEigShape(op):
input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
# The matrices in the batch must be square.
input_shape[-1].assert_is_compatible_with(input_shape[-2])
dlist = input_shape.dims
@ -106,8 +106,8 @@ def _SquareMatrixSolveShape(op):
@ops.RegisterShape("BatchMatrixSolve")
@ops.RegisterShape("BatchMatrixTriangularSolve")
def _BatchSquareMatrixSolveShape(op):
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
# The matrices must be square.
lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2])
# The matrices and right-hand sides in the batch must have the same number of
@ -127,8 +127,8 @@ def _MatrixSolveLsShape(op):
@ops.RegisterShape("BatchMatrixSolveLs")
def _BatchMatrixSolveLsShape(op):
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
# The matrices and right-hand sides in the batch must have the same number of
# rows.
lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])