From 2242803599eaaded61ee71e4df883d5a28606ca1 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <nobody@tensorflow.org>
Date: Tue, 19 Apr 2016 16:37:09 -0800
Subject: [PATCH] Rollback of "Add MultivariateNormal to
 tf.contrib.distributions."

Also fix overly stringent constraints on batchwise linalg ops & batch_matmul.
Change: 120289843
---
 tensorflow/contrib/distributions/BUILD        | 14 ---------
 tensorflow/contrib/distributions/__init__.py  |  4 +--
 .../distributions/python/ops/gaussian.py      |  6 ++--
 tensorflow/core/kernels/batch_matmul_op.cc    |  6 ++--
 .../python/kernel_tests/cholesky_op_test.py   | 29 ++++++++---------
 .../kernel_tests/determinant_op_test.py       | 18 +++++------
 .../kernel_tests/matrix_solve_ls_op_test.py   | 31 ++++++++-----------
 .../kernel_tests/matrix_solve_op_test.py      | 20 ++++--------
 .../matrix_triangular_solve_op_test.py        | 31 +++++++------------
 .../kernel_tests/self_adjoint_eig_op_test.py  | 16 +---------
 tensorflow/python/ops/linalg_ops.py           | 14 ++++-----
 11 files changed, 67 insertions(+), 122 deletions(-)

diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index bfa31dbe1cd..5feac79ecb0 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -32,19 +32,6 @@ cuda_py_tests(
     srcs = ["python/kernel_tests/gaussian_test.py"],
     additional_deps = [
         ":distributions_py",
-        "//third_party/py/scipy",
-        "//tensorflow/python:framework_test_lib",
-        "//tensorflow/python:platform_test",
-    ],
-)
-
-cuda_py_tests(
-    name = "mvn_test",
-    size = "small",
-    srcs = ["python/kernel_tests/mvn_test.py"],
-    additional_deps = [
-        ":distributions_py",
-        "//third_party/py/scipy",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:platform_test",
     ],
@@ -56,7 +43,6 @@ cuda_py_tests(
     srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
     additional_deps = [
         ":distributions_py",
-        "//third_party/py/scipy",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:platform_test",
     ],
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 54607a7379e..2f9b8fcafb1 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -21,8 +21,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-# pylint: disable=unused-import,wildcard-import,line-too-long
+# pylint: disable=unused-import,wildcard-import, line-too-long
 from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors
 from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
 from tensorflow.contrib.distributions.python.ops.gaussian import *
-from tensorflow.contrib.distributions.python.ops.mvn import *
+# from tensorflow.contrib.distributions.python.ops.dirichlet import *  # pylint: disable=line-too-long
diff --git a/tensorflow/contrib/distributions/python/ops/gaussian.py b/tensorflow/contrib/distributions/python/ops/gaussian.py
index cbb98624d97..b9dad502983 100644
--- a/tensorflow/contrib/distributions/python/ops/gaussian.py
+++ b/tensorflow/contrib/distributions/python/ops/gaussian.py
@@ -88,7 +88,7 @@ class Gaussian(object):
 
   @property
   def mean(self):
-    return self._mu * array_ops.ones_like(self._sigma)
+    return self._mu
 
   def log_pdf(self, x, name=None):
     """Log pdf of observations in `x` under these Gaussian distribution(s).
@@ -170,7 +170,7 @@ class Gaussian(object):
       return 0.5 * math_ops.log(two_pi_e1 * math_ops.square(sigma))
 
   def sample(self, n, seed=None, name=None):
-    """Sample `n` observations from the Gaussian Distributions.
+    """Sample `n` observations the Gaussian Distributions.
 
     Args:
       n: `Scalar`, type int32, the number of observations to sample.
@@ -185,7 +185,7 @@ class Gaussian(object):
       broadcast_shape = (self._mu + self._sigma).get_shape()
       n = ops.convert_to_tensor(n)
       shape = array_ops.concat(
-          0, [array_ops.pack([n]), array_ops.shape(self.mean)])
+          0, [array_ops.pack([n]), array_ops.shape(self._mu)])
       sampled = random_ops.random_normal(
           shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
 
diff --git a/tensorflow/core/kernels/batch_matmul_op.cc b/tensorflow/core/kernels/batch_matmul_op.cc
index 922e9f63de5..f5a64e1f46e 100644
--- a/tensorflow/core/kernels/batch_matmul_op.cc
+++ b/tensorflow/core/kernels/batch_matmul_op.cc
@@ -234,8 +234,8 @@ class BatchMatMul : public OpKernel {
                                         in1.shape().DebugString()));
     const int ndims = in0.dims();
     OP_REQUIRES(
-        ctx, ndims >= 2,
-        errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
+        ctx, ndims >= 3,
+        errors::InvalidArgument("In[0] and In[1] ndims must be >= 3: ", ndims));
     TensorShape out_shape;
     for (int i = 0; i < ndims - 2; ++i) {
       OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
@@ -245,7 +245,7 @@ class BatchMatMul : public OpKernel {
                                           in1.shape().DebugString()));
       out_shape.AddDim(in0.dim_size(i));
     }
-    auto n = (ndims == 2) ? 1 : out_shape.num_elements();
+    auto n = out_shape.num_elements();
     auto d0 = in0.dim_size(ndims - 2);
     auto d1 = in0.dim_size(ndims - 1);
     Tensor in0_reshaped;
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index 199b54512e0..c82a4249fc4 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -25,8 +25,19 @@ import tensorflow as tf
 
 class CholeskyOpTest(tf.test.TestCase):
 
-  def _verifyCholeskyBase(self, sess, x, chol, verification):
-    chol_np, verification_np = sess.run([chol, verification])
+  def _verifyCholesky(self, x):
+    with self.test_session() as sess:
+      # Verify that LL^T == x.
+      if x.ndim == 2:
+        chol = tf.cholesky(x)
+        verification = tf.matmul(chol,
+                                 chol,
+                                 transpose_a=False,
+                                 transpose_b=True)
+      else:
+        chol = tf.batch_cholesky(x)
+        verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
+      chol_np, verification_np = sess.run([chol, verification])
     self.assertAllClose(x, verification_np)
     self.assertShapeEqual(x, chol)
     # Check that the cholesky is lower triangular, and has positive diagonal
@@ -38,20 +49,6 @@ class CholeskyOpTest(tf.test.TestCase):
         self.assertAllClose(chol_matrix, np.tril(chol_matrix))
         self.assertTrue((np.diag(chol_matrix) > 0.0).all())
 
-  def _verifyCholesky(self, x):
-    # Verify that LL^T == x.
-    with self.test_session() as sess:
-      # Check the batch version, which works for ndim >= 2.
-      chol = tf.batch_cholesky(x)
-      verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
-      self._verifyCholeskyBase(sess, x, chol, verification)
-
-      if x.ndim == 2:  # Check the simple form of cholesky
-        chol = tf.cholesky(x)
-        verification = tf.matmul(
-            chol, chol, transpose_a=False, transpose_b=True)
-        self._verifyCholeskyBase(sess, x, chol, verification)
-
   def testBasic(self):
     self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]))
 
diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py
index 779d924ecf9..4355da8a05e 100644
--- a/tensorflow/python/kernel_tests/determinant_op_test.py
+++ b/tensorflow/python/kernel_tests/determinant_op_test.py
@@ -24,8 +24,13 @@ import tensorflow as tf
 
 class DeterminantOpTest(tf.test.TestCase):
 
-  def _compareDeterminantBase(self, matrix_x, tf_ans):
-    out = tf_ans.eval()
+  def _compareDeterminant(self, matrix_x):
+    with self.test_session():
+      if matrix_x.ndim == 2:
+        tf_ans = tf.matrix_determinant(matrix_x)
+      else:
+        tf_ans = tf.batch_matrix_determinant(matrix_x)
+      out = tf_ans.eval()
     shape = matrix_x.shape
     if shape[-1] == 0 and shape[-2] == 0:
       np_ans = np.ones(shape[:-2]).astype(matrix_x.dtype)
@@ -34,15 +39,6 @@ class DeterminantOpTest(tf.test.TestCase):
     self.assertAllClose(np_ans, out)
     self.assertShapeEqual(np_ans, tf_ans)
 
-  def _compareDeterminant(self, matrix_x):
-    with self.test_session():
-      # Check the batch version, which should work for ndim >= 2
-      self._compareDeterminantBase(
-          matrix_x, tf.batch_matrix_determinant(matrix_x))
-      if matrix_x.ndim == 2:
-        # Check the simple version
-        self._compareDeterminantBase(matrix_x, tf.matrix_determinant(matrix_x))
-
   def testBasic(self):
     # 2x2 matrices
     self._compareDeterminant(np.array([[2., 3.], [3., 4.]]).astype(np.float32))
diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
index d04020eac1d..32e49328c16 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
@@ -67,13 +67,11 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
       np_ans, _, _, _ = np.linalg.lstsq(a, b)
       for fast in [True, False]:
         with self.test_session():
-          tf_ans = tf.matrix_solve_ls(a, b, fast=fast)
-          ans = tf_ans.eval()
-        self.assertEqual(np_ans.shape, tf_ans.get_shape())
-        self.assertEqual(np_ans.shape, ans.shape)
+          tf_ans = tf.matrix_solve_ls(a, b, fast=fast).eval()
+        self.assertEqual(np_ans.shape, tf_ans.shape)
 
         # Check residual norm.
-        tf_r = b - BatchMatMul(a, ans)
+        tf_r = b - BatchMatMul(a, tf_ans)
         tf_r_norm = np.sum(tf_r * tf_r)
         np_r = b - BatchMatMul(a, np_ans)
         np_r_norm = np.sum(np_r * np_r)
@@ -85,7 +83,7 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
           # slow path, because Eigen does not return a minimum norm solution.
           # TODO(rmlarsen): Enable this check for all paths if/when we fix
           # Eigen's solver.
-          self.assertAllClose(np_ans, ans, atol=1e-5, rtol=1e-5)
+          self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
 
   def _verifySolveBatch(self, x, y):
     # Since numpy.linalg.lsqr does not support batch solves, as opposed
@@ -124,23 +122,20 @@ class MatrixSolveLsOpTest(tf.test.TestCase):
       b = y.astype(np_type)
       np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
       with self.test_session():
-        # Test with the batch version of  matrix_solve_ls on regular matrices
-        tf_ans = tf.batch_matrix_solve_ls(
-            a, b, l2_regularizer=l2_regularizer, fast=True).eval()
-        self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
-
-        # Test with the simple matrix_solve_ls on regular matrices
-        tf_ans = tf.matrix_solve_ls(
-            a, b, l2_regularizer=l2_regularizer, fast=True).eval()
-        self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
-
+        tf_ans = tf.matrix_solve_ls(a,
+                                    b,
+                                    l2_regularizer=l2_regularizer,
+                                    fast=True).eval()
+      self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
       # Test with a 2x3 batch of matrices.
       a = np.tile(x.astype(np_type), [2, 3, 1, 1])
       b = np.tile(y.astype(np_type), [2, 3, 1, 1])
       np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
       with self.test_session():
-        tf_ans = tf.batch_matrix_solve_ls(
-            a, b, l2_regularizer=l2_regularizer, fast=True).eval()
+        tf_ans = tf.batch_matrix_solve_ls(a,
+                                          b,
+                                          l2_regularizer=l2_regularizer,
+                                          fast=True).eval()
       self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
 
   def testSquare(self):
diff --git a/tensorflow/python/kernel_tests/matrix_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
index a08d0f27501..cffdf4e6884 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
@@ -37,23 +37,15 @@ class MatrixSolveOpTest(tf.test.TestCase):
           a = np.tile(a, batch_dims + [1, 1])
           a_np = np.tile(a_np, batch_dims + [1, 1])
           b = np.tile(b, batch_dims + [1, 1])
-
-        np_ans = np.linalg.solve(a_np, b)
         with self.test_session():
-          # Test the batch version, which works for ndim >= 2
-          tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint)
-          out = tf_ans.eval()
-          self.assertEqual(tf_ans.get_shape(), out.shape)
-          self.assertEqual(np_ans.shape, out.shape)
-          self.assertAllClose(np_ans, out)
-
           if a.ndim == 2:
-            # Test the simple version
             tf_ans = tf.matrix_solve(a, b, adjoint=adjoint)
-            out = tf_ans.eval()
-            self.assertEqual(out.shape, tf_ans.get_shape())
-            self.assertEqual(np_ans.shape, out.shape)
-            self.assertAllClose(np_ans, out)
+          else:
+            tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint)
+          out = tf_ans.eval()
+        np_ans = np.linalg.solve(a_np, b)
+        self.assertEqual(np_ans.shape, out.shape)
+        self.assertAllClose(np_ans, out)
 
   def testSolve(self):
     # 2x2 matrices, 2x1 right-hand side.
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index fba393d599a..f4637fa628f 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -51,27 +51,20 @@ class MatrixTriangularSolveOpTest(tf.test.TestCase):
         a = np.tile(a, batch_dims + [1, 1])
         a_np = np.tile(a_np, batch_dims + [1, 1])
         b = np.tile(b, batch_dims + [1, 1])
-
       with self.test_session():
-        # Test the batch version, which works for ndim >= 2
-        tf_ans = tf.batch_matrix_triangular_solve(
-            a, b, lower=lower, adjoint=adjoint)
-        out = tf_ans.eval()
-
-        np_ans = np.linalg.solve(a_np, b)
-
-        self.assertEqual(np_ans.shape, tf_ans.get_shape())
-        self.assertEqual(np_ans.shape, out.shape)
-        self.assertAllClose(np_ans, out)
-
         if a.ndim == 2:
-          # Test the simple version
-          tf_ans = tf.matrix_triangular_solve(
-              a, b, lower=lower, adjoint=adjoint)
-          out = tf_ans.eval()
-          self.assertEqual(np_ans.shape, tf_ans.get_shape())
-          self.assertEqual(np_ans.shape, out.shape)
-          self.assertAllClose(np_ans, out)
+          tf_ans = tf.matrix_triangular_solve(a,
+                                              b,
+                                              lower=lower,
+                                              adjoint=adjoint).eval()
+        else:
+          tf_ans = tf.batch_matrix_triangular_solve(a,
+                                                    b,
+                                                    lower=lower,
+                                                    adjoint=adjoint).eval()
+      np_ans = np.linalg.solve(a_np, b)
+      self.assertEqual(np_ans.shape, tf_ans.shape)
+      self.assertAllClose(np_ans, tf_ans)
 
   def testSolve(self):
     # 2x2 matrices, single right-hand side.
diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
index d955ee1ad5e..e2c385c9dd7 100644
--- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
+++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
@@ -71,28 +71,14 @@ class SelfAdjointEigOpTest(tf.test.TestCase):
     for i in xrange(dlist[0]):
       self._testEigs(x[i], d, tf_out[i])
 
-  def _compareBatchSelfAdjointEigRank2(self, x, use_gpu=False):
-    with self.test_session() as sess:
-      tf_eig = tf.batch_self_adjoint_eig(tf.constant(x))
-      tf_out = sess.run([tf_eig])[0]
-    dlist = x.shape
-    d = dlist[-2]
-
-    self.assertEqual(len(tf_eig.get_shape()), 2)
-    self.assertEqual([d+1, d], tf_eig.get_shape().dims[-2:])
-    self._testEigs(x, d, tf_out)
-
   def testBasic(self):
     self._compareSelfAdjointEig(
         np.array([[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]))
 
   def testBatch(self):
     simple_array = np.array([[[1., 0.], [0., 5.]]])  # shape (1, 2, 2)
-    simple_array_2d = simple_array[0]  # shape (2, 2)
     self._compareBatchSelfAdjointEigRank3(simple_array)
-    self._compareBatchSelfAdjointEigRank3(
-        np.vstack((simple_array, simple_array)))
-    self._compareBatchSelfAdjointEigRank2(simple_array_2d)
+    self._compareBatchSelfAdjointEigRank3(np.vstack((simple_array, simple_array)))
     odd_sized_array = np.array([[[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]])
     self._compareBatchSelfAdjointEigRank3(
         np.vstack((odd_sized_array, odd_sized_array)))
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 31fc2b28768..58bddb0b672 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -39,7 +39,7 @@ def _UnchangedSquare(op):
 @ops.RegisterShape("BatchCholesky")
 @ops.RegisterShape("BatchMatrixInverse")
 def _BatchUnchangedSquare(op):
-  input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
+  input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
   # The matrices in the batch must be square.
   input_shape[-1].assert_is_compatible_with(input_shape[-2])
   return [input_shape]
@@ -61,7 +61,7 @@ def _MatrixDeterminantShape(op):
 
 @ops.RegisterShape("BatchMatrixDeterminant")
 def _BatchMatrixDeterminantShape(op):
-  input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
+  input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
   # The matrices in the batch must be square.
   input_shape[-1].assert_is_compatible_with(input_shape[-2])
   if input_shape.ndims is not None:
@@ -82,7 +82,7 @@ def _SelfAdjointEigShape(op):
 
 @ops.RegisterShape("BatchSelfAdjointEig")
 def _BatchSelfAdjointEigShape(op):
-  input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
+  input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
   # The matrices in the batch must be square.
   input_shape[-1].assert_is_compatible_with(input_shape[-2])
   dlist = input_shape.dims
@@ -106,8 +106,8 @@ def _SquareMatrixSolveShape(op):
 @ops.RegisterShape("BatchMatrixSolve")
 @ops.RegisterShape("BatchMatrixTriangularSolve")
 def _BatchSquareMatrixSolveShape(op):
-  lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
-  rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
+  lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+  rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
   # The matrices must be square.
   lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2])
   # The matrices and right-hand sides in the batch must have the same number of
@@ -127,8 +127,8 @@ def _MatrixSolveLsShape(op):
 
 @ops.RegisterShape("BatchMatrixSolveLs")
 def _BatchMatrixSolveLsShape(op):
-  lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
-  rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
+  lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+  rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
   # The matrices and right-hand sides in the batch must have the same number of
   # rows.
   lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])