From d8dc9415b0c63a8024ac69142f9e40c0b4c58737 Mon Sep 17 00:00:00 2001
From: Brian Patton <bjp@google.com>
Date: Fri, 7 Feb 2020 09:58:08 -0800
Subject: [PATCH] Adds a `tf.random.stateless_poisson` sampler for CPU.

PiperOrigin-RevId: 293834809
Change-Id: I719e218b43f8aecbd74d1472f1291748a61979b8
---
 .../api_def_StatelessRandomPoisson.pbtxt      | 41 ++++++++++++
 tensorflow/core/kernels/BUILD                 |  3 +
 tensorflow/core/kernels/random_poisson_op.cc  | 18 ++----
 tensorflow/core/kernels/random_poisson_op.h   |  9 ++-
 .../core/kernels/stateless_random_ops.cc      | 58 ++++++++++++++++-
 tensorflow/core/ops/stateless_random_ops.cc   | 11 ++++
 .../eager/pywrap_gradient_exclusions.cc       |  2 +
 .../random/stateless_random_ops_test.py       | 30 +++++++--
 tensorflow/python/ops/stateless_random_ops.py | 62 +++++++++++++++++++
 .../api/golden/v1/tensorflow.random.pbtxt     |  4 ++
 .../api/golden/v1/tensorflow.raw_ops.pbtxt    |  4 ++
 .../api/golden/v2/tensorflow.random.pbtxt     |  4 ++
 .../api/golden/v2/tensorflow.raw_ops.pbtxt    |  4 ++
 13 files changed, 231 insertions(+), 19 deletions(-)
 create mode 100644 tensorflow/core/api_def/base_api/api_def_StatelessRandomPoisson.pbtxt

diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomPoisson.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomPoisson.pbtxt
new file mode 100644
index 00000000000..60228cd5a58
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomPoisson.pbtxt
@@ -0,0 +1,41 @@
+op {
+  graph_op_name: "StatelessRandomPoisson"
+  visibility: HIDDEN
+  in_arg {
+    name: "shape"
+    description: <<END
+The shape of the output tensor.
+END
+  }
+  in_arg {
+    name: "seed"
+    description: <<END
+2 seeds (shape [2]).
+END
+  }
+  in_arg {
+    name: "lam"
+    description: <<END
+The rate of the Poisson distribution. Shape must match the rightmost dimensions
+of `shape`.
+END
+  }
+  out_arg {
+    name: "output"
+    description: <<END
+Random values with specified shape.
+END
+  }
+  attr {
+    name: "dtype"
+    description: <<END
+The type of the output.
+END
+  }
+  summary: "Outputs deterministic pseudorandom random numbers from a Poisson distribution."
+  description: <<END
+Outputs random values from a Poisson distribution.
+
+The outputs are a deterministic function of `shape`, `seed`, and `lam`.
+END
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index a3d5d5d9435..be5b215eaa9 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -5133,6 +5133,7 @@ tf_kernel_library(
     deps = [
         ":bounds_check",
         ":random_op",
+        ":random_poisson_op",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
     ],
@@ -6465,6 +6466,7 @@ filegroup(
         "pad_op.h",
         "pooling_ops_3d.h",
         "random_op.h",
+        "random_poisson_op.h",
         "reduction_ops.h",
         "reduction_ops_common.h",
         "relu_op.h",
@@ -6658,6 +6660,7 @@ filegroup(
         "queue_ops.cc",
         "random_op.cc",
         "random_op_cpu.h",
+        "random_poisson_op.cc",
         "reduction_ops_all.cc",
         "reduction_ops_any.cc",
         "reduction_ops_common.cc",
diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc
index 7069f896f07..aa9a0bfe214 100644
--- a/tensorflow/core/kernels/random_poisson_op.cc
+++ b/tensorflow/core/kernels/random_poisson_op.cc
@@ -68,13 +68,6 @@ struct PoissonComputeType {
 
 namespace functor {
 
-template <typename Device, typename T, typename U>
-struct PoissonFunctor {
-  void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
-                  int num_rate, int num_samples,
-                  const random::PhiloxRandom& rng, U* samples_flat);
-};
-
 template <typename T, typename U>
 struct PoissonFunctor<CPUDevice, T, U> {
   void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat,
@@ -329,11 +322,12 @@ TF_CALL_half(REGISTER);
 TF_CALL_float(REGISTER);
 TF_CALL_double(REGISTER);
 
-#define REGISTER_V2(RTYPE, OTYPE)                              \
-  REGISTER_KERNEL_BUILDER(Name("RandomPoissonV2")              \
-                              .Device(DEVICE_CPU)              \
-                              .TypeConstraint<RTYPE>("R")      \
-                              .TypeConstraint<OTYPE>("dtype"), \
+#define REGISTER_V2(RTYPE, OTYPE)                                   \
+  template struct functor::PoissonFunctor<CPUDevice, RTYPE, OTYPE>; \
+  REGISTER_KERNEL_BUILDER(Name("RandomPoissonV2")                   \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<RTYPE>("R")           \
+                              .TypeConstraint<OTYPE>("dtype"),      \
                           RandomPoissonOp<RTYPE, OTYPE>);
 
 #define REGISTER_ALL(RTYPE)        \
diff --git a/tensorflow/core/kernels/random_poisson_op.h b/tensorflow/core/kernels/random_poisson_op.h
index 62ae01c16c4..d7a73cc8e22 100644
--- a/tensorflow/core/kernels/random_poisson_op.h
+++ b/tensorflow/core/kernels/random_poisson_op.h
@@ -16,13 +16,20 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_
 #define TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_
 
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+
 namespace tensorflow {
 
 namespace functor {
 
 // Generic helper functor for the Random Poisson Op.
 template <typename Device, typename T /* rate */, typename U /* output */>
-struct PoissonFunctor;
+struct PoissonFunctor {
+  void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
+                  int num_rate, int num_samples,
+                  const random::PhiloxRandom& rng, U* samples_flat);
+};
 
 }  // namespace functor
 
diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc
index 94c550126ff..dd451dbc2d5 100644
--- a/tensorflow/core/kernels/stateless_random_ops.cc
+++ b/tensorflow/core/kernels/stateless_random_ops.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_util.h"
 #include "tensorflow/core/kernels/random_op.h"
+#include "tensorflow/core/kernels/random_poisson_op.h"
 #include "tensorflow/core/lib/random/random_distributions.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/util/work_sharder.h"
@@ -162,6 +163,35 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
   }
 };
 
+// Samples from one or more Poisson distributions.
+template <typename T, typename U>
+class StatelessRandomPoissonOp : public StatelessRandomOpBase {
+ public:
+  using StatelessRandomOpBase::StatelessRandomOpBase;
+
+  void Fill(OpKernelContext* ctx, random::PhiloxRandom random,
+            Tensor* output) override {
+    const Tensor& rate_t = ctx->input(2);
+
+    TensorShape samples_shape = output->shape();
+    OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, rate_t.shape()),
+                errors::InvalidArgument(
+                    "Shape passed in must end with broadcasted shape."));
+
+    const int64 num_rate = rate_t.NumElements();
+    const int64 samples_per_rate = samples_shape.num_elements() / num_rate;
+    const auto rate_flat = rate_t.flat<T>().data();
+    auto samples_flat = output->flat<U>().data();
+
+    functor::PoissonFunctor<CPUDevice, T, U>()(
+        ctx, ctx->eigen_device<CPUDevice>(), rate_flat, num_rate,
+        samples_per_rate, random, samples_flat);
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomPoissonOp);
+};
+
 template <typename Device, typename T>
 class StatelessRandomGammaOp : public StatelessRandomOpBase {
  public:
@@ -354,7 +384,7 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase {
                               .HostMemory("minval")           \
                               .HostMemory("maxval")           \
                               .TypeConstraint<TYPE>("dtype"), \
-                          StatelessRandomUniformIntOp<DEVICE##Device, TYPE>);
+                          StatelessRandomUniformIntOp<DEVICE##Device, TYPE>)
 
 #define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
 #define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
@@ -368,6 +398,32 @@ TF_CALL_double(REGISTER_CPU);
 TF_CALL_int32(REGISTER_INT_CPU);
 TF_CALL_int64(REGISTER_INT_CPU);
 
+#define REGISTER_POISSON(RATE_TYPE, OUT_TYPE)                     \
+  REGISTER_KERNEL_BUILDER(Name("StatelessRandomPoisson")          \
+                              .Device(DEVICE_CPU)                 \
+                              .HostMemory("shape")                \
+                              .HostMemory("seed")                 \
+                              .HostMemory("lam")                  \
+                              .TypeConstraint<RATE_TYPE>("Rtype") \
+                              .TypeConstraint<OUT_TYPE>("dtype"), \
+                          StatelessRandomPoissonOp<RATE_TYPE, OUT_TYPE>)
+
+#define REGISTER_ALL_POISSON(RATE_TYPE)     \
+  REGISTER_POISSON(RATE_TYPE, Eigen::half); \
+  REGISTER_POISSON(RATE_TYPE, float);       \
+  REGISTER_POISSON(RATE_TYPE, double);      \
+  REGISTER_POISSON(RATE_TYPE, int32);       \
+  REGISTER_POISSON(RATE_TYPE, int64)
+
+TF_CALL_half(REGISTER_ALL_POISSON);
+TF_CALL_float(REGISTER_ALL_POISSON);
+TF_CALL_double(REGISTER_ALL_POISSON);
+TF_CALL_int32(REGISTER_ALL_POISSON);
+TF_CALL_int64(REGISTER_ALL_POISSON);
+
+#undef REGISTER_ALL_POISSON
+#undef REGISTER_POISSON
+
 #define REGISTER_GAMMA(TYPE)                                  \
   REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2")      \
                               .Device(DEVICE_CPU)             \
diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc
index db3da7020ed..27d0b71cf44 100644
--- a/tensorflow/core/ops/stateless_random_ops.cc
+++ b/tensorflow/core/ops/stateless_random_ops.cc
@@ -105,6 +105,17 @@ REGISTER_OP("StatelessRandomBinomial")
     .Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
     .SetShapeFn(StatelessShape);
 
+REGISTER_OP("StatelessRandomPoisson")
+    .Input("shape: T")
+    .Input("seed: Tseed")
+    .Input("lam: Rtype")
+    .Output("output: dtype")
+    .Attr("Rtype: {float16, float32, float64, int32, int64}")
+    .Attr("dtype: {float16, float32, float64, int32, int64}")
+    .Attr("T: {int32, int64}")
+    .Attr("Tseed: {int32, int64} = DT_INT64")
+    .SetShapeFn(StatelessShape);
+
 REGISTER_OP("StatelessRandomGammaV2")
     .Input("shape: T")
     .Input("seed: Tseed")
diff --git a/tensorflow/python/eager/pywrap_gradient_exclusions.cc b/tensorflow/python/eager/pywrap_gradient_exclusions.cc
index 51b37e3ec14..5de648bad8c 100644
--- a/tensorflow/python/eager/pywrap_gradient_exclusions.cc
+++ b/tensorflow/python/eager/pywrap_gradient_exclusions.cc
@@ -310,6 +310,7 @@ bool OpGradientDoesntRequireInputIndices(
           {"StatelessRandomBinomial", {true, {}}},
           {"StatelessRandomGammaV2", {false, {1}}},
           {"StatelessRandomNormal", {true, {}}},
+          {"StatelessRandomPoisson", {true, {}}},
           {"StatelessRandomUniform", {true, {}}},
           {"StatelessRandomUniformInt", {true, {}}},
           {"StatelessTruncatedNormal", {true, {}}},
@@ -765,6 +766,7 @@ bool OpGradientDoesntRequireOutputIndices(
           {"StatelessMultinomial", {true, {}}},
           {"StatelessRandomBinomial", {true, {}}},
           {"StatelessRandomNormal", {true, {}}},
+          {"StatelessRandomPoisson", {true, {}}},
           {"StatelessRandomUniform", {true, {}}},
           {"StatelessRandomUniformInt", {true, {}}},
           {"StatelessTruncatedNormal", {true, {}}},
diff --git a/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py b/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py
index f94d0b6c8f7..0cd32a5b046 100644
--- a/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py
@@ -133,11 +133,23 @@ class StatelessOpsTest(test.TestCase):
     for dtype in np.float16, np.float32, np.float64:
       for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]):
         kwds = dict(alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype)
-        yield (functools.partial(
-            stateless.stateless_random_gamma,
-            shape=(10,) + tuple(np.shape(alpha)),
-            **kwds),
-               functools.partial(random_ops.random_gamma, shape=(10,), **kwds))
+        yield (
+            functools.partial(stateless.stateless_random_gamma,
+                              shape=(10,) + tuple(np.shape(alpha)), **kwds),
+            functools.partial(random_ops.random_gamma, shape=(10,), **kwds))
+
+  def _poisson_cases(self):
+    for lam_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
+      for out_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
+        for lam in ([[5.5, 1., 2.]], [[7.5, 10.5], [3.8, 8.2], [1.25, 9.75]]):
+          kwds = dict(
+              lam=constant_op.constant(lam_dtype(lam), dtype=lam_dtype),
+              dtype=out_dtype)
+          yield (
+              functools.partial(stateless.stateless_random_poisson,
+                                shape=(10,) + tuple(np.shape(lam)),
+                                **kwds),
+              functools.partial(random_ops.random_poisson, shape=(10,), **kwds))
 
   @test_util.run_deprecated_v1
   def testMatchFloat(self):
@@ -155,6 +167,10 @@ class StatelessOpsTest(test.TestCase):
   def testMatchGamma(self):
     self._test_match(self._gamma_cases())
 
+  @test_util.run_deprecated_v1
+  def testMatchPoisson(self):
+    self._test_match(self._poisson_cases())
+
   @test_util.run_deprecated_v1
   def testDeterminismFloat(self):
     self._test_determinism(
@@ -173,6 +189,10 @@ class StatelessOpsTest(test.TestCase):
   def testDeterminismGamma(self):
     self._test_determinism(self._gamma_cases())
 
+  @test_util.run_deprecated_v1
+  def testDeterminismPoisson(self):
+    self._test_determinism(self._poisson_cases())
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/ops/stateless_random_ops.py b/tensorflow/python/ops/stateless_random_ops.py
index 9e051754894..94c80ea002b 100644
--- a/tensorflow/python/ops/stateless_random_ops.py
+++ b/tensorflow/python/ops/stateless_random_ops.py
@@ -32,6 +32,7 @@ from tensorflow.python.util.tf_export import tf_export
 ops.NotDifferentiable("StatelessMultinomial")
 ops.NotDifferentiable("StatelessRandomBinomial")
 ops.NotDifferentiable("StatelessRandomNormal")
+ops.NotDifferentiable("StatelessRandomPoisson")
 ops.NotDifferentiable("StatelessRandomUniform")
 ops.NotDifferentiable("StatelessRandomUniformInt")
 ops.NotDifferentiable("StatelessTruncatedNormal")
@@ -271,6 +272,67 @@ def stateless_random_gamma(shape,
     return result
 
 
+@tf_export("random.stateless_poisson")
+def stateless_random_poisson(shape,
+                             seed,
+                             lam,
+                             dtype=dtypes.int32,
+                             name=None):
+  """Outputs deterministic pseudorandom values from a Poisson distribution.
+
+  The generated values follow a Poisson distribution with specified rate
+  parameter.
+
+  This is a stateless version of `tf.random.poisson`: if run twice with the same
+  seeds, it will produce the same pseudorandom numbers. The output is consistent
+  across multiple runs on the same hardware (and between CPU and GPU), but may
+  change between versions of TensorFlow or on non-CPU/GPU hardware.
+
+  A slight difference exists in the interpretation of the `shape` parameter
+  between `stateless_poisson` and `poisson`: in `poisson`, the `shape` is always
+  prepended to the shape of `rate`; whereas in `stateless_poisson` the shape of
+  `rate` must match the trailing dimensions of `shape`.
+
+  Example:
+
+  ```python
+  samples = tf.random.stateless_poisson([10, 2], seed=[12, 34], lam=[5, 15])
+  # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
+  # the samples drawn from each distribution
+
+  samples = tf.random.stateless_poisson([7, 5, 2], seed=[12, 34], lam=[5, 15])
+  # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
+  # represents the 7x5 samples drawn from each of the two distributions
+
+  rate = tf.constant([[1.], [3.], [5.]])
+  samples = tf.random.stateless_poisson([30, 3, 1], seed=[12, 34], lam=rate)
+  # samples has shape [30, 3, 1], with 30 samples each of 3x1 distributions.
+  ```
+
+  Args:
+    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+    seed: A shape [2] integer Tensor of seeds to the random number generator.
+    lam: Tensor. The rate parameter "lambda" of the Poisson distribution. Shape
+      must match the rightmost dimensions of `shape`.
+    dtype: Dtype of the samples (int or float dtypes are permissible, as samples
+      are discrete). Default: int32.
+    name: A name for the operation (optional).
+
+  Returns:
+    samples: A Tensor of the specified shape filled with random Poisson values.
+      For each i, each `samples[..., i]` is an independent draw from the Poisson
+      distribution with rate `lam[i]`.
+
+  """
+  with ops.name_scope(name, "stateless_random_poisson",
+                      [shape, seed, lam]) as name:
+    shape = tensor_util.shape_tensor(shape)
+    result = gen_stateless_random_ops.stateless_random_poisson(
+        shape, seed=seed, lam=lam, dtype=dtype)
+    tensor_util.maybe_set_static_shape(result, shape)
+    return result
+
+
 @tf_export("random.stateless_normal")
 def stateless_random_normal(shape,
                             seed,
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt
index f39f053d2ab..9c6fa7154a3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt
@@ -92,6 +92,10 @@ tf_module {
     name: "stateless_normal"
     argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
   }
+  member_method {
+    name: "stateless_poisson"
+    argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
+  }
   member_method {
     name: "stateless_truncated_normal"
     argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 6c5c6c1c311..c292a1dbf17 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -4292,6 +4292,10 @@ tf_module {
     name: "StatelessRandomNormal"
     argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
   }
+  member_method {
+    name: "StatelessRandomPoisson"
+    argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "StatelessRandomUniform"
     argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt
index 37b3a3129e1..e3a11ee4610 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt
@@ -80,6 +80,10 @@ tf_module {
     name: "stateless_normal"
     argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
   }
+  member_method {
+    name: "stateless_poisson"
+    argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
+  }
   member_method {
     name: "stateless_truncated_normal"
     argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 6c5c6c1c311..c292a1dbf17 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -4292,6 +4292,10 @@ tf_module {
     name: "StatelessRandomNormal"
     argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
   }
+  member_method {
+    name: "StatelessRandomPoisson"
+    argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "StatelessRandomUniform"
     argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "