Add MultivariateNormal to tf.contrib.distributions.
Also fix overly stringent constraints on batchwise linalg ops & batch_matmul. Change: 120279428
This commit is contained in:
parent
3198d7ef30
commit
912ab39d93
@ -32,6 +32,19 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/gaussian_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//third_party/py/scipy",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "mvn_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/mvn_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//third_party/py/scipy",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -43,6 +56,7 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//third_party/py/scipy",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
|
@ -21,8 +21,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors
|
||||
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
|
||||
from tensorflow.contrib.distributions.python.ops.gaussian import *
|
||||
# from tensorflow.contrib.distributions.python.ops.dirichlet import * # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.distributions.python.ops.mvn import *
|
||||
|
252
tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py
Normal file
252
tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py
Normal file
@ -0,0 +1,252 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for MultivariateNormal."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class MultivariateNormalTest(tf.test.TestCase):
|
||||
|
||||
def testNonmatchingMuSigmaFails(self):
|
||||
with tf.Session():
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=[1.0, 2.0],
|
||||
sigma=[[[1.0, 0.0],
|
||||
[0.0, 1.0]],
|
||||
[[1.0, 0.0],
|
||||
[0.0, 1.0]]])
|
||||
with self.assertRaisesOpError(
|
||||
r"Rank of mu should be one less than rank of sigma"):
|
||||
mvn.mean.eval()
|
||||
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=[[1.0], [2.0]],
|
||||
sigma=[[[1.0, 0.0],
|
||||
[0.0, 1.0]],
|
||||
[[1.0, 0.0],
|
||||
[0.0, 1.0]]])
|
||||
with self.assertRaisesOpError(
|
||||
r"mu.shape and sigma.shape\[\:-1\] must match"):
|
||||
mvn.mean.eval()
|
||||
|
||||
def testNotPositiveDefiniteSigmaFails(self):
|
||||
with tf.Session():
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=[[1.0, 2.0], [1.0, 2.0]],
|
||||
sigma=[[[1.0, 0.0],
|
||||
[0.0, 1.0]],
|
||||
[[1.0, 1.0],
|
||||
[1.0, 1.0]]])
|
||||
with self.assertRaisesOpError(
|
||||
r"LLT decomposition was not successful."):
|
||||
mvn.mean.eval()
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=[[1.0, 2.0], [1.0, 2.0]],
|
||||
sigma=[[[1.0, 0.0],
|
||||
[0.0, 1.0]],
|
||||
[[-1.0, 0.0],
|
||||
[0.0, 1.0]]])
|
||||
with self.assertRaisesOpError(
|
||||
r"LLT decomposition was not successful."):
|
||||
mvn.mean.eval()
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=[[1.0, 2.0], [1.0, 2.0]],
|
||||
sigma_chol=[[[1.0, 0.0],
|
||||
[0.0, 1.0]],
|
||||
[[-1.0, 0.0],
|
||||
[0.0, 1.0]]])
|
||||
with self.assertRaisesOpError(
|
||||
r"sigma_chol is not positive definite."):
|
||||
mvn.mean.eval()
|
||||
|
||||
def testLogPDFScalar(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(mu_v)
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(sigma_v)
|
||||
x = np.array([-2.5, 2.5], dtype=np.float32)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
|
||||
log_pdf = mvn.log_pdf(x)
|
||||
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
|
||||
pdf = mvn.pdf(x)
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
|
||||
def testLogPDFScalarSigmaHalf(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0, 1.0], dtype=np.float32)
|
||||
mu = tf.constant(mu_v)
|
||||
sigma_v = np.array([[1.0, 0.1, 0.2],
|
||||
[0.1, 2.0, 0.05],
|
||||
[0.2, 0.05, 3.0]], dtype=np.float32)
|
||||
sigma_chol_v = np.linalg.cholesky(sigma_v)
|
||||
sigma_chol = tf.constant(sigma_chol_v)
|
||||
x = np.array([-2.5, 2.5, 1.0], dtype=np.float32)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(
|
||||
mu=mu, sigma_chol=sigma_chol)
|
||||
|
||||
log_pdf = mvn.log_pdf(x)
|
||||
sigma = mvn.sigma
|
||||
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
self.assertEqual(sigma.get_shape(), (3, 3))
|
||||
self.assertAllClose(sigma_v, sigma.eval())
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
|
||||
pdf = mvn.pdf(x)
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
|
||||
def testLogPDF(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(mu_v)
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(sigma_v)
|
||||
x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
|
||||
log_pdf = mvn.log_pdf(x)
|
||||
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
self.assertEqual(log_pdf.get_shape(), (3,))
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
|
||||
pdf = mvn.pdf(x)
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
|
||||
def testLogPDFMatchingDimension(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(np.vstack(3 * [mu_v]))
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(np.vstack(3 * [sigma_v[np.newaxis, :]]))
|
||||
x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
|
||||
log_pdf = mvn.log_pdf(x)
|
||||
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_log_pdf = scipy_mvn.logpdf(x)
|
||||
expected_pdf = scipy_mvn.pdf(x)
|
||||
self.assertEqual(log_pdf.get_shape(), (3,))
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
|
||||
pdf = mvn.pdf(x)
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
|
||||
def testLogPDFMultidimensional(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(
|
||||
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
|
||||
x = np.array([-2.5, 2.5], dtype=np.float32)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
|
||||
log_pdf = mvn.log_pdf(x)
|
||||
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_log_pdf = np.vstack(15 * [scipy_mvn.logpdf(x)]).reshape(3, 5)
|
||||
expected_pdf = np.vstack(15 * [scipy_mvn.pdf(x)]).reshape(3, 5)
|
||||
self.assertEqual(log_pdf.get_shape(), (3, 5))
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
|
||||
pdf = mvn.pdf(x)
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
|
||||
def testEntropy(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(mu_v)
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(sigma_v)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
entropy = mvn.entropy()
|
||||
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_entropy = scipy_mvn.entropy()
|
||||
|
||||
self.assertEqual(entropy.get_shape(), ())
|
||||
self.assertAllClose(expected_entropy, entropy.eval())
|
||||
|
||||
def testEntropyMultidimensional(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(
|
||||
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
entropy = mvn.entropy()
|
||||
|
||||
scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
|
||||
expected_entropy = np.vstack(15 * [scipy_mvn.entropy()]).reshape(3, 5)
|
||||
|
||||
self.assertEqual(entropy.get_shape(), (3, 5))
|
||||
self.assertAllClose(expected_entropy, entropy.eval())
|
||||
|
||||
def testSample(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(mu_v)
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(sigma_v)
|
||||
n = tf.constant(100000)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
samples = mvn.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
self.assertEqual(samples.get_shape(), (100000, 2))
|
||||
self.assertAllClose(sample_values.mean(axis=0), mu_v, atol=1e-2)
|
||||
self.assertAllClose(np.cov(sample_values, rowvar=0), sigma_v, atol=1e-1)
|
||||
|
||||
def testSampleMultiDimensional(self):
|
||||
with tf.Session():
|
||||
mu_v = np.array([-3.0, 3.0], dtype=np.float32)
|
||||
mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
|
||||
sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
|
||||
sigma = tf.constant(
|
||||
np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
|
||||
n = tf.constant(100000)
|
||||
mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
|
||||
samples = mvn.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
self.assertEqual(samples.get_shape(), (100000, 3, 5, 2))
|
||||
sample_values = sample_values.reshape(100000, 15, 2)
|
||||
for i in range(15):
|
||||
self.assertAllClose(
|
||||
sample_values[:, i, :].mean(axis=0), mu_v, atol=1e-2)
|
||||
self.assertAllClose(
|
||||
np.cov(sample_values[:, i, :], rowvar=0), sigma_v, atol=1e-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -88,7 +88,7 @@ class Gaussian(object):
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self._mu
|
||||
return self._mu * array_ops.ones_like(self._sigma)
|
||||
|
||||
def log_pdf(self, x, name=None):
|
||||
"""Log pdf of observations in `x` under these Gaussian distribution(s).
|
||||
@ -170,7 +170,7 @@ class Gaussian(object):
|
||||
return 0.5 * math_ops.log(two_pi_e1 * math_ops.square(sigma))
|
||||
|
||||
def sample(self, n, seed=None, name=None):
|
||||
"""Sample `n` observations the Gaussian Distributions.
|
||||
"""Sample `n` observations from the Gaussian Distributions.
|
||||
|
||||
Args:
|
||||
n: `Scalar`, type int32, the number of observations to sample.
|
||||
@ -185,7 +185,7 @@ class Gaussian(object):
|
||||
broadcast_shape = (self._mu + self._sigma).get_shape()
|
||||
n = ops.convert_to_tensor(n)
|
||||
shape = array_ops.concat(
|
||||
0, [array_ops.pack([n]), array_ops.shape(self._mu)])
|
||||
0, [array_ops.pack([n]), array_ops.shape(self.mean)])
|
||||
sampled = random_ops.random_normal(
|
||||
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
|
||||
|
||||
|
429
tensorflow/contrib/distributions/python/ops/mvn.py
Normal file
429
tensorflow/contrib/distributions/python/ops/mvn.py
Normal file
@ -0,0 +1,429 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Multivariate Normal distribution class.
|
||||
|
||||
@@MultivariateNormal
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
|
||||
|
||||
def _assert_compatible_shapes(mu, sigma):
|
||||
r_mu = array_ops.rank(mu)
|
||||
r_sigma = array_ops.rank(sigma)
|
||||
sigma_shape = array_ops.shape(sigma)
|
||||
sigma_rank = array_ops.rank(sigma)
|
||||
mu_shape = array_ops.shape(mu)
|
||||
return control_flow_ops.group(
|
||||
logging_ops.Assert(
|
||||
math_ops.equal(r_mu + 1, r_sigma),
|
||||
["Rank of mu should be one less than rank of sigma, but saw: ",
|
||||
r_mu, " vs. ", r_sigma]),
|
||||
logging_ops.Assert(
|
||||
math_ops.equal(
|
||||
array_ops.gather(sigma_shape, sigma_rank - 2),
|
||||
array_ops.gather(sigma_shape, sigma_rank - 1)),
|
||||
["Last two dimensions of sigma (%s) must be equal: " % sigma.name,
|
||||
sigma_shape]),
|
||||
logging_ops.Assert(
|
||||
math_ops.reduce_all(math_ops.equal(
|
||||
mu_shape,
|
||||
array_ops.slice(
|
||||
sigma_shape, [0], array_ops.pack([sigma_rank - 1])))),
|
||||
["mu.shape and sigma.shape[:-1] must match, but saw: ",
|
||||
mu_shape, " vs. ", sigma_shape]))
|
||||
|
||||
|
||||
def _assert_batch_positive_definite(sigma_chol):
|
||||
"""Add assertions checking that the sigmas are all Positive Definite.
|
||||
|
||||
Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that
|
||||
`all(diag(sigma_chol) > 0)`. This is because to check that a matrix is PD,
|
||||
it is sufficient that its cholesky factorization is PD, and to check that a
|
||||
triangular matrix is PD, it is sufficient to check that its diagonal
|
||||
entries are positive.
|
||||
|
||||
Args:
|
||||
sigma_chol: N-D. The lower triangular cholesky decomposition of `sigma`.
|
||||
|
||||
Returns:
|
||||
An assertion op to use with `control_dependencies`, verifying that
|
||||
`sigma_chol` is positive definite.
|
||||
"""
|
||||
sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
|
||||
return logging_ops.Assert(
|
||||
math_ops.reduce_all(sigma_batch_diag > 0),
|
||||
["sigma_chol is not positive definite. batched diagonals: ",
|
||||
sigma_batch_diag, " shaped: ", array_ops.shape(sigma_batch_diag)])
|
||||
|
||||
|
||||
def _determinant_from_sigma_chol(sigma_chol):
|
||||
det_last_dim = array_ops.rank(sigma_chol) - 2
|
||||
sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
|
||||
det = math_ops.square(math_ops.reduce_prod(
|
||||
sigma_batch_diag, reduction_indices=det_last_dim))
|
||||
det.set_shape(sigma_chol.get_shape()[:-2])
|
||||
return det
|
||||
|
||||
|
||||
class MultivariateNormal(object):
|
||||
"""The Multivariate Normal distribution on `R^k`.
|
||||
|
||||
The distribution has mean and covariance parameters mu (1-D), sigma (2-D),
|
||||
or alternatively mean `mu` and factored covariance (cholesky decomposed
|
||||
`sigma`) called `sigma_chol`.
|
||||
|
||||
The PDF of this distribution is:
|
||||
|
||||
```
|
||||
f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
|
||||
```
|
||||
|
||||
where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
|
||||
|
||||
Alternatively, if `sigma` is positive definite, it can be represented in terms
|
||||
of its lower triangular cholesky factorization
|
||||
|
||||
```sigma = sigma_chol . sigma_chol^*```
|
||||
|
||||
and the pdf above allows simpler computation:
|
||||
|
||||
```
|
||||
|det(sigma)| = reduce_prod(diag(sigma_chol))^2
|
||||
x_whitened = sigma^{-1/2} . (x - mu) = tri_solve(sigma_chol, x - mu)
|
||||
(x-mu)^* .sigma^{-1} . (x-mu) = x_whitened^* . x_whitened
|
||||
```
|
||||
|
||||
where `tri_solve()` solves a triangular system of equations.
|
||||
"""
|
||||
|
||||
def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
|
||||
"""Multivariate Normal distributions on `R^k`.
|
||||
|
||||
User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`)
|
||||
with the last dimension having length `k`.
|
||||
|
||||
User must provide exactly one of `sigma` (the covariance matrices) or
|
||||
`sigma_chol` (the cholesky decompositions of the covariance matrices).
|
||||
`sigma` or `sigma_chol` must be of rank `N+2`. The last two dimensions
|
||||
must both have length `k`. The first `N` dimensions correspond to batch
|
||||
indices.
|
||||
|
||||
If `sigma_chol` is not provided, the batch cholesky factorization of `sigma`
|
||||
is calculated for you.
|
||||
|
||||
The shapes of `mu` and `sigma` must match for the first `N` dimensions.
|
||||
|
||||
Regardless of which parameter is provided, the covariance matrices must all
|
||||
be **positive definite** (an error is raised if one of them is not).
|
||||
|
||||
Args:
|
||||
mu: (N+1)-D. `float` or `double` tensor, the means of the distributions.
|
||||
sigma: (N+2)-D. (optional) `float` or `double` tensor, the covariances
|
||||
of the distribution(s). The first `N+1` dimensions must match
|
||||
those of `mu`. Must be batch-positive-definite.
|
||||
sigma_chol: (N+2)-D. (optional) `float` or `double` tensor, a
|
||||
lower-triangular factorization of `sigma`
|
||||
(`sigma = sigma_chol . sigma_chol^*`). The first `N+1` dimensions
|
||||
must match those of `mu`. The tensor itself need not be batch
|
||||
lower triangular: we ignore the upper triangular part. However,
|
||||
the batch diagonals must be positive (i.e., sigma_chol must be
|
||||
batch-positive-definite).
|
||||
name: The name to give Ops created by the initializer.
|
||||
|
||||
Raises:
|
||||
ValueError: if neither sigma nor sigma_chol is provided.
|
||||
TypeError: if mu and sigma (resp. sigma_chol) are different dtypes.
|
||||
"""
|
||||
if (sigma is None) == (sigma_chol is None):
|
||||
raise ValueError("Exactly one of sigma and sigma_chol must be provided")
|
||||
|
||||
with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"):
|
||||
sigma_or_half = sigma_chol if sigma is None else sigma
|
||||
|
||||
mu = ops.convert_to_tensor(mu)
|
||||
sigma_or_half = ops.convert_to_tensor(sigma_or_half)
|
||||
|
||||
contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half))
|
||||
|
||||
with ops.control_dependencies([
|
||||
_assert_compatible_shapes(mu, sigma_or_half)]):
|
||||
mu = array_ops.identity(mu, name="mu")
|
||||
|
||||
# Store the dimensionality of the MVNs
|
||||
self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1)
|
||||
|
||||
if sigma_chol is not None:
|
||||
# Ensure we only keep the lower triangular part.
|
||||
sigma_chol = array_ops.batch_matrix_band_part(
|
||||
sigma_chol, num_lower=-1, num_upper=0)
|
||||
sigma_det = _determinant_from_sigma_chol(sigma_chol)
|
||||
with ops.control_dependencies([
|
||||
_assert_batch_positive_definite(sigma_chol)]):
|
||||
self._sigma = math_ops.batch_matmul(
|
||||
sigma_chol, sigma_chol, adj_y=True, name="sigma")
|
||||
self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
|
||||
self._sigma_det = array_ops.identity(sigma_det, "sigma_det")
|
||||
self._mu = array_ops.identity(mu, "mu")
|
||||
else: # sigma is not None
|
||||
sigma_chol = linalg_ops.batch_cholesky(sigma)
|
||||
sigma_det = _determinant_from_sigma_chol(sigma_chol)
|
||||
# batch_cholesky checks for PSD; so we can just use it here.
|
||||
with ops.control_dependencies([sigma_chol]):
|
||||
self._sigma = array_ops.identity(sigma, "sigma")
|
||||
self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
|
||||
self._sigma_det = array_ops.identity(sigma_det, "sigma_det")
|
||||
self._mu = array_ops.identity(mu, "mu")
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._mu.dtype
|
||||
|
||||
@property
|
||||
def mu(self):
|
||||
return self._mu
|
||||
|
||||
@property
|
||||
def sigma(self):
|
||||
return self._sigma
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self._mu
|
||||
|
||||
@property
|
||||
def sigma_det(self):
|
||||
return self._sigma_det
|
||||
|
||||
def log_pdf(self, x, name=None):
|
||||
"""Log pdf of observations `x` given these Multivariate Normals.
|
||||
|
||||
Args:
|
||||
x: tensor of dtype `dtype`, must be broadcastable with `mu`.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
|
||||
"""
|
||||
with ops.op_scope(
|
||||
[self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"):
|
||||
x = ops.convert_to_tensor(x)
|
||||
contrib_tensor_util.assert_same_float_dtype((self._mu, x))
|
||||
|
||||
x_centered = x - self.mu
|
||||
|
||||
x_rank = array_ops.rank(x_centered)
|
||||
sigma_rank = array_ops.rank(self._sigma_chol)
|
||||
|
||||
x_rank_vec = array_ops.pack([x_rank])
|
||||
sigma_rank_vec = array_ops.pack([sigma_rank])
|
||||
x_shape = array_ops.shape(x_centered)
|
||||
|
||||
# sigma_chol is shaped [D, E, F, ..., k, k]
|
||||
# x_centered shape is one of:
|
||||
# [D, E, F, ..., k], or [F, ..., k], or
|
||||
# [A, B, C, D, E, F, ..., k]
|
||||
# and we need to convert x_centered to shape:
|
||||
# [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist)
|
||||
# then transpose and reshape x_whitened back to one of the shapes:
|
||||
# [D, E, F, ..., k], or [1, 1, F, ..., k], or
|
||||
# [A, B, C, D, E, F, ..., k]
|
||||
|
||||
# This helper handles the case where rank(x_centered) < rank(sigma)
|
||||
def _broadcast_x_not_higher_rank_than_sigma():
|
||||
return array_ops.reshape(
|
||||
x_centered,
|
||||
array_ops.concat(
|
||||
# Reshape to ones(deficient x rank) + x_shape + [1]
|
||||
0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]),
|
||||
dtype=x_rank.dtype),
|
||||
x_shape,
|
||||
[1])))
|
||||
|
||||
# These helpers handle the case where rank(x_centered) >= rank(sigma)
|
||||
def _broadcast_x_higher_rank_than_sigma():
|
||||
x_shape_left = array_ops.slice(
|
||||
x_shape, [0], sigma_rank_vec - 1)
|
||||
x_shape_right = array_ops.slice(
|
||||
x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
|
||||
x_shape_perm = array_ops.concat(
|
||||
0, (math_ops.range(sigma_rank - 1, x_rank),
|
||||
math_ops.range(0, sigma_rank - 1)))
|
||||
return array_ops.reshape(
|
||||
# Convert to [D, E, F, ..., k, B, C]
|
||||
array_ops.transpose(
|
||||
x_centered, perm=x_shape_perm),
|
||||
# Reshape to [D, E, F, ..., k, B*C]
|
||||
array_ops.concat(
|
||||
0, (x_shape_right,
|
||||
array_ops.pack([
|
||||
math_ops.reduce_prod(x_shape_left, 0)]))))
|
||||
|
||||
def _unbroadcast_x_higher_rank_than_sigma():
|
||||
x_shape_left = array_ops.slice(
|
||||
x_shape, [0], sigma_rank_vec - 1)
|
||||
x_shape_right = array_ops.slice(
|
||||
x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
|
||||
x_shape_perm = array_ops.concat(
|
||||
0, (math_ops.range(sigma_rank - 1, x_rank),
|
||||
math_ops.range(0, sigma_rank - 1)))
|
||||
return array_ops.transpose(
|
||||
# [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k]
|
||||
array_ops.reshape(
|
||||
# convert to [D, E, F, ..., k, B, C]
|
||||
x_whitened_broadcast,
|
||||
array_ops.concat(0, (x_shape_right, x_shape_left))),
|
||||
perm=x_shape_perm)
|
||||
|
||||
# Step 1: reshape x_centered
|
||||
x_centered_broadcast = control_flow_ops.cond(
|
||||
# x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1]
|
||||
# or == [F, ..., k] => [1, 1, F, ..., k, 1]
|
||||
x_rank <= sigma_rank - 1,
|
||||
_broadcast_x_not_higher_rank_than_sigma,
|
||||
# x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C]
|
||||
_broadcast_x_higher_rank_than_sigma)
|
||||
|
||||
x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve(
|
||||
self._sigma_chol, x_centered_broadcast)
|
||||
|
||||
# Reshape x_whitened_broadcast back to x_whitened
|
||||
x_whitened = control_flow_ops.cond(
|
||||
x_rank <= sigma_rank - 1,
|
||||
lambda: array_ops.reshape(x_whitened_broadcast, x_shape),
|
||||
_unbroadcast_x_higher_rank_than_sigma)
|
||||
|
||||
x_whitened = array_ops.expand_dims(x_whitened, -1)
|
||||
# Reshape x_whitened to contain row vectors
|
||||
# Returns a batchwise scalar
|
||||
x_whitened_norm = math_ops.batch_matmul(
|
||||
x_whitened, x_whitened, adj_x=True)
|
||||
x_whitened_norm = control_flow_ops.cond(
|
||||
x_rank <= sigma_rank - 1,
|
||||
lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]),
|
||||
lambda: array_ops.squeeze(x_whitened_norm, [-1]))
|
||||
|
||||
log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype)
|
||||
k = math_ops.cast(self._k, self.dtype)
|
||||
log_pdf_value = (
|
||||
-math_ops.log(self._sigma_det) -k * log_two_pi - x_whitened_norm) / 2
|
||||
final_shaped_value = control_flow_ops.cond(
|
||||
x_rank <= sigma_rank - 1,
|
||||
lambda: log_pdf_value,
|
||||
lambda: array_ops.squeeze(log_pdf_value, [-1]))
|
||||
|
||||
output_static_shape = x_centered.get_shape()[:-1]
|
||||
final_shaped_value.set_shape(output_static_shape)
|
||||
return final_shaped_value
|
||||
|
||||
def pdf(self, x, name=None):
|
||||
"""The PDF of observations `x` under these Multivariate Normals.
|
||||
|
||||
Args:
|
||||
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
pdf: tensor of dtype `dtype`, the pdf values of `x`.
|
||||
"""
|
||||
with ops.op_scope(
|
||||
[self._mu, self._sigma_chol, x], name, "MultivariateNormalPdf"):
|
||||
return math_ops.exp(self.log_pdf(x))
|
||||
|
||||
def entropy(self, name=None):
|
||||
"""The entropies of these Multivariate Normals.
|
||||
|
||||
Args:
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
entropy: tensor of dtype `dtype`, the entropies.
|
||||
"""
|
||||
with ops.op_scope(
|
||||
[self._mu, self._sigma_chol], name, "MultivariateNormalEntropy"):
|
||||
one_plus_log_two_pi = constant_op.constant(
|
||||
1 + math.log(2 * math.pi), dtype=self.dtype)
|
||||
|
||||
# Use broadcasting rules to calculate the full broadcast sigma.
|
||||
k = math_ops.cast(self._k, dtype=self.dtype)
|
||||
entropy_value = (
|
||||
k * one_plus_log_two_pi + math_ops.log(self._sigma_det)) / 2
|
||||
entropy_value.set_shape(self._sigma_det.get_shape())
|
||||
return entropy_value
|
||||
|
||||
def sample(self, n, seed=None, name=None):
|
||||
"""Sample `n` observations from the Multivariate Normal Distributions.
|
||||
|
||||
Args:
|
||||
n: `Scalar`, type int32, the number of observations to sample.
|
||||
seed: Python integer, the random seed.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
samples: `[n, ...]`, a `Tensor` of `n` samples for each
|
||||
of the distributions determined by broadcasting the hyperparameters.
|
||||
"""
|
||||
with ops.op_scope(
|
||||
[self._mu, self._sigma_chol, n], name, "MultivariateNormalSample"):
|
||||
# TODO(ebrevdo): Is there a better way to get broadcast_shape?
|
||||
broadcast_shape = self.mu.get_shape()
|
||||
n = ops.convert_to_tensor(n)
|
||||
sigma_shape_left = array_ops.slice(
|
||||
array_ops.shape(self._sigma_chol),
|
||||
[0], array_ops.pack([array_ops.rank(self._sigma_chol) - 2]))
|
||||
|
||||
k_n = array_ops.pack([self._k, n])
|
||||
shape = array_ops.concat(0, [sigma_shape_left, k_n])
|
||||
white_samples = random_ops.random_normal(
|
||||
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
|
||||
|
||||
correlated_samples = math_ops.batch_matmul(
|
||||
self._sigma_chol, white_samples)
|
||||
|
||||
# Move the last dimension to the front
|
||||
perm = array_ops.concat(
|
||||
0,
|
||||
(array_ops.pack([array_ops.rank(correlated_samples) - 1]),
|
||||
math_ops.range(0, array_ops.rank(correlated_samples) - 1)))
|
||||
|
||||
# TODO(ebrevdo): Once we get a proper tensor contraction op,
|
||||
# perform the inner product using that instead of batch_matmul
|
||||
# and this slow transpose can go away!
|
||||
correlated_samples = array_ops.transpose(correlated_samples, perm)
|
||||
|
||||
samples = correlated_samples + self.mu
|
||||
|
||||
# Provide some hints to shape inference
|
||||
n_val = tensor_util.constant_value(n)
|
||||
final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
|
||||
samples.set_shape(final_shape)
|
||||
|
||||
return samples
|
@ -234,8 +234,8 @@ class BatchMatMul : public OpKernel {
|
||||
in1.shape().DebugString()));
|
||||
const int ndims = in0.dims();
|
||||
OP_REQUIRES(
|
||||
ctx, ndims >= 3,
|
||||
errors::InvalidArgument("In[0] and In[1] ndims must be >= 3: ", ndims));
|
||||
ctx, ndims >= 2,
|
||||
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
|
||||
TensorShape out_shape;
|
||||
for (int i = 0; i < ndims - 2; ++i) {
|
||||
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
|
||||
@ -245,7 +245,7 @@ class BatchMatMul : public OpKernel {
|
||||
in1.shape().DebugString()));
|
||||
out_shape.AddDim(in0.dim_size(i));
|
||||
}
|
||||
auto n = out_shape.num_elements();
|
||||
auto n = (ndims == 2) ? 1 : out_shape.num_elements();
|
||||
auto d0 = in0.dim_size(ndims - 2);
|
||||
auto d1 = in0.dim_size(ndims - 1);
|
||||
Tensor in0_reshaped;
|
||||
|
@ -25,19 +25,8 @@ import tensorflow as tf
|
||||
|
||||
class CholeskyOpTest(tf.test.TestCase):
|
||||
|
||||
def _verifyCholesky(self, x):
|
||||
with self.test_session() as sess:
|
||||
# Verify that LL^T == x.
|
||||
if x.ndim == 2:
|
||||
chol = tf.cholesky(x)
|
||||
verification = tf.matmul(chol,
|
||||
chol,
|
||||
transpose_a=False,
|
||||
transpose_b=True)
|
||||
else:
|
||||
chol = tf.batch_cholesky(x)
|
||||
verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
|
||||
chol_np, verification_np = sess.run([chol, verification])
|
||||
def _verifyCholeskyBase(self, sess, x, chol, verification):
|
||||
chol_np, verification_np = sess.run([chol, verification])
|
||||
self.assertAllClose(x, verification_np)
|
||||
self.assertShapeEqual(x, chol)
|
||||
# Check that the cholesky is lower triangular, and has positive diagonal
|
||||
@ -49,6 +38,20 @@ class CholeskyOpTest(tf.test.TestCase):
|
||||
self.assertAllClose(chol_matrix, np.tril(chol_matrix))
|
||||
self.assertTrue((np.diag(chol_matrix) > 0.0).all())
|
||||
|
||||
def _verifyCholesky(self, x):
|
||||
# Verify that LL^T == x.
|
||||
with self.test_session() as sess:
|
||||
# Check the batch version, which works for ndim >= 2.
|
||||
chol = tf.batch_cholesky(x)
|
||||
verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
|
||||
self._verifyCholeskyBase(sess, x, chol, verification)
|
||||
|
||||
if x.ndim == 2: # Check the simple form of cholesky
|
||||
chol = tf.cholesky(x)
|
||||
verification = tf.matmul(
|
||||
chol, chol, transpose_a=False, transpose_b=True)
|
||||
self._verifyCholeskyBase(sess, x, chol, verification)
|
||||
|
||||
def testBasic(self):
|
||||
self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]))
|
||||
|
||||
|
@ -24,13 +24,8 @@ import tensorflow as tf
|
||||
|
||||
class DeterminantOpTest(tf.test.TestCase):
|
||||
|
||||
def _compareDeterminant(self, matrix_x):
|
||||
with self.test_session():
|
||||
if matrix_x.ndim == 2:
|
||||
tf_ans = tf.matrix_determinant(matrix_x)
|
||||
else:
|
||||
tf_ans = tf.batch_matrix_determinant(matrix_x)
|
||||
out = tf_ans.eval()
|
||||
def _compareDeterminantBase(self, matrix_x, tf_ans):
|
||||
out = tf_ans.eval()
|
||||
shape = matrix_x.shape
|
||||
if shape[-1] == 0 and shape[-2] == 0:
|
||||
np_ans = np.ones(shape[:-2]).astype(matrix_x.dtype)
|
||||
@ -39,6 +34,15 @@ class DeterminantOpTest(tf.test.TestCase):
|
||||
self.assertAllClose(np_ans, out)
|
||||
self.assertShapeEqual(np_ans, tf_ans)
|
||||
|
||||
def _compareDeterminant(self, matrix_x):
|
||||
with self.test_session():
|
||||
# Check the batch version, which should work for ndim >= 2
|
||||
self._compareDeterminantBase(
|
||||
matrix_x, tf.batch_matrix_determinant(matrix_x))
|
||||
if matrix_x.ndim == 2:
|
||||
# Check the simple version
|
||||
self._compareDeterminantBase(matrix_x, tf.matrix_determinant(matrix_x))
|
||||
|
||||
def testBasic(self):
|
||||
# 2x2 matrices
|
||||
self._compareDeterminant(np.array([[2., 3.], [3., 4.]]).astype(np.float32))
|
||||
|
@ -67,11 +67,13 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
|
||||
np_ans, _, _, _ = np.linalg.lstsq(a, b)
|
||||
for fast in [True, False]:
|
||||
with self.test_session():
|
||||
tf_ans = tf.matrix_solve_ls(a, b, fast=fast).eval()
|
||||
self.assertEqual(np_ans.shape, tf_ans.shape)
|
||||
tf_ans = tf.matrix_solve_ls(a, b, fast=fast)
|
||||
ans = tf_ans.eval()
|
||||
self.assertEqual(np_ans.shape, tf_ans.get_shape())
|
||||
self.assertEqual(np_ans.shape, ans.shape)
|
||||
|
||||
# Check residual norm.
|
||||
tf_r = b - BatchMatMul(a, tf_ans)
|
||||
tf_r = b - BatchMatMul(a, ans)
|
||||
tf_r_norm = np.sum(tf_r * tf_r)
|
||||
np_r = b - BatchMatMul(a, np_ans)
|
||||
np_r_norm = np.sum(np_r * np_r)
|
||||
@ -83,7 +85,7 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
|
||||
# slow path, because Eigen does not return a minimum norm solution.
|
||||
# TODO(rmlarsen): Enable this check for all paths if/when we fix
|
||||
# Eigen's solver.
|
||||
self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
|
||||
self.assertAllClose(np_ans, ans, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def _verifySolveBatch(self, x, y):
|
||||
# Since numpy.linalg.lsqr does not support batch solves, as opposed
|
||||
@ -122,20 +124,23 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
|
||||
b = y.astype(np_type)
|
||||
np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
|
||||
with self.test_session():
|
||||
tf_ans = tf.matrix_solve_ls(a,
|
||||
b,
|
||||
l2_regularizer=l2_regularizer,
|
||||
fast=True).eval()
|
||||
self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
|
||||
# Test with the batch version of matrix_solve_ls on regular matrices
|
||||
tf_ans = tf.batch_matrix_solve_ls(
|
||||
a, b, l2_regularizer=l2_regularizer, fast=True).eval()
|
||||
self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# Test with the simple matrix_solve_ls on regular matrices
|
||||
tf_ans = tf.matrix_solve_ls(
|
||||
a, b, l2_regularizer=l2_regularizer, fast=True).eval()
|
||||
self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# Test with a 2x3 batch of matrices.
|
||||
a = np.tile(x.astype(np_type), [2, 3, 1, 1])
|
||||
b = np.tile(y.astype(np_type), [2, 3, 1, 1])
|
||||
np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
|
||||
with self.test_session():
|
||||
tf_ans = tf.batch_matrix_solve_ls(a,
|
||||
b,
|
||||
l2_regularizer=l2_regularizer,
|
||||
fast=True).eval()
|
||||
tf_ans = tf.batch_matrix_solve_ls(
|
||||
a, b, l2_regularizer=l2_regularizer, fast=True).eval()
|
||||
self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def testSquare(self):
|
||||
|
@ -37,15 +37,23 @@ class MatrixSolveOpTest(tf.test.TestCase):
|
||||
a = np.tile(a, batch_dims + [1, 1])
|
||||
a_np = np.tile(a_np, batch_dims + [1, 1])
|
||||
b = np.tile(b, batch_dims + [1, 1])
|
||||
with self.test_session():
|
||||
if a.ndim == 2:
|
||||
tf_ans = tf.matrix_solve(a, b, adjoint=adjoint)
|
||||
else:
|
||||
tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint)
|
||||
out = tf_ans.eval()
|
||||
|
||||
np_ans = np.linalg.solve(a_np, b)
|
||||
self.assertEqual(np_ans.shape, out.shape)
|
||||
self.assertAllClose(np_ans, out)
|
||||
with self.test_session():
|
||||
# Test the batch version, which works for ndim >= 2
|
||||
tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint)
|
||||
out = tf_ans.eval()
|
||||
self.assertEqual(tf_ans.get_shape(), out.shape)
|
||||
self.assertEqual(np_ans.shape, out.shape)
|
||||
self.assertAllClose(np_ans, out)
|
||||
|
||||
if a.ndim == 2:
|
||||
# Test the simple version
|
||||
tf_ans = tf.matrix_solve(a, b, adjoint=adjoint)
|
||||
out = tf_ans.eval()
|
||||
self.assertEqual(out.shape, tf_ans.get_shape())
|
||||
self.assertEqual(np_ans.shape, out.shape)
|
||||
self.assertAllClose(np_ans, out)
|
||||
|
||||
def testSolve(self):
|
||||
# 2x2 matrices, 2x1 right-hand side.
|
||||
|
@ -51,20 +51,27 @@ class MatrixTriangularSolveOpTest(tf.test.TestCase):
|
||||
a = np.tile(a, batch_dims + [1, 1])
|
||||
a_np = np.tile(a_np, batch_dims + [1, 1])
|
||||
b = np.tile(b, batch_dims + [1, 1])
|
||||
|
||||
with self.test_session():
|
||||
# Test the batch version, which works for ndim >= 2
|
||||
tf_ans = tf.batch_matrix_triangular_solve(
|
||||
a, b, lower=lower, adjoint=adjoint)
|
||||
out = tf_ans.eval()
|
||||
|
||||
np_ans = np.linalg.solve(a_np, b)
|
||||
|
||||
self.assertEqual(np_ans.shape, tf_ans.get_shape())
|
||||
self.assertEqual(np_ans.shape, out.shape)
|
||||
self.assertAllClose(np_ans, out)
|
||||
|
||||
if a.ndim == 2:
|
||||
tf_ans = tf.matrix_triangular_solve(a,
|
||||
b,
|
||||
lower=lower,
|
||||
adjoint=adjoint).eval()
|
||||
else:
|
||||
tf_ans = tf.batch_matrix_triangular_solve(a,
|
||||
b,
|
||||
lower=lower,
|
||||
adjoint=adjoint).eval()
|
||||
np_ans = np.linalg.solve(a_np, b)
|
||||
self.assertEqual(np_ans.shape, tf_ans.shape)
|
||||
self.assertAllClose(np_ans, tf_ans)
|
||||
# Test the simple version
|
||||
tf_ans = tf.matrix_triangular_solve(
|
||||
a, b, lower=lower, adjoint=adjoint)
|
||||
out = tf_ans.eval()
|
||||
self.assertEqual(np_ans.shape, tf_ans.get_shape())
|
||||
self.assertEqual(np_ans.shape, out.shape)
|
||||
self.assertAllClose(np_ans, out)
|
||||
|
||||
def testSolve(self):
|
||||
# 2x2 matrices, single right-hand side.
|
||||
|
@ -71,14 +71,28 @@ class SelfAdjointEigOpTest(tf.test.TestCase):
|
||||
for i in xrange(dlist[0]):
|
||||
self._testEigs(x[i], d, tf_out[i])
|
||||
|
||||
def _compareBatchSelfAdjointEigRank2(self, x, use_gpu=False):
|
||||
with self.test_session() as sess:
|
||||
tf_eig = tf.batch_self_adjoint_eig(tf.constant(x))
|
||||
tf_out = sess.run([tf_eig])[0]
|
||||
dlist = x.shape
|
||||
d = dlist[-2]
|
||||
|
||||
self.assertEqual(len(tf_eig.get_shape()), 2)
|
||||
self.assertEqual([d+1, d], tf_eig.get_shape().dims[-2:])
|
||||
self._testEigs(x, d, tf_out)
|
||||
|
||||
def testBasic(self):
|
||||
self._compareSelfAdjointEig(
|
||||
np.array([[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]))
|
||||
|
||||
def testBatch(self):
|
||||
simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2)
|
||||
simple_array_2d = simple_array[0] # shape (2, 2)
|
||||
self._compareBatchSelfAdjointEigRank3(simple_array)
|
||||
self._compareBatchSelfAdjointEigRank3(np.vstack((simple_array, simple_array)))
|
||||
self._compareBatchSelfAdjointEigRank3(
|
||||
np.vstack((simple_array, simple_array)))
|
||||
self._compareBatchSelfAdjointEigRank2(simple_array_2d)
|
||||
odd_sized_array = np.array([[[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]])
|
||||
self._compareBatchSelfAdjointEigRank3(
|
||||
np.vstack((odd_sized_array, odd_sized_array)))
|
||||
|
@ -39,7 +39,7 @@ def _UnchangedSquare(op):
|
||||
@ops.RegisterShape("BatchCholesky")
|
||||
@ops.RegisterShape("BatchMatrixInverse")
|
||||
def _BatchUnchangedSquare(op):
|
||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
|
||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||
# The matrices in the batch must be square.
|
||||
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
||||
return [input_shape]
|
||||
@ -61,7 +61,7 @@ def _MatrixDeterminantShape(op):
|
||||
|
||||
@ops.RegisterShape("BatchMatrixDeterminant")
|
||||
def _BatchMatrixDeterminantShape(op):
|
||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
|
||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||
# The matrices in the batch must be square.
|
||||
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
||||
if input_shape.ndims is not None:
|
||||
@ -82,7 +82,7 @@ def _SelfAdjointEigShape(op):
|
||||
|
||||
@ops.RegisterShape("BatchSelfAdjointEig")
|
||||
def _BatchSelfAdjointEigShape(op):
|
||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
|
||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||
# The matrices in the batch must be square.
|
||||
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
||||
dlist = input_shape.dims
|
||||
@ -106,8 +106,8 @@ def _SquareMatrixSolveShape(op):
|
||||
@ops.RegisterShape("BatchMatrixSolve")
|
||||
@ops.RegisterShape("BatchMatrixTriangularSolve")
|
||||
def _BatchSquareMatrixSolveShape(op):
|
||||
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
|
||||
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
|
||||
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
|
||||
# The matrices must be square.
|
||||
lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2])
|
||||
# The matrices and right-hand sides in the batch must have the same number of
|
||||
@ -127,8 +127,8 @@ def _MatrixSolveLsShape(op):
|
||||
|
||||
@ops.RegisterShape("BatchMatrixSolveLs")
|
||||
def _BatchMatrixSolveLsShape(op):
|
||||
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
|
||||
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
|
||||
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
|
||||
# The matrices and right-hand sides in the batch must have the same number of
|
||||
# rows.
|
||||
lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])
|
||||
|
Loading…
Reference in New Issue
Block a user