diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index ffdd415e637..da59e1fa3ad 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1497,6 +1497,7 @@ cc_library(
         "//tensorflow/core/kernels:ragged_ops",
         "//tensorflow/core/kernels:random_ops",
         "//tensorflow/core/kernels:stateful_random_ops",
+        "//tensorflow/core/kernels:random_binomial_op",
         "//tensorflow/core/kernels:random_poisson_op",
         "//tensorflow/core/kernels:remote_fused_graph_ops",
         "//tensorflow/core/kernels:required",
diff --git a/tensorflow/core/api_def/base_api/api_def_StatefulRandomBinomial.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatefulRandomBinomial.pbtxt
new file mode 100644
index 00000000000..752c2ba48bc
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StatefulRandomBinomial.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "StatefulRandomBinomial"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatefulRandomBinomial.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatefulRandomBinomial.pbtxt
new file mode 100644
index 00000000000..cb371d5674f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatefulRandomBinomial.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "StatefulRandomBinomial"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1785ba0973c..1188251f085 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -5222,6 +5222,38 @@ tf_cuda_cc_test(
     ],
 )
 
+tf_kernel_library(
+    name = "random_binomial_op",
+    prefix = "random_binomial_op",
+    deps = [
+        ":cwise_op",
+        ":random_ops",
+        ":resource_variable_ops",
+        ":stateful_random_ops",
+        ":training_op_helpers",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:random_ops_op_lib",
+    ],
+)
+
+tf_cuda_cc_test(
+    name = "random_binomial_op_test",
+    size = "small",
+    srcs = ["random_binomial_op_test.cc"],
+    deps = [
+        ":ops_util",
+        ":random_binomial_op",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
 tf_kernel_library(
     name = "random_poisson_op",
     prefix = "random_poisson_op",
diff --git a/tensorflow/core/kernels/random_binomial_op.cc b/tensorflow/core/kernels/random_binomial_op.cc
new file mode 100644
index 00000000000..6ed36605530
--- /dev/null
+++ b/tensorflow/core/kernels/random_binomial_op.cc
@@ -0,0 +1,447 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/random_ops.cc.
+// NOTE: If the algorithm is changed, please run the test
+// .../python/kernel_tests/random:random_binomial_test
+// commenting out the "tf.set_random_seed(seed)" lines, and using the
+// "--runs-per-test=1000" flag. This tests the statistical correctness of the
+// op results.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/random_binomial_op.h"
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
+#include "tensorflow/core/kernels/training_op_helpers.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+#define UNIFORM(X)                                    \
+  if (uniform_remaining == 0) {                       \
+    uniform_remaining = Uniform::kResultElementCount; \
+    uniform_result = uniform(gen);                    \
+  }                                                   \
+  uniform_remaining--;                                \
+  double X = uniform_result[uniform_remaining]
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace {
+
+typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
+
+// Binomial inversion. Given prob, sum geometric random variables until they
+// exceed count. The number of random variables used is binomially distributed.
+// This is also known as binomial inversion, as this is equivalent to inverting
+// the Binomial CDF.
+double binomial_inversion(double count, double prob,
+                          random::PhiloxRandom* gen) {
+  using Eigen::numext::ceil;
+  using Eigen::numext::log;
+  using Eigen::numext::log1p;
+
+  double geom_sum = 0;
+  int num_geom = 0;
+
+  Uniform uniform;
+  typename Uniform::ResultType uniform_result;
+  int16 uniform_remaining = 0;
+
+  while (true) {
+    UNIFORM(u);
+    double geom = ceil(log(u) / log1p(-prob));
+    geom_sum += geom;
+    if (geom_sum > count) {
+      break;
+    }
+    ++num_geom;
+  }
+  return num_geom;
+}
+
+double stirling_approx_tail(double k) {
+  static double kTailValues[] = {0.0810614667953272,  0.0413406959554092,
+                                 0.0276779256849983,  0.02079067210376509,
+                                 0.0166446911898211,  0.0138761288230707,
+                                 0.0118967099458917,  0.0104112652619720,
+                                 0.00925546218271273, 0.00833056343336287};
+  if (k <= 9) {
+    return kTailValues[static_cast<int>(k)];
+  }
+  double kp1sq = (k + 1) * (k + 1);
+  return (1 / 12 - (1 / 360 + 1 / 1260 / kp1sq) / kp1sq) / (k + 1);
+}
+
+// We use a transformation-rejection algorithm from
+// pairs of uniform random variables due to Hormann.
+// https://www.tandfonline.com/doi/abs/10.1080/00949659308811496
+double btrs(double count, double prob, random::PhiloxRandom* gen) {
+  using Eigen::numext::abs;
+  using Eigen::numext::floor;
+  using Eigen::numext::log;
+  using Eigen::numext::log1p;
+  using Eigen::numext::sqrt;
+
+  // This is spq in the paper.
+  const double stddev = sqrt(count * prob * (1 - prob));
+
+  // Other coefficients for Transformed Rejection sampling.
+  const double b = 1.15 + 2.53 * stddev;
+  const double a = -0.0873 + 0.0248 * b + 0.01 * prob;
+  const double c = count * prob + 0.5;
+  const double v_r = 0.92 - 4.2 / b;
+  const double r = prob / (1 - prob);
+
+  Uniform uniform;
+  typename Uniform::ResultType uniform_result;
+  int16 uniform_remaining = 0;
+
+  while (true) {
+    UNIFORM(u);
+    UNIFORM(v);
+    u = u - 0.5;
+    double us = 0.5 - abs(u);
+    double k = floor((2 * a / us + b) * u + c);
+
+    // Region for which the box is tight, and we
+    // can return our calculated value This should happen
+    // 0.86 * v_r times. In the limit as n * p is large,
+    // the acceptance rate converges to ~79% (and in the lower
+    // regime it is ~24%).
+    if (us >= 0.07 && v <= v_r) {
+      return k;
+    }
+    // Reject non-sensical answers.
+    if (k < 0 || k > count) {
+      continue;
+    }
+
+    double alpha = (2.83 + 5.1 / b) * stddev;
+    double m = floor((count + 1) * prob);
+    // This deviates from Hormann's BRTS algorithm, as there is a log missing.
+    // For all (u, v) pairs outside of the bounding box, this calculates the
+    // transformed-reject ratio.
+    v = log(v * alpha / (a / (us * us) + b));
+    double upperbound =
+        ((m + 0.5) * log((m + 1) / (r * (count - m + 1))) +
+         (count + 1) * log((count - m + 1) / (count - k + 1)) +
+         (k + 0.5) * log(r * (count - k + 1) / (k + 1)) +
+         stirling_approx_tail(m) + stirling_approx_tail(count - m) -
+         stirling_approx_tail(k) - stirling_approx_tail(count - k));
+    if (v <= upperbound) {
+      return k;
+    }
+  }
+}
+
+}  // namespace
+
+namespace functor {
+
+template <typename T, typename U>
+struct RandomBinomialFunctor<CPUDevice, T, U> {
+  void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
+                  int64 samples_per_batch, int64 num_elements,
+                  typename TTypes<T>::ConstFlat counts,
+                  typename TTypes<T>::ConstFlat probs,
+                  const random::PhiloxRandom& gen,
+                  typename TTypes<U>::Flat output) {
+    auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
+
+    auto DoWork = [samples_per_batch, num_elements, &counts, &probs, &gen,
+                   &output](int start_batch, int limit_batch) {
+      // Capturing "gen" by-value would only make a copy for the _shared_
+      // lambda.  Since we want to let each worker have its own copy, we pass
+      // "gen" by reference and explicitly do a copy assignment here.
+      random::PhiloxRandom gen_copy = gen;
+      // Skip takes units of 128 bytes.  +3 is so rounding doesn't lead to
+      // us using the same state in different batches.
+      // The sample from each iteration uses 2 random numbers.
+      gen_copy.Skip(start_batch * 2 * 3 * (samples_per_batch + 3) / 4);
+
+      // Vectorized intermediate calculations for uniform rejection sampling.
+      // We always generate at most 4 samples.
+      Eigen::array<T, 4> z;
+      Eigen::array<T, 4> g;
+
+      for (int64 b = start_batch; b < limit_batch; ++b) {
+        // We are passed a flat array for each of the parameter tensors.
+        // The input is either a scalar broadcasted to all batches or a vector
+        // with length num_batches, but the scalar becomes an array of length 1.
+        T count = counts((counts.dimension(0) == 1) ? 0 : b);
+        T prob = probs((probs.dimension(0) == 1) ? 0 : b);
+
+        // The last batch can be short, if we adjusted num_batches and
+        // samples_per_batch.
+        const int64 limit_sample =
+            std::min((b + 1) * samples_per_batch, num_elements);
+        int64 sample = b * samples_per_batch;
+
+        // Calculate normalized samples, then convert them.
+        // Determine the method to use.
+        double dcount = static_cast<double>(count);
+        if (prob <= T(0.5)) {
+          double dp = static_cast<double>(prob);
+          if (count * prob >= T(10)) {
+            while (sample < limit_sample) {
+              output(sample) = static_cast<U>(btrs(dcount, dp, &gen_copy));
+              sample++;
+            }
+          } else {
+            while (sample < limit_sample) {
+              output(sample) =
+                  static_cast<U>(binomial_inversion(dcount, dp, &gen_copy));
+              sample++;
+            }
+          }
+        } else {
+          T q = T(1) - prob;
+          double dcount = static_cast<double>(count);
+          double dq = static_cast<double>(q);
+          if (count * q >= T(10)) {
+            while (sample < limit_sample) {
+              output(sample) =
+                  static_cast<U>(dcount - btrs(dcount, dq, &gen_copy));
+              sample++;
+            }
+          } else {
+            while (sample < limit_sample) {
+              output(sample) = static_cast<U>(
+                  dcount - binomial_inversion(dcount, dq, &gen_copy));
+              sample++;
+            }
+          }
+        }
+      }
+    };
+
+    const int64 batch_init_cost =
+        // normMin, normMax
+        (Eigen::TensorOpCost::AddCost<T>() +
+         Eigen::TensorOpCost::MulCost<T>()) *
+            2
+        // sqrtFactor
+        + Eigen::TensorOpCost::AddCost<T>() +
+        Eigen::TensorOpCost::MulCost<T>() +
+        Eigen::internal::functor_traits<
+            Eigen::internal::scalar_sqrt_op<T>>::Cost
+        // cutoff
+        + Eigen::TensorOpCost::MulCost<T>() * 4 +
+        Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost
+        // diff
+        + Eigen::TensorOpCost::AddCost<T>();
+    // This will depend on count * p (or count * q).
+    // For n * p < 10, on average, O(n * p) calls to uniform are
+    // needed, with that
+    // many multiplies. ~10 uniform calls on average with ~200 cost op calls.
+    //
+    // Very roughly, for rate >= 10, the four calls to log
+    // occur for ~72 percent of samples.
+    // 4 x 100 (64-bit cycles per log) * 0.72 = ~288
+    // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each:
+    // 40 * .72  = ~25.
+    //
+    // Finally, there are several other ops that are done every loop along with
+    // 2 uniform generations along with 5 other ops at 3-6 cycles each.
+    // ~15 / .89 = ~16
+    //
+    // In total this should be ~529 + 2 * Uniform::kElementCost.
+    // We assume that half the tensor has rate < 10, so on average 6
+    // uniform's
+    // will be needed. We will upper bound the other op cost by the one for
+    // rate > 10.
+    static const int kElementCost = 529 + 6 * Uniform::kElementCost +
+                                    6 * random::PhiloxRandom::kElementCost;
+    // Assume we use uniform sampling, and accept the 2nd sample on average.
+    const int64 batch_cost = batch_init_cost + kElementCost * samples_per_batch;
+    Shard(worker_threads.num_threads, worker_threads.workers, num_batches,
+          batch_cost, DoWork);
+  }
+};
+
+}  // namespace functor
+
+namespace {
+
+// Samples from a binomial distribution, using the given parameters.
+template <typename Device, typename T, typename U>
+class RandomBinomialOp : public OpKernel {
+  // Reshape batches so each batch is this size if possible.
+  static const int32 kDesiredBatchSize = 100;
+
+ public:
+  explicit RandomBinomialOp(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor& alg_tensor = ctx->input(1);
+    const Tensor& shape_tensor = ctx->input(2);
+    const Tensor& counts_tensor = ctx->input(3);
+    const Tensor& probs_tensor = ctx->input(4);
+
+    OP_REQUIRES(ctx, alg_tensor.dims() == 0,
+                errors::InvalidArgument("algorithm must be of shape [], not ",
+                                        alg_tensor.shape().DebugString()));
+    Algorithm alg = alg_tensor.flat<Algorithm>()(0);
+
+    OP_REQUIRES(
+        ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
+        errors::InvalidArgument("Input shape should be a vector, got shape: ",
+                                shape_tensor.shape().DebugString()));
+    int32 num_batches = shape_tensor.flat<int32>()(0);
+
+    int32 samples_per_batch = 1;
+    const int32 num_dims = shape_tensor.dim_size(0);
+    for (int32 i = 1; i < num_dims; i++) {
+      samples_per_batch *= shape_tensor.flat<int32>()(i);
+    }
+    const int32 num_elements = num_batches * samples_per_batch;
+
+    // Allocate the output before fudging num_batches and samples_per_batch.
+    auto shape_vec = shape_tensor.flat<int32>();
+    TensorShape tensor_shape;
+    OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
+                            shape_vec.data(), shape_vec.size(), &tensor_shape));
+    Tensor* samples_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor));
+
+    // Parameters must be 0-d or 1-d.
+    OP_REQUIRES(ctx, counts_tensor.dims() <= 1,
+                errors::InvalidArgument(
+                    "Input counts should be a scalar or vector, got shape: ",
+                    counts_tensor.shape().DebugString()));
+    OP_REQUIRES(ctx, probs_tensor.dims() <= 1,
+                errors::InvalidArgument(
+                    "Input probs should be a scalar or vector, got shape: ",
+                    probs_tensor.shape().DebugString()));
+
+    if ((counts_tensor.dims() == 0 || counts_tensor.dim_size(0) == 1) &&
+        (probs_tensor.dims() == 0 || probs_tensor.dim_size(0) == 1)) {
+      // All batches have the same parameters, so we can update the batch size
+      // to a reasonable value to improve parallelism (ensure enough batches,
+      // and no very small batches which have high overhead).
+      int32 size = num_batches * samples_per_batch;
+      int32 adjusted_samples = kDesiredBatchSize;
+      // Ensure adjusted_batches * adjusted_samples >= size.
+      int32 adjusted_batches = Eigen::divup(size, adjusted_samples);
+      num_batches = adjusted_batches;
+      samples_per_batch = adjusted_samples;
+    } else {
+      // Parameters must be broadcastable to the shape [num_batches].
+      OP_REQUIRES(
+          ctx,
+          TensorShapeUtils::IsScalar(counts_tensor.shape()) ||
+              counts_tensor.dim_size(0) == 1 ||
+              counts_tensor.dim_size(0) == num_batches,
+          errors::InvalidArgument(
+              "Input counts should have length 1 or shape[0], got shape: ",
+              counts_tensor.shape().DebugString()));
+      OP_REQUIRES(
+          ctx,
+          TensorShapeUtils::IsScalar(probs_tensor.shape()) ||
+              probs_tensor.dim_size(0) == 1 ||
+              probs_tensor.dim_size(0) == num_batches,
+          errors::InvalidArgument(
+              "Input probs should have length 1 or shape[0], got shape: ",
+              probs_tensor.shape().DebugString()));
+    }
+    Var* var = nullptr;
+    OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
+
+    ScopedUnlockUnrefVar var_guard(var);
+    Tensor* var_tensor = var->tensor();
+    OP_REQUIRES(
+        ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE,
+        errors::InvalidArgument("dtype of RNG state variable must be ",
+                                DataTypeString(STATE_ELEMENT_DTYPE), ", not ",
+                                DataTypeString(var_tensor->dtype())));
+    OP_REQUIRES(ctx, var_tensor->dims() == 1,
+                errors::InvalidArgument(
+                    "RNG state must have one and only one dimension, not ",
+                    var_tensor->dims()));
+    auto var_tensor_flat = var_tensor->flat<StateElementType>();
+    OP_REQUIRES(ctx, alg == RNG_ALG_PHILOX,
+                errors::InvalidArgument("Unsupported algorithm id: ", alg));
+    static_assert(std::is_same<StateElementType, int64>::value,
+                  "StateElementType must be int64");
+    static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value,
+                  "PhiloxRandom::ResultElementType must be uint32");
+    OP_REQUIRES(ctx, var_tensor_flat.size() >= PHILOX_MIN_STATE_SIZE,
+                errors::InvalidArgument(
+                    "For Philox algorithm, the size of state must be at least ",
+                    PHILOX_MIN_STATE_SIZE, "; got ", var_tensor_flat.size()));
+
+    // Each worker has the fudge factor for samples_per_batch, so use it here.
+    OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>(
+                            ctx, var_tensor, var->copy_on_read_mode.load()));
+    auto var_data = var_tensor_flat.data();
+    auto philox = GetPhiloxRandomFromMem(var_data);
+    UpdateMemWithPhiloxRandom(
+        philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data);
+    var_guard.Release();
+
+    auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
+    binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
+                     samples_per_batch, num_elements, counts_tensor.flat<T>(),
+                     probs_tensor.flat<T>(), philox, samples_tensor->flat<U>());
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(RandomBinomialOp);
+};
+
+}  // namespace
+
+#define REGISTER(RTYPE, TYPE)                                 \
+  REGISTER_KERNEL_BUILDER(Name("StatefulRandomBinomial")      \
+                              .Device(DEVICE_CPU)             \
+                              .HostMemory("resource")         \
+                              .HostMemory("algorithm")        \
+                              .HostMemory("shape")            \
+                              .HostMemory("counts")           \
+                              .HostMemory("probs")            \
+                              .TypeConstraint<RTYPE>("dtype") \
+                              .TypeConstraint<TYPE>("T"),     \
+                          RandomBinomialOp<CPUDevice, TYPE, RTYPE>)
+
+#define REGISTER_ALL(RTYPE)     \
+  REGISTER(RTYPE, Eigen::half); \
+  REGISTER(RTYPE, float);       \
+  REGISTER(RTYPE, double);
+
+REGISTER_ALL(Eigen::half);
+REGISTER_ALL(float);
+REGISTER_ALL(double);
+REGISTER_ALL(int32);
+REGISTER_ALL(int64);
+
+#undef REGISTER
+#undef REGISTER_ALL
+
+}  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/random_binomial_op.h b/tensorflow/core/kernels/random_binomial_op.h
new file mode 100644
index 00000000000..05c489da83a
--- /dev/null
+++ b/tensorflow/core/kernels/random_binomial_op.h
@@ -0,0 +1,61 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+
+class OpKernelContext;
+
+namespace functor {
+
+// Sample a binomial random variable, with probs and counts for each batch.
+// Uses binomial inversion and a transformed rejection sampling method as
+// described in
+// https://pdfs.semanticscholar.org/471b/c2726e25bbf8801ef781630a2c13f654268e.pdf.
+// Two different algorithms are employed, depending on the size of
+// counts * probs (or counts * (1 - probs) if probs > 0.5.
+// If counts * probs < 10, we simply sum up Geometric random variables until
+// they exceed count, and the number we used is binomially distributed.
+// In expectation, this will take O(counts * probs) time, and requiring in
+// expectation the same number of random variates.
+// This can be much cheaper than summing bernoulli random variates, as we
+// will always need O(counts) bernoulli random variates (so this requires fewer
+// uniform r.v.s as well as can be faster).
+//
+// If counts * probs > 10, we use a transformed-rejection algorithm based on
+// pairs of uniform random variates due to Hormann.
+// https://pdfs.semanticscholar.org/471b/c2726e25bbf8801ef781630a2c13f654268e.pdf
+// This algorithm has higher acceptance rates for counts * probs large, as the
+// proposal distribution becomes quite tight, requiring approximately two
+// uniform random variates as counts * probs becomes large.
+template <typename Device, typename T, typename U>
+struct RandomBinomialFunctor {
+  void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
+                  int64 samples_per_batch, int64 num_elements,
+                  typename TTypes<T>::ConstFlat counts,
+                  typename TTypes<T>::ConstFlat probs,
+                  const random::PhiloxRandom& gen,
+                  typename TTypes<U>::Flat output);
+};
+
+}  // namespace functor
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_
diff --git a/tensorflow/core/kernels/random_binomial_op_test.cc b/tensorflow/core/kernels/random_binomial_op_test.cc
new file mode 100644
index 00000000000..9f8f47ef853
--- /dev/null
+++ b/tensorflow/core/kernels/random_binomial_op_test.cc
@@ -0,0 +1,107 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+static Graph* RandomBinomialGraph(double count, double prob, int num_batches,
+                                  int samples_per_batch) {
+  Graph* g = new Graph(OpRegistry::Global());
+  Tensor shape_t(DT_INT32, TensorShape({2}));
+  shape_t.flat<int32>().setValues({num_batches, samples_per_batch});
+
+  Tensor counts_t(DT_FLOAT, TensorShape({num_batches}));
+  counts_t.flat<float>().setConstant(count);
+  Tensor probs_t(DT_FLOAT, TensorShape({num_batches}));
+  probs_t.flat<float>().setConstant(prob);
+
+  Node* ret;
+  TF_CHECK_OK(NodeBuilder(g->NewName("randombinomial"), "RandomBinomial")
+                  .Input(test::graph::Constant(g, shape_t))
+                  .Input(test::graph::Constant(g, counts_t))
+                  .Input(test::graph::Constant(g, probs_t))
+                  .Attr("dtype", DT_FLOAT)
+                  .Finalize(g, &ret));
+  return g;
+}
+
+static Graph* RandomBinomialInv(int num_batches, int samples_per_batch) {
+  // Because counts * probs < 10, we are guaranteed to use inversion.
+  return RandomBinomialGraph(10., 0.3, num_batches, samples_per_batch);
+}
+
+static Graph* RandomBinomialRej(int num_batches, int samples_per_batch) {
+  // Because counts * probs > 10, we are guaranteed to use rejection.
+  return RandomBinomialGraph(100., 0.3, num_batches, samples_per_batch);
+}
+
+static Graph* RandomBinomialInvComplement(int num_batches,
+                                          int samples_per_batch) {
+  // Because counts * (1 - probs) < 10, we are guaranteed to use inversion.
+  return RandomBinomialGraph(10., 0.8, num_batches, samples_per_batch);
+}
+
+static Graph* RandomBinomialRejComplement(int num_batches,
+                                          int samples_per_batch) {
+  // Because counts * (1 - probs) > 10, we are guaranteed to use inversion.
+  return RandomBinomialGraph(100., 0.2, num_batches, samples_per_batch);
+}
+
+#define BM_RandomBinomialInv(DEVICE, B, S)                           \
+  static void BM_RandomBinomialInv_##DEVICE##_##B##_##S(int iters) { \
+    test::Benchmark(#DEVICE, RandomBinomialInv(B, S)).Run(iters);    \
+    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);      \
+  }                                                                  \
+  BENCHMARK(BM_RandomBinomialInv_##DEVICE##_##B##_##S);
+
+#define BM_RandomBinomialRej(DEVICE, B, S)                           \
+  static void BM_RandomBinomialRej_##DEVICE##_##B##_##S(int iters) { \
+    test::Benchmark(#DEVICE, RandomBinomialRej(B, S)).Run(iters);    \
+    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);      \
+  }                                                                  \
+  BENCHMARK(BM_RandomBinomialRej_##DEVICE##_##B##_##S);
+
+#define BM_RandomBinomialInvComplement(DEVICE, B, S)                           \
+  static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S(int iters) { \
+    test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S)).Run(iters);    \
+    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);                \
+  }                                                                            \
+  BENCHMARK(BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S);
+
+#define BM_RandomBinomialRejComplement(DEVICE, B, S)                           \
+  static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S(int iters) { \
+    test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S)).Run(iters);    \
+    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);                \
+  }                                                                            \
+  BENCHMARK(BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S);
+
+BM_RandomBinomialInv(cpu, 1000, 1000);
+BM_RandomBinomialRej(cpu, 1000, 1000);
+BM_RandomBinomialInvComplement(cpu, 1000, 1000);
+BM_RandomBinomialRejComplement(cpu, 1000, 1000);
+BM_RandomBinomialInv(gpu, 1000, 1000);
+BM_RandomBinomialRej(gpu, 1000, 1000);
+BM_RandomBinomialInvComplement(gpu, 1000, 1000);
+BM_RandomBinomialRejComplement(gpu, 1000, 1000);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/stateful_random_ops.cc b/tensorflow/core/ops/stateful_random_ops.cc
index c351391580c..80e766cd617 100644
--- a/tensorflow/core/ops/stateful_random_ops.cc
+++ b/tensorflow/core/ops/stateful_random_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference.h"
 
