From 1f6cd6fbb0b1c10326f9e56637449b142ab208c2 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <nobody@tensorflow.org>
Date: Thu, 12 May 2016 13:02:03 -0800
Subject: [PATCH] Base classes for probability distributions and uniform
 distribution Change: 122194730

---
 tensorflow/contrib/distributions/BUILD        |  10 +
 tensorflow/contrib/distributions/__init__.py  |  11 +-
 .../python/kernel_tests/uniform_test.py       | 220 +++++++++++++++
 .../distributions/python/ops/distribution.py  | 256 ++++++++++++++++++
 .../distributions/python/ops/uniform.py       | 240 ++++++++++++++++
 5 files changed, 736 insertions(+), 1 deletion(-)
 create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/uniform_test.py
 create mode 100644 tensorflow/contrib/distributions/python/ops/distribution.py
 create mode 100644 tensorflow/contrib/distributions/python/ops/uniform.py

diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 87e500081e9..3a8c9f2321c 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -38,6 +38,16 @@ cuda_py_tests(
     ],
 )
 
+cuda_py_tests(
+    name = "uniform_test",
+    size = "small",
+    srcs = ["python/kernel_tests/uniform_test.py"],
+    additional_deps = [
+        ":distributions_py",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
 cuda_py_tests(
     name = "mvn_test",
     size = "small",
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index f3263ff7858..5b4bbac8270 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -12,16 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Classes representing statistical distributions.  Ops for working with them.
+"""Classes representing statistical distributions and ops for working with them.
 
 ## Classes for statistical distributions.
 
 Classes that represent batches of statistical distributions.  Each class is
 initialized with parameters that define the distributions.
 
+### Base classes
+
+@@BaseDistribution
+@@ContinuousDistribution
+@@DiscreteDistribution
+
 ### Univariate (scalar) distributions
 
 @@Gaussian
+@@Uniform
 
 ### Multivariate distributions
 
@@ -44,6 +51,8 @@ from __future__ import print_function
 
 # pylint: disable=unused-import,wildcard-import,line-too-long
 from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
+from tensorflow.contrib.distributions.python.ops.distribution import *
 from tensorflow.contrib.distributions.python.ops.gaussian import *
 from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
 from tensorflow.contrib.distributions.python.ops.mvn import *
+from tensorflow.contrib.distributions.python.ops.uniform import *
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/uniform_test.py b/tensorflow/contrib/distributions/python/kernel_tests/uniform_test.py
new file mode 100644
index 00000000000..4fefa69e048
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/uniform_test.py
@@ -0,0 +1,220 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Uniform distribution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+
+class UniformTest(tf.test.TestCase):
+
+  def testUniformRange(self):
+    with self.test_session():
+      a = 3.0
+      b = 10.0
+      uniform = tf.contrib.distributions.Uniform(a=a, b=b)
+      self.assertAllClose(a, uniform.a.eval())
+      self.assertAllClose(b, uniform.b.eval())
+      self.assertAllClose(b - a, uniform.range.eval())
+
+  def testUniformPDF(self):
+    with self.test_session():
+      a = tf.constant([-3.0] * 5 + [15.0])
+      b = tf.constant([11.0] * 5 + [20.0])
+      uniform = tf.contrib.distributions.Uniform(a=a, b=b)
+
+      a_v = -3.0
+      b_v = 11.0
+      x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
+
+      def _expected_pdf():
+        pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
+        pdf[x > b_v] = 0.0
+        pdf[x < a_v] = 0.0
+        pdf[5] = 1.0 / (20.0 - 15.0)
+        return pdf
+
+      expected_pdf = _expected_pdf()
+
+      pdf = uniform.pdf(x)
+      self.assertAllClose(expected_pdf, pdf.eval())
+
+      log_pdf = uniform.log_pdf(x)
+      self.assertAllClose(np.log(expected_pdf), log_pdf.eval())
+
+  def testUniformShape(self):
+    with self.test_session():
+      a = tf.constant([-3.0] * 5)
+      b = tf.constant(11.0)
+      uniform = tf.contrib.distributions.Uniform(a=a, b=b)
+
+      self.assertEqual(uniform.batch_shape().eval(), (5,))
+      self.assertEqual(uniform.get_batch_shape(), tf.TensorShape([5]))
+      self.assertEqual(uniform.event_shape().eval(), 1)
+      self.assertEqual(uniform.get_event_shape(), tf.TensorShape([]))
+
+  def testUniformPDFWithScalarEndpoint(self):
+    with self.test_session():
+      a = tf.constant([0.0, 5.0])
+      b = tf.constant(10.0)
+      uniform = tf.contrib.distributions.Uniform(a=a, b=b)
+
+      x = np.array([0.0, 8.0], dtype=np.float32)
+      expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
+
+      pdf = uniform.pdf(x)
+      self.assertAllClose(expected_pdf, pdf.eval())
+
+  def testUniformCDF(self):
+    with self.test_session():
+      batch_size = 6
+      a = tf.constant([1.0] * batch_size)
+      b = tf.constant([11.0] * batch_size)
+      a_v = 1.0
+      b_v = 11.0
+      x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
+
+      uniform = tf.contrib.distributions.Uniform(a=a, b=b)
+
+      def _expected_cdf():
+        cdf = (x - a_v) / (b_v - a_v)
+        cdf[x >= b_v] = 1
+        cdf[x < a_v] = 0
+        return cdf
+
+      cdf = uniform.cdf(x)
+      self.assertAllClose(_expected_cdf(), cdf.eval())
+
+      log_cdf = uniform.log_cdf(x)
+      self.assertAllClose(np.log(_expected_cdf()), log_cdf.eval())
+
+  def testUniformEntropy(self):
+    with self.test_session():
+      a_v = np.array([1.0, 1.0, 1.0])
+      b_v = np.array([[1.5, 2.0, 3.0]])
+      uniform = tf.contrib.distributions.Uniform(a=a_v, b=b_v)
+
+      expected_entropy = np.log(b_v - a_v)
+      self.assertAllClose(expected_entropy, uniform.entropy().eval())
+
+  def testUniformAssertMaxGtMin(self):
+    with self.test_session():
+      a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
+      b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+      uniform = tf.contrib.distributions.Uniform(a=a_v, b=b_v)
+
+      with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError,
+                                               "x < y"):
+        uniform.a.eval()
+
+  def testUniformSample(self):
+    with self.test_session():
+      a = tf.constant([3.0, 4.0])
+      b = tf.constant(13.0)
+      a1_v = 3.0
+      a2_v = 4.0
+      b_v = 13.0
+      n = tf.constant(100000)
+      uniform = tf.contrib.distributions.Uniform(a=a, b=b)
+
+      samples = uniform.sample(n, seed=137)
+      sample_values = samples.eval()
+      self.assertEqual(sample_values.shape, (100000, 2))
+      self.assertAllClose(sample_values[::, 0].mean(), (b_v + a1_v) / 2,
+                          atol=1e-2)
+      self.assertAllClose(sample_values[::, 1].mean(), (b_v + a2_v) / 2,
+                          atol=1e-2)
+      self.assertFalse(np.any(sample_values[::, 0] < a1_v) or np.any(
+          sample_values >= b_v))
+      self.assertFalse(np.any(sample_values[::, 1] < a2_v) or np.any(
+          sample_values >= b_v))
+
+  def testUniformSampleMultiDimensional(self):
+    with self.test_session():
+      batch_size = 2
+      a_v = [3.0, 22.0]
+      b_v = [13.0, 35.0]
+      a = tf.constant([a_v] * batch_size)
+      b = tf.constant([b_v] * batch_size)
+
+      uniform = tf.contrib.distributions.Uniform(a=a, b=b)
+
+      n_v = 100000
+      n = tf.constant(n_v)
+      samples = uniform.sample(n, seed=138)
+      self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
+
+      sample_values = samples.eval()
+
+      self.assertFalse(np.any(sample_values[:, 0, 0] < a_v[0]) or np.any(
+          sample_values[:, 0, 0] >= b_v[0]))
+      self.assertFalse(np.any(sample_values[:, 0, 1] < a_v[1]) or np.any(
+          sample_values[:, 0, 1] >= b_v[1]))
+
+      self.assertAllClose(sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2,
+                          atol=1e-2)
+      self.assertAllClose(sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2,
+                          atol=1e-2)
+
+  def testUniformMeanAndVariance(self):
+    with self.test_session():
+      a = 10.0
+      b = 100.0
+      uniform = tf.contrib.distributions.Uniform(a=a, b=b)
+      self.assertAllClose(uniform.variance.eval(), (b - a)**2 / 12)
+      self.assertAllClose(uniform.mean.eval(), (b + a) / 2)
+
+  def testUniformNans(self):
+    with self.test_session():
+      a = 10.0
+      b = [11.0, 100.0]
+      uniform = tf.contrib.distributions.Uniform(a=a, b=b)
+
+      no_nans = tf.constant(1.0)
+      nans = tf.constant(0.0) / tf.constant(0.0)
+      self.assertTrue(tf.is_nan(nans).eval())
+      with_nans = tf.pack([no_nans, nans])
+
+      pdf = uniform.pdf(with_nans)
+
+      is_nan = tf.is_nan(pdf).eval()
+      print(pdf.eval())
+      self.assertFalse(is_nan[0])
+      self.assertTrue(is_nan[1])
+
+  def testUniformSamplePdf(self):
+    with self.test_session():
+      a = 10.0
+      b = [11.0, 100.0]
+      uniform = tf.contrib.distributions.Uniform(a, b)
+      self.assertTrue(tf.reduce_all(uniform.pdf(uniform.sample(10)) > 0).eval())
+
+  def testUniformBroadcasting(self):
+    with self.test_session():
+      a = 10.0
+      b = [11.0, 20.0]
+      uniform = tf.contrib.distributions.Uniform(a, b)
+
+      pdf = uniform.pdf([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
+      expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
+      self.assertAllClose(expected_pdf, pdf.eval())
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/distribution.py b/tensorflow/contrib/distributions/python/ops/distribution.py
new file mode 100644
index 00000000000..16056102d15
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/distribution.py
@@ -0,0 +1,256 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base classes for probability distributions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import six
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+
+
+@six.add_metaclass(abc.ABCMeta)
+class BaseDistribution(object):
+  """Abstract base class for probability distributions.
+
+  This class, along with `ContinuousDistribution` and `DiscreteDistribution`,
+  defines the API for probability distributions.
+
+  Users will never instantiate a `BaseDistribution`, but will instead
+  instantiate subclasses of either `ContinuousDistribution` or
+  `DiscreteDistribution`.
+
+  Developers of new distributions should prefer to subclass
+  `ContinuousDistribution` or `DiscreteDistribution`.
+
+  ### API
+
+  The key methods for probability distributions are defined here. The likelihood
+  functions (`pdf`, `log_pdf`) and (`pmf`, `log_pmf`) are defined in
+  `ContinuousDistribution` and `DiscreteDistribution`, respectively.
+
+  To keep ops generated by the distribution tied together by name, subclasses
+  should override `name` and use it to preprend names of ops in other methods
+  (see `cdf` for an example).
+
+  Subclasses that wish to support `cdf` and `log_cdf` can override `log_cdf`
+  and use the base class's implementation for `cdf`.
+
+  ### Broadcasting, batching, and shapes
+
+  All distributions support batches of independent distributions of that type.
+  The batch shape is determined by broadcasting together the parameters.
+
+  The shape of arguments to `__init__`, `cdf`, `log_cdf`, and the likelihood
+  functions defined in `ContinuousDistribution` and `DiscreteDistribution`
+  reflect this broadcasting, as does the return value of `sample`.
+
+  `sample_shape = (n,) + batch_shape + event_shape`, where `sample_shape` is the
+  shape of the `Tensor` returned from `sample`, `n` is the number of samples,
+  `batch_shape` defines how many independent distributions there are, and
+  `event_shape` defines the shape of samples from each of those independent
+  distributions. Samples are independent along the `batch_shape` dimensions,
+  but not necessarily so along the `event_shape` dimensions (dependending on
+  the particulars of the underlying distribution).
+
+  Using the `Uniform` distribution as an example:
+
+  ```python
+  minval = 3.0
+  maxval = [[4.0, 6.0],
+            [10.0, 12.0]]
+
+  # Broadcasting:
+  # This instance represents 4 Uniform distributions. Each has a lower bound at
+  # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape.
+  u = Uniform(minval, maxval)
+
+  # `event_shape` is `TensorShape([])`.
+  event_shape = u.get_event_shape()
+  # `event_shape_t` is a `Tensor` which will evaluate to a scalar 1.
+  event_shape_t = u.event_shape
+
+  # Sampling returns a sample per distribution.  `samples` has shape
+  # (5, 2, 2), which is (n,) + batch_shape + event_shape, where n=5,
+  # batch_shape=(2, 2), and event_shape=().
+  samples = u.sample(5)
+
+  # The broadcasting holds across methods. Here we use `cdf` as an example. The
+  # same holds for `log_cdf` and the likelihood functions.
+
+  # `cum_prob` has shape (2, 2) as the `value` argument was broadcasted to the
+  # shape of the `Uniform` instance.
+  cum_prob_broadcast = u.cdf(4.0)
+
+  # `cum_prob`'s shape is (2, 2), one per distribution. No broadcasting
+  # occurred.
+  cum_prob_per_dist = u.cdf([[4.0, 5.0],
+                             [6.0, 7.0]])
+
+  # INVALID as the `value` argument is not broadcastable to the distribution's
+  # shape.
+  cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
+  ```
+  """
+
+  @abc.abstractproperty
+  def name(self):
+    """Name to prepend to all ops."""
+    pass
+
+  @abc.abstractproperty
+  def dtype(self):
+    """dtype of samples from this distribution."""
+    pass
+
+  @abc.abstractmethod
+  def event_shape(self, name=None):
+    """Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
+
+    Args:
+      name: name to give to the op
+
+    Returns:
+      `Tensor` `event_shape`
+    """
+    pass
+
+  @abc.abstractmethod
+  def get_event_shape(self):
+    """`TensorShape` available at graph construction time.
+
+    Same meaning as `event_shape`. May be only partially defined.
+    """
+    pass
+
+  @abc.abstractmethod
+  def batch_shape(self, name=None):
+    """Batch dimensions of this instance as a 1-D int32 `Tensor`.
+
+    The product of the dimensions of the `batch_shape` is the number of
+    independent distributions of this kind the instance represents.
+
+    Args:
+      name: name to give to the op
+
+    Returns:
+      `Tensor` `batch_shape`
+    """
+    pass
+
+  @abc.abstractmethod
+  def get_batch_shape(self):
+    """`TensorShape` available at graph construction time.
+
+    Same meaning as `batch_shape`. May be only partially defined.
+    """
+    pass
+
+  def sample(self, n, seed=None, name=None):
+    """Generate `n` samples.
+
+    Args:
+      n: scalar. Number of samples to draw from each distribution.
+      seed: Python integer seed for RNG
+      name: name to give to the op.
+
+    Returns:
+      samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
+          with values of type `self.dtype`.
+    """
+    raise NotImplementedError("sample not implemented")
+
+  def cdf(self, value, name="cdf"):
+    """Cumulative distribution function."""
+    value = ops.convert_to_tensor(value)
+    with ops.op_scope([value], self.name):
+      with ops.name_scope(name):
+        return math_ops.exp(self.log_cdf(value))
+
+  def log_cdf(self, value, name="log_cdf"):
+    """Log CDF."""
+    raise NotImplementedError("log_cdf is not implemented")
+
+  def entropy(self, name=None):
+    """Entropy of the distribution in nats."""
+    raise NotImplementedError("entropy not implemented")
+
+  @property
+  def mean(self):
+    raise NotImplementedError("mean not implemented")
+
+
+class ContinuousDistribution(BaseDistribution):
+  """Base class for continuous probability distributions.
+
+  `ContinuousDistribution` defines the API for the likelihood functions `pdf`
+  and `log_pdf` of continuous probability distributions.
+
+  Subclasses must override both `pdf` and `log_pdf` but one can call this base
+  class's implementation.
+
+  See `BaseDistribution` for more information on the API for probability
+  distributions.
+  """
+
+  @abc.abstractmethod
+  def pdf(self, value, name="pdf"):
+    """Probability density function."""
+    value = ops.convert_to_tensor(value)
+    with ops.op_scope([value], self.name):
+      with ops.name_scope(name):
+        return math_ops.exp(self.log_pdf(value))
+
+  @abc.abstractmethod
+  def log_pdf(self, value, name="log_pdf"):
+    """Log of the probability density function."""
+    value = ops.convert_to_tensor(value)
+    with ops.op_scope([value], self.name):
+      with ops.name_scope(name):
+        return math_ops.log(self.pdf(value))
+
+
+class DiscreteDistribution(BaseDistribution):
+  """Base class for discrete probability distributions.
+
+  `DiscreteDistribution` defines the API for the likelihood functions `pmf` and
+  `log_pmf` of discrete probability distributions.
+
+  Subclasses must override both `pmf` and `log_pmf` but one can call this base
+  class's implementation.
+
+  See `BaseDistribution` for more information on the API for probability
+  distributions.
+  """
+
+  @abc.abstractmethod
+  def pmf(self, value, name="pmf"):
+    """Probability mass function."""
+    value = ops.convert_to_tensor(value)
+    with ops.op_scope([value], self.name):
+      with ops.name_scope(name):
+        return math_ops.exp(self.log_pmf(value))
+
+  @abc.abstractmethod
+  def log_pmf(self, value, name="log_pmf"):
+    """Log of the probability mass function."""
+    value = ops.convert_to_tensor(value)
+    with ops.op_scope([value], self.name):
+      with ops.name_scope(name):
+        return math_ops.log(self.pmf(value))
diff --git a/tensorflow/contrib/distributions/python/ops/uniform.py b/tensorflow/contrib/distributions/python/ops/uniform.py
new file mode 100644
index 00000000000..707976640ee
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/uniform.py
@@ -0,0 +1,240 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Uniform distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution  # pylint: disable=line-too-long
+from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util  # pylint: disable=line-too-long
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+
+class Uniform(ContinuousDistribution):
+  """Uniform distribution with `a` and `b` parameters.
+
+  The PDF of this distribution is constant between [`a`, `b`], and 0 elsewhere.
+  """
+
+  def __init__(self, a=0.0, b=1.0, name="Uniform"):
+    """Construct Uniform distributions with `a` and `b`.
+
+    The parameters `a` and `b` must be shaped in a way that supports
+    broadcasting (e.g. `b - a` is a valid operation).
+
+    Here are examples without broadcasting:
+
+    ```python
+    # Without broadcasting
+    u1 = Uniform(3.0, 4.0)  # a single uniform distribution [3, 4]
+    u2 = Uniform([1.0, 2.0], [3.0, 4.0])  # 2 distributions [1, 3], [2, 4]
+    u3 = Uniform([[1.0, 2.0],
+                  [3.0, 4.0]],
+                 [[1.5, 2.5],
+                  [3.5, 4.5]])  # 4 distributions
+    ```
+
+    And with broadcasting:
+
+    ```python
+    u1 = Uniform(3.0, [5.0, 6.0, 7.0])  # 3 distributions
+    ```
+
+    Args:
+      a: `float` or `double` tensor, the minimum endpoint.
+      b: `float` or `double` tensor, the maximum endpoint. Must be > `a`.
+      name: The name to prefix Ops created by this distribution class.
+
+    Raises:
+      InvalidArgumentError: if `a >= b`.
+    """
+    with ops.op_scope([a, b], name):
+      with ops.control_dependencies([check_ops.assert_less(a, b)]):
+        a = ops.convert_to_tensor(a, name="a")
+        b = ops.convert_to_tensor(b, name="b")
+        if a.dtype != b.dtype:
+          raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
+                          (a.dtype, b.dtype))
+
+    self._a = a
+    self._b = b
+    self._name = name
+    self._batch_shape = self._ones().get_shape()
+    self._event_shape = tensor_shape.TensorShape([])
+
+    contrib_tensor_util.assert_same_float_dtype((a, b))
+
+  @property
+  def name(self):
+    return self._name
+
+  @property
+  def dtype(self):
+    return self.a.dtype
+
+  def batch_shape(self, name="batch_shape"):
+    with ops.name_scope(self.name):
+      return array_ops.shape(self._ones(), name=name)
+
+  def get_batch_shape(self):
+    return self._batch_shape
+
+  def event_shape(self, name="event_shape"):
+    with ops.name_scope(self.name):
+      return constant_op.constant(1, name=name)
+
+  def get_event_shape(self):
+    return self._event_shape
+
+  @property
+  def a(self):
+    return self._a
+
+  @property
+  def b(self):
+    return self._b
+
+  def pdf(self, x, name="pdf"):
+    """The PDF of observations in `x` under these Uniform distribution(s).
+
+    Args:
+      x: tensor of dtype `dtype`, must be broadcastable with `a` and `b`.
+      name: The name to give this op.
+
+    Returns:
+      pdf: tensor of dtype `dtype`, the pdf values of `x`. If `x` is `nan`, will
+          return `nan`.
+    """
+    with ops.op_scope([self.a, self.b, x], self.name):
+      with ops.name_scope(name):
+        x = ops.convert_to_tensor(x, name="x")
+        if x.dtype != self.dtype:
+          raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
+                          (x.dtype, self.dtype))
+
+        broadcasted_x = x * self._ones()
+        return math_ops.select(
+            math_ops.is_nan(broadcasted_x), broadcasted_x, math_ops.select(
+                math_ops.logical_or(broadcasted_x < self.a,
+                                    broadcasted_x > self.b),
+                array_ops.zeros_like(broadcasted_x),
+                (1.0 / self.range) * array_ops.ones_like(broadcasted_x)))
+
+  def log_pdf(self, x, name="log_pdf"):
+    return super(Uniform, self).log_pdf(x, name)
+
+  def cdf(self, x, name="cdf"):
+    """CDF of observations in `x` under these Uniform distribution(s).
+
+    Args:
+      x: tensor of dtype `dtype`, must be broadcastable with `a` and `b`.
+      name: The name to give this op.
+
+    Returns:
+      cdf: tensor of dtype `dtype`, the CDFs of `x`. If `x` is `nan`, will
+          return `nan`.
+    """
+    with ops.op_scope([self.a, self.b, x], self.name):
+      with ops.name_scope(name):
+        x = ops.convert_to_tensor(x, name="x")
+        if x.dtype != self.dtype:
+          raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
+                          (x.dtype, self.dtype))
+
+    broadcasted_x = x * self._ones()
+    return math_ops.select(broadcasted_x < self.a,
+                           array_ops.zeros_like(broadcasted_x),
+                           math_ops.select(broadcasted_x >= self.b,
+                                           array_ops.ones_like(broadcasted_x),
+                                           (broadcasted_x - self.a) /
+                                           self.range))
+
+  def log_cdf(self, x, name="log_cdf"):
+    with ops.op_scope([self.a, self.b, x], self.name):
+      with ops.name_scope(name):
+        x = ops.convert_to_tensor(x, name="x")
+        return math_ops.log(self.cdf(x))
+
+  def entropy(self, name="entropy"):
+    """The entropy of Uniform distribution(s).
+
+    Args:
+      name: The name to give this op.
+
+    Returns:
+      entropy: tensor of dtype `dtype`, the entropy.
+    """
+    with ops.op_scope([self.a, self.b], self.name):
+      with ops.name_scope(name):
+        return math_ops.log(self.range)
+
+  def sample(self, n, seed=None, name="sample"):
+    """Sample `n` observations from the Uniform Distributions.
+
+    Args:
+      n: `Scalar`, type int32, the number of observations to sample.
+      seed: Python integer, the random seed.
+      name: The name to give this op.
+
+    Returns:
+      samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
+          with values of type `self.dtype`.
+    """
+    with ops.op_scope([self.a, self.b, n], self.name):
+      with ops.name_scope(name):
+        n = ops.convert_to_tensor(n, name="n")
+        n_val = tensor_util.constant_value(n)
+
+        shape = array_ops.concat(0, [array_ops.pack([n]), self.batch_shape()])
+        samples = random_ops.random_uniform(shape=shape,
+                                            dtype=self.dtype,
+                                            seed=seed)
+
+        # Provide some hints to shape inference
+        inferred_shape = tensor_shape.vector(n_val).concatenate(
+            self.get_batch_shape())
+        samples.set_shape(inferred_shape)
+
+        return (array_ops.expand_dims(self.a, 0) + array_ops.expand_dims(
+            self.range, 0) * samples)
+
+  @property
+  def mean(self):
+    return (self.a + self.b) / 2
+
+  @property
+  def variance(self):
+    return math_ops.square(self.range) / 12
+
+  @property
+  def range(self):
+    """`b - a`."""
+    return self.b - self.a
+
+  # TODO(rsepassi): Find a more efficient way of doing the broadcasting in_ones
+  # and _zeros.
+  def _ones(self):
+    return array_ops.ones_like(self.a + self.b)
+
+  def _zeros(self):
+    return array_ops.zeros_like(self.a + self.b)