From ed4300da87a05be20adb8b11428a8e78f2fe828a Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 3 Aug 2016 16:45:38 -0800
Subject: [PATCH] Added Binomial and Multinomial distributions.

- Refactored some common asserts into a distribution_util library.
- Changed some documentation for distributions (in particular providing more helpful error messages, properly escaping values in comments, etc.).
Change: 129280447
---
 tensorflow/contrib/distributions/BUILD        |  30 +-
 tensorflow/contrib/distributions/__init__.py  |   4 +
 .../python/kernel_tests/bernoulli_test.py     |  11 +-
 .../python/kernel_tests/binomial_test.py      | 173 +++++++++
 .../dirichlet_multinomial_test.py             |   2 +-
 .../python/kernel_tests/multinomial_test.py   | 226 ++++++++++++
 .../distributions/python/ops/bernoulli.py     |  50 +--
 .../contrib/distributions/python/ops/beta.py  |  31 +-
 .../distributions/python/ops/binomial.py      | 340 +++++++++++++++++
 .../distributions/python/ops/categorical.py   |  16 +-
 .../contrib/distributions/python/ops/chi2.py  |  12 +-
 .../distributions/python/ops/dirichlet.py     |  49 +--
 .../python/ops/dirichlet_multinomial.py       | 123 +++----
 .../python/ops/distribution_util.py           | 177 +++++++++
 .../distributions/python/ops/exponential.py   |  15 +-
 .../contrib/distributions/python/ops/gamma.py |  21 +-
 .../distributions/python/ops/inverse_gamma.py |  28 +-
 .../python/ops/kullback_leibler.py            |   4 +-
 .../distributions/python/ops/laplace.py       |  17 +-
 .../distributions/python/ops/multinomial.py   | 343 ++++++++++++++++++
 .../contrib/distributions/python/ops/mvn.py   |  20 +-
 .../distributions/python/ops/normal.py        |  17 +-
 .../distributions/python/ops/student_t.py     |  36 +-
 .../python/ops/transformed_distribution.py    |   1 +
 .../distributions/python/ops/uniform.py       |  23 +-
 25 files changed, 1503 insertions(+), 266 deletions(-)
 create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py
 create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py
 create mode 100644 tensorflow/contrib/distributions/python/ops/binomial.py
 create mode 100644 tensorflow/contrib/distributions/python/ops/distribution_util.py
 create mode 100644 tensorflow/contrib/distributions/python/ops/multinomial.py

diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 3fd428e1220..2d5a708bac6 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -99,7 +99,16 @@ cuda_py_tests(
     srcs = ["python/kernel_tests/beta_test.py"],
     additional_deps = [
         ":distributions_py",
-        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+cuda_py_tests(
+    name = "binomial_test",
+    size = "small",
+    srcs = ["python/kernel_tests/binomial_test.py"],
+    additional_deps = [
+        ":distributions_py",
         "//tensorflow/python:platform_test",
     ],
     tags = ["notsan"],
@@ -179,9 +188,8 @@ cuda_py_tests(
 )
 
 cuda_py_tests(
-    name = "kullback_leibler_test",
-    size = "small",
-    srcs = ["python/kernel_tests/kullback_leibler_test.py"],
+    name = "laplace_test",
+    srcs = ["python/kernel_tests/laplace_test.py"],
     additional_deps = [
         ":distributions_py",
         "//tensorflow/python:framework_test_lib",
@@ -190,13 +198,14 @@ cuda_py_tests(
 )
 
 cuda_py_tests(
-    name = "laplace_test",
-    srcs = ["python/kernel_tests/laplace_test.py"],
+    name = "multinomial_test",
+    srcs = ["python/kernel_tests/multinomial_test.py"],
     additional_deps = [
         ":distributions_py",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:platform_test",
     ],
+    tags = ["notsan"],
 )
 
 cuda_py_tests(
@@ -239,6 +248,15 @@ cuda_py_tests(
     srcs = ["python/kernel_tests/uniform_test.py"],
     additional_deps = [
         ":distributions_py",
+        "//tensorflow/python:framework_test_lib",
+    ],
+)
+
+cuda_py_tests(
+    name = "kullback_leibler_test",
+    size = "small",
+    srcs = ["python/kernel_tests/kullback_leibler_test.py"],
+    additional_deps = [
         "//tensorflow/python:platform_test",
     ],
 )
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 2b32556f3eb..83719157761 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -25,6 +25,7 @@ initialized with parameters that define the distributions.
 
 ### Univariate (scalar) distributions
 
+@@Binomial
 @@Bernoulli
 @@Beta
 @@Categorical
@@ -50,6 +51,7 @@ initialized with parameters that define the distributions.
 
 @@Dirichlet
 @@DirichletMultinomial
+@@Multinomial
 
 ### Transformed distributions
 
@@ -79,6 +81,7 @@ from __future__ import print_function
 
 from tensorflow.contrib.distributions.python.ops.bernoulli import *
 from tensorflow.contrib.distributions.python.ops.beta import *
+from tensorflow.contrib.distributions.python.ops.binomial import *
 from tensorflow.contrib.distributions.python.ops.categorical import *
 from tensorflow.contrib.distributions.python.ops.chi2 import *
 from tensorflow.contrib.distributions.python.ops.dirichlet import *
@@ -89,6 +92,7 @@ from tensorflow.contrib.distributions.python.ops.gamma import *
 from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
 from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
 from tensorflow.contrib.distributions.python.ops.laplace import *
+from tensorflow.contrib.distributions.python.ops.multinomial import *
 from tensorflow.contrib.distributions.python.ops.mvn import *
 from tensorflow.contrib.distributions.python.ops.normal import *
 from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py
index c636a4d060c..82f77fbfd1e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py
@@ -57,10 +57,17 @@ class BernoulliTest(tf.test.TestCase):
       self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
 
   def testInvalidP(self):
-    invalid_ps = [1.01, -0.01, 2., -3.]
+    invalid_ps = [1.01, 2.]
     for p in invalid_ps:
       with self.test_session():
-        with self.assertRaisesOpError("x <= y"):
+        with self.assertRaisesOpError("p has components greater than 1"):
+          dist = tf.contrib.distributions.Bernoulli(p=p)
+          dist.p.eval()
+
+    invalid_ps = [-0.01, -3.]
+    for p in invalid_ps:
+      with self.test_session():
+        with self.assertRaisesOpError("Condition x >= 0"):
           dist = tf.contrib.distributions.Bernoulli(p=p)
           dist.p.eval()
 
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py
new file mode 100644
index 00000000000..8b2520f8368
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py
@@ -0,0 +1,173 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+import tensorflow as tf
+
+
+class BinomialTest(tf.test.TestCase):
+
+  def testSimpleShapes(self):
+    with self.test_session():
+      p = np.float32(np.random.beta(1, 1))
+      binom = tf.contrib.distributions.Binomial(n=1., p=p)
+      self.assertAllEqual([], binom.event_shape().eval())
+      self.assertAllEqual([], binom.batch_shape().eval())
+      self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
+      self.assertEqual(tf.TensorShape([]), binom.get_batch_shape())
+
+  def testComplexShapes(self):
+    with self.test_session():
+      p = np.random.beta(1, 1, size=(3, 2)).astype(np.float32)
+      n = [[3., 2], [4, 5], [6, 7]]
+      binom = tf.contrib.distributions.Binomial(n=n, p=p)
+      self.assertAllEqual([], binom.event_shape().eval())
+      self.assertAllEqual([3, 2], binom.batch_shape().eval())
+      self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
+      self.assertEqual(tf.TensorShape([3, 2]), binom.get_batch_shape())
+
+  def testNProperty(self):
+    p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
+    n = [[3.], [4]]
+    with self.test_session():
+      binom = tf.contrib.distributions.Binomial(n=n, p=p)
+      self.assertEqual((2, 1), binom.n.get_shape())
+      self.assertAllClose(n, binom.n.eval())
+
+  def testPProperty(self):
+    p = [[0.1, 0.2, 0.7]]
+    with self.test_session():
+      binom = tf.contrib.distributions.Binomial(n=3., p=p)
+      self.assertEqual((1, 3), binom.p.get_shape())
+      self.assertEqual((1, 3), binom.logits.get_shape())
+      self.assertAllClose(p, binom.p.eval())
+
+  def testLogitsProperty(self):
+    logits = [[0., 9., -0.5]]
+    with self.test_session():
+      binom = tf.contrib.distributions.Binomial(n=3., logits=logits)
+      self.assertEqual((1, 3), binom.p.get_shape())
+      self.assertEqual((1, 3), binom.logits.get_shape())
+      self.assertAllClose(logits, binom.logits.eval())
+
+  def testPmfNandCountsAgree(self):
+    p = [[0.1, 0.2, 0.7]]
+    n = [[5.]]
+    with self.test_session():
+      binom = tf.contrib.distributions.Binomial(n=n, p=p)
+      binom.pmf([2., 3, 2]).eval()
+      binom.pmf([3., 1, 2]).eval()
+      with self.assertRaisesOpError('Condition x >= 0.*'):
+        binom.pmf([-1., 4, 2]).eval()
+      with self.assertRaisesOpError('Condition x <= y.*'):
+        binom.pmf([7., 3, 0]).eval()
+
+  def testPmf_non_integer_counts(self):
+    p = [[0.1, 0.2, 0.7]]
+    n = [[5.]]
+    with self.test_session():
+      # No errors with integer n.
+      binom = tf.contrib.distributions.Binomial(n=n, p=p)
+      binom.pmf([2., 3, 2]).eval()
+      binom.pmf([3., 1, 2]).eval()
+      # Both equality and integer checking fail.
+      with self.assertRaisesOpError('Condition x == y.*'):
+        binom.pmf([1.0, 2.5, 1.5]).eval()
+
+      binom = tf.contrib.distributions.Binomial(n=n, p=p, validate_args=False)
+      binom.pmf([1., 2., 3.]).eval()
+      # Non-integer arguments work.
+      binom.pmf([1.0, 2.5, 1.5]).eval()
+
+  def testPmfBothZeroBatches(self):
+    with self.test_session():
+      # Both zero-batches.  No broadcast
+      p = 0.5
+      counts = 1.
+      pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
+      self.assertAllClose(0.5, pmf.eval())
+      self.assertEqual((), pmf.get_shape())
+
+  def testPmfBothZeroBatchesNontrivialN(self):
+    with self.test_session():
+      # Both zero-batches.  No broadcast
+      p = 0.1
+      counts = 3.
+      binom = tf.contrib.distributions.Binomial(n=5., p=p)
+      pmf = binom.pmf(counts)
+      self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval())
+      self.assertEqual((), pmf.get_shape())
+
+  def testPmfPStretchedInBroadcastWhenSameRank(self):
+    with self.test_session():
+      p = [[0.1, 0.9]]
+      counts = [[1., 2.]]
+      pmf = tf.contrib.distributions.Binomial(n=3., p=p).pmf(counts)
+      self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval())
+      self.assertEqual((1, 2), pmf.get_shape())
+
+  def testPmfPStretchedInBroadcastWhenLowerRank(self):
+    with self.test_session():
+      p = [0.1, 0.4]
+      counts = [[1.], [0.]]
+      pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
+      self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval())
+      self.assertEqual((2, 2), pmf.get_shape())
+
+  def testBinomialMean(self):
+    with self.test_session():
+      n = 5.
+      p = [0.1, 0.2, 0.7]
+      binom = tf.contrib.distributions.Binomial(n=n, p=p)
+      expected_means = stats.binom.mean(n, p)
+      self.assertEqual((3,), binom.mean().get_shape())
+      self.assertAllClose(expected_means, binom.mean().eval())
+
+  def testBinomialVariance(self):
+    with self.test_session():
+      n = 5.
+      p = [0.1, 0.2, 0.7]
+      binom = tf.contrib.distributions.Binomial(n=n, p=p)
+      expected_variances = stats.binom.var(n, p)
+      self.assertEqual((3,), binom.variance().get_shape())
+      self.assertAllClose(expected_variances, binom.variance().eval())
+
+  def testBinomialMode(self):
+    with self.test_session():
+      n = 5.
+      p = [0.1, 0.2, 0.7]
+      binom = tf.contrib.distributions.Binomial(n=n, p=p)
+      expected_modes = [0., 1, 4]
+      self.assertEqual((3,), binom.mode().get_shape())
+      self.assertAllClose(expected_modes, binom.mode().eval())
+
+  def testBinomialMultipleMode(self):
+    with self.test_session():
+      n = 9.
+      p = [0.1, 0.2, 0.7]
+      binom = tf.contrib.distributions.Binomial(n=n, p=p)
+      # For the case where (n + 1) * p is an integer, the modes are:
+      # (n + 1) * p and (n + 1) * p - 1. In this case, we get back
+      # the larger of the two modes.
+      expected_modes = [1., 2, 7]
+      self.assertEqual((3,), binom.mode().get_shape())
+      self.assertAllClose(expected_modes, binom.mode().eval())
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
index 1a3f5eaf66c..23833a246b9 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
@@ -65,7 +65,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
       dist.pmf([3., 0, 2]).eval()
       with self.assertRaisesOpError('Condition x >= 0.*'):
         dist.pmf([-1., 4, 2]).eval()
-      with self.assertRaisesOpError('Condition x == y.*'):
+      with self.assertRaisesOpError('counts do not sum to n'):
         dist.pmf([3., 3, 0]).eval()
 
   def testPmf_non_integer_counts(self):
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py
new file mode 100644
index 00000000000..55c7825bf3e
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py
@@ -0,0 +1,226 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+
+class MultinomialTest(tf.test.TestCase):
+
+  def testSimpleShapes(self):
+    with self.test_session():
+      p = [.1, .3, .6]
+      dist = tf.contrib.distributions.Multinomial(n=1., p=p)
+      self.assertEqual(3, dist.event_shape().eval())
+      self.assertAllEqual([], dist.batch_shape().eval())
+      self.assertEqual(tf.TensorShape([3]), dist.get_event_shape())
+      self.assertEqual(tf.TensorShape([]), dist.get_batch_shape())
+
+  def testComplexShapes(self):
+    with self.test_session():
+      p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
+      n = [[3., 2], [4, 5], [6, 7]]
+      dist = tf.contrib.distributions.Multinomial(n=n, p=p)
+      self.assertEqual(2, dist.event_shape().eval())
+      self.assertAllEqual([3, 2], dist.batch_shape().eval())
+      self.assertEqual(tf.TensorShape([2]), dist.get_event_shape())
+      self.assertEqual(tf.TensorShape([3, 2]), dist.get_batch_shape())
+
+  def testNProperty(self):
+    p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
+    n = [[3.], [4]]
+    with self.test_session():
+      dist = tf.contrib.distributions.Multinomial(n=n, p=p)
+      self.assertEqual((2, 1), dist.n.get_shape())
+      self.assertAllClose(n, dist.n.eval())
+
+  def testPProperty(self):
+    p = [[0.1, 0.2, 0.7]]
+    with self.test_session():
+      dist = tf.contrib.distributions.Multinomial(n=3., p=p)
+      self.assertEqual((1, 3), dist.p.get_shape())
+      self.assertEqual((1, 3), dist.logits.get_shape())
+      self.assertAllClose(p, dist.p.eval())
+
+  def testLogitsProperty(self):
+    logits = [[0., 9., -0.5]]
+    with self.test_session():
+      multinom = tf.contrib.distributions.Multinomial(n=3., logits=logits)
+      self.assertEqual((1, 3), multinom.p.get_shape())
+      self.assertEqual((1, 3), multinom.logits.get_shape())
+      self.assertAllClose(logits, multinom.logits.eval())
+
+  def testPmfNandCountsAgree(self):
+    p = [[0.1, 0.2, 0.7]]
+    n = [[5.]]
+    with self.test_session():
+      dist = tf.contrib.distributions.Multinomial(n=n, p=p)
+      dist.pmf([2., 3, 0]).eval()
+      dist.pmf([3., 0, 2]).eval()
+      with self.assertRaisesOpError('Condition x >= 0.*'):
+        dist.pmf([-1., 4, 2]).eval()
+      with self.assertRaisesOpError('counts do not sum to n'):
+        dist.pmf([3., 3, 0]).eval()
+
+  def testPmf_non_integer_counts(self):
+    p = [[0.1, 0.2, 0.7]]
+    n = [[5.]]
+    with self.test_session():
+      # No errors with integer n.
+      multinom = tf.contrib.distributions.Multinomial(n=n, p=p)
+      multinom.pmf([2., 1, 2]).eval()
+      multinom.pmf([3., 0, 2]).eval()
+      # Counts don't sum to n.
+      with self.assertRaisesOpError('counts do not sum to n'):
+        multinom.pmf([2., 3, 2]).eval()
+      # Counts are non-integers.
+      with self.assertRaisesOpError('Condition x == y.*'):
+        multinom.pmf([1.0, 2.5, 1.5]).eval()
+
+      multinom = tf.contrib.distributions.Multinomial(
+          n=n, p=p, validate_args=False)
+      multinom.pmf([1., 2., 2.]).eval()
+      # Non-integer arguments work.
+      multinom.pmf([1.0, 2.5, 1.5]).eval()
+
+  def testPmfBothZeroBatches(self):
+    with self.test_session():
+      # Both zero-batches.  No broadcast
+      p = [0.5, 0.5]
+      counts = [1., 0]
+      pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
+      self.assertAllClose(0.5, pmf.eval())
+      self.assertEqual((), pmf.get_shape())
+
+  def testPmfBothZeroBatchesNontrivialN(self):
+    with self.test_session():
+      # Both zero-batches.  No broadcast
+      p = [0.1, 0.9]
+      counts = [3., 2]
+      dist = tf.contrib.distributions.Multinomial(n=5., p=p)
+      pmf = dist.pmf(counts)
+      # 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000.
+      self.assertAllClose(81./10000, pmf.eval())
+      self.assertEqual((), pmf.get_shape())
+
+  def testPmfPStretchedInBroadcastWhenSameRank(self):
+    with self.test_session():
+      p = [[0.1, 0.9]]
+      counts = [[1., 0], [0, 1]]
+      pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
+      self.assertAllClose([0.1, 0.9], pmf.eval())
+      self.assertEqual((2), pmf.get_shape())
+
+  def testPmfPStretchedInBroadcastWhenLowerRank(self):
+    with self.test_session():
+      p = [0.1, 0.9]
+      counts = [[1., 0], [0, 1]]
+      pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
+      self.assertAllClose([0.1, 0.9], pmf.eval())
+      self.assertEqual((2), pmf.get_shape())
+
+  def testPmfCountsStretchedInBroadcastWhenSameRank(self):
+    with self.test_session():
+      p = [[0.1, 0.9], [0.7, 0.3]]
+      counts = [[1., 0]]
+      pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
+      self.assertAllClose(pmf.eval(), [0.1, 0.7])
+      self.assertEqual((2), pmf.get_shape())
+
+  def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
+    with self.test_session():
+      p = [[0.1, 0.9], [0.7, 0.3]]
+      counts = [1., 0]
+      pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
+      self.assertAllClose(pmf.eval(), [0.1, 0.7])
+      self.assertEqual(pmf.get_shape(), (2))
+
+  def testPmfShapeCountsStretched_N(self):
+    with self.test_session():
+      # [2, 2, 2]
+      p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]]
+      # [2, 2]
+      n = [[3., 3], [3, 3]]
+      # [2]
+      counts = [2., 1]
+      pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
+      pmf.eval()
+      self.assertEqual(pmf.get_shape(), (2, 2))
+
+  def testPmfShapeCountsPStretched_N(self):
+    with self.test_session():
+      p = [0.1, 0.9]
+      counts = [3., 2]
+      n = np.full([4, 3], 5., dtype=np.float32)
+      pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
+      pmf.eval()
+      self.assertEqual((4, 3), pmf.get_shape())
+
+  def testMultinomialMean(self):
+    with self.test_session():
+      n = 5.
+      p = [0.1, 0.2, 0.7]
+      dist = tf.contrib.distributions.Multinomial(n=n, p=p)
+      expected_means = 5 * np.array(p, dtype=np.float32)
+      self.assertEqual((3,), dist.mean().get_shape())
+      self.assertAllClose(expected_means, dist.mean().eval())
+
+  def testMultinomialVariance(self):
+    with self.test_session():
+      n = 5.
+      p = [0.1, 0.2, 0.7]
+      dist = tf.contrib.distributions.Multinomial(n=n, p=p)
+      expected_variances = [
+          [9./20, -1/10, -7/20], [-1/10, 4/5, -7/10], [-7/20, -7/10, 21/20]]
+      self.assertEqual((3, 3), dist.variance().get_shape())
+      self.assertAllClose(expected_variances, dist.variance().eval())
+
+  def testMultinomialVariance_batch(self):
+    with self.test_session():
+      # Shape [2]
+      n = [5.] * 2
+      # Shape [4, 1, 2]
+      p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2
+      dist = tf.contrib.distributions.Multinomial(n=n, p=p)
+      # Shape [2, 2]
+      inner_var = [[9./20, -9/20], [-9/20, 9/20]]
+      # Shape [4, 2, 2, 2]
+      expected_variances = [[inner_var, inner_var]] * 4
+      self.assertEqual((4, 2, 2, 2), dist.variance().get_shape())
+      self.assertAllClose(expected_variances, dist.variance().eval())
+
+  def testVariance_multidimensional(self):
+    # Shape [3, 5, 4]
+    p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
+    # Shape [6, 3, 3]
+    p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32)
+
+    ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32)
+    ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
+
+    with self.test_session():
+      dist = tf.contrib.distributions.Multinomial(ns, p)
+      dist2 = tf.contrib.distributions.Multinomial(ns2, p2)
+
+      variance = dist.variance()
+      variance2 = dist2.variance()
+      self.assertEqual((3, 5, 4, 4), variance.get_shape())
+      self.assertEqual((6, 3, 3, 3), variance2.get_shape())
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bernoulli.py b/tensorflow/contrib/distributions/python/ops/bernoulli.py
index fe5826e491f..1db599b3fea 100644
--- a/tensorflow/contrib/distributions/python/ops/bernoulli.py
+++ b/tensorflow/contrib/distributions/python/ops/bernoulli.py
@@ -19,15 +19,13 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.contrib.distributions.python.ops import distribution_util
 from tensorflow.contrib.distributions.python.ops import kullback_leibler  # pylint: disable=line-too-long
-from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
 from tensorflow.python.ops import random_ops
@@ -38,10 +36,6 @@ class Bernoulli(distribution.Distribution):
 
   The Bernoulli distribution is parameterized by p, the probability of a
   positive event.
-
-  Note, the following methods of the base class aren't implemented:
-    * cdf
-    * log_cdf
   """
 
   def __init__(self,
@@ -64,10 +58,10 @@ class Bernoulli(distribution.Distribution):
       dtype: dtype for samples.
       validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args,
        `log_pmf` may return nans.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: A name for this distribution.
 
     Raises:
@@ -77,27 +71,8 @@ class Bernoulli(distribution.Distribution):
     self._name = name
     self._dtype = dtype
     self._validate_args = validate_args
-    check_op = check_ops.assert_less_equal
-    if p is None and logits is None:
-      raise ValueError("Must pass p or logits.")
-    elif p is not None and logits is not None:
-      raise ValueError("Must pass either p or logits, not both.")
-    elif p is None:
-      with ops.op_scope([logits], name):
-        self._logits = array_ops.identity(logits, name="logits")
-      with ops.name_scope(name):
-        with ops.name_scope("p"):
-          self._p = math_ops.sigmoid(self._logits)
-    elif logits is None:
-      with ops.name_scope(name):
-        with ops.name_scope("p"):
-          p = array_ops.identity(p)
-          one = constant_op.constant(1., p.dtype)
-          zero = constant_op.constant(0., p.dtype)
-          self._p = control_flow_ops.with_dependencies(
-              [check_op(p, one), check_op(zero, p)] if validate_args else [], p)
-        with ops.name_scope("logits"):
-          self._logits = math_ops.log(self._p) - math_ops.log(1. - self._p)
+    self._logits, self._p = distribution_util.get_logits_and_prob(
+        name=name, logits=logits, p=p, validate_args=validate_args)
     with ops.name_scope(name):
       with ops.name_scope("q"):
         self._q = 1. - self._p
@@ -184,8 +159,12 @@ class Bernoulli(distribution.Distribution):
         event = ops.convert_to_tensor(event, name="event")
         event = math_ops.cast(event, self.logits.dtype)
         logits = self.logits
-        if ((event.get_shape().ndims is not None) or
-            (logits.get_shape().ndims is not None) or
+        # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
+        # so we do this here.
+        # TODO(b/30637701): Check dynamic shape, and don't broadcast if the
+        # dynamic shapes are the same.
+        if (not event.get_shape().is_fully_defined() or
+            not logits.get_shape().is_fully_defined() or
             event.get_shape() != logits.get_shape()):
           logits = array_ops.ones_like(event) * logits
           event = array_ops.ones_like(logits) * event
@@ -206,8 +185,7 @@ class Bernoulli(distribution.Distribution):
     with ops.name_scope(self.name):
       with ops.op_scope([self.p, n], name):
         n = ops.convert_to_tensor(n, name="n")
-        new_shape = array_ops.concat(
-            0, [array_ops.expand_dims(n, 0), self.batch_shape()])
+        new_shape = array_ops.concat(0, ([n], self.batch_shape()))
         uniform = random_ops.random_uniform(
             new_shape, seed=seed, dtype=dtypes.float32)
         sample = math_ops.less(uniform, self.p)
diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py
index 2bd64180682..fcf4a9056c3 100644
--- a/tensorflow/contrib/distributions/python/ops/beta.py
+++ b/tensorflow/contrib/distributions/python/ops/beta.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 # ==============================================================================
 """The Beta distribution class."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -95,6 +96,7 @@ class Beta(distribution.Distribution):
   x = [.2, .3, .9]
   dist.pdf(x)  # Shape [2]
   ```
+
   """
 
   def __init__(self, a, b, validate_args=True, allow_nan_stats=False,
@@ -102,20 +104,20 @@ class Beta(distribution.Distribution):
     """Initialize a batch of Beta distributions.
 
     Args:
-      a:  Positive `float` or `double` tensor with shape broadcastable to
+      a:  Positive floating point tensor with shape broadcastable to
         `[N1,..., Nm]` `m >= 0`.  Defines this as a batch of `N1 x ... x Nm`
          different Beta distributions. This also defines the
          dtype of the distribution.
-      b:  Positive `float` or `double` tensor with shape broadcastable to
+      b:  Positive floating point tensor with shape broadcastable to
         `[N1,..., Nm]` `m >= 0`.  Defines this as a batch of `N1 x ... x Nm`
          different Beta distributions.
       validate_args: Whether to assert valid values for parameters `a` and `b`,
-        and `x` in `prob` and `log_prob`.  If False, correct behavior is not
+        and `x` in `prob` and `log_prob`.  If `False`, correct behavior is not
         guaranteed.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to prefix Ops created by this distribution class.
 
     Examples:
@@ -127,6 +129,7 @@ class Beta(distribution.Distribution):
     # Define a 2-batch.
     dist = Beta([1.0, 2.0], [4.0, 5.0])
     ```
+
     """
     with ops.op_scope([a, b], name):
       with ops.control_dependencies([
@@ -276,8 +279,14 @@ class Beta(distribution.Distribution):
                array_ops.ones_like(a_b_sum, dtype=self.dtype)))
         else:
           return control_flow_ops.with_dependencies([
-              check_ops.assert_less(one, a),
-              check_ops.assert_less(one, b)], mode)
+              check_ops.assert_less(
+                  one, a,
+                  message="mode not defined for components of a <= 1"
+              ),
+              check_ops.assert_less(
+                  one, b,
+                  message="mode not defined for components of b <= 1"
+              )], mode)
 
   def entropy(self, name="entropy"):
     """Entropy of the distribution in nats."""
@@ -306,7 +315,7 @@ class Beta(distribution.Distribution):
     """`Log(P[counts])`, computed for every batch member.
 
     Args:
-      x:  Non-negative `float` or `double`, tensor whose shape can
+      x:  Non-negative floating point tensor whose shape can
         be broadcast with `self.a` and `self.b`.  For fixed leading
         dimensions, the last dimension represents counts for the corresponding
         Beta distribution in `self.a` and `self.b`. `x` is only legal if