@@ -82,6 +83,29 @@ REGISTER_OP("NonDeterministicInts")
       return Status::OK();
     });
 
+REGISTER_OP("StatefulRandomBinomial")
+    .Input("resource: resource")
+    .Input("algorithm: int64")
+    .Input("shape: S")
+    .Input("counts: T")
+    .Input("probs: T")
+    .Output("output: dtype")
+    .Attr("S: {int32, int64}")
+    .Attr("T: {half, float, double, int32, int64} = DT_DOUBLE")
+    .Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      using shape_inference::ShapeHandle;
+
+      ShapeHandle unused;
+      TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused));
+      TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused));
+
+      ShapeHandle out;
+      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));
+      c->set_output(0, out);
+      return Status::OK();
+    });
+
 // Register the depracated 'StatefulStandardNormal' op. This op is a short-lived
 // version where the 'resource' variable also contains the algorithm tag.
 // It is deprecated in favor of 'StatefulStandardNormalV2'.
diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD
index 8452982a447..f6afae97791 100644
--- a/tensorflow/python/kernel_tests/random/BUILD
+++ b/tensorflow/python/kernel_tests/random/BUILD
@@ -155,6 +155,23 @@ cuda_py_test(
     xla_enable_strict_auto_jit = True,
 )
 
+cuda_py_test(
+    name = "random_binomial_test",
+    size = "medium",
+    srcs = ["random_binomial_test.py"],
+    additional_deps = [
+        ":util",
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:stateful_random_ops",
+    ],
+    xla_enable_strict_auto_jit = True,
+)
+
 cuda_py_test(
     name = "random_poisson_test",
     size = "medium",
@@ -169,5 +186,4 @@ cuda_py_test(
         "//tensorflow/python:platform",
         "//tensorflow/python:random_ops",
     ],
-    xla_enable_strict_auto_jit = True,
 )
diff --git a/tensorflow/python/kernel_tests/random/random_binomial_test.py b/tensorflow/python/kernel_tests/random/random_binomial_test.py
new file mode 100644
index 00000000000..7214d7ef3c9
--- /dev/null
+++ b/tensorflow/python/kernel_tests/random/random_binomial_test.py
@@ -0,0 +1,120 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.stateful_random_ops.binomial."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.kernel_tests.random import util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import stateful_random_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+# All supported dtypes for binomial().
+_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64,
+                     dtypes.int32, dtypes.int64)
+
+
+class RandomBinomialTest(test.TestCase):
+  """This is a large test due to the moments computation taking some time."""
+
+  def _Sampler(self, num, counts, probs, dtype, seed=None):
+
+    def func():
+      rng = stateful_random_ops.Generator(seed=seed).binomial(
+          shape=[10 * num], counts=counts, probs=probs, dtype=dtype)
+      ret = array_ops.reshape(rng, [10, num])
+      ret = self.evaluate(ret)
+      return ret
+
+    return func
+
+  @test_util.run_v2_only
+  def testMoments(self):
+    try:
+      from scipy import stats  # pylint: disable=g-import-not-at-top
+    except ImportError as e:
+      tf_logging.warn("Cannot test moments: %s", e)
+      return
+    # The moments test is a z-value test.  This is the largest z-value
+    # we want to tolerate. Since the z-test approximates a unit normal
+    # distribution, it should almost definitely never exceed 6.
+    z_limit = 6.0
+    for dt in _SUPPORTED_DTYPES:
+      # Test when n * p > 10, and n * p < 10
+      for stride in 0, 4, 10:
+        for counts in (1., 10., 22., 50.):
+          for prob in (0.1, 0.5, 0.8):
+            sampler = self._Sampler(int(1e5), counts, prob, dt, seed=12345)
+            z_scores = util.test_moment_matching(
+                # Use float64 samples.
+                sampler().astype(np.float64),
+                number_moments=6,
+                dist=stats.binom(counts, prob),
+                stride=stride,
+            )
+            self.assertAllLess(z_scores, z_limit)
+
+  @test_util.run_v2_only
+  def testSeed(self):
+    for dt in dtypes.float16, dtypes.float32, dtypes.float64:
+      sx = self._Sampler(1000, counts=10., probs=0.4, dtype=dt, seed=345)
+      sy = self._Sampler(1000, counts=10., probs=0.4, dtype=dt, seed=345)
+      self.assertAllEqual(sx(), sy())
+
+  def testZeroShape(self):
+    rnd = stateful_random_ops.Generator(seed=12345).binomial([0], [], [])
+    self.assertEqual([0], rnd.shape.as_list())
+
+  def testShape(self):
+    rng = stateful_random_ops.Generator(seed=12345)
+    # Scalar parameters.
+    rnd = rng.binomial(shape=[10], counts=np.float32(2.), probs=np.float32(0.5))
+    self.assertEqual([10], rnd.shape.as_list())
+
+    # Vector parameters.
+    rnd = rng.binomial(
+        shape=[10],
+        counts=array_ops.ones([10], dtype=np.float32),
+        probs=0.3 * array_ops.ones([10], dtype=np.float32))
+    self.assertEqual([10], rnd.shape.as_list())
+    rnd = rng.binomial(
+        shape=[2, 5],
+        counts=array_ops.ones([2], dtype=np.float32),
+        probs=0.4 * array_ops.ones([2], dtype=np.float32))
+    self.assertEqual([2, 5], rnd.shape.as_list())
+
+    # Scalar counts, vector probs.
+    rnd = rng.binomial(
+        shape=[10],
+        counts=np.float32(5.),
+        probs=0.8 * array_ops.ones([10], dtype=np.float32))
+    self.assertEqual([10], rnd.shape.as_list())
+
+    # Vector counts, scalar probs.
+    rnd = rng.binomial(
+        shape=[10],
+        counts=array_ops.ones([10], dtype=np.float32),
+        probs=np.float32(0.9))
+    self.assertEqual([10], rnd.shape.as_list())
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/ops/stateful_random_ops.py b/tensorflow/python/ops/stateful_random_ops.py
index ca92fe006f2..9f8884224e9 100644
--- a/tensorflow/python/ops/stateful_random_ops.py
+++ b/tensorflow/python/ops/stateful_random_ops.py
@@ -368,6 +368,51 @@ class Generator(tracking.AutoTrackable):
           self.state.handle, self.algorithm, shape=shape,
           dtype=dtype, name=name)
 
