diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py
index 0adaf7d816d..e1644d548da 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py
@@ -170,6 +170,56 @@ class PoissonTest(test.TestCase):
       self.assertEqual((6,), poisson.mode().get_shape())
       self.assertAllClose(lam_v, poisson.mode().eval())
 
+  def testPoissonSample(self):
+    with self.test_session():
+      lam_v = 4.0
+      lam = constant_op.constant(lam_v)
+      # Choosing `n >= (k/rtol)**2, roughly ensures our sample mean should be
+      # within `k` std. deviations of actual up to rtol precision.
+      n = int(100e3)
+      poisson = poisson_lib.Poisson(rate=lam)
+      samples = poisson.sample(n, seed=123456)
+      sample_values = samples.eval()
+      self.assertEqual(samples.get_shape(), (n,))
+      self.assertEqual(sample_values.shape, (n,))
+      self.assertAllClose(
+          sample_values.mean(), stats.poisson.mean(lam_v), rtol=.01)
+      self.assertAllClose(
+          sample_values.var(), stats.poisson.var(lam_v), rtol=.01)
+
+  def testPoissonSampleMultidimensionalMean(self):
+    with self.test_session():
+      lam_v = np.array([np.arange(1, 51, dtype=np.float32)])  # 1 x 50
+      poisson = poisson_lib.Poisson(rate=lam_v)
+      # Choosing `n >= (k/rtol)**2, roughly ensures our sample mean should be
+      # within `k` std. deviations of actual up to rtol precision.
+      n = int(100e3)
+      samples = poisson.sample(n, seed=123456)
+      sample_values = samples.eval()
+      self.assertEqual(samples.get_shape(), (n, 1, 50))
+      self.assertEqual(sample_values.shape, (n, 1, 50))
+      self.assertAllClose(
+          sample_values.mean(axis=0),
+          stats.poisson.mean(lam_v),
+          rtol=.01,
+          atol=0)
+
+  def testPoissonSampleMultidimensionalVariance(self):
+    with self.test_session():
+      lam_v = np.array([np.arange(5, 15, dtype=np.float32)])  # 1 x 10
+      poisson = poisson_lib.Poisson(rate=lam_v)
+      # Choosing `n >= 2 * lam * (k/rtol)**2, roughly ensures our sample
+      # variance should be within `k` std. deviations of actual up to rtol
+      # precision.
+      n = int(300e3)
+      samples = poisson.sample(n, seed=123456)
+      sample_values = samples.eval()
+      self.assertEqual(samples.get_shape(), (n, 1, 10))
+      self.assertEqual(sample_values.shape, (n, 1, 10))
+
+      self.assertAllClose(
+          sample_values.var(axis=0), stats.poisson.var(lam_v), rtol=.03, atol=0)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py
index 799796ace0c..e1ddc9a0e18 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson.py
@@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
-
+from tensorflow.python.ops import random_ops
 
 __all__ = [
     "Poisson",
@@ -148,6 +148,10 @@ class Poisson(distribution.Distribution):
   def _mode(self):
     return math_ops.floor(self.rate)
 
+  def _sample_n(self, n, seed=None):
+    return random_ops.random_poisson(
+        self.rate, [n], dtype=self.dtype, seed=seed)
+
   def _assert_valid_sample(self, x, check_integer=True):
     if not self.validate_args:
       return x
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 2c75353ffed..4dd9bffe80a 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -635,6 +635,7 @@ cc_library(
         "//tensorflow/core/kernels:parameterized_truncated_normal_op",
         "//tensorflow/core/kernels:parsing",
         "//tensorflow/core/kernels:random_ops",
+        "//tensorflow/core/kernels:random_poisson_op",
         "//tensorflow/core/kernels:remote_fused_graph_ops",
         "//tensorflow/core/kernels:required",
         "//tensorflow/core/kernels:resource_variable_ops",
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index ef4dd047875..f0ab5520f11 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -219,6 +219,16 @@ Node* RandomGamma(Graph* g, Node* shape, Node* alpha) {
   return ret;
 }
 
+Node* RandomPoisson(Graph* g, Node* shape, Node* lam) {
+  Node* ret;
+  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson")
+                  .Input(shape)
+                  .Input(lam)
+                  .Attr("seed", 0)
+                  .Finalize(g, &ret));
+  return ret;
+}
+
 Node* Unary(Graph* g, const string& func, Node* input, int index) {
   Node* ret;
   TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 7a23b20c2c8..d508f65ada5 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -113,6 +113,10 @@ Node* RandomGaussian(Graph* g, Node* input, DataType dtype);
 // Output dtype determined by alpha.
 Node* RandomGamma(Graph* g, Node* shape, Node* alpha);
 
+// Generates random poisson distribution with the given shape and lam[s].
+// Output dtype determined by lam.
+Node* RandomPoisson(Graph* g, Node* shape, Node* lam);
+
 // Generates random parameters from the truncated standard normal distribution
 // of the nput shape
 Node* TruncatedNormal(Graph* g, Node* input, DataType dtype);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ab5923b0e7b..06a11d31ab0 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3361,6 +3361,33 @@ tf_cuda_cc_test(
     ],
 )
 
+tf_kernel_library(
+    name = "random_poisson_op",
+    prefix = "random_poisson_op",
+    deps = [
+        ":random_ops",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:random_ops_op_lib",
+    ],
+)
+
+tf_cuda_cc_test(
+    name = "random_poisson_op_test",
+    size = "small",
+    srcs = ["random_poisson_op_test.cc"],
+    deps = [
+        ":ops_util",
+        ":random_poisson_op",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
 tf_kernel_library(
     name = "word2vec_kernels",
     prefix = "word2vec_kernels",
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index 0a1de111627..f3c7e0f26b1 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -541,7 +541,7 @@ TF_CALL_int64(REGISTER_INT);
       PhiloxRandomOp<                                               \
           GPUDevice,                                                \
           random::TruncatedNormalDistribution<                      \
-              random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
+              random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
 
 #define REGISTER_INT(IntType)                                   \
   REGISTER_KERNEL_BUILDER(Name("RandomUniformInt")              \
diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc
new file mode 100644
index 00000000000..553a4a7f939
--- /dev/null
+++ b/tensorflow/core/kernels/random_poisson_op.cc
@@ -0,0 +1,357 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/random_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/random_poisson_op.h"
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+#if EIGEN_COMP_GNUC && __cplusplus > 199711L
+#define DISABLE_FLOAT_EQUALITY_WARNING \
+  _Pragma("GCC diagnostic push")       \
+      _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
+#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
+#else
+#define DISABLE_FLOAT_EQUALITY_WARNING
+#define ENABLE_FLOAT_EQUALITY_WARNING
+#endif
+
+#define UNIFORM(X)                                    \
+  if (uniform_remaining == 0) {                       \
+    uniform_remaining = Uniform::kResultElementCount; \
+    uniform_result = uniform(&gen);                   \
+  }                                                   \
+  uniform_remaining--;                                \
+  CT X = uniform_result[uniform_remaining]
+
+namespace tensorflow {
+namespace {
+
+static constexpr int kReservedSamplesPerOutput = 256;
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// We will compute half-precision Poisson samples with float precision
+// intermediate calculations.
+template <typename T>
+struct PoissonComputeType {
+  typedef T ComputeType;
+};
+
+template <>
+struct PoissonComputeType<Eigen::half> {
+  typedef float ComputeType;
+};
+
+}  // namespace
+
+namespace functor {
+
+template <typename Device, typename T>
+struct PoissonFunctor {
+  void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
+                  int num_rate, int num_samples,
+                  const random::PhiloxRandom& rng, T* samples_flat);
+};
+
+template <typename T>
+struct PoissonFunctor<CPUDevice, T> {
+  void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat,
+                  int num_rate, int num_samples,
+                  const random::PhiloxRandom& rng, T* samples_flat) {
+    // Two different algorithms are employed, depending on the size of
+    // rate.
+    // If rate < 10, we use an algorithm attributed to Knuth:
+    // Seminumerical Algorithms. Art of Computer Programming, Volume 2.
+    //
+    // This algorithm runs in O(rate) time, and will require O(rate)
+    // uniform
+    // variates.
+    //
+    // If rate >= 10 we use a transformation-rejection algorithm from
+    // pairs
+    // of uniform random variables due to Hormann.
+    // http://www.sciencedirect.com/science/article/pii/0167668793909974
+    //
+    // The algorithm has an acceptance rate of ~89% for the smallest rate
+    // (~10),
+    // and higher accept rates for higher rate, so runtime is
+    // O(NumRate * NumSamples * k) with k ~ 1 / 0.89.
+    //
+    // We partition work first across rates then across
+    // samples-per-rate to
+    // avoid a couple flops which can be done on a per-rate basis.
+
+    typedef random::UniformDistribution<random::PhiloxRandom, CT> Uniform;
+
+    auto DoWork = [num_samples, num_rate, &rng, samples_flat, rate_flat](
+        int start_output, int limit_output) {
+      // Capturing "rng" by value would only make a copy for the _shared_
+      // lambda.  Since we want to let each worker have its own copy, we pass
+      // "rng" by reference and explicitly do a copy assignment.
+
+      Uniform uniform;
+      typename Uniform::ResultType uniform_result;
+      for (int64 output_idx = start_output; output_idx < limit_output;
+           /* output_idx incremented within inner loop below */) {
+        const int64 rate_idx = output_idx / num_samples;
+
+        // Several calculations can be done on a per-rate basis.
+        const CT rate = CT(rate_flat[rate_idx]);
+
+        auto samples_rate_output = samples_flat + rate_idx;
+
+        if (rate < CT(10)) {
+          // Knuth's algorithm for generating Poisson random variates.
+          // Given a Poisson process, the time between events is exponentially
+          // distributed. If we have a Poisson process with rate lambda, then,
+          // the time between events is distributed Exp(lambda). If X ~
+          // Uniform(0, 1), then Y ~ Exp(lambda), where Y = -log(X) / lambda.
+          // Thus to simulate a Poisson draw, we can draw X_i ~ Exp(lambda),
+          // and N ~ Poisson(lambda), where N is the least number such that
+          // \sum_i^N X_i > 1.
+          const CT exp_neg_rate = Eigen::numext::exp(-rate);
+
+          // Compute the rest of the samples for the current rate value.
+          for (int64 sample_idx = output_idx % num_samples;
+               sample_idx < num_samples && output_idx < limit_output;
+               sample_idx++, output_idx++) {
+            random::PhiloxRandom gen = rng;
+            gen.Skip(kReservedSamplesPerOutput * output_idx);
+            int16 uniform_remaining = 0;
+
+            CT prod = 1;
+            CT x = 0;
+
+            // Keep trying until we surpass e^(-rate). This will take
+            // expected time proportional to rate.
+            while (true) {
+              UNIFORM(u);
+              prod = prod * u;
+              if (prod <= exp_neg_rate) {
+                samples_rate_output[sample_idx * num_rate] = T(x);
+                break;
+              }
+              x += 1;
+            }
+          }
+          continue;
+        }
+        // Transformed rejection due to Hormann.
+        //
+        // Given a CDF F(x), and G(x), a dominating distribution chosen such
+        // that it is close to the inverse CDF F^-1(x), compute the following
+        // steps:
+        //
+        // 1) Generate U and V, two independent random variates. Set U = U - 0.5
+        // (this step isn't strictly necessary, but is done to make some
+        // calculations symmetric and convenient. Henceforth, G is defined on
+        // [-0.5, 0.5]).
+        //
+        // 2) If V <= alpha * F'(G(U)) * G'(U), return floor(G(U)), else return
+        // to step 1. alpha is the acceptance probability of the rejection
+        // algorithm.
+        //
+        // For more details on transformed rejection, see:
+        // http://citeseer.ist.psu.edu/viewdoc/citations;jsessionid=1BEB35946CC807879F55D42512E5490C?doi=10.1.1.48.3054.
+        //
+        // The dominating distribution in this case:
+        //
+        // G(u) = (2 * a / (2 - |u|) + b) * u + c
+
+        using Eigen::numext::log;
+        const CT log_rate = log(rate);
+
+        // Constants used to define the dominating distribution. Names taken
+        // from Hormann's paper. Constants were chosen to define the tightest
+        // G(u) for the inverse Poisson CDF.
+        const CT b = CT(0.931) + CT(2.53) * Eigen::numext::sqrt(rate);
+        const CT a = CT(-0.059) + CT(0.02483) * b;
+
+        // This is the inverse acceptance rate. At a minimum (when rate = 10),
+        // this corresponds to ~75% acceptance. As the rate becomes larger, this
+        // approaches ~89%.
+        const CT inv_alpha = CT(1.1239) + CT(1.1328) / (b - CT(3.4));
+
+        // Compute the rest of the samples for the current rate value.
+        for (int64 sample_idx = output_idx % num_samples;
+             sample_idx < num_samples && output_idx < limit_output;
+             sample_idx++, output_idx++) {
+          random::PhiloxRandom gen = rng;
+          gen.Skip(kReservedSamplesPerOutput * output_idx);
+          int16 uniform_remaining = 0;
+
+          while (true) {
+            UNIFORM(u);
+            u -= CT(0.5);
+            UNIFORM(v);
+
+            CT u_shifted = CT(0.5) - Eigen::numext::abs(u);
+            CT k = Eigen::numext::floor((CT(2) * a / u_shifted + b) * u + rate +
+                                        CT(0.43));
+
+            // When alpha * f(G(U)) * G'(U) is close to 1, it is possible to
+            // find a rectangle (-u_r, u_r) x (0, v_r) under the curve, such
+            // that if v <= v_r and |u| <= u_r, then we can accept.
+            // Here v_r = 0.9227 - 3.6224 / (b - 2) and u_r = 0.43.
+            if (u_shifted >= CT(0.07) &&
+                v <= CT(0.9277) - CT(3.6224) / (b - CT(2))) {
+              samples_rate_output[sample_idx * num_rate] = T(k);
+              break;
+            }
+
+            if (k < 0 || (u_shifted < CT(0.013) && v > u_shifted)) {
+              continue;
+            }
+
+            // The expression below is equivalent to the computation of step 2)
+            // in transformed rejection (v <= alpha * F'(G(u)) * G'(u)).
+            CT s = log(v * inv_alpha / (a / (u_shifted * u_shifted) + b));
+            CT t = -rate + k * log_rate - Eigen::numext::lgamma(k + 1);
+            if (s <= t) {
+              samples_rate_output[sample_idx * num_rate] = T(k);
+              break;
+            }
+          }
+        }
+      }
+    };
+
+    // This will depend on rate.
+    // For rate < 10, on average, O(rate) calls to uniform are
+    // needed, with that
+    // many multiplies. ~10 uniform calls on average with ~25 cost op calls.
+    //
+    // Very roughly, for rate >= 10, the single call to log + call to
+    // lgamma
+    // occur for ~60 percent of samples.
+    // 2 x 100 (64-bit cycles per log) * 0.62 = ~124
+    // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each:
+    // 40 * .62  = ~25.
+    //
+    // Finally, there are several other ops that are done every loop along with
+    // 2 uniform generations along with 5 other ops at 3-6 cycles each.
+    // ~15 / .89 = ~16
+    //
+    // In total this should be ~165 + 2 * Uniform::kElementCost.
+    // We assume that half the tensor has rate < 10, so on average 6
+    // uniform's
+    // will be needed. We will upper bound the other op cost by the one for
+    // rate > 10.
+    static const int kElementCost = 165 + 6 * Uniform::kElementCost +
+                                    6 * random::PhiloxRandom::kElementCost;
+    auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
+    Shard(worker_threads.num_threads, worker_threads.workers,
+          num_rate * num_samples, kElementCost, DoWork);
+  }
+
+ private:
+  typedef typename PoissonComputeType<T>::ComputeType CT;
+};
+
+}  // namespace functor
+
+namespace {
+
+// Samples from one or more Poisson distributions.
+template <typename T>
+class RandomPoissonOp : public OpKernel {
+ public:
+  explicit RandomPoissonOp(OpKernelConstruction* context) : OpKernel(context) {
+    OP_REQUIRES_OK(context, generator_.Init(context));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor& shape_t = ctx->input(0);
+    const Tensor& rate_t = ctx->input(1);
+
+    OP_REQUIRES(ctx,
+                TensorShapeUtils::IsVector(shape_t.shape()) &&
+                    (shape_t.dtype() == DataType::DT_INT32 ||
+                     shape_t.dtype() == DataType::DT_INT64),
+                errors::InvalidArgument(
+                    "shape must be a vector of {int32,int64}, got shape: ",
+                    shape_t.DebugString()));
+    TensorShape samples_shape;
+    if (shape_t.dtype() == DataType::DT_INT32) {
+      auto vec = shape_t.flat<int32>();
+      OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
+                                                      &samples_shape));
+    } else if (shape_t.dtype() == DataType::DT_INT64) {
+      auto vec = shape_t.flat<int64>();
+      OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
+                                                      &samples_shape));
+    }
+    const int64 num_samples = samples_shape.num_elements();
+    OP_REQUIRES(ctx, num_samples > 0,
+                errors::InvalidArgument(
+                    "Input shape should have non-zero element count, got: ",
+                    num_samples));
+
+    samples_shape.AppendShape(rate_t.shape());
+    // Allocate output samples.
+    Tensor* samples_t = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
+
+    const auto rate_flat = rate_t.flat<T>().data();
+    const int64 num_rate = rate_t.NumElements();
+    OP_REQUIRES(
+        ctx, num_rate > 0,
+        errors::InvalidArgument(
+            "Input rate should have non-zero element count, got: ", num_rate));
+    auto samples_flat = samples_t->flat<T>().data();
+    random::PhiloxRandom rng = generator_.ReserveRandomOutputs(
+        num_samples * num_rate, kReservedSamplesPerOutput);
+
+    functor::PoissonFunctor<CPUDevice, T>()(ctx, ctx->eigen_device<CPUDevice>(),
+                                            rate_flat, num_rate, num_samples,
+                                            rng, samples_flat);
+  }
+
+ private:
+  GuardedPhiloxRandom generator_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(RandomPoissonOp);
+};
+}  // namespace
+
+#undef UNIFORM
+
+#define REGISTER(TYPE)                                                        \
+  REGISTER_KERNEL_BUILDER(                                                    \
+      Name("RandomPoisson").Device(DEVICE_CPU).TypeConstraint<TYPE>("dtype"), \
+      RandomPoissonOp<TYPE>);
+
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+
+#undef REGISTER
+
+}  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/random_poisson_op.h b/tensorflow/core/kernels/random_poisson_op.h
new file mode 100644
index 00000000000..6c49acc8007
--- /dev/null
+++ b/tensorflow/core/kernels/random_poisson_op.h
@@ -0,0 +1,31 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
+#define TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
+
+namespace tensorflow {
+
+namespace functor {
+
+// Generic helper functor for the Random Poisson Op.
+template <typename Device, typename T>
+struct PoissonFunctor;
+
+}  // namespace functor
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
diff --git a/tensorflow/core/kernels/random_poisson_op_test.cc b/tensorflow/core/kernels/random_poisson_op_test.cc
new file mode 100644
index 00000000000..bccdbf6c7f5
--- /dev/null
+++ b/tensorflow/core/kernels/random_poisson_op_test.cc
@@ -0,0 +1,82 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <random>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+Tensor VecShape(int64 v) {
+  if (v >= std::numeric_limits<int32>::max()) {
+    Tensor shape(DT_INT64, TensorShape({1}));
+    shape.vec<int64>()(0) = v;
+    return shape;
+  } else {
+    Tensor shape(DT_INT32, TensorShape({1}));
+    shape.vec<int32>()(0) = v;
+    return shape;
+  }
+}
+
+Tensor VecLam32(int64 n, int magnitude) {
+  std::mt19937 gen(0x12345);
+  std::uniform_real_distribution<float> dist(0.0, 1.0);
+  Tensor lams(DT_FLOAT, TensorShape({n}));
+  for (int i = 0; i < n; i++) {
+    // Generate in range (magnitude, 2 * magnitude)
+    lams.vec<float>()(i) = magnitude * (1 + dist(gen));
+  }
+  return lams;
+}
+
+Tensor VecLam64(int64 n, int magnitude) {
+  std::mt19937 gen(0x12345);
+  std::uniform_real_distribution<double> dist(0.0, 1.0);
+  Tensor lams(DT_DOUBLE, TensorShape({n}));
+  for (int i = 0; i < n; i++) {
+    // Generate in range (magnitude, 2 * magnitude)
+    lams.vec<double>()(i) = magnitude * (1 + dist(gen));
+  }
+  return lams;
+}
+
+#define BM_Poisson(DEVICE, BITS, MAGNITUDE)                            \
+  static void BM_##DEVICE##_RandomPoisson_lam_##MAGNITUDE##_##BITS(    \
+      int iters, int nsamp, int nlam) {                                \
+    testing::ItemsProcessed(static_cast<int64>(iters) * nsamp * nlam); \
+    Graph* g = new Graph(OpRegistry::Global());                        \
+    test::graph::RandomPoisson(                                        \
+        g, test::graph::Constant(g, VecShape(nsamp)),                  \
+        test::graph::Constant(g, VecLam##BITS(nlam, MAGNITUDE)));      \
+    test::Benchmark(#DEVICE, g).Run(iters);                            \
+  }                                                                    \
+  BENCHMARK(BM_##DEVICE##_RandomPoisson_lam_##MAGNITUDE##_##BITS)      \
+      ->RangePair(1, 64, 2, 50);
+
+BM_Poisson(cpu, 32, 1);
+BM_Poisson(cpu, 32, 8);
+BM_Poisson(cpu, 32, 32);
+
+BM_Poisson(cpu, 64, 1);
+BM_Poisson(cpu, 64, 8);
+BM_Poisson(cpu, 64, 32);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 776523f33fb..7b2da9d8e6d 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -276,4 +276,48 @@ output: A tensor with shape `shape + shape(alpha)`. Each slice
   `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
 )doc");
 
+REGISTER_OP("RandomPoisson")
+    .SetIsStateful()
+    .Input("shape: S")
+    .Input("rate: dtype")
+    .Output("output: dtype")
+    .Attr("seed: int = 0")
+    .Attr("seed2: int = 0")
+    .Attr("S: {int32, int64}")
+    .Attr("dtype: {half, float, double}")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle out;
+      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+      TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
+      c->set_output(0, out);
+      return Status::OK();
+    })
+    .Doc(R"doc(
+Outputs random values from the Poisson distribution(s) described by rate.
+
+This op uses two algorithms, depending on rate. If rate >= 10, then
+the algorithm by Hormann is used to acquire samples via
+transformation-rejection.
+See http://www.sciencedirect.com/science/article/pii/0167668793909974.
+
+Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
+random variables.
+See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
+Programming, Volume 2. Addison Wesley
+
+shape: 1-D integer tensor. Shape of independent samples to draw from each
+  distribution described by the shape parameters given in rate.
+rate: A tensor in which each scalar is a "rate" parameter describing the
+  associated poisson distribution.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+  generator is seeded by the given seed.  Otherwise, it is seeded by a
+  random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor with shape `shape + shape(rate)`. Each slice
+  `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
+  `rate[i0, i1, ...iN]`. The dtype of the output matches the dtype of
+  rate.
+)doc");
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/random_ops_test.cc b/tensorflow/core/ops/random_ops_test.cc
index 524e1079981..b0aa565485b 100644
--- a/tensorflow/core/ops/random_ops_test.cc
+++ b/tensorflow/core/ops/random_ops_test.cc
@@ -53,4 +53,19 @@ TEST(RandomOpsTest, RandomGamma_ShapeFn) {
   INFER_OK(op, "[3];[]", "[1,2,3]");
 }
 
+TEST(RandomOpsTest, RandomPoisson_ShapeFn) {
+  ShapeInferenceTestOp op("RandomPoisson");
+  op.input_tensors.resize(2);
+
+  INFER_OK(op, "?;?", "?");
+  INFER_OK(op, "?;[3]", "?");
+  INFER_OK(op, "[1];?", "?");
+  INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,2];[3,4]");
+  Tensor shape = test::AsTensor<int32>({1, 2, 3});
+  op.input_tensors[0] = &shape;
+  INFER_OK(op, "[3];[4,?]", "[1,2,3,d1_0,d1_1]");
+  INFER_OK(op, "[3];[4,5]", "[1,2,3,d1_0,d1_1]");
+  INFER_OK(op, "[3];[]", "[1,2,3]");
+}
+
 }  // end namespace tensorflow
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index 3bcc5377797..05a520850e1 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -96,6 +96,7 @@ print(sess.run(var))
 @@random_crop
 @@multinomial
 @@random_gamma
+@@random_poisson
 @@set_random_seed
 """
 
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 119cffe0e9b..2ecbd089922 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2112,6 +2112,21 @@ cuda_py_test(
     ],
 )
 
+cuda_py_test(
+    name = "random_poisson_test",
+    size = "medium",
+    srcs = ["random_poisson_test.py"],
+    additional_deps = [
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:random_ops",
+    ],
+)
+
 cuda_py_test(
     name = "rnn_test",
     size = "medium",
diff --git a/tensorflow/python/kernel_tests/random_poisson_test.py b/tensorflow/python/kernel_tests/random_poisson_test.py
new file mode 100644
index 00000000000..01281b7bd03
--- /dev/null
+++ b/tensorflow/python/kernel_tests/random_poisson_test.py
@@ -0,0 +1,178 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.random_ops.random_poisson."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange  # pylint: disable=redefined-builtin
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+class RandomPoissonTest(test.TestCase):
+  """This is a large test due to the moments computation taking some time."""
+
+  def _Sampler(self, num, lam, dtype, use_gpu, seed=None):
+
+    def func():
+      with self.test_session(use_gpu=use_gpu, graph=ops.Graph()) as sess:
+        rng = random_ops.random_poisson(lam, [num], dtype=dtype, seed=seed)
+        ret = np.empty([10, num])
+        for i in xrange(10):
+          ret[i, :] = sess.run(rng)
+      return ret
+
+    return func
+
+  # TODO(srvasude): Factor this out along with the corresponding moment testing
+  # method in random_gamma_test into a single library.
+  def testMoments(self):
+    try:
+      from scipy import stats  # pylint: disable=g-import-not-at-top
+    except ImportError as e:
+      tf_logging.warn("Cannot test moments: %s", e)
+      return
+    # The moments test is a z-value test.  This is the largest z-value
+    # we want to tolerate. Since the z-test approximates a unit normal
+    # distribution, it should almost definitely never exceed 6.
+    z_limit = 6.0
+    for dt in dtypes.float16, dtypes.float32, dtypes.float64:
+      # Test when lam < 10 and when lam >= 10
+      for stride in 0, 4, 10:
+        for lam in (3., 20):
+          max_moment = 5
+          sampler = self._Sampler(10000, lam, dt, use_gpu=False, seed=12345)
+          moments = [0] * (max_moment + 1)
+          moments_sample_count = [0] * (max_moment + 1)
+          x = np.array(sampler().flat)  # sampler does 10x samples
+          for k in range(len(x)):
+            moment = 1.
+            for i in range(max_moment + 1):
+              index = k + i * stride
+              if index >= len(x):
+                break
+              moments[i] += moment
+              moments_sample_count[i] += 1
+              moment *= x[index]
+          for i in range(max_moment + 1):
+            moments[i] /= moments_sample_count[i]
+          for i in range(1, max_moment + 1):
+            g = stats.poisson(lam)
+            if stride == 0:
+              moments_i_mean = g.moment(i)
+              moments_i_squared = g.moment(2 * i)
+            else:
+              moments_i_mean = pow(g.moment(1), i)
+              moments_i_squared = pow(g.moment(2), i)
+            moments_i_var = (
+                moments_i_squared - moments_i_mean * moments_i_mean)
+            # Assume every operation has a small numerical error.
+            # It takes i multiplications to calculate one i-th moment.
+            error_per_moment = i * 1e-6
+            total_variance = (
+                moments_i_var / moments_sample_count[i] + error_per_moment)
+            if not total_variance:
+              total_variance = 1e-10
+            # z_test is approximately a unit normal distribution.
+            z_test = abs(
+                (moments[i] - moments_i_mean) / np.sqrt(total_variance))
+            self.assertLess(z_test, z_limit)
+
+  # Checks that the CPU and GPU implementation returns the same results,
+  # given the same random seed
+  def testCPUGPUMatch(self):
+    for dt in dtypes.float16, dtypes.float32, dtypes.float64:
+      results = {}
+      for use_gpu in [False, True]:
+        sampler = self._Sampler(1000, 1.0, dt, use_gpu=use_gpu, seed=12345)
+        results[use_gpu] = sampler()
+      if dt == dtypes.float16:
+        self.assertAllClose(results[False], results[True], rtol=1e-3, atol=1e-3)
+      else:
+        self.assertAllClose(results[False], results[True], rtol=1e-6, atol=1e-6)
+
+  def testSeed(self):
+    for dt in dtypes.float16, dtypes.float32, dtypes.float64:
+      sx = self._Sampler(1000, 1.0, dt, use_gpu=True, seed=345)
+      sy = self._Sampler(1000, 1.0, dt, use_gpu=True, seed=345)
+      self.assertAllEqual(sx(), sy())
+
+  def testNoCSE(self):
+    """CSE = constant subexpression eliminator.
+
+    SetIsStateful() should prevent two identical random ops from getting
+    merged.
+    """
+    for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
+      with self.test_session(use_gpu=True):
+        rnd1 = random_ops.random_poisson(2.0, [24], dtype=dtype)
+        rnd2 = random_ops.random_poisson(2.0, [24], dtype=dtype)
+        diff = rnd2 - rnd1
+        # Since these are all positive integers, the norm will
+        # be at least 1 if they are different.
+        self.assertGreaterEqual(np.linalg.norm(diff.eval()), 1)
+
+  def testShape(self):
+    # Fully known shape.
+    rnd = random_ops.random_poisson(2.0, [150], seed=12345)
+    self.assertEqual([150], rnd.get_shape().as_list())
+    rnd = random_ops.random_poisson(
+        lam=array_ops.ones([1, 2, 3]),
+        shape=[150],
+        seed=12345)
+    self.assertEqual([150, 1, 2, 3], rnd.get_shape().as_list())
+    rnd = random_ops.random_poisson(
+        lam=array_ops.ones([1, 2, 3]),
+        shape=[20, 30],
+        seed=12345)
+    self.assertEqual([20, 30, 1, 2, 3], rnd.get_shape().as_list())
+    rnd = random_ops.random_poisson(
+        lam=array_ops.placeholder(dtypes.float32, shape=(2,)),
+        shape=[12],
+        seed=12345)
+    self.assertEqual([12, 2], rnd.get_shape().as_list())
+    # Partially known shape.
+    rnd = random_ops.random_poisson(
+        lam=array_ops.ones([7, 3]),
+        shape=array_ops.placeholder(dtypes.int32, shape=(1,)),
+        seed=12345)
+    self.assertEqual([None, 7, 3], rnd.get_shape().as_list())
+    rnd = random_ops.random_poisson(
+        lam=array_ops.ones([9, 6]),
+        shape=array_ops.placeholder(dtypes.int32, shape=(3,)),
+        seed=12345)
+    self.assertEqual([None, None, None, 9, 6], rnd.get_shape().as_list())
+    # Unknown shape.
+    rnd = random_ops.random_poisson(
+        lam=array_ops.placeholder(dtypes.float32),
+        shape=array_ops.placeholder(dtypes.int32),
+        seed=12345)
+    self.assertIs(None, rnd.get_shape().ndims)
+    rnd = random_ops.random_poisson(
+        lam=array_ops.placeholder(dtypes.float32),
+        shape=[50],
+        seed=12345)
+    self.assertIs(None, rnd.get_shape().ndims)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index dab15976b58..4937f1a50a8 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -282,6 +282,7 @@ ParseSingleSequenceExample
 
 # random_ops
 RandomGamma
+RandomPoisson
 RandomUniform
 RandomUniformInt
 RandomShuffle
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 34b4d361021..5a753ae7a1e 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -71,10 +71,8 @@ def random_normal(shape,
     mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
     stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
     seed1, seed2 = random_seed.get_seed(seed)
-    rnd = gen_random_ops._random_standard_normal(shape_tensor,
-                                                 dtype,
-                                                 seed=seed1,
-                                                 seed2=seed2)
+    rnd = gen_random_ops._random_standard_normal(
+        shape_tensor, dtype, seed=seed1, seed2=seed2)
     mul = rnd * stddev_tensor
     value = math_ops.add(mul, mean_tensor, name=name)
     return value
@@ -125,13 +123,14 @@ def parameterized_truncated_normal(shape,
     minvals_tensor = ops.convert_to_tensor(minvals, dtype=dtype, name="minvals")
     maxvals_tensor = ops.convert_to_tensor(maxvals, dtype=dtype, name="maxvals")
     seed1, seed2 = random_seed.get_seed(seed)
-    rnd = gen_random_ops._parameterized_truncated_normal(shape_tensor,
-                                                         means_tensor,
-                                                         stddevs_tensor,
-                                                         minvals_tensor,
-                                                         maxvals_tensor,
-                                                         seed=seed1,
-                                                         seed2=seed2)
+    rnd = gen_random_ops._parameterized_truncated_normal(
+        shape_tensor,
+        means_tensor,
+        stddevs_tensor,
+        minvals_tensor,
+        maxvals_tensor,
+        seed=seed1,
+        seed2=seed2)
     return rnd
 
 
@@ -168,10 +167,8 @@ def truncated_normal(shape,
     mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
     stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
     seed1, seed2 = random_seed.get_seed(seed)
-    rnd = gen_random_ops._truncated_normal(shape_tensor,
-                                           dtype,
-                                           seed=seed1,
-                                           seed2=seed2)
+    rnd = gen_random_ops._truncated_normal(
+        shape_tensor, dtype, seed=seed1, seed2=seed2)
     mul = rnd * stddev_tensor
     value = math_ops.add(mul, mean_tensor, name=name)
     return value
@@ -232,17 +229,11 @@ def random_uniform(shape,
     maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
     seed1, seed2 = random_seed.get_seed(seed)
     if dtype.is_integer:
-      return gen_random_ops._random_uniform_int(shape,
-                                                minval,
-                                                maxval,
-                                                seed=seed1,
-                                                seed2=seed2,
-                                                name=name)
+      return gen_random_ops._random_uniform_int(
+          shape, minval, maxval, seed=seed1, seed2=seed2, name=name)
     else:
-      rnd = gen_random_ops._random_uniform(shape,
-                                           dtype,
-                                           seed=seed1,
-                                           seed2=seed2)
+      rnd = gen_random_ops._random_uniform(
+          shape, dtype, seed=seed1, seed2=seed2)
       return math_ops.add(rnd * (maxval - minval), minval, name=name)
 
 
@@ -275,10 +266,8 @@ def random_shuffle(value, seed=None, name=None):
     dimension.
   """
   seed1, seed2 = random_seed.get_seed(seed)
-  return gen_random_ops._random_shuffle(value,
-                                        seed=seed1,
-                                        seed2=seed2,
-                                        name=name)
+  return gen_random_ops._random_shuffle(
+      value, seed=seed1, seed2=seed2, name=name)
 
 
 def random_crop(value, size, seed=None, name=None):
@@ -349,10 +338,8 @@ def multinomial(logits, num_samples, seed=None, name=None):
   with ops.name_scope(name, "multinomial", [logits]):
     logits = ops.convert_to_tensor(logits, name="logits")
     seed1, seed2 = random_seed.get_seed(seed)
-    return gen_random_ops.multinomial(logits,
-                                      num_samples,
-                                      seed=seed1,
-                                      seed2=seed2)
+    return gen_random_ops.multinomial(
+        logits, num_samples, seed=seed1, seed2=seed2)
 
 
 ops.NotDifferentiable("Multinomial")
@@ -426,15 +413,52 @@ def random_gamma(shape,
   with ops.name_scope(name, "random_gamma", [shape, alpha, beta]):
     shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32)
     alpha = ops.convert_to_tensor(alpha, name="alpha", dtype=dtype)
-    beta = ops.convert_to_tensor(beta if beta is not None else 1,
-                                 name="beta",
-                                 dtype=dtype)
+    beta = ops.convert_to_tensor(
+        beta if beta is not None else 1, name="beta", dtype=dtype)
     alpha_broadcast = alpha + array_ops.zeros_like(beta)
     seed1, seed2 = random_seed.get_seed(seed)
-    return gen_random_ops._random_gamma(shape,
-                                        alpha_broadcast,
-                                        seed=seed1,
-                                        seed2=seed2) / beta
+    return gen_random_ops._random_gamma(
+        shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta
 
 
 ops.NotDifferentiable("RandomGamma")
+
+
+def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
+  """Draws `shape` samples from each of the given Poisson distribution(s).
+
+  `lam` is the rate parameter describing the distribution(s).
+
+  Example:
+
+    samples = tf.random_poisson([0.5, 1.5], [10])
+    # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
+    # the samples drawn from each distribution
+
+    samples = tf.random_poisson([12.2, 3.3], [7, 5])
+    # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
+    # represents the 7x5 samples drawn from each of the two distributions
+
+  Args:
+    lam: A Tensor or Python value or N-D array of type `dtype`.
+      `lam` provides the rate parameter(s) describing the poisson
+      distribution(s) to sample.
+    shape: A 1-D integer Tensor or Python array. The shape of the output samples
+      to be drawn per "rate"-parameterized distribution.
+    dtype: The type of `lam` and the output: `float16`, `float32`, or
+      `float64`.
+    seed: A Python integer. Used to create a random seed for the distributions.
+      See
+      [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
+      for behavior.
+    name: Optional name for the operation.
+
+  Returns:
+    samples: a `Tensor` of shape `tf.concat(shape, tf.shape(lam))` with
+      values of type `dtype`.
+  """
+  with ops.name_scope(name, "random_poisson", [lam, shape]):
+    lam = ops.convert_to_tensor(lam, name="lam", dtype=dtype)
+    shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32)
+    seed1, seed2 = random_seed.get_seed(seed)
+    return gen_random_ops._random_poisson(shape, lam, seed=seed1, seed2=seed2)