diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 5feac79ecb0..bfa31dbe1cd 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -32,6 +32,19 @@ cuda_py_tests(
     srcs = ["python/kernel_tests/gaussian_test.py"],
     additional_deps = [
         ":distributions_py",
+        "//third_party/py/scipy",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+cuda_py_tests(
+    name = "mvn_test",
+    size = "small",
+    srcs = ["python/kernel_tests/mvn_test.py"],
+    additional_deps = [
+        ":distributions_py",
+        "//third_party/py/scipy",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:platform_test",
     ],
@@ -43,6 +56,7 @@ cuda_py_tests(
     srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
     additional_deps = [
         ":distributions_py",
+        "//third_party/py/scipy",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:platform_test",
     ],
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 2f9b8fcafb1..54607a7379e 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -21,8 +21,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-# pylint: disable=unused-import,wildcard-import, line-too-long
+# pylint: disable=unused-import,wildcard-import,line-too-long
 from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors
 from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
 from tensorflow.contrib.distributions.python.ops.gaussian import *
-# from tensorflow.contrib.distributions.python.ops.dirichlet import *  # pylint: disable=line-too-long
+from tensorflow.contrib.distributions.python.ops.mvn import *
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py
new file mode 100644
index 00000000000..8b249c22362
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_test.py
@@ -0,0 +1,252 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MultivariateNormal."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+import tensorflow as tf
+
+
+class MultivariateNormalTest(tf.test.TestCase):
+
+  def testNonmatchingMuSigmaFails(self):
+    with tf.Session():
+      mvn = tf.contrib.distributions.MultivariateNormal(
+          mu=[1.0, 2.0],
+          sigma=[[[1.0, 0.0],
+                  [0.0, 1.0]],
+                 [[1.0, 0.0],
+                  [0.0, 1.0]]])
+      with self.assertRaisesOpError(
+          r"Rank of mu should be one less than rank of sigma"):
+        mvn.mean.eval()
+
+      mvn = tf.contrib.distributions.MultivariateNormal(
+          mu=[[1.0], [2.0]],
+          sigma=[[[1.0, 0.0],
+                  [0.0, 1.0]],
+                 [[1.0, 0.0],
+                  [0.0, 1.0]]])
+      with self.assertRaisesOpError(
+          r"mu.shape and sigma.shape\[\:-1\] must match"):
+        mvn.mean.eval()
+
+  def testNotPositiveDefiniteSigmaFails(self):
+    with tf.Session():
+      mvn = tf.contrib.distributions.MultivariateNormal(
+          mu=[[1.0, 2.0], [1.0, 2.0]],
+          sigma=[[[1.0, 0.0],
+                  [0.0, 1.0]],
+                 [[1.0, 1.0],
+                  [1.0, 1.0]]])
+      with self.assertRaisesOpError(
+          r"LLT decomposition was not successful."):
+        mvn.mean.eval()
+      mvn = tf.contrib.distributions.MultivariateNormal(
+          mu=[[1.0, 2.0], [1.0, 2.0]],
+          sigma=[[[1.0, 0.0],
+                  [0.0, 1.0]],
+                 [[-1.0, 0.0],
+                  [0.0, 1.0]]])
+      with self.assertRaisesOpError(
+          r"LLT decomposition was not successful."):
+        mvn.mean.eval()
+      mvn = tf.contrib.distributions.MultivariateNormal(
+          mu=[[1.0, 2.0], [1.0, 2.0]],
+          sigma_chol=[[[1.0, 0.0],
+                       [0.0, 1.0]],
+                      [[-1.0, 0.0],
+                       [0.0, 1.0]]])
+      with self.assertRaisesOpError(
+          r"sigma_chol is not positive definite."):
+        mvn.mean.eval()
+
+  def testLogPDFScalar(self):
+    with tf.Session():
+      mu_v = np.array([-3.0, 3.0], dtype=np.float32)
+      mu = tf.constant(mu_v)
+      sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
+      sigma = tf.constant(sigma_v)
+      x = np.array([-2.5, 2.5], dtype=np.float32)
+      mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
+
+      log_pdf = mvn.log_pdf(x)
+
+      scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
+      expected_log_pdf = scipy_mvn.logpdf(x)
+      expected_pdf = scipy_mvn.pdf(x)
+      self.assertAllClose(expected_log_pdf, log_pdf.eval())
+
+      pdf = mvn.pdf(x)
+      self.assertAllClose(expected_pdf, pdf.eval())
+
+  def testLogPDFScalarSigmaHalf(self):
+    with tf.Session():
+      mu_v = np.array([-3.0, 3.0, 1.0], dtype=np.float32)
+      mu = tf.constant(mu_v)
+      sigma_v = np.array([[1.0, 0.1, 0.2],
+                          [0.1, 2.0, 0.05],
+                          [0.2, 0.05, 3.0]], dtype=np.float32)
+      sigma_chol_v = np.linalg.cholesky(sigma_v)
+      sigma_chol = tf.constant(sigma_chol_v)
+      x = np.array([-2.5, 2.5, 1.0], dtype=np.float32)
+      mvn = tf.contrib.distributions.MultivariateNormal(
+          mu=mu, sigma_chol=sigma_chol)
+
+      log_pdf = mvn.log_pdf(x)
+      sigma = mvn.sigma
+
+      scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
+      expected_log_pdf = scipy_mvn.logpdf(x)
+      expected_pdf = scipy_mvn.pdf(x)
+      self.assertEqual(sigma.get_shape(), (3, 3))
+      self.assertAllClose(sigma_v, sigma.eval())
+      self.assertAllClose(expected_log_pdf, log_pdf.eval())
+
+      pdf = mvn.pdf(x)
+      self.assertAllClose(expected_pdf, pdf.eval())
+
+  def testLogPDF(self):
+    with tf.Session():
+      mu_v = np.array([-3.0, 3.0], dtype=np.float32)
+      mu = tf.constant(mu_v)
+      sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
+      sigma = tf.constant(sigma_v)
+      x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32)
+      mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
+
+      log_pdf = mvn.log_pdf(x)
+
+      scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
+      expected_log_pdf = scipy_mvn.logpdf(x)
+      expected_pdf = scipy_mvn.pdf(x)
+      self.assertEqual(log_pdf.get_shape(), (3,))
+      self.assertAllClose(expected_log_pdf, log_pdf.eval())
+
+      pdf = mvn.pdf(x)
+      self.assertAllClose(expected_pdf, pdf.eval())
+
+  def testLogPDFMatchingDimension(self):
+    with tf.Session():
+      mu_v = np.array([-3.0, 3.0], dtype=np.float32)
+      mu = tf.constant(np.vstack(3 * [mu_v]))
+      sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
+      sigma = tf.constant(np.vstack(3 * [sigma_v[np.newaxis, :]]))
+      x = np.array([[-2.5, 2.5], [4.0, 0.0], [-1.0, 2.0]], dtype=np.float32)
+      mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
+
+      log_pdf = mvn.log_pdf(x)
+
+      scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
+      expected_log_pdf = scipy_mvn.logpdf(x)
+      expected_pdf = scipy_mvn.pdf(x)
+      self.assertEqual(log_pdf.get_shape(), (3,))
+      self.assertAllClose(expected_log_pdf, log_pdf.eval())
+
+      pdf = mvn.pdf(x)
+      self.assertAllClose(expected_pdf, pdf.eval())
+
+  def testLogPDFMultidimensional(self):
+    with tf.Session():
+      mu_v = np.array([-3.0, 3.0], dtype=np.float32)
+      mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
+      sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
+      sigma = tf.constant(
+          np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
+      x = np.array([-2.5, 2.5], dtype=np.float32)
+      mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
+
+      log_pdf = mvn.log_pdf(x)
+
+      scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
+      expected_log_pdf = np.vstack(15 * [scipy_mvn.logpdf(x)]).reshape(3, 5)
+      expected_pdf = np.vstack(15 * [scipy_mvn.pdf(x)]).reshape(3, 5)
+      self.assertEqual(log_pdf.get_shape(), (3, 5))
+      self.assertAllClose(expected_log_pdf, log_pdf.eval())
+
+      pdf = mvn.pdf(x)
+      self.assertAllClose(expected_pdf, pdf.eval())
+
+  def testEntropy(self):
+    with tf.Session():
+      mu_v = np.array([-3.0, 3.0], dtype=np.float32)
+      mu = tf.constant(mu_v)
+      sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
+      sigma = tf.constant(sigma_v)
+      mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
+      entropy = mvn.entropy()
+
+      scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
+      expected_entropy = scipy_mvn.entropy()
+
+      self.assertEqual(entropy.get_shape(), ())
+      self.assertAllClose(expected_entropy, entropy.eval())
+
+  def testEntropyMultidimensional(self):
+    with tf.Session():
+      mu_v = np.array([-3.0, 3.0], dtype=np.float32)
+      mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
+      sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
+      sigma = tf.constant(
+          np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
+      mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
+      entropy = mvn.entropy()
+
+      scipy_mvn = stats.multivariate_normal(mean=mu_v, cov=sigma_v)
+      expected_entropy = np.vstack(15 * [scipy_mvn.entropy()]).reshape(3, 5)
+
+      self.assertEqual(entropy.get_shape(), (3, 5))
+      self.assertAllClose(expected_entropy, entropy.eval())
+
+  def testSample(self):
+    with tf.Session():
+      mu_v = np.array([-3.0, 3.0], dtype=np.float32)
+      mu = tf.constant(mu_v)
+      sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
+      sigma = tf.constant(sigma_v)
+      n = tf.constant(100000)
+      mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
+      samples = mvn.sample(n, seed=137)
+      sample_values = samples.eval()
+      self.assertEqual(samples.get_shape(), (100000, 2))
+      self.assertAllClose(sample_values.mean(axis=0), mu_v, atol=1e-2)
+      self.assertAllClose(np.cov(sample_values, rowvar=0), sigma_v, atol=1e-1)
+
+  def testSampleMultiDimensional(self):
+    with tf.Session():
+      mu_v = np.array([-3.0, 3.0], dtype=np.float32)
+      mu = tf.constant(np.vstack(15 * [mu_v]).reshape(3, 5, 2))
+      sigma_v = np.array([[1.0, 0.5], [0.5, 1.0]], dtype=np.float32)
+      sigma = tf.constant(
+          np.vstack(15 * [sigma_v[np.newaxis, :]]).reshape(3, 5, 2, 2))
+      n = tf.constant(100000)
+      mvn = tf.contrib.distributions.MultivariateNormal(mu=mu, sigma=sigma)
+      samples = mvn.sample(n, seed=137)
+      sample_values = samples.eval()
+      self.assertEqual(samples.get_shape(), (100000, 3, 5, 2))
+      sample_values = sample_values.reshape(100000, 15, 2)
+      for i in range(15):
+        self.assertAllClose(
+            sample_values[:, i, :].mean(axis=0), mu_v, atol=1e-2)
+        self.assertAllClose(
+            np.cov(sample_values[:, i, :], rowvar=0), sigma_v, atol=1e-1)
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/gaussian.py b/tensorflow/contrib/distributions/python/ops/gaussian.py
index b9dad502983..cbb98624d97 100644
--- a/tensorflow/contrib/distributions/python/ops/gaussian.py
+++ b/tensorflow/contrib/distributions/python/ops/gaussian.py
@@ -88,7 +88,7 @@ class Gaussian(object):
 
   @property
   def mean(self):
-    return self._mu
+    return self._mu * array_ops.ones_like(self._sigma)
 
   def log_pdf(self, x, name=None):
     """Log pdf of observations in `x` under these Gaussian distribution(s).
@@ -170,7 +170,7 @@ class Gaussian(object):
       return 0.5 * math_ops.log(two_pi_e1 * math_ops.square(sigma))
 
   def sample(self, n, seed=None, name=None):
-    """Sample `n` observations the Gaussian Distributions.
+    """Sample `n` observations from the Gaussian Distributions.
 
     Args:
       n: `Scalar`, type int32, the number of observations to sample.
@@ -185,7 +185,7 @@ class Gaussian(object):
       broadcast_shape = (self._mu + self._sigma).get_shape()
       n = ops.convert_to_tensor(n)
       shape = array_ops.concat(
-          0, [array_ops.pack([n]), array_ops.shape(self._mu)])
+          0, [array_ops.pack([n]), array_ops.shape(self.mean)])
       sampled = random_ops.random_normal(
           shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
 
diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py
new file mode 100644
index 00000000000..4ddd577d46b
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/mvn.py
@@ -0,0 +1,429 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Multivariate Normal distribution class.
+
+@@MultivariateNormal
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util  # pylint: disable=line-too-long
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+
+def _assert_compatible_shapes(mu, sigma):
+  r_mu = array_ops.rank(mu)
+  r_sigma = array_ops.rank(sigma)
+  sigma_shape = array_ops.shape(sigma)
+  sigma_rank = array_ops.rank(sigma)
+  mu_shape = array_ops.shape(mu)
+  return control_flow_ops.group(
+      logging_ops.Assert(
+          math_ops.equal(r_mu + 1, r_sigma),
+          ["Rank of mu should be one less than rank of sigma, but saw: ",
+           r_mu, " vs. ", r_sigma]),
+      logging_ops.Assert(
+          math_ops.equal(
+              array_ops.gather(sigma_shape, sigma_rank - 2),
+              array_ops.gather(sigma_shape, sigma_rank - 1)),
+          ["Last two dimensions of sigma (%s) must be equal: " % sigma.name,
+           sigma_shape]),
+      logging_ops.Assert(
+          math_ops.reduce_all(math_ops.equal(
+              mu_shape,
+              array_ops.slice(
+                  sigma_shape, [0], array_ops.pack([sigma_rank - 1])))),
+          ["mu.shape and sigma.shape[:-1] must match, but saw: ",
+           mu_shape, " vs. ", sigma_shape]))
+
+
+def _assert_batch_positive_definite(sigma_chol):
+  """Add assertions checking that the sigmas are all Positive Definite.
+
+  Given `sigma_chol == cholesky(sigma)`, it is sufficient to check that
+  `all(diag(sigma_chol) > 0)`.  This is because to check that a matrix is PD,
+  it is sufficient that its cholesky factorization is PD, and to check that a
+  triangular matrix is PD, it is sufficient to check that its diagonal
+  entries are positive.
+
+  Args:
+    sigma_chol: N-D.  The lower triangular cholesky decomposition of `sigma`.
+
+  Returns:
+    An assertion op to use with `control_dependencies`, verifying that
+    `sigma_chol` is positive definite.
+  """
+  sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
+  return logging_ops.Assert(
+      math_ops.reduce_all(sigma_batch_diag > 0),
+      ["sigma_chol is not positive definite.  batched diagonals: ",
+       sigma_batch_diag, " shaped: ", array_ops.shape(sigma_batch_diag)])
+
+
+def _determinant_from_sigma_chol(sigma_chol):
+  det_last_dim = array_ops.rank(sigma_chol) - 2
+  sigma_batch_diag = array_ops.batch_matrix_diag_part(sigma_chol)
+  det = math_ops.square(math_ops.reduce_prod(
+      sigma_batch_diag, reduction_indices=det_last_dim))
+  det.set_shape(sigma_chol.get_shape()[:-2])
+  return det
+
+
+class MultivariateNormal(object):
+  """The Multivariate Normal distribution on `R^k`.
+
+  The distribution has mean and covariance parameters mu (1-D), sigma (2-D),
+  or alternatively mean `mu` and factored covariance (cholesky decomposed
+  `sigma`) called `sigma_chol`.
+
+  The PDF of this distribution is:
+
+  ```
+  f(x) = (2*pi)^(-k/2) |det(sigma)|^(-1/2) exp(-1/2*(x-mu)^*.sigma^{-1}.(x-mu))
+  ```
+
+  where `.` denotes the inner product on `R^k` and `^*` denotes transpose.
+
+  Alternatively, if `sigma` is positive definite, it can be represented in terms
+  of its lower triangular cholesky factorization
+
+  ```sigma = sigma_chol . sigma_chol^*```
+
+  and the pdf above allows simpler computation:
+
+  ```
+  |det(sigma)| = reduce_prod(diag(sigma_chol))^2
+  x_whitened = sigma^{-1/2} . (x - mu) = tri_solve(sigma_chol, x - mu)
+  (x-mu)^* .sigma^{-1} . (x-mu) = x_whitened^* . x_whitened
+  ```
+
+  where `tri_solve()` solves a triangular system of equations.
+  """
+
+  def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
+    """Multivariate Normal distributions on `R^k`.
+
+    User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`)
+    with the last dimension having length `k`.
+
+    User must provide exactly one of `sigma` (the covariance matrices) or
+    `sigma_chol` (the cholesky decompositions of the covariance matrices).
+    `sigma` or `sigma_chol` must be of rank `N+2`.  The last two dimensions
+    must both have length `k`.  The first `N` dimensions correspond to batch
+    indices.
+
+    If `sigma_chol` is not provided, the batch cholesky factorization of `sigma`
+    is calculated for you.
+
+    The shapes of `mu` and `sigma` must match for the first `N` dimensions.
+
+    Regardless of which parameter is provided, the covariance matrices must all
+    be **positive definite** (an error is raised if one of them is not).
+
+    Args:
+      mu: (N+1)-D.  `float` or `double` tensor, the means of the distributions.
+      sigma: (N+2)-D.  (optional) `float` or `double` tensor, the covariances
+        of the distribution(s).  The first `N+1` dimensions must match
+        those of `mu`.  Must be batch-positive-definite.
+      sigma_chol: (N+2)-D.  (optional) `float` or `double` tensor, a
+        lower-triangular factorization of `sigma`
+        (`sigma = sigma_chol . sigma_chol^*`).  The first `N+1` dimensions
+        must match those of `mu`.  The tensor itself need not be batch
+        lower triangular: we ignore the upper triangular part.  However,
+        the batch diagonals must be positive (i.e., sigma_chol must be
+        batch-positive-definite).
+      name: The name to give Ops created by the initializer.
+
+    Raises:
+      ValueError: if neither sigma nor sigma_chol is provided.
+      TypeError: if mu and sigma (resp. sigma_chol) are different dtypes.
+    """
+    if (sigma is None) == (sigma_chol is None):
+      raise ValueError("Exactly one of sigma and sigma_chol must be provided")
+
+    with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"):
+      sigma_or_half = sigma_chol if sigma is None else sigma
+
+      mu = ops.convert_to_tensor(mu)
+      sigma_or_half = ops.convert_to_tensor(sigma_or_half)
+
+      contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half))
+
+      with ops.control_dependencies([
+          _assert_compatible_shapes(mu, sigma_or_half)]):
+        mu = array_ops.identity(mu, name="mu")
+
+        # Store the dimensionality of the MVNs
+        self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1)
+
+        if sigma_chol is not None:
+          # Ensure we only keep the lower triangular part.
+          sigma_chol = array_ops.batch_matrix_band_part(
+              sigma_chol, num_lower=-1, num_upper=0)
+          sigma_det = _determinant_from_sigma_chol(sigma_chol)
+          with ops.control_dependencies([
+              _assert_batch_positive_definite(sigma_chol)]):
+            self._sigma = math_ops.batch_matmul(
+                sigma_chol, sigma_chol, adj_y=True, name="sigma")
+            self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
+            self._sigma_det = array_ops.identity(sigma_det, "sigma_det")
+            self._mu = array_ops.identity(mu, "mu")
+        else:  # sigma is not None
+          sigma_chol = linalg_ops.batch_cholesky(sigma)
+          sigma_det = _determinant_from_sigma_chol(sigma_chol)
+          # batch_cholesky checks for PSD; so we can just use it here.
+          with ops.control_dependencies([sigma_chol]):
+            self._sigma = array_ops.identity(sigma, "sigma")
+            self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
+            self._sigma_det = array_ops.identity(sigma_det, "sigma_det")
+            self._mu = array_ops.identity(mu, "mu")
+
+  @property
+  def dtype(self):
+    return self._mu.dtype
+
+  @property
+  def mu(self):
+    return self._mu
+
+  @property
+  def sigma(self):
+    return self._sigma
+
+  @property
+  def mean(self):
+    return self._mu
+
+  @property
+  def sigma_det(self):
+    return self._sigma_det
+
+  def log_pdf(self, x, name=None):
+    """Log pdf of observations `x` given these Multivariate Normals.
+
+    Args:
+      x: tensor of dtype `dtype`, must be broadcastable with `mu`.
+      name: The name to give this op.
+
+    Returns:
+      log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
+    """
+    with ops.op_scope(
+        [self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"):
+      x = ops.convert_to_tensor(x)
+      contrib_tensor_util.assert_same_float_dtype((self._mu, x))
+
+      x_centered = x - self.mu
+
+      x_rank = array_ops.rank(x_centered)
+      sigma_rank = array_ops.rank(self._sigma_chol)
+
+      x_rank_vec = array_ops.pack([x_rank])
+      sigma_rank_vec = array_ops.pack([sigma_rank])
+      x_shape = array_ops.shape(x_centered)
+
+      # sigma_chol is shaped [D, E, F, ..., k, k]
+      # x_centered shape is one of:
+      #   [D, E, F, ..., k], or [F, ..., k], or
+      #   [A, B, C, D, E, F, ..., k]
+      # and we need to convert x_centered to shape:
+      #   [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist)
+      # then transpose and reshape x_whitened back to one of the shapes:
+      #   [D, E, F, ..., k], or [1, 1, F, ..., k], or
+      #   [A, B, C, D, E, F, ..., k]
+
+      # This helper handles the case where rank(x_centered) < rank(sigma)
+      def _broadcast_x_not_higher_rank_than_sigma():
+        return array_ops.reshape(
+            x_centered,
+            array_ops.concat(
+                # Reshape to ones(deficient x rank) + x_shape + [1]
+                0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]),
+                                   dtype=x_rank.dtype),
+                    x_shape,
+                    [1])))
+
+      # These helpers handle the case where rank(x_centered) >= rank(sigma)
+      def _broadcast_x_higher_rank_than_sigma():
+        x_shape_left = array_ops.slice(
+            x_shape, [0], sigma_rank_vec - 1)
+        x_shape_right = array_ops.slice(
+            x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
+        x_shape_perm = array_ops.concat(
+            0, (math_ops.range(sigma_rank - 1, x_rank),
+                math_ops.range(0, sigma_rank - 1)))
+        return array_ops.reshape(
+            # Convert to [D, E, F, ..., k, B, C]
+            array_ops.transpose(
+                x_centered, perm=x_shape_perm),
+            # Reshape to [D, E, F, ..., k, B*C]
+            array_ops.concat(
+                0, (x_shape_right,
+                    array_ops.pack([
+                        math_ops.reduce_prod(x_shape_left, 0)]))))
+
+      def _unbroadcast_x_higher_rank_than_sigma():
+        x_shape_left = array_ops.slice(
+            x_shape, [0], sigma_rank_vec - 1)
+        x_shape_right = array_ops.slice(
+            x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
+        x_shape_perm = array_ops.concat(
+            0, (math_ops.range(sigma_rank - 1, x_rank),
+                math_ops.range(0, sigma_rank - 1)))
+        return array_ops.transpose(
+            # [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k]
+            array_ops.reshape(
+                # convert to [D, E, F, ..., k, B, C]
+                x_whitened_broadcast,
+                array_ops.concat(0, (x_shape_right, x_shape_left))),
+            perm=x_shape_perm)
+
+      # Step 1: reshape x_centered
+      x_centered_broadcast = control_flow_ops.cond(
+          # x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1]
+          # or         == [F, ..., k] => [1, 1, F, ..., k, 1]
+          x_rank <= sigma_rank - 1,
+          _broadcast_x_not_higher_rank_than_sigma,
+          # x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C]
+          _broadcast_x_higher_rank_than_sigma)
+
+      x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve(
+          self._sigma_chol, x_centered_broadcast)
+
+      # Reshape x_whitened_broadcast back to x_whitened
+      x_whitened = control_flow_ops.cond(
+          x_rank <= sigma_rank - 1,
+          lambda: array_ops.reshape(x_whitened_broadcast, x_shape),
+          _unbroadcast_x_higher_rank_than_sigma)
+
+      x_whitened = array_ops.expand_dims(x_whitened, -1)
+      # Reshape x_whitened to contain row vectors
+      # Returns a batchwise scalar
+      x_whitened_norm = math_ops.batch_matmul(
+          x_whitened, x_whitened, adj_x=True)
+      x_whitened_norm = control_flow_ops.cond(
+          x_rank <= sigma_rank - 1,
+          lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]),
+          lambda: array_ops.squeeze(x_whitened_norm, [-1]))
+
+      log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype)
+      k = math_ops.cast(self._k, self.dtype)
+      log_pdf_value = (
+          -math_ops.log(self._sigma_det) -k * log_two_pi - x_whitened_norm) / 2
+      final_shaped_value = control_flow_ops.cond(
+          x_rank <= sigma_rank - 1,
+          lambda: log_pdf_value,
+          lambda: array_ops.squeeze(log_pdf_value, [-1]))
+
+      output_static_shape = x_centered.get_shape()[:-1]
+      final_shaped_value.set_shape(output_static_shape)
+      return final_shaped_value
+
+  def pdf(self, x, name=None):
+    """The PDF of observations `x` under these Multivariate Normals.
+
+    Args:
+      x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
+      name: The name to give this op.
+
+    Returns:
+      pdf: tensor of dtype `dtype`, the pdf values of `x`.
+    """
+    with ops.op_scope(
+        [self._mu, self._sigma_chol, x], name, "MultivariateNormalPdf"):
+      return math_ops.exp(self.log_pdf(x))
+
+  def entropy(self, name=None):
+    """The entropies of these Multivariate Normals.
+
+    Args:
+      name: The name to give this op.
+
+    Returns:
+      entropy: tensor of dtype `dtype`, the entropies.
+    """
+    with ops.op_scope(
+        [self._mu, self._sigma_chol], name, "MultivariateNormalEntropy"):
+      one_plus_log_two_pi = constant_op.constant(
+          1 + math.log(2 * math.pi), dtype=self.dtype)
+
+      # Use broadcasting rules to calculate the full broadcast sigma.
+      k = math_ops.cast(self._k, dtype=self.dtype)
+      entropy_value = (
+          k * one_plus_log_two_pi + math_ops.log(self._sigma_det)) / 2
+      entropy_value.set_shape(self._sigma_det.get_shape())
+      return entropy_value
+
+  def sample(self, n, seed=None, name=None):
+    """Sample `n` observations from the Multivariate Normal Distributions.
+
+    Args:
+      n: `Scalar`, type int32, the number of observations to sample.
+      seed: Python integer, the random seed.
+      name: The name to give this op.
+
+    Returns:
+      samples: `[n, ...]`, a `Tensor` of `n` samples for each
+        of the distributions determined by broadcasting the hyperparameters.
+    """
+    with ops.op_scope(
+        [self._mu, self._sigma_chol, n], name, "MultivariateNormalSample"):
+      # TODO(ebrevdo): Is there a better way to get broadcast_shape?
+      broadcast_shape = self.mu.get_shape()
+      n = ops.convert_to_tensor(n)
+      sigma_shape_left = array_ops.slice(
+          array_ops.shape(self._sigma_chol),
+          [0], array_ops.pack([array_ops.rank(self._sigma_chol) - 2]))
+
+      k_n = array_ops.pack([self._k, n])
+      shape = array_ops.concat(0, [sigma_shape_left, k_n])
+      white_samples = random_ops.random_normal(
+          shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
+
+      correlated_samples = math_ops.batch_matmul(
+          self._sigma_chol, white_samples)
+
+      # Move the last dimension to the front
+      perm = array_ops.concat(
+          0,
+          (array_ops.pack([array_ops.rank(correlated_samples) - 1]),
+           math_ops.range(0, array_ops.rank(correlated_samples) - 1)))
+
+      # TODO(ebrevdo): Once we get a proper tensor contraction op,
+      # perform the inner product using that instead of batch_matmul
+      # and this slow transpose can go away!
+      correlated_samples = array_ops.transpose(correlated_samples, perm)
+
+      samples = correlated_samples + self.mu
+
+      # Provide some hints to shape inference
+      n_val = tensor_util.constant_value(n)
+      final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
+      samples.set_shape(final_shape)
+
+      return samples
diff --git a/tensorflow/core/kernels/batch_matmul_op.cc b/tensorflow/core/kernels/batch_matmul_op.cc
index f5a64e1f46e..922e9f63de5 100644
--- a/tensorflow/core/kernels/batch_matmul_op.cc
+++ b/tensorflow/core/kernels/batch_matmul_op.cc
@@ -234,8 +234,8 @@ class BatchMatMul : public OpKernel {
                                         in1.shape().DebugString()));
     const int ndims = in0.dims();
     OP_REQUIRES(
-        ctx, ndims >= 3,
-        errors::InvalidArgument("In[0] and In[1] ndims must be >= 3: ", ndims));
+        ctx, ndims >= 2,
+        errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
     TensorShape out_shape;
     for (int i = 0; i < ndims - 2; ++i) {
       OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
@@ -245,7 +245,7 @@ class BatchMatMul : public OpKernel {
                                           in1.shape().DebugString()));
       out_shape.AddDim(in0.dim_size(i));
     }
-    auto n = out_shape.num_elements();
+    auto n = (ndims == 2) ? 1 : out_shape.num_elements();
     auto d0 = in0.dim_size(ndims - 2);
     auto d1 = in0.dim_size(ndims - 1);
     Tensor in0_reshaped;
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index c82a4249fc4..199b54512e0 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -25,19 +25,8 @@ import tensorflow as tf
 
 class CholeskyOpTest(tf.test.TestCase):
 
-  def _verifyCholesky(self, x):
-    with self.test_session() as sess:
-      # Verify that LL^T == x.
-      if x.ndim == 2:
-        chol = tf.cholesky(x)
-        verification = tf.matmul(chol,
-                                 chol,
-                                 transpose_a=False,
-                                 transpose_b=True)
-      else:
-        chol = tf.batch_cholesky(x)
-        verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
-      chol_np, verification_np = sess.run([chol, verification])
+  def _verifyCholeskyBase(self, sess, x, chol, verification):
+    chol_np, verification_np = sess.run([chol, verification])
     self.assertAllClose(x, verification_np)
     self.assertShapeEqual(x, chol)
     # Check that the cholesky is lower triangular, and has positive diagonal
@@ -49,6 +38,20 @@ class CholeskyOpTest(tf.test.TestCase):
         self.assertAllClose(chol_matrix, np.tril(chol_matrix))
         self.assertTrue((np.diag(chol_matrix) > 0.0).all())
 
+  def _verifyCholesky(self, x):
+    # Verify that LL^T == x.
+    with self.test_session() as sess:
+      # Check the batch version, which works for ndim >= 2.
+      chol = tf.batch_cholesky(x)
+      verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
+      self._verifyCholeskyBase(sess, x, chol, verification)
+
+      if x.ndim == 2:  # Check the simple form of cholesky
+        chol = tf.cholesky(x)
+        verification = tf.matmul(
+            chol, chol, transpose_a=False, transpose_b=True)
+        self._verifyCholeskyBase(sess, x, chol, verification)
+
   def testBasic(self):
     self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]))
 
diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py
index 4355da8a05e..779d924ecf9 100644
--- a/tensorflow/python/kernel_tests/determinant_op_test.py
+++ b/tensorflow/python/kernel_tests/determinant_op_test.py
@@ -24,13 +24,8 @@ import tensorflow as tf
 
 class DeterminantOpTest(tf.test.TestCase):
 
-  def _compareDeterminant(self, matrix_x):
-    with self.test_session():
-      if matrix_x.ndim == 2:
-        tf_ans = tf.matrix_determinant(matrix_x)
-      else:
-        tf_ans = tf.batch_matrix_determinant(matrix_x)
-      out = tf_ans.eval()
+  def _compareDeterminantBase(self, matrix_x, tf_ans):
+    out = tf_ans.eval()
     shape = matrix_x.shape
     if shape[-1] == 0 and shape[-2] == 0:
       np_ans = np.ones(shape[:-2]).astype(matrix_x.dtype)
@@ -39,6 +34,15 @@ class DeterminantOpTest(tf.test.TestCase):
     self.assertAllClose(np_ans, out)
     self.assertShapeEqual(np_ans, tf_ans)
 
+  def _compareDeterminant(self, matrix_x):
+    with self.test_session():
+      # Check the batch version, which should work for ndim >= 2
+      self._compareDeterminantBase(
+          matrix_x, tf.batch_matrix_determinant(matrix_x))
+      if matrix_x.ndim == 2:
+        # Check the simple version
+        self._compareDeterminantBase(matrix_x, tf.matrix_determinant(matrix_x))
+
   def testBasic(self):
     # 2x2 matrices
     self._compareDeterminant(np.array([[2., 3.], [3., 4.]]).astype(np.float32))
diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
index 32e49328c16..d04020eac1d 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
@@ -67,11 +67,13 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
       np_ans, _, _, _ = np.linalg.lstsq(a, b)
       for fast in [True, False]:
         with self.test_session():
-          tf_ans = tf.matrix_solve_ls(a, b, fast=fast).eval()
-        self.assertEqual(np_ans.shape, tf_ans.shape)
+          tf_ans = tf.matrix_solve_ls(a, b, fast=fast)
+          ans = tf_ans.eval()
+        self.assertEqual(np_ans.shape, tf_ans.get_shape())
+        self.assertEqual(np_ans.shape, ans.shape)
 
         # Check residual norm.
-        tf_r = b - BatchMatMul(a, tf_ans)
+        tf_r = b - BatchMatMul(a, ans)
         tf_r_norm = np.sum(tf_r * tf_r)
         np_r = b - BatchMatMul(a, np_ans)
         np_r_norm = np.sum(np_r * np_r)
@@ -83,7 +85,7 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
           # slow path, because Eigen does not return a minimum norm solution.
           # TODO(rmlarsen): Enable this check for all paths if/when we fix
           # Eigen's solver.
-          self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
+          self.assertAllClose(np_ans, ans, atol=1e-5, rtol=1e-5)
 
   def _verifySolveBatch(self, x, y):
     # Since numpy.linalg.lsqr does not support batch solves, as opposed
@@ -122,20 +124,23 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
       b = y.astype(np_type)
       np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
       with self.test_session():
-        tf_ans = tf.matrix_solve_ls(a,
-                                    b,
-                                    l2_regularizer=l2_regularizer,
-                                    fast=True).eval()
-      self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
+        # Test with the batch version of  matrix_solve_ls on regular matrices
+        tf_ans = tf.batch_matrix_solve_ls(
+            a, b, l2_regularizer=l2_regularizer, fast=True).eval()
+        self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
+
+        # Test with the simple matrix_solve_ls on regular matrices
+        tf_ans = tf.matrix_solve_ls(
+            a, b, l2_regularizer=l2_regularizer, fast=True).eval()
+        self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
+
       # Test with a 2x3 batch of matrices.
       a = np.tile(x.astype(np_type), [2, 3, 1, 1])
       b = np.tile(y.astype(np_type), [2, 3, 1, 1])
       np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
       with self.test_session():
-        tf_ans = tf.batch_matrix_solve_ls(a,
-                                          b,
-                                          l2_regularizer=l2_regularizer,
-                                          fast=True).eval()
+        tf_ans = tf.batch_matrix_solve_ls(
+            a, b, l2_regularizer=l2_regularizer, fast=True).eval()
       self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
 
   def testSquare(self):
diff --git a/tensorflow/python/kernel_tests/matrix_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
index cffdf4e6884..a08d0f27501 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
@@ -37,15 +37,23 @@ class MatrixSolveOpTest(tf.test.TestCase):
           a = np.tile(a, batch_dims + [1, 1])
           a_np = np.tile(a_np, batch_dims + [1, 1])
           b = np.tile(b, batch_dims + [1, 1])
-        with self.test_session():
-          if a.ndim == 2:
-            tf_ans = tf.matrix_solve(a, b, adjoint=adjoint)
-          else:
-            tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint)
-          out = tf_ans.eval()
+
         np_ans = np.linalg.solve(a_np, b)
-        self.assertEqual(np_ans.shape, out.shape)
-        self.assertAllClose(np_ans, out)
+        with self.test_session():
+          # Test the batch version, which works for ndim >= 2
+          tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint)
+          out = tf_ans.eval()
+          self.assertEqual(tf_ans.get_shape(), out.shape)
+          self.assertEqual(np_ans.shape, out.shape)
+          self.assertAllClose(np_ans, out)
+
+          if a.ndim == 2:
+            # Test the simple version
+            tf_ans = tf.matrix_solve(a, b, adjoint=adjoint)
+            out = tf_ans.eval()
+            self.assertEqual(out.shape, tf_ans.get_shape())
+            self.assertEqual(np_ans.shape, out.shape)
+            self.assertAllClose(np_ans, out)
 
   def testSolve(self):
     # 2x2 matrices, 2x1 right-hand side.
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index f4637fa628f..fba393d599a 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -51,20 +51,27 @@ class MatrixTriangularSolveOpTest(tf.test.TestCase):
         a = np.tile(a, batch_dims + [1, 1])
         a_np = np.tile(a_np, batch_dims + [1, 1])
         b = np.tile(b, batch_dims + [1, 1])
+
       with self.test_session():
+        # Test the batch version, which works for ndim >= 2
+        tf_ans = tf.batch_matrix_triangular_solve(
+            a, b, lower=lower, adjoint=adjoint)
+        out = tf_ans.eval()
+
+        np_ans = np.linalg.solve(a_np, b)
+
+        self.assertEqual(np_ans.shape, tf_ans.get_shape())
+        self.assertEqual(np_ans.shape, out.shape)
+        self.assertAllClose(np_ans, out)
+
         if a.ndim == 2:
-          tf_ans = tf.matrix_triangular_solve(a,
-                                              b,
-                                              lower=lower,
-                                              adjoint=adjoint).eval()
-        else:
-          tf_ans = tf.batch_matrix_triangular_solve(a,
-                                                    b,
-                                                    lower=lower,
-                                                    adjoint=adjoint).eval()
-      np_ans = np.linalg.solve(a_np, b)
-      self.assertEqual(np_ans.shape, tf_ans.shape)
-      self.assertAllClose(np_ans, tf_ans)
+          # Test the simple version
+          tf_ans = tf.matrix_triangular_solve(
+              a, b, lower=lower, adjoint=adjoint)
+          out = tf_ans.eval()
+          self.assertEqual(np_ans.shape, tf_ans.get_shape())
+          self.assertEqual(np_ans.shape, out.shape)
+          self.assertAllClose(np_ans, out)
 
   def testSolve(self):
     # 2x2 matrices, single right-hand side.
diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
index e2c385c9dd7..d955ee1ad5e 100644
--- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
+++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
@@ -71,14 +71,28 @@ class SelfAdjointEigOpTest(tf.test.TestCase):
     for i in xrange(dlist[0]):
       self._testEigs(x[i], d, tf_out[i])
 
+  def _compareBatchSelfAdjointEigRank2(self, x, use_gpu=False):
+    with self.test_session() as sess:
+      tf_eig = tf.batch_self_adjoint_eig(tf.constant(x))
+      tf_out = sess.run([tf_eig])[0]
+    dlist = x.shape
+    d = dlist[-2]
+
+    self.assertEqual(len(tf_eig.get_shape()), 2)
+    self.assertEqual([d+1, d], tf_eig.get_shape().dims[-2:])
+    self._testEigs(x, d, tf_out)
+
   def testBasic(self):
     self._compareSelfAdjointEig(
         np.array([[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]))
 
   def testBatch(self):
     simple_array = np.array([[[1., 0.], [0., 5.]]])  # shape (1, 2, 2)
+    simple_array_2d = simple_array[0]  # shape (2, 2)
     self._compareBatchSelfAdjointEigRank3(simple_array)
-    self._compareBatchSelfAdjointEigRank3(np.vstack((simple_array, simple_array)))
+    self._compareBatchSelfAdjointEigRank3(
+        np.vstack((simple_array, simple_array)))
+    self._compareBatchSelfAdjointEigRank2(simple_array_2d)
     odd_sized_array = np.array([[[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]])
     self._compareBatchSelfAdjointEigRank3(
         np.vstack((odd_sized_array, odd_sized_array)))
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 58bddb0b672..31fc2b28768 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -39,7 +39,7 @@ def _UnchangedSquare(op):
 @ops.RegisterShape("BatchCholesky")
 @ops.RegisterShape("BatchMatrixInverse")
 def _BatchUnchangedSquare(op):
-  input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+  input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
   # The matrices in the batch must be square.
   input_shape[-1].assert_is_compatible_with(input_shape[-2])
   return [input_shape]
@@ -61,7 +61,7 @@ def _MatrixDeterminantShape(op):
 
 @ops.RegisterShape("BatchMatrixDeterminant")
 def _BatchMatrixDeterminantShape(op):
-  input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+  input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
   # The matrices in the batch must be square.
   input_shape[-1].assert_is_compatible_with(input_shape[-2])
   if input_shape.ndims is not None:
@@ -82,7 +82,7 @@ def _SelfAdjointEigShape(op):
 
 @ops.RegisterShape("BatchSelfAdjointEig")
 def _BatchSelfAdjointEigShape(op):
-  input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+  input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
   # The matrices in the batch must be square.
   input_shape[-1].assert_is_compatible_with(input_shape[-2])
   dlist = input_shape.dims
@@ -106,8 +106,8 @@ def _SquareMatrixSolveShape(op):
 @ops.RegisterShape("BatchMatrixSolve")
 @ops.RegisterShape("BatchMatrixTriangularSolve")
 def _BatchSquareMatrixSolveShape(op):
-  lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
-  rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
+  lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
+  rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
   # The matrices must be square.
   lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2])
   # The matrices and right-hand sides in the batch must have the same number of
@@ -127,8 +127,8 @@ def _MatrixSolveLsShape(op):
 
 @ops.RegisterShape("BatchMatrixSolveLs")
 def _BatchMatrixSolveLsShape(op):
-  lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
-  rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
+  lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
+  rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
   # The matrices and right-hand sides in the batch must have the same number of
   # rows.
   lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])