+  def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None):
+    """Outputs random values from a binomial distribution.
+
+    The generated values follow a binomial distribution with specified count and
+    probability of success parameters.
+
+    Example:
+
+    ```python
+    counts = [10., 20.]
+    # Probability of success.
+    probs = [0.8, 0.9]
+
+    rng = tf.random.experimental.Generator(seed=234)
+    binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs)
+    ```
+
+
+    Args:
+      shape: A 1-D integer Tensor or Python array. The shape of the output
+        tensor.
+      counts: A 0/1-D Tensor or Python value`. The counts of the binomial
+        distribution.
+      probs: A 0/1-D Tensor or Python value`. The probability of success for the
+        binomial distribution.
+      dtype: The type of the output. Default: tf.int32
+      name: A name for the operation (optional).
+
+    Returns:
+      A tensor of the specified shape filled with random binomial values.
+    """
+    dtype = dtypes.as_dtype(dtype)
+    with ops.name_scope(name, "binomial", [shape, counts, probs]) as name:
+      counts = ops.convert_to_tensor(counts, name="counts")
+      probs = ops.convert_to_tensor(probs, name="probs")
+      shape_tensor = _shape_tensor(shape)
+      return gen_stateful_random_ops.stateful_random_binomial(
+          self.state.handle,
+          self.algorithm,
+          shape=shape_tensor,
+          counts=counts,
+          probs=probs,
+          dtype=dtype,
+          name=name)
+
   # TODO(wangpeng): implement other distributions
 
   def _make_int64_keys(self, shape=()):
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.-generator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.-generator.pbtxt
index 012e4a87079..98b3e8220ab 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.-generator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.-generator.pbtxt
@@ -16,6 +16,10 @@ tf_class {
     name: "__init__"
     argspec: "args=[\'self\', \'copy_from\', \'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
   }