@@ -334,7 +343,7 @@ class Beta(distribution.Distribution):
     """`P[x]`, computed for every batch member.
 
     Args:
-      x:  Non-negative `float`, `double` tensor whose shape can
+      x:  Non-negative floating point tensor whose shape can
         be broadcast with `self.a` and `self.b`.  For fixed leading
         dimensions, the last dimension represents x for the corresponding Beta
         distribution in `self.a` and `self.b`. `x` is only legal if is
diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py
new file mode 100644
index 00000000000..9978d0ad613
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/binomial.py
@@ -0,0 +1,340 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Binomial distribution class."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=line-too-long
+
+from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+
+# pylint: enable=line-too-long
+
+
+class Binomial(distribution.Distribution):
+  """Binomial distribution.
+
+  This distribution is parameterized by a vector `p` of probabilities and `n`,
+  the total counts.
+
+  #### Mathematical details
+
+  The Binomial is a distribution over the number of successes in `n` independent
+  trials, with each trial having the same probability of success `p`.
+  The probability mass function (pmf):
+
+  ```pmf(k) = n! / (k! * (n - k)!) * (p)^k * (1 - p)^(n - k)```
+
+  #### Examples
+
+  Create a single distribution, corresponding to 5 coin flips.
+
+  ```python
+  dist = Binomial(n=5., p=.5)
+  ```
+
+  Create a single distribution (using logits), corresponding to 5 coin flips.
+
+  ```python
+  dist = Binomial(n=5., logits=0.)
+  ```
+
+  Creates 3 distributions with the third distribution most likely to have
+  successes.
+
+  ```python
+  p = [.2, .3, .8]
+  # n will be broadcast to [4., 4., 4.], to match p.
+  dist = Binomial(n=4., p=p)
+  ```
+
+  The distribution functions can be evaluated on counts.
+
+  ```python
+  # counts same shape as p.
+  counts = [1., 2, 3]
+  dist.prob(counts)  # Shape [3]
+
+  # p will be broadcast to [[.2, .3, .8], [.2, .3, .8]] to match counts.
+  counts = [[1., 2, 1], [2, 2, 4]]
+  dist.prob(counts)  # Shape [2, 3]
+
+  # p will be broadcast to shape [5, 7, 3] to match counts.
+  counts = [[...]]  # Shape [5, 7, 3]
+  dist.prob(counts)  # Shape [5, 7, 3]
+  ```
+  """
+
+  def __init__(self,
+               n,
+               logits=None,
+               p=None,
+               validate_args=True,
+               allow_nan_stats=False,
+               name="Binomial"):
+    """Initialize a batch of Binomial distributions.
+
+    Args:
+      n:  Non-negative floating point tensor with shape broadcastable to
+        `[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`.
+        Defines this as a batch of `N1 x ... x Nm` different Binomial
+        distributions. Its components should be equal to integer values.
+      logits: Floating point tensor representing the log-odds of a
+        positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
+        the same dtype as `n`. Each entry represents logits for the probability
+        of success for independent Binomial distributions.
+      p:  Positive floating point tensor with shape broadcastable to
+        `[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the
+        probability of success for independent Binomial distributions.
+      validate_args: Whether to assert valid values for parameters `n` and `p`,
+        and `x` in `prob` and `log_prob`.  If `False`, correct behavior is not
+        guaranteed.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
+      name: The name to prefix Ops created by this distribution class.
+
+    Examples:
+
+    ```python
+    # Define 1-batch of a binomial distribution.
+    dist = Binomial(n=2., p=.9)
+
+    # Define a 2-batch.
+    dist = Binomial(n=[4., 5], p=[.1, .3])
+    ```
+
+    """
+
+    self._logits, self._p = distribution_util.get_logits_and_prob(
+        name=name, logits=logits, p=p, validate_args=validate_args)
+
+    with ops.op_scope([n], name):
+      with ops.control_dependencies([
+          check_ops.assert_non_negative(
+              n, message="n has negative components."),
+          distribution_util.assert_integer_form(
+              n, message="n has non-integer components."
+          )] if validate_args else []):
+        self._n = array_ops.identity(n, name="convert_n")
+
+        self._name = name
+        self._validate_args = validate_args
+        self._allow_nan_stats = allow_nan_stats
+
+        self._mean = self._n * self._p
+        self._get_batch_shape = self._mean.get_shape()
+        self._get_event_shape = tensor_shape.TensorShape([])
+
+  @property
+  def name(self):
+    """Name to prepend to all ops."""
+    return self._name
+
+  @property
+  def dtype(self):
+    """dtype of samples from this distribution."""
+    return self._p.dtype
+
+  @property
+  def validate_args(self):
+    """Boolean describing behavior on invalid input."""
+    return self._validate_args
+
+  @property
+  def allow_nan_stats(self):
+    """Boolean describing behavior when a stat is undefined for batch member."""
+    return self._allow_nan_stats
+
+  def batch_shape(self, name="batch_shape"):
+    """Batch dimensions of this instance as a 1-D int32 `Tensor`.
+
+    The product of the dimensions of the `batch_shape` is the number of
+    independent distributions of this kind the instance represents.
+
+    Args:
+      name: name to give to the op
+
+    Returns:
+      `Tensor` `batch_shape`
+    """
+    return array_ops.shape(self._mean)
+
+  def get_batch_shape(self):
+    """`TensorShape` available at graph construction time.
+
+    Same meaning as `batch_shape`. May be only partially defined.
+
+    Returns:
+      batch shape
+    """
+    return self._get_batch_shape
+
+  def event_shape(self, name="event_shape"):
+    """Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
+
+    Args:
+      name: name to give to the op
+
+    Returns:
+      `Tensor` `event_shape`
+    """
+    with ops.name_scope(self.name):
+      with ops.op_scope([], name):
+        return constant_op.constant([], name=name, dtype=dtypes.int32)
+
+  def get_event_shape(self):
+    """`TensorShape` available at graph construction time.
+
+    Same meaning as `event_shape`. May be only partially defined.
+
+    Returns:
+      event shape
+    """
+    return self._get_event_shape
+
+  @property
+  def n(self):
+    """Number of trials."""
+    return self._n
+
+  @property
+  def logits(self):
+    """Log-odds."""
+    return self._logits
+
+  @property
+  def p(self):
+    """Probability of success."""
+    return self._p
+
+  def mean(self, name="mean"):
+    """Mean of the distribution."""
+    with ops.name_scope(self.name):
+      return array_ops.identity(self._mean, name=name)
+
+  def variance(self, name="variance"):
+    """Variance of the distribution."""
+    with ops.name_scope(self.name):
+      with ops.op_scope([self._n, self._p], name):
+        return self._n * self._p * (1 - self._p)
+
+  def std(self, name="std"):
+    """Standard deviation of the distribution."""
+    with ops.name_scope(self.name):
+      with ops.op_scope([self._n, self._p], name):
+        return math_ops.sqrt(self.variance())
+
+  def mode(self, name="mode"):
+    """Mode of the distribution.
+
+    Note that when `(n + 1) * p` is an integer, there are actually two modes.
+    Namely, `(n + 1) * p` and `(n + 1) * p - 1` are both modes. Here we return
+    only the larger of the two modes.
+
+    Args:
+      name: The name for this op.
+
+    Returns:
+      The mode of the Binomial distribution.
+    """
+    with ops.name_scope(self.name):
+      with ops.op_scope([self._n, self._p], name):
+        return math_ops.floor((self._n + 1) * self._p)
+
+  def log_prob(self, counts, name="log_prob"):
+    """`Log(P[counts])`, computed for every batch member.
+
+    For each batch member of counts `k`, `P[counts]` is the probability that
+    after sampling `n` draws from this Binomial distribution, the number of
+    successes is `k`.  Note that different sequences of draws can result in the
+    same counts, thus the probability includes a combinatorial coefficient.
+
+    Args:
+      counts:  Non-negative tensor with dtype `dtype` and whose shape can be
+        broadcast with `self.p` and `self.n`. `counts` is only legal if it is
+        less than or equal to `n` and its components are equal to integer
+        values.
+      name:  Name to give this Op, defaults to "log_prob".
+
+    Returns:
+      Log probabilities for each record, shape `[N1,...,Nm]`.
+    """
+    n = self._n
+    p = self._p
+    with ops.name_scope(self.name):
+      with ops.op_scope([self._n, self._p, counts], name):
+        counts = self._check_counts(counts)
+
+        prob_prob = counts * math_ops.log(p) + (
+            n - counts) * math_ops.log(1 - p)
+
+        combinations = math_ops.lgamma(n + 1) - math_ops.lgamma(
+            counts + 1) - math_ops.lgamma(n - counts + 1)
+        log_prob = prob_prob + combinations
+        return log_prob
+
+  def prob(self, counts, name="prob"):
+    """`P[counts]`, computed for every batch member.
+
+
+    For each batch member of counts `k`, `P[counts]` is the probability that
+    after sampling `n` draws from this Binomial distribution, the number of
+    successes is `k`.  Note that different sequences of draws can result in the
+    same counts, thus the probability includes a combinatorial coefficient.
+
+    Args:
+      counts:  Non-negative tensor with dtype `dtype` and whose shape can be
+        broadcast with `self.p` and `self.n`. `counts` is only legal if it is
+        less than or equal to `n` and its components are equal to integer
+        values.
+      name:  Name to give this Op, defaults to "prob".
+
+    Returns:
+      Probabilities for each record, shape `[N1,...,Nm]`.
+    """
+    return super(Binomial, self).prob(counts, name=name)
+
+  @property
+  def is_continuous(self):
+    return False
+
+  @property
+  def is_reparameterized(self):
+    return False
+
+  def _check_counts(self, counts):
+    """Check counts for proper shape, values, then return tensor version."""
+    counts = ops.convert_to_tensor(counts, name="counts_before_deps")
+    if not self.validate_args:
+      return counts
+    return control_flow_ops.with_dependencies([
+        check_ops.assert_non_negative(
+            counts, message="counts has negative components."),
+        check_ops.assert_less_equal(
+            counts, self._n, message="counts are not less than or equal to n."),
+        distribution_util.assert_integer_form(
+            counts, message="counts have non-integer components.")], counts)
diff --git a/tensorflow/contrib/distributions/python/ops/categorical.py b/tensorflow/contrib/distributions/python/ops/categorical.py
index 64572ed7885..e79a732a0c9 100644
--- a/tensorflow/contrib/distributions/python/ops/categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/categorical.py
@@ -34,11 +34,6 @@ class Categorical(distribution.Distribution):
 
   The categorical distribution is parameterized by the log-probabilities
   of a set of classes.
-
-  Note, the following methods of the base class aren't implemented:
-    * mean
-    * cdf
-    * log_cdf
   """
 
   def __init__(
@@ -57,10 +52,10 @@ class Categorical(distribution.Distribution):
           indexes into the classes.
       dtype: The type of the event samples (default: int32).
       validate_args: Unused in this distribution.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: A name for this distribution (optional).
     """
     self._allow_nan_stats = allow_nan_stats
@@ -177,8 +172,7 @@ class Categorical(distribution.Distribution):
         samples = math_ops.cast(samples, self._dtype)
         ret = array_ops.reshape(
             array_ops.transpose(samples),
-            array_ops.concat(
-                0, [array_ops.expand_dims(n, 0), self.batch_shape()]))
+            array_ops.concat(0, ([n], self.batch_shape())))
         ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
                       .concatenate(self.get_batch_shape()))
         return ret
diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py
index 65840373f12..e09ef6324b8 100644
--- a/tensorflow/contrib/distributions/python/ops/chi2.py
+++ b/tensorflow/contrib/distributions/python/ops/chi2.py
@@ -42,15 +42,15 @@ class Chi2(gamma.Gamma):
     """Construct Chi2 distributions with parameter `df`.
 
     Args:
-      df: `float` or `double` tensor, the degrees of freedom of the
+      df: Floating point tensor, the degrees of freedom of the
         distribution(s).  `df` must contain only positive values.
       validate_args: Whether to assert that `df > 0`, and that `x > 0` in the
-        methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
+        methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
         and the inputs are invalid, correct behavior is not guaranteed.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to prepend to all ops created by this distribution.
     """
     # Even though all stats of chi2 are defined for valid parameters, this is
diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet.py b/tensorflow/contrib/distributions/python/ops/dirichlet.py
index b4f59d5bd8c..25aee5cf03e 100644
--- a/tensorflow/contrib/distributions/python/ops/dirichlet.py
+++ b/tensorflow/contrib/distributions/python/ops/dirichlet.py
@@ -19,9 +19,8 @@ from __future__ import print_function
 
 # pylint: disable=line-too-long
 
-import numpy as np
-
 from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.contrib.distributions.python.ops import distribution_util
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
@@ -29,7 +28,6 @@ from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import special_math_ops
@@ -37,24 +35,6 @@ from tensorflow.python.ops import special_math_ops
 # pylint: enable=line-too-long
 
 
-def _assert_close(x, y, data=None, summarize=None, name=None):
-  if x.dtype.is_integer:
-    return check_ops.assert_equal(
-        x, y, data=data, summarize=summarize, name=name)
-
-  with ops.op_scope([x, y, data], name, "assert_close"):
-    x = ops.convert_to_tensor(x, name="x")
-    y = ops.convert_to_tensor(y, name="y")
-    tol = np.finfo(x.dtype.as_numpy_dtype).resolution
-    if data is None:
-      data = [
-          "Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
-          y.name, y
-      ]
-    condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
-    return logging_ops.Assert(condition, data, summarize=summarize)
-
-
 class Dirichlet(distribution.Distribution):
   """Dirichlet distribution.
 
@@ -117,6 +97,7 @@ class Dirichlet(distribution.Distribution):
   x = [.2, .3, .5]
   dist.prob(x)  # Shape [2]
   ```
+
   """
 
   def __init__(self,
@@ -127,16 +108,16 @@ class Dirichlet(distribution.Distribution):
     """Initialize a batch of Dirichlet distributions.
 
     Args:
-      alpha:  Positive `float` or `double` tensor with shape broadcastable to
+      alpha:  Positive floating point tensor with shape broadcastable to
         `[N1,..., Nm, k]` `m >= 0`.  Defines this as a batch of `N1 x ... x Nm`
          different `k` class Dirichlet distributions.
       validate_args: Whether to assert valid values for parameters `alpha` and
-        `x` in `prob` and `log_prob`.  If False, correct behavior is not
+        `x` in `prob` and `log_prob`.  If `False`, correct behavior is not
         guaranteed.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to prefix Ops created by this distribution class.
 
     Examples:
@@ -149,6 +130,7 @@ class Dirichlet(distribution.Distribution):
     # Define a 2-batch of 3-class distributions.
     dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
     ```
+
     """
     with ops.op_scope([alpha], name):
       alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps")
@@ -302,7 +284,9 @@ class Dirichlet(distribution.Distribution):
                array_ops.ones_like(self._alpha, dtype=self.dtype)))
         else:
           return control_flow_ops.with_dependencies([
-              check_ops.assert_less(one, self._alpha)
+              check_ops.assert_less(
+                  one, self._alpha,
+                  message="mode not defined for components of alpha <= 1")
           ], mode)
 
   def entropy(self, name="entropy"):
@@ -334,7 +318,7 @@ class Dirichlet(distribution.Distribution):
     """`Log(P[counts])`, computed for every batch member.
 
     Args:
-      x:  Non-negative `float` or `double`, tensor whose shape can
+      x:  Non-negative tensor with dtype `dtype` and whose shape can
         be broadcast with `self.alpha`.  For fixed leading dimensions, the last
         dimension represents counts for the corresponding Dirichlet distribution
         in `self.alpha`. `x` is only legal if it sums up to one.
@@ -359,7 +343,7 @@ class Dirichlet(distribution.Distribution):
     """`P[x]`, computed for every batch member.
 
     Args:
-      x:  Non-negative `float`, `double` tensor whose shape can
+      x:  Non-negative tensor with dtype `dtype` and whose shape can
         be broadcast with `self.alpha`.  For fixed leading dimensions, the last
         dimension represents x for the corresponding Dirichlet distribution in
         `self.alpha` and `self.beta`. `x` is only legal if it sums up to one.
@@ -407,7 +391,8 @@ class Dirichlet(distribution.Distribution):
     x = ops.convert_to_tensor(x, name="x_before_deps")
     candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1])
     one = constant_op.constant(1., self.dtype)
-    dependencies = [check_ops.assert_positive(x), check_ops.assert_less(x, one),
-                    _assert_close(one, candidate_one)
+    dependencies = [check_ops.assert_positive(x), check_ops.assert_less(
+        x, one, message="x has components greater than or equal to 1"),
+                    distribution_util.assert_close(one, candidate_one)
                    ] if self.validate_args else []
     return control_flow_ops.with_dependencies(dependencies, x)
diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
index 7c779fff065..67cdd566c67 100644
--- a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
+++ b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
@@ -13,13 +13,15 @@
 # limitations under the License.
 # ==============================================================================
 """The Dirichlet Multinomial distribution class."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
 # pylint: disable=line-too-long
 
-from tensorflow.contrib.distributions.python.ops import distribution  # pylint: disable=line-too-long
+from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.contrib.distributions.python.ops import distribution_util
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
@@ -30,34 +32,6 @@ from tensorflow.python.ops import special_math_ops
 # pylint: enable=line-too-long
 
 
-def _assert_integer_form(x):
-  """Check x for integer components (or floats that are equal to integers)."""
-  x = ops.convert_to_tensor(x, name='x')
-  casted_x = math_ops.to_int64(x)
-  return check_ops.assert_equal(x, math_ops.cast(
-      math_ops.round(casted_x), x.dtype))
-
-
-def _log_combinations(n, counts, name='log_combinations'):
-  """Log number of ways counts could have come in."""
-  # First a bit about the number of ways counts could have come in:
-  # E.g. if counts = [1, 2], then this is 3 choose 2.
-  # In general, this is (sum counts)! / sum(counts!)
-  # The sum should be along the last dimension of counts.  This is the
-  # "distribution" dimension. Here n a priori represents the sum of counts.
-  with ops.op_scope([counts], name):
-    # To compute factorials, use the fact that Gamma(n + 1) = n!
-    # Compute two terms, each a sum over counts.  Compute each for each
-    # batch member.
-    # Log Gamma((sum counts) + 1) = Log((sum counts)!)
-    total_permutations = math_ops.lgamma(n + 1)
-    # sum(Log Gamma(counts + 1)) = Log sum(counts!)
-    counts_factorial = math_ops.lgamma(counts + 1)
-    redundant_permutations = math_ops.reduce_sum(counts_factorial,
-                                                 reduction_indices=[-1])
-    return total_permutations - redundant_permutations
-
-
 class DirichletMultinomial(distribution.Distribution):
   """DirichletMultinomial mixture distribution.
 
@@ -126,6 +100,7 @@ class DirichletMultinomial(distribution.Distribution):
   counts = [2, 1, 0]
   dist.pmf(counts)  # Shape [2]
   ```
+
   """
 
   # TODO(b/27419586) Change docstring for dtype of alpha once int allowed.
@@ -134,26 +109,26 @@ class DirichletMultinomial(distribution.Distribution):
                alpha,
                validate_args=True,
                allow_nan_stats=False,
-               name='DirichletMultinomial'):
+               name="DirichletMultinomial"):
     """Initialize a batch of DirichletMultinomial distributions.
 
     Args:
-      n:  Non-negative `float` or `double` tensor, whose dtype is the same as
+      n:  Non-negative floating point tensor, whose dtype is the same as
         `alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`.
         Defines this as a batch of `N1 x ... x Nm` different Dirichlet
-        multinomial distributions. Its components should be equal to integral
+        multinomial distributions. Its components should be equal to integer
         values.
-      alpha:  Positive `float` or `double` tensor, whose dtype is the same as
+      alpha: Positive floating point tensor, whose dtype is the same as
         `n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`.  Defines
         this as a batch of `N1 x ... x Nm` different `k` class Dirichlet
         multinomial distributions.
       validate_args: Whether to assert valid values for parameters `alpha` and
-        `n`, and `x` in `prob` and `log_prob`.  If False, correct behavior is
+        `n`, and `x` in `prob` and `log_prob`.  If `False`, correct behavior is
         not guaranteed.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to prefix Ops created by this distribution class.
 
     Examples:
@@ -166,6 +141,7 @@ class DirichletMultinomial(distribution.Distribution):
     # Define a 2-batch of 3-class distributions.
     dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
     ```
+
     """
     self._allow_nan_stats = allow_nan_stats
     self._validate_args = validate_args
@@ -221,7 +197,7 @@ class DirichletMultinomial(distribution.Distribution):
     """dtype of samples from this distribution."""
     return self._alpha.dtype
 
-  def mean(self, name='mean'):
+  def mean(self, name="mean"):
     """Class means for every batch member."""
     alpha = self._alpha
     alpha_sum = self._alpha_sum
@@ -231,7 +207,7 @@ class DirichletMultinomial(distribution.Distribution):
         mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1)
         return array_ops.expand_dims(n, -1) * mean_no_n
 
-  def variance(self, name='mean'):
+  def variance(self, name="mean"):
     """Class variances for every batch member.
 
     The variance for each batch member is defined as the following:
@@ -273,7 +249,7 @@ class DirichletMultinomial(distribution.Distribution):
         variance *= array_ops.expand_dims(shared_factor, -1)
         return variance
 
-  def batch_shape(self, name='batch_shape'):
+  def batch_shape(self, name="batch_shape"):
     """Batch dimensions of this instance as a 1-D int32 `Tensor`.
 
     The product of the dimensions of the `batch_shape` is the number of
@@ -299,7 +275,7 @@ class DirichletMultinomial(distribution.Distribution):
     """
     return self._get_batch_shape
 
-  def event_shape(self, name='event_shape'):
+  def event_shape(self, name="event_shape"):
     """Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
 
     Args:
@@ -322,15 +298,15 @@ class DirichletMultinomial(distribution.Distribution):
     """
     return self._get_event_shape
 
-  def cdf(self, x, name='cdf'):
+  def cdf(self, x, name="cdf"):
     raise NotImplementedError(
-        'DirichletMultinomial does not have a well-defined cdf.')
+        "DirichletMultinomial does not have a well-defined cdf.")
 
-  def log_cdf(self, x, name='log_cdf'):
+  def log_cdf(self, x, name="log_cdf"):
     raise NotImplementedError(
-        'DirichletMultinomial does not have a well-defined cdf.')
+        "DirichletMultinomial does not have a well-defined cdf.")
 
-  def log_prob(self, counts, name='log_prob'):
+  def log_prob(self, counts, name="log_prob"):
     """`Log(P[counts])`, computed for every batch member.
 
     For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
@@ -340,12 +316,11 @@ class DirichletMultinomial(distribution.Distribution):
     probability includes a combinatorial coefficient.
 
     Args:
-      counts:  Non-negative `float` or `double` tensor whose dtype is the same
-        `self` and whose shape can be broadcast with `self.alpha`.  For fixed
-        leading dimensions, the last dimension represents counts for the
-        corresponding Dirichlet Multinomial distribution in `self.alpha`.
-        `counts` is only legal if it sums up to `n` and its components are
-        equal to integral values.
+      counts:  Non-negative tensor with dtype `dtype` and whose shape can be
+        broadcast with `self.alpha`.  For fixed leading dimensions, the last
+        dimension represents counts for the corresponding Dirichlet Multinomial
+        distribution in `self.alpha`. `counts` is only legal if it sums up to
+        `n` and its components are equal to integer values.
       name:  Name to give this Op, defaults to "log_prob".
 
     Returns:
@@ -359,20 +334,11 @@ class DirichletMultinomial(distribution.Distribution):
 
         ordered_prob = (special_math_ops.lbeta(alpha + counts) -
                         special_math_ops.lbeta(alpha))
-        log_prob = ordered_prob + _log_combinations(n, counts)
-        # If alpha = counts = [[]], ordered_prob carries the right shape, which
-        # is [].  However, since reduce_sum([[]]) = [0], log_combinations = [0],
-        # which is not correct.  Luckily, [] + [0] = [], so the sum is fine, but
-        # shape must be inferred from ordered_prob. We must also make this
-        # broadcastable with n, so this is multiplied by n to ensure the shape
-        # is correctly inferred.
-        # Note also that tf.constant([]).get_shape() =
-        # TensorShape([Dimension(0)])
-        broadcasted_tensor = ordered_prob * n
-        log_prob.set_shape(broadcasted_tensor.get_shape())
+        log_prob = ordered_prob + distribution_util.log_combinations(
+            n, counts)
         return log_prob
 
-  def prob(self, counts, name='prob'):
+  def prob(self, counts, name="prob"):
     """`P[counts]`, computed for every batch member.
 
     For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability
@@ -382,12 +348,11 @@ class DirichletMultinomial(distribution.Distribution):
     probability includes a combinatorial coefficient.
 
     Args:
-      counts:  Non-negative `float` or `double` tensor whose dtype is the same
-        `self` and whose shape can be broadcast with `self.alpha`.  For fixed
-        leading dimensions, the last dimension represents counts for the
-        corresponding Dirichlet Multinomial distribution in `self.alpha`.
-        `counts` is only legal if it sums up to `n` and its components are
-        equal to integral values.
+      counts:  Non-negative tensor with dtype `dtype` and whose shape can be
+        broadcast with `self.alpha`.  For fixed leading dimensions, the last
+        dimension represents counts for the corresponding Dirichlet Multinomial
+        distribution in `self.alpha`. `counts` is only legal if it sums up to
+        `n` and its components are equal to integer values.
       name:  Name to give this Op, defaults to "prob".
 
     Returns:
@@ -397,18 +362,21 @@ class DirichletMultinomial(distribution.Distribution):
 
   def _check_counts(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
-    counts = ops.convert_to_tensor(counts, name='counts')
+    counts = ops.convert_to_tensor(counts, name="counts")
     if not self.validate_args:
       return counts
     candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
 
     return control_flow_ops.with_dependencies([
         check_ops.assert_non_negative(counts),
-        check_ops.assert_equal(self._n, candidate_n),
-        _assert_integer_form(counts)], counts)
+        check_ops.assert_equal(
+            self._n, candidate_n,
+            message="counts do not sum to n"
+        ),
+        distribution_util.assert_integer_form(counts)], counts)
 
   def _check_alpha(self, alpha):
-    alpha = ops.convert_to_tensor(alpha, name='alpha')
+    alpha = ops.convert_to_tensor(alpha, name="alpha")
     if not self.validate_args:
       return alpha
     return control_flow_ops.with_dependencies(
@@ -416,11 +384,12 @@ class DirichletMultinomial(distribution.Distribution):
          check_ops.assert_positive(alpha)], alpha)
 
   def _check_n(self, n):
-    n = ops.convert_to_tensor(n, name='n')
+    n = ops.convert_to_tensor(n, name="n")
     if not self.validate_args:
       return n
     return control_flow_ops.with_dependencies(
-        [check_ops.assert_non_negative(n), _assert_integer_form(n)], n)
+        [check_ops.assert_non_negative(n),
+         distribution_util.assert_integer_form(n)], n)
 
   @property
   def is_continuous(self):
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
new file mode 100644
index 00000000000..9c751270032
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -0,0 +1,177 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for probability distributions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+
+
+def assert_close(
+    x, y, data=None, summarize=None, message=None, name="assert_close"):
+  """Assert that that x and y are within machine epsilon of each other.
+
+  Args:
+    x: Numeric `Tensor`
+    y: Numeric `Tensor`
+    data: The tensors to print out if the condition is `False`. Defaults to
+      error message and first few entries of `x` and `y`.
+    summarize: Print this many entries of each tensor.
+    message: A string to prefix to the default message.
+    name: A name for this operation (optional).
+
+  Returns:
+    Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
+  """
+  message = message or ""
+  x = ops.convert_to_tensor(x, name="x")
+  y = ops.convert_to_tensor(y, name="y")
+
+  if x.dtype.is_integer:
+    return check_ops.assert_equal(
+        x, y, data=data, summarize=summarize, message=message, name=name)
+
+  with ops.op_scope([x, y, data], name, "assert_close"):
+    tol = np.finfo(x.dtype.as_numpy_dtype).resolution
+    if data is None:
+      data = [
+          message,
+          "Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
+          y.name, y
+      ]
+    condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
+    return logging_ops.Assert(
+        condition, data, summarize=summarize)
+
+
+def assert_integer_form(
+    x, data=None, summarize=None, message=None, name="assert_integer_form"):
+  """Assert that x has integer components (or floats equal to integers).
+
+  Args:
+    x: Numeric `Tensor`
+    data: The tensors to print out if the condition is `False`. Defaults to
+      error message and first few entries of `x` and `y`.
+    summarize: Print this many entries of each tensor.
+    message: A string to prefix to the default message.
+    name: A name for this operation (optional).
+
+  Returns:
+    Op raising `InvalidArgumentError` if round(x) != x.
+  """
+
+  message = message or "x has non-integer components"
+  x = ops.convert_to_tensor(x, name="x")
+  casted_x = math_ops.to_int64(x)
+  return check_ops.assert_equal(
+      x, math_ops.cast(math_ops.round(casted_x), x.dtype),
+      data=data, summarize=summarize, message=message, name=name)
+
+
+def get_logits_and_prob(
+    logits=None, p=None, multidimensional=False, validate_args=True, name=None):
+  """Converts logits to probabilities and vice-versa, and returns both.
+
+  Args:
+    logits: Numeric `Tensor` representing log-odds.
+    p: Numeric `Tensor` representing probabilities.
+    multidimensional: Given `p` a [N1, N2, ... k] dimensional tensor,
+      whether the last dimension represents the probability between k classes.
+      This will additionally assert that the values in the last dimension
+      sum to one. If `False`, will instead assert that each value is in
+      `[0, 1]`.
+    validate_args: Whether to assert `0 <= p <= 1` if multidimensional is
+      `False`, otherwise that the last dimension of `p` sums to one.
+    name: A name for this operation (optional).
+
+  Returns:
+    Tuple with `logits` and `p`. If `p` has an entry that is `0` or `1`, then
+    the corresponding entry in the returned logits will be `-Inf` and `Inf`
+    respectively.
+
+  Raises:
+    ValueError: if neither `p` nor `logits` were passed in, or both were.
+  """
+  if p is None and logits is None:
+    raise ValueError("Must pass p or logits.")
+  elif p is not None and logits is not None:
+    raise ValueError("Must pass either p or logits, not both.")
+  elif p is None:
+    with ops.op_scope([logits], name):
+      logits = array_ops.identity(logits, name="logits")
+    with ops.name_scope(name):
+      with ops.name_scope("p"):
+        p = math_ops.sigmoid(logits)
+  elif logits is None:
+    with ops.name_scope(name):
+      with ops.name_scope("p"):
+        p = array_ops.identity(p)
+        if validate_args:
+          one = constant_op.constant(1., p.dtype)
+          dependencies = [check_ops.assert_non_negative(p)]
+          if multidimensional:
+            dependencies += [assert_close(
+                math_ops.reduce_sum(p, reduction_indices=[-1]),
+                one, message="p does not sum to 1.")]
+          else:
+            dependencies += [check_ops.assert_less_equal(
+                p, one, message="p has components greater than 1.")]
+          p = control_flow_ops.with_dependencies(dependencies, p)
+      with ops.name_scope("logits"):
+        logits = math_ops.log(p) - math_ops.log(1. - p)
+  return (logits, p)
+
+
+def log_combinations(n, counts, name="log_combinations"):
+  """Multinomial coefficient.
+
+  Given `n` and `counts`, where `counts` has last dimension `k`, we compute
+  the multinomial coefficient as:
+
+  ```n! / sum_i n_i!```
+
+  where `i` runs over all `k` classes.
+
+  Args:
+    n: Numeric `Tensor` broadcastable with `counts`. This represents `n`
+      outcomes.
+    counts: Numeric `Tensor` broadcastable with `n`. This represents counts
+      in `k` classes, where `k` is the last dimension of the tensor.
+    name: A name for this operation (optional).
+
+  Returns:
+    `Tensor` representing the multinomial coefficient between `n` and `counts`.
+  """
+  # First a bit about the number of ways counts could have come in:
+  # E.g. if counts = [1, 2], then this is 3 choose 2.
+  # In general, this is (sum counts)! / sum(counts!)
+  # The sum should be along the last dimension of counts.  This is the
+  # "distribution" dimension. Here n a priori represents the sum of counts.
+  with ops.op_scope([n, counts], name):
+    total_permutations = math_ops.lgamma(n + 1)
+    counts_factorial = math_ops.lgamma(counts + 1)
+    redundant_permutations = math_ops.reduce_sum(counts_factorial,
+                                                 reduction_indices=[-1])
+    return total_permutations - redundant_permutations
diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py
index c49b3eeba8d..c1a7eb025ef 100644
--- a/tensorflow/contrib/distributions/python/ops/exponential.py
+++ b/tensorflow/contrib/distributions/python/ops/exponential.py
@@ -46,15 +46,15 @@ class Exponential(gamma.Gamma):
     """Construct Exponential distribution with parameter `lam`.
 
     Args:
-      lam: `float` or `double` tensor, the rate of the distribution(s).
+      lam: Floating point tensor, the rate of the distribution(s).
         `lam` must contain only positive values.
       validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the
-        methods `prob(x)` and `log_prob(x)`.  If `validate_args` is False
+        methods `prob(x)` and `log_prob(x)`.  If `validate_args` is `False`
         and the inputs are invalid, correct behavior is not guaranteed.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member. If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to prepend to all ops created by this distribution.
     """
     # Even though all statistics of are defined for valid inputs, this is not
@@ -95,8 +95,7 @@ class Exponential(gamma.Gamma):
     broadcast_shape = self._lam.get_shape()
     with ops.op_scope([self.lam, n], name, "ExponentialSample"):
       n = ops.convert_to_tensor(n, name="n")
-      shape = array_ops.concat(
-          0, [array_ops.pack([n]), array_ops.shape(self._lam)])
+      shape = array_ops.concat(0, ([n], array_ops.shape(self._lam)))
       # Sample uniformly-at-random from the open-interval (0, 1).
       sampled = random_ops.random_uniform(
           shape, minval=np.nextafter(
diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py
index 1f733ceda16..6bd93877613 100644
--- a/tensorflow/contrib/distributions/python/ops/gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/gamma.py
@@ -69,19 +69,19 @@ class Gamma(distribution.Distribution):
     broadcasting (e.g. `alpha + beta` is a valid operation).
 
     Args:
-      alpha: `float` or `double` tensor, the shape params of the
+      alpha: Floating point tensor, the shape params of the
         distribution(s).
         alpha must contain only positive values.
-      beta: `float` or `double` tensor, the inverse scale params of the
+      beta: Floating point tensor, the inverse scale params of the
         distribution(s).
         beta must contain only positive values.
       validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
-        the methods `prob(x)` and `log_prob(x)`.  If `validate_args` is False
+        the methods `prob(x)` and `log_prob(x)`.  If `validate_args` is `False`
         and the inputs are invalid, correct behavior is not guaranteed.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to prepend to all ops created by this distribution.
 
     Raises:
@@ -213,9 +213,12 @@ class Gamma(distribution.Distribution):
           nan = np.nan * self._ones()
           return math_ops.select(alpha_ge_1, mode_if_defined, nan)
         else:
-          one = ops.convert_to_tensor(1.0, dtype=self.dtype)
+          one = constant_op.constant(1.0, dtype=self.dtype)
           return control_flow_ops.with_dependencies(
-              [check_ops.assert_less(one, alpha)], mode_if_defined)
+              [check_ops.assert_less(
+                  one, alpha,
+                  message="mode not defined for components of alpha <= 1"
+              )], mode_if_defined)
 
   def variance(self, name="variance"):
     """Variance of each batch member."""
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index a23f6df5717..d78e82a7524 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -69,18 +69,18 @@ class InverseGamma(distribution.Distribution):
     broadcasting (e.g. `alpha + beta` is a valid operation).
 
     Args:
-      alpha: `float` or `double` tensor, the shape params of the
+      alpha: Floating point tensor, the shape params of the
         distribution(s).
         alpha must contain only positive values.
-      beta: `float` or `double` tensor, the scale params of the distribution(s).
+      beta: Floating point tensor, the scale params of the distribution(s).
         beta must contain only positive values.
       validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
-        the methods `prob(x)` and `log_prob(x)`.  If `validate_args` is False
+        the methods `prob(x)` and `log_prob(x)`.  If `validate_args` is `False`
         and the inputs are invalid, correct behavior is not guaranteed.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to prepend to all ops created by this distribution.
 
     Raises:
@@ -206,9 +206,12 @@ class InverseGamma(distribution.Distribution):
           nan = np.nan * self._ones()
           return math_ops.select(alpha_gt_1, mean_if_defined, nan)
         else:
-          one = ops.convert_to_tensor(1.0, dtype=self.dtype)
+          one = constant_op.constant(1.0, dtype=self.dtype)
           return control_flow_ops.with_dependencies(
-              [check_ops.assert_less(one, alpha)], mean_if_defined)
+              [check_ops.assert_less(
+                  one, alpha,
+                  message="mean not defined for components of alpha <= 1")],
+              mean_if_defined)
 
   def mode(self, name="mode"):
     """Mode of each batch member.
@@ -250,9 +253,12 @@ class InverseGamma(distribution.Distribution):
           nan = np.nan * self._ones()
           return math_ops.select(alpha_gt_2, var_if_defined, nan)
         else:
-          two = ops.convert_to_tensor(2.0, dtype=self.dtype)
+          two = constant_op.constant(2.0, dtype=self.dtype)
           return control_flow_ops.with_dependencies(
-              [check_ops.assert_less(two, alpha)], var_if_defined)
+              [check_ops.assert_less(
+                  two, alpha,
+                  message="variance not defined for components of alpha <= 2")],
+              var_if_defined)
 
   def log_prob(self, x, name="log_prob"):
     """Log prob of observations in `x` under these InverseGamma distribution(s).
diff --git a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py
index c134ca2cbfd..c1e0b2d2398 100644
--- a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py
+++ b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py
@@ -34,9 +34,9 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
   Args:
     dist_a: instance of distributions.Distribution.
     dist_b: instance of distributions.Distribution.
-    allow_nan: If False (default), a runtime error is raised
+    allow_nan: If `False` (default), a runtime error is raised
       if the KL returns NaN values for any batch entry of the given
-      distributions.  If True, the KL may return a NaN for the given entry.
+      distributions.  If `True`, the KL may return a NaN for the given entry.
     name: (optional) Name scope to use for created operations.
 
   Returns:
diff --git a/tensorflow/contrib/distributions/python/ops/laplace.py b/tensorflow/contrib/distributions/python/ops/laplace.py
index ee6aa81c0f4..a03a80d4ece 100644
--- a/tensorflow/contrib/distributions/python/ops/laplace.py
+++ b/tensorflow/contrib/distributions/python/ops/laplace.py
@@ -60,17 +60,17 @@ class Laplace(distribution.Distribution):
     broadcasting (e.g., `loc / scale` is a valid operation).
 
     Args:
-      loc: `float` or `double` tensor which characterizes the location (center)
+      loc: Floating point tensor which characterizes the location (center)
         of the distribution.
-      scale: `float` or `double`, positive-valued tensor which characterzes the
-        spread of the distribution.
+      scale: Positive floating point tensor which characterizes the spread of
+        the distribution.
       validate_args: Whether to validate input with asserts.  If `validate_args`
         is `False`, and the inputs are invalid, correct behavior is not
         guaranteed.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to give Ops created by the initializer.
 
     Raises:
@@ -294,8 +294,7 @@ class Laplace(distribution.Distribution):
       with ops.op_scope([self._loc, self._scale, n], name):
         n = ops.convert_to_tensor(n)
         n_val = tensor_util.constant_value(n)
-        shape = array_ops.concat(
-            0, [array_ops.pack([n]), self.batch_shape()])
+        shape = array_ops.concat(0, ([n], self.batch_shape()))
         # Sample uniformly-at-random from the open-interval (-1, 1).
         uniform_samples = random_ops.random_uniform(
             shape=shape,
diff --git a/tensorflow/contrib/distributions/python/ops/multinomial.py b/tensorflow/contrib/distributions/python/ops/multinomial.py
new file mode 100644
index 00000000000..477dd06673e
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/multinomial.py
@@ -0,0 +1,343 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Multinomial distribution class."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=line-too-long
+
+from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+
+# pylint: enable=line-too-long
+
+
+class Multinomial(distribution.Distribution):
+  """Multinomial distribution.
+
+  This distribution is parameterized by a vector `p` of probability
+  parameters for `k` classes and `n`, the counts per each class..
+
+  #### Mathematical details
+
+  The Multinomial is a distribution over k-class count data, meaning
+  for each k-tuple of non-negative integer `counts = [n_1,...,n_k]`, we have a
+  probability of these draws being made from the distribution.  The distribution
+  has hyperparameters `p = (p_1,...,p_k)`, and probability mass
+  function (pmf):
+
+  ```pmf(counts) = n! / (n_1!...n_k!) * (p_1)^n_1*(p_2)^n_2*...(p_k)^n_k```
+
+  where above `n = sum_j n_j`, `n!` is `n` factorial.
+
+  #### Examples
+
+  Create a 3-class distribution, with the 3rd class is most likely to be drawn,
+  using logits..
+
+  ```python
+  logits = [-50., -43, 0]
+  dist = Multinomial(n=4., logits=logits)
+  ```
+
+  Create a 3-class distribution, with the 3rd class is most likely to be drawn.
+
+  ```python
+  p = [.2, .3, .5]
+  dist = Multinomial(n=4., p=p)
+  ```
+
+  The distribution functions can be evaluated on counts.
+
+  ```python
+  # counts same shape as p.
+  counts = [1., 0, 3]
+  dist.prob(counts)  # Shape []
+
+  # p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
+  counts = [[1., 2, 1], [2, 2, 0]]
+  dist.prob(counts)  # Shape [2]
+
+  # p will be broadcast to shape [5, 7, 3] to match counts.
+  counts = [[...]]  # Shape [5, 7, 3]
+  dist.prob(counts)  # Shape [5, 7]
+  ```
+
+  Create a 2-batch of 3-class distributions.
+
+  ```python
+  p = [[.1, .2, .7], [.3, .3, .4]]  # Shape [2, 3]
+  dist = Multinomial(n=[4., 5], p=p)
+
+  counts = [[2., 1, 1], [3, 1, 1]]
+  dist.prob(counts)  # Shape [2]
+  ```
+  """
+
+  def __init__(self,
+               n,
+               logits=None,
+               p=None,
+               validate_args=True,
+               allow_nan_stats=False,
+               name="Multinomial"):
+    """Initialize a batch of Multinomial distributions.
+
+    Args:
+      n:  Non-negative floating point tensor with shape broadcastable to
+        `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
+        `N1 x ... x Nm` different Multinomial distributions.  Its components
+        should be equal to integer values.
+      logits: Floating point tensor representing the log-odds of a
+        positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
+        and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm`
+        different `k` class Multinomial distributions.
+      p:  Positive floating point tensor with shape broadcastable to
+        `[N1,..., Nm, k]` `m >= 0` and same dtype as `n`.  Defines this as
+        a batch of `N1 x ... x Nm` different `k` class Multinomial
+        distributions. `p`'s components in the last portion of its shape should
+        sum up to 1.
+      validate_args: Whether to assert valid values for parameters `n` and `p`,
+        and `x` in `prob` and `log_prob`.  If `False`, correct behavior is not
+        guaranteed.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
+      name: The name to prefix Ops created by this distribution class.
+
+    Examples:
+
+    ```python
+    # Define 1-batch of 2-class multinomial distribution,
+    # also known as a Binomial distribution.
+    dist = Multinomial(n=2., p=[.1, .9])
+
+    # Define a 2-batch of 3-class distributions.
+    dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]])
+    ```
+
+    """
+
+    self._logits, self._p = distribution_util.get_logits_and_prob(
+        name=name, logits=logits, p=p, validate_args=validate_args,
+        multidimensional=True)
+    with ops.op_scope([n, self._p], name):
+      with ops.control_dependencies([
+          check_ops.assert_non_negative(
+              n, message="n has negative components."),
+          distribution_util.assert_integer_form(
+              n, message="n has non-integer components."
+          )] if validate_args else []):
+        self._n = array_ops.identity(n, name="convert_n")
+        self._name = name
+
+        self._validate_args = validate_args
+        self._allow_nan_stats = allow_nan_stats
+
+        self._mean = array_ops.expand_dims(n, -1) * self._p
+        # Only used for inferring shape.
+        self._broadcast_shape = math_ops.reduce_sum(self._mean,
+                                                    reduction_indices=[-1],
+                                                    keep_dims=False)
+
+        self._get_batch_shape = self._broadcast_shape.get_shape()
+        self._get_event_shape = (
+            self._mean.get_shape().with_rank_at_least(1)[-1:])
+
+  @property
+  def n(self):
+    """Number of trials."""
+    return self._n
+
+  @property
+  def p(self):
+    """Event probabilities."""
+    return self._p
+
+  @property
+  def logits(self):
+    """Log-odds."""
+    return self._logits
+
+  @property
+  def name(self):
+    """Name to prepend to all ops."""
+    return self._name
+
+  @property
+  def dtype(self):
+    """dtype of samples from this distribution."""
+    return self._p.dtype
+
+  @property
+  def validate_args(self):
+    """Boolean describing behavior on invalid input."""
+    return self._validate_args
+
+  @property
+  def allow_nan_stats(self):
+    """Boolean describing behavior when a stat is undefined for batch member."""
+    return self._allow_nan_stats
+
+  def batch_shape(self, name="batch_shape"):
+    """Batch dimensions of this instance as a 1-D int32 `Tensor`.
+
+    The product of the dimensions of the `batch_shape` is the number of
+    independent distributions of this kind the instance represents.
+
+    Args:
+      name: name to give to the op
+
+    Returns:
+      `Tensor` `batch_shape`
+    """
+    with ops.name_scope(self.name):
+      with ops.op_scope([self._broadcast_shape], name):
+        return array_ops.shape(self._broadcast_shape)
+
+  def get_batch_shape(self):
+    """`TensorShape` available at graph construction time.
+
+    Same meaning as `batch_shape`. May be only partially defined.
+
+    Returns:
+      batch shape
+    """
+    return self._get_batch_shape
+
+  def event_shape(self, name="event_shape"):
+    """Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
+
+    Args:
+      name: name to give to the op
+
+    Returns:
+      `Tensor` `event_shape`
+    """
+    with ops.name_scope(self.name):
+      with ops.op_scope([self._mean], name):
+        return array_ops.gather(array_ops.shape(self._mean),
+                                [array_ops.rank(self._mean) - 1])
+
+  def get_event_shape(self):
+    """`TensorShape` available at graph construction time.
+
+    Same meaning as `event_shape`. May be only partially defined.
+
+    Returns:
+      event shape
+    """
+    return self._get_event_shape
+
+  def mean(self, name="mean"):
+    """Mean of the distribution."""
+    with ops.name_scope(self.name):
+      return array_ops.identity(self._mean, name=name)
+
+  def variance(self, name="variance"):
+    """Variance of the distribution."""
+    with ops.name_scope(self.name):
+      with ops.op_scope([self._n, self._p, self._mean], name):
+        p = array_ops.expand_dims(
+            self._p * array_ops.expand_dims(
+                array_ops.ones_like(self._n), -1), -1)
+        variance = -math_ops.batch_matmul(
+            array_ops.expand_dims(self._mean, -1), p, adj_y=True)
+        variance += array_ops.batch_matrix_diag(self._mean)
+        return variance
+
+  def log_prob(self, counts, name="log_prob"):
+    """`Log(P[counts])`, computed for every batch member.
+
+    For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
+    that after sampling `n` draws from this Multinomial distribution, the
+    number of draws falling in class `j` is `n_j`.  Note that different
+    sequences of draws can result in the same counts, thus the probability
+    includes a combinatorial coefficient.
+
+    Args:
+      counts:  Non-negative tensor with dtype `dtype` and whose shape can
+        be broadcast with `self.p` and `self.n`.  For fixed leading dimensions,
+        the last dimension represents counts for the corresponding Multinomial
+        distribution in `self.p`. `counts` is only legal if it sums up to `n`
+        and its components are equal to integer values.
+      name:  Name to give this Op, defaults to "log_prob".
+
+    Returns:
+      Log probabilities for each record, shape `[N1,...,Nm]`.
+    """
+    n = self._n
+    p = self._p
+    with ops.name_scope(self.name):
+      with ops.op_scope([n, p, counts], name):
+        counts = self._check_counts(counts)
+
+        prob_prob = math_ops.reduce_sum(counts * math_ops.log(self._p),
+                                        reduction_indices=[-1])
+        log_prob = prob_prob + distribution_util.log_combinations(
+            n, counts)
+        return log_prob
+
+  def prob(self, counts, name="prob"):
+    """`P[counts]`, computed for every batch member.
+
+    For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
+    that after sampling `n` draws from this Multinomial distribution, the
+    number of draws falling in class `j` is `n_j`.  Note that different
+    sequences of draws can result in the same counts, thus the probability
+    includes a combinatorial coefficient.
+
+    Args:
+      counts:  Non-negative tensor with dtype `dtype` and whose shape can
+        be broadcast with `self.p` and `self.n`.  For fixed leading dimensions,
+        the last dimension represents counts for the corresponding Multinomial
+        distribution in `self.p`. `counts` is only legal if it sums up to `n`
+        and its components are equal to integer values.
+      name:  Name to give this Op, defaults to "prob".
+
+    Returns:
+      Probabilities for each record, shape `[N1,...,Nm]`.
+    """
+    return super(Multinomial, self).prob(counts, name=name)
+
+  @property
+  def is_continuous(self):
+    return False
+
+  @property
+  def is_reparameterized(self):
+    return False
+
+  def _check_counts(self, counts):
+    """Check counts for proper shape, values, then return tensor version."""
+    counts = ops.convert_to_tensor(counts, name="counts_before_deps")
+    candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
+    if not self.validate_args:
+      return counts
+
+    return control_flow_ops.with_dependencies([
+        check_ops.assert_non_negative(
+            counts, message="counts has negative components."),
+        check_ops.assert_equal(
+            self._n, candidate_n, message="counts do not sum to n."),
+        distribution_util.assert_integer_form(
+            counts, message="counts have non-integer components.")], counts)
diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py
index dafddc0faac..8936594dfac 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn.py
@@ -105,9 +105,9 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
     which determines the covariance.
 
     Args:
-      mu: `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
-      cov: `float` or `double` instance of `OperatorPDBase` with same `dtype`
-        as `mu` and shape `[N1,...,Nb, k, k]`.
+      mu: Floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
+      cov: Instance of `OperatorPDBase` with same `dtype` as `mu` and shape
+        `[N1,...,Nb, k, k]`.
       validate_args: Whether to validate input with asserts.  If `validate_args`
         is `False`, and the inputs are invalid, correct behavior is not
         guaranteed.
@@ -466,7 +466,7 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
     The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stdev[i]`.
 
     Args:
-      mu:  Rank `N + 1` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
+      mu:  Rank `N + 1` floating point tensor with shape `[N1,...,Nb, k]`,
         `b >= 0`.
       diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`,
         representing the standard deviations.  Must be positive.
@@ -581,13 +581,13 @@ class MultivariateNormalDiagPlusVDVT(MultivariateNormalOperatorPD):
     ```
 
     Args:
-      mu:  Rank `n + 1` `float` or `double` tensor with shape `[N1,...,Nn, k]`,
+      mu:  Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`,
         `n >= 0`.  The means.
-      diag_large:  Optional rank `n + 1` `float` or `double` tensor, shape
+      diag_large:  Optional rank `n + 1` floating point tensor, shape
         `[N1,...,Nn, k]` `n >= 0`.  Defines the diagonal matrix `M`.
-      v:  Rank `n + 1` `float` or `double` tensor, shape `[N1,...,Nn, k, r]`
+      v:  Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]`
         `n >= 0`.  Defines the matrix `V`.
-      diag_small:  Rank `n + 1` `float` or `double` tensor, shape
+      diag_small:  Rank `n + 1` floating point tensor, shape
         `[N1,...,Nn, k]` `n >= 0`.  Defines the diagonal matrix `D`.  Default
         is `None`, which means `D` will be the identity matrix.
       validate_args: Whether to validate input with asserts.  If `validate_args`
@@ -670,7 +670,7 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
     factors, such that the covariance of each batch member is `chol chol^T`.
 
     Args:
-      mu: `(N+1)-D`  `float` or `double` tensor with shape `[N1,...,Nb, k]`,
+      mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
         `b >= 0`.
       chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
         `[N1,...,Nb, k, k]`.  The upper triangular part is ignored (treated as
@@ -750,7 +750,7 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
     User must provide means `mu` and `sigma`, the mean and covariance.
 
     Args:
-      mu: `(N+1)-D`  `float` or `double` tensor with shape `[N1,...,Nb, k]`,
+      mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
         `b >= 0`.
       sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
         `[N1,...,Nb, k, k]`.  Each batch member must be positive definite.
diff --git a/tensorflow/contrib/distributions/python/ops/normal.py b/tensorflow/contrib/distributions/python/ops/normal.py
index dff8c7fdbbe..182afa31f7f 100644
--- a/tensorflow/contrib/distributions/python/ops/normal.py
+++ b/tensorflow/contrib/distributions/python/ops/normal.py
@@ -92,15 +92,15 @@ class Normal(distribution.Distribution):
     broadcasting (e.g. `mu + sigma` is a valid operation).
 
     Args:
-      mu: `float` or `double` tensor, the means of the distribution(s).
-      sigma: `float` or `double` tensor, the stddevs of the distribution(s).
+      mu: Floating point tensor, the means of the distribution(s).
+      sigma: Floating point tensor, the stddevs of the distribution(s).
         sigma must contain only positive values.
       validate_args: Whether to assert that `sigma > 0`. If `validate_args` is
-        False, correct output is not guaranteed when input is invalid.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+        `False`, correct output is not guaranteed when input is invalid.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to give Ops created by the initializer.
 
     Raises:
@@ -321,8 +321,7 @@ class Normal(distribution.Distribution):
       with ops.op_scope([self._mu, self._sigma, n], name):
         broadcast_shape = (self._mu + self._sigma).get_shape()
         n = ops.convert_to_tensor(n)
-        shape = array_ops.concat(
-            0, [array_ops.pack([n]), array_ops.shape(self.mean())])
+        shape = array_ops.concat(0, ([n], array_ops.shape(self.mean())))
         sampled = random_ops.random_normal(
             shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
 
diff --git a/tensorflow/contrib/distributions/python/ops/student_t.py b/tensorflow/contrib/distributions/python/ops/student_t.py
index e5fa624ddc4..8e43c95b6db 100644
--- a/tensorflow/contrib/distributions/python/ops/student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/student_t.py
@@ -82,6 +82,7 @@ class StudentT(distribution.Distribution):
   # returning a length 2 tensor.
   dist.pdf(3.0)
   ```
+
   """
 
   def __init__(self,
@@ -99,19 +100,19 @@ class StudentT(distribution.Distribution):
     broadcasting (e.g. `df + mu + sigma` is a valid operation).
 
     Args:
-      df: `float` or `double` tensor, the degrees of freedom of the
+      df: Floating point tensor, the degrees of freedom of the
         distribution(s). `df` must contain only positive values.
-      mu: `float` or `double` tensor, the means of the distribution(s).
-      sigma: `float` or `double` tensor, the scaling factor for the
+      mu: Floating point tensor, the means of the distribution(s).
+      sigma: Floating point tensor, the scaling factor for the
         distribution(s). `sigma` must contain only positive values.
         Note that `sigma` is not the standard deviation of this distribution.
       validate_args: Whether to assert that `df > 0, sigma > 0`. If
-        `validate_args` is False and inputs are invalid, correct behavior is not
-        guaranteed.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+        `validate_args` is `False` and inputs are invalid, correct behavior is
+        not guaranteed.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to give Ops created by the initializer.
 
     Raises:
@@ -185,9 +186,12 @@ class StudentT(distribution.Distribution):
           nan = np.nan + self._zeros()
           return math_ops.select(df_gt_1, result_if_defined, nan)
         else:
-          one = ops.convert_to_tensor(1.0, dtype=self.dtype)
+          one = constant_op.constant(1.0, dtype=self.dtype)
           return control_flow_ops.with_dependencies(
-              [check_ops.assert_less(one, self._df)], result_if_defined)
+              [check_ops.assert_less(
+                  one, self._df,
+                  message="mean not defined for components of df <= 1"
+              )], result_if_defined)
 
   def mode(self, name="mode"):
     with ops.name_scope(self.name):
@@ -232,9 +236,12 @@ class StudentT(distribution.Distribution):
               result_where_defined,
               self._zeros() + np.nan)
         else:
-          one = ops.convert_to_tensor(1.0, self.dtype)
+          one = constant_op.constant(1.0, dtype=self.dtype)
           return control_flow_ops.with_dependencies(
-              [check_ops.assert_less(one, self._df)], result_where_defined)
+              [check_ops.assert_less(
+                  one, self._df,
+                  message="variance not defined for components of df <= 1"
+              )], result_where_defined)
 
   def std(self, name="std"):
     with ops.name_scope(self.name):
@@ -348,8 +355,7 @@ class StudentT(distribution.Distribution):
         # Let X = R*cos(theta), and let Y = R*sin(theta).
         # Then X ~ t_df and Y ~ t_df.
         # The variates X and Y are not independent.
-        shape = array_ops.concat(0, [array_ops.pack([2, n]),
-                                     self.batch_shape()])
+        shape = array_ops.concat(0, ([2, n], self.batch_shape()))
         uniform = random_ops.random_uniform(shape=shape,
                                             dtype=self.dtype,
                                             seed=seed)
diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
index 185741b2176..82971301560 100644
--- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
@@ -57,6 +57,7 @@ class TransformedDistribution(distribution.Distribution):
     name="LogitNormalTransformedDistribution"
   )
   ```
+
   """
 
   def __init__(self,
diff --git a/tensorflow/contrib/distributions/python/ops/uniform.py b/tensorflow/contrib/distributions/python/ops/uniform.py
index eb196a3ea91..09437d36d16 100644
--- a/tensorflow/contrib/distributions/python/ops/uniform.py
+++ b/tensorflow/contrib/distributions/python/ops/uniform.py
@@ -67,14 +67,14 @@ class Uniform(distribution.Distribution):
     ```
 
     Args:
-      a: `float` or `double` tensor, the minimum endpoint.
-      b: `float` or `double` tensor, the maximum endpoint. Must be > `a`.
-      validate_args: Whether to assert that `a > b`. If `validate_args` is False
-        and inputs are invalid, correct behavior is not guaranteed.
-      allow_nan_stats:  Boolean, default False.  If False, raise an exception if
-        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
-        If True, batch members with valid parameters leading to undefined
-        statistics will return NaN for this statistic.
+      a: Floating point tensor, the minimum endpoint.
+      b: Floating point tensor, the maximum endpoint. Must be > `a`.
+      validate_args: Whether to assert that `a > b`. If `validate_args` is
+        `False` and inputs are invalid, correct behavior is not guaranteed.
+      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
+        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+        batch member.  If `True`, batch members with valid parameters leading to
+        undefined statistics will return NaN for this statistic.
       name: The name to prefix Ops created by this distribution class.
 
     Raises:
@@ -83,8 +83,9 @@ class Uniform(distribution.Distribution):
     self._allow_nan_stats = allow_nan_stats
     self._validate_args = validate_args
     with ops.op_scope([a, b], name):
-      with ops.control_dependencies([check_ops.assert_less(a, b)] if
-                                    validate_args else []):
+      with ops.control_dependencies([check_ops.assert_less(
+          a, b, message="uniform not defined when a > b.")] if validate_args
+                                    else []):
         a = array_ops.identity(a, name="a")
         b = array_ops.identity(b, name="b")
 
@@ -228,7 +229,7 @@ class Uniform(distribution.Distribution):
         n = ops.convert_to_tensor(n, name="n")
         n_val = tensor_util.constant_value(n)
 
-        shape = array_ops.concat(0, [array_ops.pack([n]), self.batch_shape()])
+        shape = array_ops.concat(0, ([n], self.batch_shape()))
         samples = random_ops.random_uniform(shape=shape,
                                             dtype=self.dtype,
                                             seed=seed)