diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 3a8c9f2321c..c9cfc922079 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -27,6 +27,33 @@ cuda_py_tests(
     ],
 )
 
+cuda_py_tests(
+    name = "gamma_test",
+    srcs = ["python/kernel_tests/gamma_test.py"],
+    additional_deps = [
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+cuda_py_tests(
+    name = "chi2_test",
+    srcs = ["python/kernel_tests/chi2_test.py"],
+    additional_deps = [
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+cuda_py_tests(
+    name = "exponential_test",
+    srcs = ["python/kernel_tests/exponential_test.py"],
+    additional_deps = [
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
 cuda_py_tests(
     name = "gaussian_test",
     size = "small",
@@ -65,7 +92,6 @@ cuda_py_tests(
     srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
     additional_deps = [
         ":distributions_py",
-        "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:platform_test",
     ],
 )
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 5b4bbac8270..74cedaa251e 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -27,6 +27,9 @@ initialized with parameters that define the distributions.
 
 ### Univariate (scalar) distributions
 
+@@Chi2
+@@Exponential
+@@Gamma
 @@Gaussian
 @@Uniform
 
@@ -50,8 +53,12 @@ from __future__ import division
 from __future__ import print_function
 
 # pylint: disable=unused-import,wildcard-import,line-too-long
+
+from tensorflow.contrib.distributions.python.ops.chi2 import *
 from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
 from tensorflow.contrib.distributions.python.ops.distribution import *
+from tensorflow.contrib.distributions.python.ops.exponential import *
+from tensorflow.contrib.distributions.python.ops.gamma import *
 from tensorflow.contrib.distributions.python.ops.gaussian import *
 from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
 from tensorflow.contrib.distributions.python.ops.mvn import *
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py b/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py
new file mode 100644
index 00000000000..84763735637
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py
@@ -0,0 +1,85 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for initializers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+import tensorflow as tf
+
+
+class Chi2Test(tf.test.TestCase):
+
+  def testChi2LogPDF(self):
+    with tf.Session():
+      batch_size = 6
+      df = tf.constant([2.0] * batch_size, dtype=np.float64)
+      df_v = 2.0
+      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64)
+      chi2 = tf.contrib.distributions.Chi2(df=df)
+      expected_log_pdf = stats.chi2.logpdf(x, df_v)
+
+      log_pdf = chi2.log_pdf(x)
+      self.assertEqual(log_pdf.get_shape(), (6,))
+      self.assertAllClose(log_pdf.eval(), expected_log_pdf)
+
+      pdf = chi2.pdf(x)
+      self.assertEqual(pdf.get_shape(), (6,))
+      self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
+
+  def testChi2CDF(self):
+    with tf.Session():
+      batch_size = 6
+      df = tf.constant([2.0] * batch_size, dtype=np.float64)
+      df_v = 2.0
+      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64)
+
+      chi2 = tf.contrib.distributions.Chi2(df=df)
+      expected_cdf = stats.chi2.cdf(x, df_v)
+
+      cdf = chi2.cdf(x)
+      self.assertEqual(cdf.get_shape(), (6,))
+      self.assertAllClose(cdf.eval(), expected_cdf)
+
+  def testChi2Mean(self):
+    with tf.Session():
+      df_v = np.array([1., 3, 5], dtype=np.float64)
+      expected_mean = stats.chi2.mean(df_v)
+      chi2 = tf.contrib.distributions.Chi2(df=df_v)
+      self.assertEqual(chi2.mean.get_shape(), (3,))
+      self.assertAllClose(chi2.mean.eval(), expected_mean)
+
+  def testChi2Variance(self):
+    with tf.Session():
+      df_v = np.array([1., 3, 5], np.float64)
+      expected_variances = stats.chi2.var(df_v)
+      chi2 = tf.contrib.distributions.Chi2(df=df_v)
+      self.assertEqual(chi2.variance.get_shape(), (3,))
+      self.assertAllClose(chi2.variance.eval(), expected_variances)
+
+  def testChi2Entropy(self):
+    with tf.Session():
+      df_v = np.array([1., 3, 5], dtype=np.float64)
+      expected_entropy = stats.chi2.entropy(df_v)
+      chi2 = tf.contrib.distributions.Chi2(df=df_v)
+      self.assertEqual(chi2.entropy().get_shape(), (3,))
+      self.assertAllClose(chi2.entropy().eval(), expected_entropy)
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py b/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py
new file mode 100644
index 00000000000..3113034b985
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py
@@ -0,0 +1,85 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for initializers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+import tensorflow as tf
+
+
+class ExponentialTest(tf.test.TestCase):
+
+  def testExponentialLogPDF(self):
+    with tf.Session():
+      batch_size = 6
+      lam = tf.constant([2.0] * batch_size)
+      lam_v = 2.0
+      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+      exponential = tf.contrib.distributions.Exponential(lam=lam)
+      expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
+
+      log_pdf = exponential.log_pdf(x)
+      self.assertEqual(log_pdf.get_shape(), (6,))
+      self.assertAllClose(log_pdf.eval(), expected_log_pdf)
+
+      pdf = exponential.pdf(x)
+      self.assertEqual(pdf.get_shape(), (6,))
+      self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
+
+  def testExponentialCDF(self):
+    with tf.Session():
+      batch_size = 6
+      lam = tf.constant([2.0] * batch_size)
+      lam_v = 2.0
+      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+
+      exponential = tf.contrib.distributions.Exponential(lam=lam)
+      expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
+
+      cdf = exponential.cdf(x)
+      self.assertEqual(cdf.get_shape(), (6,))
+      self.assertAllClose(cdf.eval(), expected_cdf)
+
+  def testExponentialMean(self):
+    with tf.Session():
+      lam_v = np.array([1.0, 4.0, 2.5])
+      expected_mean = stats.expon.mean(scale=1 / lam_v)
+      exponential = tf.contrib.distributions.Exponential(lam=lam_v)
+      self.assertEqual(exponential.mean.get_shape(), (3,))
+      self.assertAllClose(exponential.mean.eval(), expected_mean)
+
+  def testExponentialVariance(self):
+    with tf.Session():
+      lam_v = np.array([1.0, 4.0, 2.5])
+      expected_variance = stats.expon.var(scale=1 / lam_v)
+      exponential = tf.contrib.distributions.Exponential(lam=lam_v)
+      self.assertEqual(exponential.variance.get_shape(), (3,))
+      self.assertAllClose(exponential.variance.eval(), expected_variance)
+
+  def testExponentialEntropy(self):
+    with tf.Session():
+      lam_v = np.array([1.0, 4.0, 2.5])
+      expected_entropy = stats.expon.entropy(scale=1 / lam_v)
+      exponential = tf.contrib.distributions.Exponential(lam=lam_v)
+      self.assertEqual(exponential.entropy().get_shape(), (3,))
+      self.assertAllClose(exponential.entropy().eval(), expected_entropy)
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py
new file mode 100644
index 00000000000..22f44aeaf46
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py
@@ -0,0 +1,142 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for initializers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+import tensorflow as tf
+
+
+class GammaTest(tf.test.TestCase):
+
+  def testGammaShape(self):
+    with tf.Session():
+      alpha = tf.constant([3.0] * 5)
+      beta = tf.constant(11.0)
+      gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
+
+      self.assertEqual(gamma.batch_shape().eval(), (5,))
+      self.assertEqual(gamma.get_batch_shape(), tf.TensorShape([5]))
+      self.assertEqual(gamma.event_shape().eval(), 1)
+      self.assertEqual(gamma.get_event_shape(), tf.TensorShape([]))
+
+  def testGammaLogPDF(self):
+    with tf.Session():
+      batch_size = 6
+      alpha = tf.constant([2.0] * batch_size)
+      beta = tf.constant([3.0] * batch_size)
+      alpha_v = 2.0
+      beta_v = 3.0
+      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+      gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
+      expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+      log_pdf = gamma.log_pdf(x)
+      self.assertEqual(log_pdf.get_shape(), (6,))
+      self.assertAllClose(log_pdf.eval(), expected_log_pdf)
+
+      pdf = gamma.pdf(x)
+      self.assertEqual(pdf.get_shape(), (6,))
+      self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
+
+  def testGammaLogPDFMultidimensional(self):
+    with tf.Session():
+      batch_size = 6
+      alpha = tf.constant([[2.0, 4.0]] * batch_size)
+      beta = tf.constant([[3.0, 4.0]] * batch_size)
+      alpha_v = np.array([2.0, 4.0])
+      beta_v = np.array([3.0, 4.0])
+      x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+      gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
+      expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+      log_pdf = gamma.log_pdf(x)
+      log_pdf_values = log_pdf.eval()
+      self.assertEqual(log_pdf.get_shape(), (6, 2))
+      self.assertAllClose(log_pdf_values, expected_log_pdf)
+
+      pdf = gamma.pdf(x)
+      pdf_values = pdf.eval()
+      self.assertEqual(pdf.get_shape(), (6, 2))
+      self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+
+  def testGammaLogPDFMultidimensionalBroadcasting(self):
+    with tf.Session():
+      batch_size = 6
+      alpha = tf.constant([[2.0, 4.0]] * batch_size)
+      beta = tf.constant(3.0)
+      alpha_v = np.array([2.0, 4.0])
+      beta_v = 3.0
+      x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+      gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
+      expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+      log_pdf = gamma.log_pdf(x)
+      log_pdf_values = log_pdf.eval()
+      self.assertEqual(log_pdf.get_shape(), (6, 2))
+      self.assertAllClose(log_pdf_values, expected_log_pdf)
+
+      pdf = gamma.pdf(x)
+      pdf_values = pdf.eval()
+      self.assertEqual(pdf.get_shape(), (6, 2))
+      self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+
+  def testGammaCDF(self):
+    with tf.Session():
+      batch_size = 6
+      alpha = tf.constant([2.0] * batch_size)
+      beta = tf.constant([3.0] * batch_size)
+      alpha_v = 2.0
+      beta_v = 3.0
+      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+
+      gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
+      expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
+
+      cdf = gamma.cdf(x)
+      self.assertEqual(cdf.get_shape(), (6,))
+      self.assertAllClose(cdf.eval(), expected_cdf)
+
+  def testGammaMean(self):
+    with tf.Session():
+      alpha_v = np.array([1.0, 3.0, 2.5])
+      beta_v = np.array([1.0, 4.0, 5.0])
+      gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
+      expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
+      self.assertEqual(gamma.mean.get_shape(), (3,))
+      self.assertAllClose(gamma.mean.eval(), expected_means)
+
+  def testGammaVariance(self):
+    with tf.Session():
+      alpha_v = np.array([1.0, 3.0, 2.5])
+      beta_v = np.array([1.0, 4.0, 5.0])
+      gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
+      expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
+      self.assertEqual(gamma.variance.get_shape(), (3,))
+      self.assertAllClose(gamma.variance.eval(), expected_variances)
+
+  def testGammaEntropy(self):
+    with tf.Session():
+      alpha_v = np.array([1.0, 3.0, 2.5])
+      beta_v = np.array([1.0, 4.0, 5.0])
+      expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
+      gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
+      self.assertEqual(gamma.entropy().get_shape(), (3,))
+      self.assertAllClose(gamma.entropy().eval(), expected_entropy)
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py
new file mode 100644
index 00000000000..cdcb5620f20
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/chi2.py
@@ -0,0 +1,46 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Chi2 distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops import gamma
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+
+
+class Chi2(gamma.Gamma):
+  """The Chi2 distribution with degrees of freedom df.
+
+  The PDF of this distribution is:
+
+  ```pdf(x) = (x^(df/2 - 1)e^(-x/2))/(2^(k/2)Gamma(k/2)), x > 0```
+
+  Note that the Chi2 distribution is a special case of the Gamma distribution,
+  with Chi2(df) = Gamma(df/2, 1/2).
+  """
+
+  def __init__(self, df, name="Chi2"):
+    with ops.op_scope([df], name, "init"):
+      df = ops.convert_to_tensor(df)
+      self._df = df
+      super(Chi2, self).__init__(alpha=df / 2,
+                                 beta=math_ops.cast(0.5, dtype=df.dtype))
+
+  @property
+  def df(self):
+    return self._df
diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py
new file mode 100644
index 00000000000..4652e6b3ec7
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/exponential.py
@@ -0,0 +1,47 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Exponential distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops import gamma
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+
+
+class Exponential(gamma.Gamma):
+  """The Exponential distribution with rate parameter lam.
+
+  The PDF of this distribution is:
+
+  ```pdf(x) = (lam * e^(-lam * x)), x > 0```
+
+  Note that the Exponential distribution is a special case of the Gamma
+  distribution, with Exponential(lam) = Gamma(1, lam).
+  """
+
+  def __init__(self, lam, name="Exponential"):
+    with ops.op_scope([lam], name, "init"):
+      lam = ops.convert_to_tensor(lam)
+      self._lam = lam
+      super(Exponential, self).__init__(
+          alpha=math_ops.cast(1.0, dtype=lam.dtype),
+          beta=lam)
+
+  @property
+  def lam(self):
+    return self._lam
diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py
new file mode 100644
index 00000000000..2c445a3f12d
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/gamma.py
@@ -0,0 +1,208 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Gamma distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution  # pylint: disable=line-too-long
+from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util  # pylint: disable=line-too-long
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+class Gamma(ContinuousDistribution):
+  """The `Gamma` distribution with parameter alpha and beta.
+
+  The parameters are the shape and inverse scale parameters alpha, beta.
+
+  The PDF of this distribution is:
+
+  ```pdf(x) = (beta^alpha)(x^(alpha-1))e^(-x*beta)/Gamma(alpha), x > 0```
+
+  and the CDF of this distribution is:
+
+  ```cdf(x) =  GammaInc(alpha, beta * x) / Gamma(alpha), x > 0```
+
+  where GammaInc is the incomplete lower Gamma function.
+
+  Examples:
+
+  ```python
+  dist = Gamma(alpha=3.0, beta=2.0)
+  dist2 = Gamma(alpha=[3.0, 4.0], beta=[2.0, 3.0])
+  ```
+
+  """
+
+  def __init__(self, alpha, beta, name="Gamma"):
+    """Construct Gamma distributions with parameters `alpha` and `beta`.
+
+    The parameters `alpha` and `beta` must be shaped in a way that supports
+    broadcasting (e.g. `alpha + beta` is a valid operation).
+
+    Args:
+      alpha: `float` or `double` tensor, the shape params of the
+        distribution(s).
+        alpha must contain only positive values.
+      beta: `float` or `double` tensor, the inverse scale params of the
+        distribution(s).
+        beta must contain only positive values.
+      name: The name to give Ops created by the initializer.
+
+    Raises:
+      TypeError: if `alpha` and `beta` are different dtypes.
+    """
+    with ops.op_scope([alpha, beta], name):
+      alpha = ops.convert_to_tensor(alpha, name="alpha_before_dependencies")
+      beta = ops.convert_to_tensor(beta, name="beta_before_dependencies")
+      contrib_tensor_util.assert_same_float_dtype((alpha, beta))
+      with ops.control_dependencies([
+          check_ops.assert_positive(alpha), check_ops.assert_positive(beta)
+      ]):
+        self._alpha = alpha
+        self._beta = beta
+        self._name = name
+
+    with ops.op_scope([self._alpha, self._beta], name, "mean"):
+      self._mean = self._alpha / self._beta
+      self._batch_shape = self._mean.get_shape()
+
+    with ops.op_scope([self._alpha, self._beta], name, "variance"):
+      self._variance = self._alpha / math_ops.square(self._beta)
+
+    self._event_shape = tensor_shape.TensorShape([])
+
+  @property
+  def name(self):
+    return self._name
+
+  @property
+  def dtype(self):
+    return self._alpha.dtype
+
+  @property
+  def alpha(self):
+    return self._alpha
+
+  @property
+  def beta(self):
+    return self._beta
+
+  def batch_shape(self, name="batch_shape"):
+    with ops.name_scope(self.name):
+      return array_ops.shape(self._mean, name=name)
+
+  def get_batch_shape(self):
+    return self._batch_shape
+
+  def event_shape(self, name="event_shape"):
+    with ops.name_scope(self.name):
+      return constant_op.constant(1, name=name)
+
+  def get_event_shape(self):
+    return self._event_shape
+
+  @property
+  def mean(self):
+    return self._mean
+
+  @property
+  def variance(self):
+    return self._variance
+
+  def log_pdf(self, x, name="log_pdf"):
+    """Log pdf of observations in `x` under these Gamma distribution(s).
+
+    Args:
+      x: tensor of dtype `dtype`, must be broadcastable with `alpha` and `beta`.
+      name: The name to give this op.
+
+    Returns:
+      log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
+    Raises:
+      TypeError: if `x` and `alpha` are different dtypes.
+    """
+    with ops.op_scope([self._alpha, self._beta, x], self.name):
+      with ops.name_scope(name):
+        alpha = self._alpha
+        beta = self._beta
+        x = ops.convert_to_tensor(x)
+        x = control_flow_ops.with_dependencies(
+            [check_ops.assert_positive(x)], x)
+        contrib_tensor_util.assert_same_float_dtype(tensors=[x,],
+                                                    dtype=self.dtype)
+
+        return (alpha * math_ops.log(beta) + (alpha - 1) * math_ops.log(x) -
+                beta * x - math_ops.lgamma(self._alpha))
+
+  def pdf(self, x, name="pdf"):
+    with ops.name_scope(name):
+      return math_ops.exp(self.log_pdf(x, name))
+
+  def log_cdf(self, x, name="log_cdf"):
+    """Log CDF of observations `x` under these Gamma distribution(s).
+
+    Args:
+      x: tensor of dtype `dtype`, must be broadcastable with `alpha` and `beta`.
+      name: The name to give this op.
+
+    Returns:
+      log_cdf: tensor of dtype `dtype`, the log-CDFs of `x`.
+    """
+    with ops.op_scope([self._alpha, self._beta, x], self.name):
+      with ops.name_scope(name):
+        x = ops.convert_to_tensor(x)
+        x = control_flow_ops.with_dependencies(
+            [check_ops.assert_positive(x)], x)
+        contrib_tensor_util.assert_same_float_dtype(tensors=[x,],
+                                                    dtype=self.dtype)
+        # Note that igamma returns the regularized incomplete gamma function,
+        # which is what we want for the CDF.
+        return math_ops.log(math_ops.igamma(self._alpha, self._beta * x))
+
+  def cdf(self, x, name="cdf"):
+    with ops.op_scope([self._alpha, self._beta, x], self.name):
+      with ops.name_scope(name):
+        return math_ops.igamma(self._alpha, self._beta * x)
+
+  def entropy(self, name="entropy"):
+    """The entropy of Gamma distribution(s).
+
+    This is defined to be
+
+    ```entropy = alpha - log(beta) + log(Gamma(alpha))
+                 + (1-alpha)digamma(alpha)```
+
+    where digamma(alpha) is the digamma function.
+
+    Args:
+      name: The name to give this op.
+
+    Returns:
+      entropy: tensor of dtype `dtype`, the entropy.
+    """
+    with ops.op_scope([self.alpha, self._beta], self.name):
+      with ops.name_scope(name):
+        alpha = self._alpha
+        beta = self._beta
+        return (alpha - math_ops.log(beta) + math_ops.lgamma(alpha) +
+                (1 - alpha) * math_ops.digamma(alpha))