+  member_method {
+    name: "binomial"
+    argspec: "args=[\'self\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
+  }
   member_method {
     name: "make_seeds"
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index b3da1e0af23..3902aa5fe25 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -3708,6 +3708,10 @@ tf_module {
     name: "StatefulPartitionedCall"
     argspec: "args=[\'args\', \'Tout\', \'f\', \'config\', \'config_proto\', \'executor_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
   }
+  member_method {
+    name: "StatefulRandomBinomial"
+    argspec: "args=[\'resource\', \'algorithm\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
+  }
   member_method {
     name: "StatefulStandardNormal"
     argspec: "args=[\'resource\', \'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.-generator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.-generator.pbtxt
index 012e4a87079..98b3e8220ab 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.-generator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.-generator.pbtxt
@@ -16,6 +16,10 @@ tf_class {
     name: "__init__"
     argspec: "args=[\'self\', \'copy_from\', \'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
   }
+  member_method {
+    name: "binomial"
+    argspec: "args=[\'self\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
+  }
   member_method {
     name: "make_seeds"
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index b3da1e0af23..3902aa5fe25 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -3708,6 +3708,10 @@ tf_module {
     name: "StatefulPartitionedCall"
     argspec: "args=[\'args\', \'Tout\', \'f\', \'config\', \'config_proto\', \'executor_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
   }
+  member_method {
+    name: "StatefulRandomBinomial"
+    argspec: "args=[\'resource\', \'algorithm\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
+  }
   member_method {
     name: "StatefulStandardNormal"
     argspec: "args=[\'resource\', \'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "