From e88f6739a7a93c009dd8dbdc1d65cc2548553185 Mon Sep 17 00:00:00 2001
From: Alexey Radul <axch@google.com>
Date: Wed, 7 Oct 2020 15:33:40 -0700
Subject: [PATCH] random_poisson(rate=+inf) should return +inf, not nan.

Leaving the status quo for integer output dtypes, which is determined by how non-finite floats cast to said integers.

PiperOrigin-RevId: 335966603
Change-Id: I5bfb53a17de32e02a10595d0393f037f982b5194
---
 tensorflow/core/kernels/random_poisson_op.cc           | 10 ++++++++++
 .../python/kernel_tests/random/random_poisson_test.py  |  5 +++++
 2 files changed, 15 insertions(+)

diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc
index dcb7d6b0f0e..2c898d95749 100644
--- a/tensorflow/core/kernels/random_poisson_op.cc
+++ b/tensorflow/core/kernels/random_poisson_op.cc
@@ -150,6 +150,16 @@ struct PoissonFunctor<CPUDevice, T, U> {
           }
           continue;
         }
+        if (Eigen::numext::isinf(rate) && rate > CT(0)) {
+          // Fill the rest of the samples for the current rate value.
+          for (int64 sample_idx = output_idx % num_samples;
+               sample_idx < num_samples && output_idx < limit_output;
+               sample_idx++, output_idx++) {
+            U k = Eigen::NumTraits<U>::infinity();
+            samples_rate_output[sample_idx * num_rate] = k;
+          }
+          continue;
+        }
         // Transformed rejection due to Hormann.
         //
         // Given a CDF F(x), and G(x), a dominating distribution chosen such
diff --git a/tensorflow/python/kernel_tests/random/random_poisson_test.py b/tensorflow/python/kernel_tests/random/random_poisson_test.py
index 51dd4cb47ca..eafa1d9382c 100644
--- a/tensorflow/python/kernel_tests/random/random_poisson_test.py
+++ b/tensorflow/python/kernel_tests/random/random_poisson_test.py
@@ -171,6 +171,11 @@ class RandomPoissonTest(test.TestCase):
               constant_op.constant([1], dtype=lam_dt), [10],
               dtype=out_dt).eval()
 
+  @test_util.run_deprecated_v1
+  def testInfRate(self):
+    sample = random_ops.random_poisson(shape=[2], lam=np.inf)
+    self.assertAllEqual([np.inf, np.inf], self.evaluate(sample))
+
 
 if __name__ == "__main__":
   test.main()