From 8b07605e45f55c942d1436116fd5b0cc83a29e1d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Tue, 7 Feb 2017 18:08:21 -0800 Subject: [PATCH] Add tf.random_poisson(shape, lam) to tf core. Fixes #6798 Change: 146861107 --- .../python/kernel_tests/poisson_test.py | 50 +++ .../distributions/python/ops/poisson.py | 6 +- tensorflow/core/BUILD | 1 + tensorflow/core/graph/testlib.cc | 10 + tensorflow/core/graph/testlib.h | 4 + tensorflow/core/kernels/BUILD | 27 ++ tensorflow/core/kernels/random_op.cc | 2 +- tensorflow/core/kernels/random_poisson_op.cc | 357 ++++++++++++++++++ tensorflow/core/kernels/random_poisson_op.h | 31 ++ .../core/kernels/random_poisson_op_test.cc | 82 ++++ tensorflow/core/ops/random_ops.cc | 44 +++ tensorflow/core/ops/random_ops_test.cc | 15 + tensorflow/python/framework/constant_op.py | 1 + tensorflow/python/kernel_tests/BUILD | 15 + .../kernel_tests/random_poisson_test.py | 178 +++++++++ tensorflow/python/ops/hidden_ops.txt | 1 + tensorflow/python/ops/random_ops.py | 104 +++-- 17 files changed, 886 insertions(+), 42 deletions(-) create mode 100644 tensorflow/core/kernels/random_poisson_op.cc create mode 100644 tensorflow/core/kernels/random_poisson_op.h create mode 100644 tensorflow/core/kernels/random_poisson_op_test.cc create mode 100644 tensorflow/python/kernel_tests/random_poisson_test.py diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py index 0adaf7d816d..e1644d548da 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py @@ -170,6 +170,56 @@ class PoissonTest(test.TestCase): self.assertEqual((6,), poisson.mode().get_shape()) self.assertAllClose(lam_v, poisson.mode().eval()) + def testPoissonSample(self): + with self.test_session(): + lam_v = 4.0 + lam = constant_op.constant(lam_v) + # Choosing `n >= (k/rtol)**2, roughly ensures our sample mean should be + # within `k` std. deviations of actual up to rtol precision. + n = int(100e3) + poisson = poisson_lib.Poisson(rate=lam) + samples = poisson.sample(n, seed=123456) + sample_values = samples.eval() + self.assertEqual(samples.get_shape(), (n,)) + self.assertEqual(sample_values.shape, (n,)) + self.assertAllClose( + sample_values.mean(), stats.poisson.mean(lam_v), rtol=.01) + self.assertAllClose( + sample_values.var(), stats.poisson.var(lam_v), rtol=.01) + + def testPoissonSampleMultidimensionalMean(self): + with self.test_session(): + lam_v = np.array([np.arange(1, 51, dtype=np.float32)]) # 1 x 50 + poisson = poisson_lib.Poisson(rate=lam_v) + # Choosing `n >= (k/rtol)**2, roughly ensures our sample mean should be + # within `k` std. deviations of actual up to rtol precision. + n = int(100e3) + samples = poisson.sample(n, seed=123456) + sample_values = samples.eval() + self.assertEqual(samples.get_shape(), (n, 1, 50)) + self.assertEqual(sample_values.shape, (n, 1, 50)) + self.assertAllClose( + sample_values.mean(axis=0), + stats.poisson.mean(lam_v), + rtol=.01, + atol=0) + + def testPoissonSampleMultidimensionalVariance(self): + with self.test_session(): + lam_v = np.array([np.arange(5, 15, dtype=np.float32)]) # 1 x 10 + poisson = poisson_lib.Poisson(rate=lam_v) + # Choosing `n >= 2 * lam * (k/rtol)**2, roughly ensures our sample + # variance should be within `k` std. deviations of actual up to rtol + # precision. + n = int(300e3) + samples = poisson.sample(n, seed=123456) + sample_values = samples.eval() + self.assertEqual(samples.get_shape(), (n, 1, 10)) + self.assertEqual(sample_values.shape, (n, 1, 10)) + + self.assertAllClose( + sample_values.var(axis=0), stats.poisson.var(lam_v), rtol=.03, atol=0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index 799796ace0c..e1ddc9a0e18 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops - +from tensorflow.python.ops import random_ops __all__ = [ "Poisson", @@ -148,6 +148,10 @@ class Poisson(distribution.Distribution): def _mode(self): return math_ops.floor(self.rate) + def _sample_n(self, n, seed=None): + return random_ops.random_poisson( + self.rate, [n], dtype=self.dtype, seed=seed) + def _assert_valid_sample(self, x, check_integer=True): if not self.validate_args: return x diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 2c75353ffed..4dd9bffe80a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -635,6 +635,7 @@ cc_library( "//tensorflow/core/kernels:parameterized_truncated_normal_op", "//tensorflow/core/kernels:parsing", "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:random_poisson_op", "//tensorflow/core/kernels:remote_fused_graph_ops", "//tensorflow/core/kernels:required", "//tensorflow/core/kernels:resource_variable_ops", diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index ef4dd047875..f0ab5520f11 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -219,6 +219,16 @@ Node* RandomGamma(Graph* g, Node* shape, Node* alpha) { return ret; } +Node* RandomPoisson(Graph* g, Node* shape, Node* lam) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson") + .Input(shape) + .Input(lam) + .Attr("seed", 0) + .Finalize(g, &ret)); + return ret; +} + Node* Unary(Graph* g, const string& func, Node* input, int index) { Node* ret; TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry()) diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h index 7a23b20c2c8..d508f65ada5 100644 --- a/tensorflow/core/graph/testlib.h +++ b/tensorflow/core/graph/testlib.h @@ -113,6 +113,10 @@ Node* RandomGaussian(Graph* g, Node* input, DataType dtype); // Output dtype determined by alpha. Node* RandomGamma(Graph* g, Node* shape, Node* alpha); +// Generates random poisson distribution with the given shape and lam[s]. +// Output dtype determined by lam. +Node* RandomPoisson(Graph* g, Node* shape, Node* lam); + // Generates random parameters from the truncated standard normal distribution // of the nput shape Node* TruncatedNormal(Graph* g, Node* input, DataType dtype); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ab5923b0e7b..06a11d31ab0 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3361,6 +3361,33 @@ tf_cuda_cc_test( ], ) +tf_kernel_library( + name = "random_poisson_op", + prefix = "random_poisson_op", + deps = [ + ":random_ops", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:random_ops_op_lib", + ], +) + +tf_cuda_cc_test( + name = "random_poisson_op_test", + size = "small", + srcs = ["random_poisson_op_test.cc"], + deps = [ + ":ops_util", + ":random_poisson_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "word2vec_kernels", prefix = "word2vec_kernels", diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 0a1de111627..f3c7e0f26b1 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -541,7 +541,7 @@ TF_CALL_int64(REGISTER_INT); PhiloxRandomOp< \ GPUDevice, \ random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >) + random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); #define REGISTER_INT(IntType) \ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc new file mode 100644 index 00000000000..553a4a7f939 --- /dev/null +++ b/tensorflow/core/kernels/random_poisson_op.cc @@ -0,0 +1,357 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/random_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/random_poisson_op.h" + +#include <algorithm> +#include <cmath> +#include <memory> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow/core/util/work_sharder.h" + +#if EIGEN_COMP_GNUC && __cplusplus > 199711L +#define DISABLE_FLOAT_EQUALITY_WARNING \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") +#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") +#else +#define DISABLE_FLOAT_EQUALITY_WARNING +#define ENABLE_FLOAT_EQUALITY_WARNING +#endif + +#define UNIFORM(X) \ + if (uniform_remaining == 0) { \ + uniform_remaining = Uniform::kResultElementCount; \ + uniform_result = uniform(&gen); \ + } \ + uniform_remaining--; \ + CT X = uniform_result[uniform_remaining] + +namespace tensorflow { +namespace { + +static constexpr int kReservedSamplesPerOutput = 256; + +typedef Eigen::ThreadPoolDevice CPUDevice; + +// We will compute half-precision Poisson samples with float precision +// intermediate calculations. +template <typename T> +struct PoissonComputeType { + typedef T ComputeType; +}; + +template <> +struct PoissonComputeType<Eigen::half> { + typedef float ComputeType; +}; + +} // namespace + +namespace functor { + +template <typename Device, typename T> +struct PoissonFunctor { + void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat, + int num_rate, int num_samples, + const random::PhiloxRandom& rng, T* samples_flat); +}; + +template <typename T> +struct PoissonFunctor<CPUDevice, T> { + void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat, + int num_rate, int num_samples, + const random::PhiloxRandom& rng, T* samples_flat) { + // Two different algorithms are employed, depending on the size of + // rate. + // If rate < 10, we use an algorithm attributed to Knuth: + // Seminumerical Algorithms. Art of Computer Programming, Volume 2. + // + // This algorithm runs in O(rate) time, and will require O(rate) + // uniform + // variates. + // + // If rate >= 10 we use a transformation-rejection algorithm from + // pairs + // of uniform random variables due to Hormann. + // http://www.sciencedirect.com/science/article/pii/0167668793909974 + // + // The algorithm has an acceptance rate of ~89% for the smallest rate + // (~10), + // and higher accept rates for higher rate, so runtime is + // O(NumRate * NumSamples * k) with k ~ 1 / 0.89. + // + // We partition work first across rates then across + // samples-per-rate to + // avoid a couple flops which can be done on a per-rate basis. + + typedef random::UniformDistribution<random::PhiloxRandom, CT> Uniform; + + auto DoWork = [num_samples, num_rate, &rng, samples_flat, rate_flat]( + int start_output, int limit_output) { + // Capturing "rng" by value would only make a copy for the _shared_ + // lambda. Since we want to let each worker have its own copy, we pass + // "rng" by reference and explicitly do a copy assignment. + + Uniform uniform; + typename Uniform::ResultType uniform_result; + for (int64 output_idx = start_output; output_idx < limit_output; + /* output_idx incremented within inner loop below */) { + const int64 rate_idx = output_idx / num_samples; + + // Several calculations can be done on a per-rate basis. + const CT rate = CT(rate_flat[rate_idx]); + + auto samples_rate_output = samples_flat + rate_idx; + + if (rate < CT(10)) { + // Knuth's algorithm for generating Poisson random variates. + // Given a Poisson process, the time between events is exponentially + // distributed. If we have a Poisson process with rate lambda, then, + // the time between events is distributed Exp(lambda). If X ~ + // Uniform(0, 1), then Y ~ Exp(lambda), where Y = -log(X) / lambda. + // Thus to simulate a Poisson draw, we can draw X_i ~ Exp(lambda), + // and N ~ Poisson(lambda), where N is the least number such that + // \sum_i^N X_i > 1. + const CT exp_neg_rate = Eigen::numext::exp(-rate); + + // Compute the rest of the samples for the current rate value. + for (int64 sample_idx = output_idx % num_samples; + sample_idx < num_samples && output_idx < limit_output; + sample_idx++, output_idx++) { + random::PhiloxRandom gen = rng; + gen.Skip(kReservedSamplesPerOutput * output_idx); + int16 uniform_remaining = 0; + + CT prod = 1; + CT x = 0; + + // Keep trying until we surpass e^(-rate). This will take + // expected time proportional to rate. + while (true) { + UNIFORM(u); + prod = prod * u; + if (prod <= exp_neg_rate) { + samples_rate_output[sample_idx * num_rate] = T(x); + break; + } + x += 1; + } + } + continue; + } + // Transformed rejection due to Hormann. + // + // Given a CDF F(x), and G(x), a dominating distribution chosen such + // that it is close to the inverse CDF F^-1(x), compute the following + // steps: + // + // 1) Generate U and V, two independent random variates. Set U = U - 0.5 + // (this step isn't strictly necessary, but is done to make some + // calculations symmetric and convenient. Henceforth, G is defined on + // [-0.5, 0.5]). + // + // 2) If V <= alpha * F'(G(U)) * G'(U), return floor(G(U)), else return + // to step 1. alpha is the acceptance probability of the rejection + // algorithm. + // + // For more details on transformed rejection, see: + // http://citeseer.ist.psu.edu/viewdoc/citations;jsessionid=1BEB35946CC807879F55D42512E5490C?doi=10.1.1.48.3054. + // + // The dominating distribution in this case: + // + // G(u) = (2 * a / (2 - |u|) + b) * u + c + + using Eigen::numext::log; + const CT log_rate = log(rate); + + // Constants used to define the dominating distribution. Names taken + // from Hormann's paper. Constants were chosen to define the tightest + // G(u) for the inverse Poisson CDF. + const CT b = CT(0.931) + CT(2.53) * Eigen::numext::sqrt(rate); + const CT a = CT(-0.059) + CT(0.02483) * b; + + // This is the inverse acceptance rate. At a minimum (when rate = 10), + // this corresponds to ~75% acceptance. As the rate becomes larger, this + // approaches ~89%. + const CT inv_alpha = CT(1.1239) + CT(1.1328) / (b - CT(3.4)); + + // Compute the rest of the samples for the current rate value. + for (int64 sample_idx = output_idx % num_samples; + sample_idx < num_samples && output_idx < limit_output; + sample_idx++, output_idx++) { + random::PhiloxRandom gen = rng; + gen.Skip(kReservedSamplesPerOutput * output_idx); + int16 uniform_remaining = 0; + + while (true) { + UNIFORM(u); + u -= CT(0.5); + UNIFORM(v); + + CT u_shifted = CT(0.5) - Eigen::numext::abs(u); + CT k = Eigen::numext::floor((CT(2) * a / u_shifted + b) * u + rate + + CT(0.43)); + + // When alpha * f(G(U)) * G'(U) is close to 1, it is possible to + // find a rectangle (-u_r, u_r) x (0, v_r) under the curve, such + // that if v <= v_r and |u| <= u_r, then we can accept. + // Here v_r = 0.9227 - 3.6224 / (b - 2) and u_r = 0.43. + if (u_shifted >= CT(0.07) && + v <= CT(0.9277) - CT(3.6224) / (b - CT(2))) { + samples_rate_output[sample_idx * num_rate] = T(k); + break; + } + + if (k < 0 || (u_shifted < CT(0.013) && v > u_shifted)) { + continue; + } + + // The expression below is equivalent to the computation of step 2) + // in transformed rejection (v <= alpha * F'(G(u)) * G'(u)). + CT s = log(v * inv_alpha / (a / (u_shifted * u_shifted) + b)); + CT t = -rate + k * log_rate - Eigen::numext::lgamma(k + 1); + if (s <= t) { + samples_rate_output[sample_idx * num_rate] = T(k); + break; + } + } + } + } + }; + + // This will depend on rate. + // For rate < 10, on average, O(rate) calls to uniform are + // needed, with that + // many multiplies. ~10 uniform calls on average with ~25 cost op calls. + // + // Very roughly, for rate >= 10, the single call to log + call to + // lgamma + // occur for ~60 percent of samples. + // 2 x 100 (64-bit cycles per log) * 0.62 = ~124 + // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each: + // 40 * .62 = ~25. + // + // Finally, there are several other ops that are done every loop along with + // 2 uniform generations along with 5 other ops at 3-6 cycles each. + // ~15 / .89 = ~16 + // + // In total this should be ~165 + 2 * Uniform::kElementCost. + // We assume that half the tensor has rate < 10, so on average 6 + // uniform's + // will be needed. We will upper bound the other op cost by the one for + // rate > 10. + static const int kElementCost = 165 + 6 * Uniform::kElementCost + + 6 * random::PhiloxRandom::kElementCost; + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, + num_rate * num_samples, kElementCost, DoWork); + } + + private: + typedef typename PoissonComputeType<T>::ComputeType CT; +}; + +} // namespace functor + +namespace { + +// Samples from one or more Poisson distributions. +template <typename T> +class RandomPoissonOp : public OpKernel { + public: + explicit RandomPoissonOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, generator_.Init(context)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& shape_t = ctx->input(0); + const Tensor& rate_t = ctx->input(1); + + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(shape_t.shape()) && + (shape_t.dtype() == DataType::DT_INT32 || + shape_t.dtype() == DataType::DT_INT64), + errors::InvalidArgument( + "shape must be a vector of {int32,int64}, got shape: ", + shape_t.DebugString())); + TensorShape samples_shape; + if (shape_t.dtype() == DataType::DT_INT32) { + auto vec = shape_t.flat<int32>(); + OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(), + &samples_shape)); + } else if (shape_t.dtype() == DataType::DT_INT64) { + auto vec = shape_t.flat<int64>(); + OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(), + &samples_shape)); + } + const int64 num_samples = samples_shape.num_elements(); + OP_REQUIRES(ctx, num_samples > 0, + errors::InvalidArgument( + "Input shape should have non-zero element count, got: ", + num_samples)); + + samples_shape.AppendShape(rate_t.shape()); + // Allocate output samples. + Tensor* samples_t = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t)); + + const auto rate_flat = rate_t.flat<T>().data(); + const int64 num_rate = rate_t.NumElements(); + OP_REQUIRES( + ctx, num_rate > 0, + errors::InvalidArgument( + "Input rate should have non-zero element count, got: ", num_rate)); + auto samples_flat = samples_t->flat<T>().data(); + random::PhiloxRandom rng = generator_.ReserveRandomOutputs( + num_samples * num_rate, kReservedSamplesPerOutput); + + functor::PoissonFunctor<CPUDevice, T>()(ctx, ctx->eigen_device<CPUDevice>(), + rate_flat, num_rate, num_samples, + rng, samples_flat); + } + + private: + GuardedPhiloxRandom generator_; + + TF_DISALLOW_COPY_AND_ASSIGN(RandomPoissonOp); +}; +} // namespace + +#undef UNIFORM + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomPoisson").Device(DEVICE_CPU).TypeConstraint<TYPE>("dtype"), \ + RandomPoissonOp<TYPE>); + +TF_CALL_half(REGISTER); +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); + +#undef REGISTER + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/random_poisson_op.h b/tensorflow/core/kernels/random_poisson_op.h new file mode 100644 index 00000000000..6c49acc8007 --- /dev/null +++ b/tensorflow/core/kernels/random_poisson_op.h @@ -0,0 +1,31 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_ +#define TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_ + +namespace tensorflow { + +namespace functor { + +// Generic helper functor for the Random Poisson Op. +template <typename Device, typename T> +struct PoissonFunctor; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_ diff --git a/tensorflow/core/kernels/random_poisson_op_test.cc b/tensorflow/core/kernels/random_poisson_op_test.cc new file mode 100644 index 00000000000..bccdbf6c7f5 --- /dev/null +++ b/tensorflow/core/kernels/random_poisson_op_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <random> + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +Tensor VecShape(int64 v) { + if (v >= std::numeric_limits<int32>::max()) { + Tensor shape(DT_INT64, TensorShape({1})); + shape.vec<int64>()(0) = v; + return shape; + } else { + Tensor shape(DT_INT32, TensorShape({1})); + shape.vec<int32>()(0) = v; + return shape; + } +} + +Tensor VecLam32(int64 n, int magnitude) { + std::mt19937 gen(0x12345); + std::uniform_real_distribution<float> dist(0.0, 1.0); + Tensor lams(DT_FLOAT, TensorShape({n})); + for (int i = 0; i < n; i++) { + // Generate in range (magnitude, 2 * magnitude) + lams.vec<float>()(i) = magnitude * (1 + dist(gen)); + } + return lams; +} + +Tensor VecLam64(int64 n, int magnitude) { + std::mt19937 gen(0x12345); + std::uniform_real_distribution<double> dist(0.0, 1.0); + Tensor lams(DT_DOUBLE, TensorShape({n})); + for (int i = 0; i < n; i++) { + // Generate in range (magnitude, 2 * magnitude) + lams.vec<double>()(i) = magnitude * (1 + dist(gen)); + } + return lams; +} + +#define BM_Poisson(DEVICE, BITS, MAGNITUDE) \ + static void BM_##DEVICE##_RandomPoisson_lam_##MAGNITUDE##_##BITS( \ + int iters, int nsamp, int nlam) { \ + testing::ItemsProcessed(static_cast<int64>(iters) * nsamp * nlam); \ + Graph* g = new Graph(OpRegistry::Global()); \ + test::graph::RandomPoisson( \ + g, test::graph::Constant(g, VecShape(nsamp)), \ + test::graph::Constant(g, VecLam##BITS(nlam, MAGNITUDE))); \ + test::Benchmark(#DEVICE, g).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_RandomPoisson_lam_##MAGNITUDE##_##BITS) \ + ->RangePair(1, 64, 2, 50); + +BM_Poisson(cpu, 32, 1); +BM_Poisson(cpu, 32, 8); +BM_Poisson(cpu, 32, 32); + +BM_Poisson(cpu, 64, 1); +BM_Poisson(cpu, 64, 8); +BM_Poisson(cpu, 64, 32); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc index 776523f33fb..7b2da9d8e6d 100644 --- a/tensorflow/core/ops/random_ops.cc +++ b/tensorflow/core/ops/random_ops.cc @@ -276,4 +276,48 @@ output: A tensor with shape `shape + shape(alpha)`. Each slice `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. )doc"); +REGISTER_OP("RandomPoisson") + .SetIsStateful() + .Input("shape: S") + .Input("rate: dtype") + .Output("output: dtype") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("S: {int32, int64}") + .Attr("dtype: {half, float, double}") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); + c->set_output(0, out); + return Status::OK(); + }) + .Doc(R"doc( +Outputs random values from the Poisson distribution(s) described by rate. + +This op uses two algorithms, depending on rate. If rate >= 10, then +the algorithm by Hormann is used to acquire samples via +transformation-rejection. +See http://www.sciencedirect.com/science/article/pii/0167668793909974. + +Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform +random variables. +See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer +Programming, Volume 2. Addison Wesley + +shape: 1-D integer tensor. Shape of independent samples to draw from each + distribution described by the shape parameters given in rate. +rate: A tensor in which each scalar is a "rate" parameter describing the + associated poisson distribution. +seed: If either `seed` or `seed2` are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: A second seed to avoid seed collision. + +output: A tensor with shape `shape + shape(rate)`. Each slice + `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for + `rate[i0, i1, ...iN]`. The dtype of the output matches the dtype of + rate. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/core/ops/random_ops_test.cc b/tensorflow/core/ops/random_ops_test.cc index 524e1079981..b0aa565485b 100644 --- a/tensorflow/core/ops/random_ops_test.cc +++ b/tensorflow/core/ops/random_ops_test.cc @@ -53,4 +53,19 @@ TEST(RandomOpsTest, RandomGamma_ShapeFn) { INFER_OK(op, "[3];[]", "[1,2,3]"); } +TEST(RandomOpsTest, RandomPoisson_ShapeFn) { + ShapeInferenceTestOp op("RandomPoisson"); + op.input_tensors.resize(2); + + INFER_OK(op, "?;?", "?"); + INFER_OK(op, "?;[3]", "?"); + INFER_OK(op, "[1];?", "?"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,2];[3,4]"); + Tensor shape = test::AsTensor<int32>({1, 2, 3}); + op.input_tensors[0] = &shape; + INFER_OK(op, "[3];[4,?]", "[1,2,3,d1_0,d1_1]"); + INFER_OK(op, "[3];[4,5]", "[1,2,3,d1_0,d1_1]"); + INFER_OK(op, "[3];[]", "[1,2,3]"); +} + } // end namespace tensorflow diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 3bcc5377797..05a520850e1 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -96,6 +96,7 @@ print(sess.run(var)) @@random_crop @@multinomial @@random_gamma +@@random_poisson @@set_random_seed """ diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 119cffe0e9b..2ecbd089922 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2112,6 +2112,21 @@ cuda_py_test( ], ) +cuda_py_test( + name = "random_poisson_test", + size = "medium", + srcs = ["random_poisson_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + ], +) + cuda_py_test( name = "rnn_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/random_poisson_test.py b/tensorflow/python/kernel_tests/random_poisson_test.py new file mode 100644 index 00000000000..01281b7bd03 --- /dev/null +++ b/tensorflow/python/kernel_tests/random_poisson_test.py @@ -0,0 +1,178 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.random_ops.random_poisson.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +class RandomPoissonTest(test.TestCase): + """This is a large test due to the moments computation taking some time.""" + + def _Sampler(self, num, lam, dtype, use_gpu, seed=None): + + def func(): + with self.test_session(use_gpu=use_gpu, graph=ops.Graph()) as sess: + rng = random_ops.random_poisson(lam, [num], dtype=dtype, seed=seed) + ret = np.empty([10, num]) + for i in xrange(10): + ret[i, :] = sess.run(rng) + return ret + + return func + + # TODO(srvasude): Factor this out along with the corresponding moment testing + # method in random_gamma_test into a single library. + def testMoments(self): + try: + from scipy import stats # pylint: disable=g-import-not-at-top + except ImportError as e: + tf_logging.warn("Cannot test moments: %s", e) + return + # The moments test is a z-value test. This is the largest z-value + # we want to tolerate. Since the z-test approximates a unit normal + # distribution, it should almost definitely never exceed 6. + z_limit = 6.0 + for dt in dtypes.float16, dtypes.float32, dtypes.float64: + # Test when lam < 10 and when lam >= 10 + for stride in 0, 4, 10: + for lam in (3., 20): + max_moment = 5 + sampler = self._Sampler(10000, lam, dt, use_gpu=False, seed=12345) + moments = [0] * (max_moment + 1) + moments_sample_count = [0] * (max_moment + 1) + x = np.array(sampler().flat) # sampler does 10x samples + for k in range(len(x)): + moment = 1. + for i in range(max_moment + 1): + index = k + i * stride + if index >= len(x): + break + moments[i] += moment + moments_sample_count[i] += 1 + moment *= x[index] + for i in range(max_moment + 1): + moments[i] /= moments_sample_count[i] + for i in range(1, max_moment + 1): + g = stats.poisson(lam) + if stride == 0: + moments_i_mean = g.moment(i) + moments_i_squared = g.moment(2 * i) + else: + moments_i_mean = pow(g.moment(1), i) + moments_i_squared = pow(g.moment(2), i) + moments_i_var = ( + moments_i_squared - moments_i_mean * moments_i_mean) + # Assume every operation has a small numerical error. + # It takes i multiplications to calculate one i-th moment. + error_per_moment = i * 1e-6 + total_variance = ( + moments_i_var / moments_sample_count[i] + error_per_moment) + if not total_variance: + total_variance = 1e-10 + # z_test is approximately a unit normal distribution. + z_test = abs( + (moments[i] - moments_i_mean) / np.sqrt(total_variance)) + self.assertLess(z_test, z_limit) + + # Checks that the CPU and GPU implementation returns the same results, + # given the same random seed + def testCPUGPUMatch(self): + for dt in dtypes.float16, dtypes.float32, dtypes.float64: + results = {} + for use_gpu in [False, True]: + sampler = self._Sampler(1000, 1.0, dt, use_gpu=use_gpu, seed=12345) + results[use_gpu] = sampler() + if dt == dtypes.float16: + self.assertAllClose(results[False], results[True], rtol=1e-3, atol=1e-3) + else: + self.assertAllClose(results[False], results[True], rtol=1e-6, atol=1e-6) + + def testSeed(self): + for dt in dtypes.float16, dtypes.float32, dtypes.float64: + sx = self._Sampler(1000, 1.0, dt, use_gpu=True, seed=345) + sy = self._Sampler(1000, 1.0, dt, use_gpu=True, seed=345) + self.assertAllEqual(sx(), sy()) + + def testNoCSE(self): + """CSE = constant subexpression eliminator. + + SetIsStateful() should prevent two identical random ops from getting + merged. + """ + for dtype in dtypes.float16, dtypes.float32, dtypes.float64: + with self.test_session(use_gpu=True): + rnd1 = random_ops.random_poisson(2.0, [24], dtype=dtype) + rnd2 = random_ops.random_poisson(2.0, [24], dtype=dtype) + diff = rnd2 - rnd1 + # Since these are all positive integers, the norm will + # be at least 1 if they are different. + self.assertGreaterEqual(np.linalg.norm(diff.eval()), 1) + + def testShape(self): + # Fully known shape. + rnd = random_ops.random_poisson(2.0, [150], seed=12345) + self.assertEqual([150], rnd.get_shape().as_list()) + rnd = random_ops.random_poisson( + lam=array_ops.ones([1, 2, 3]), + shape=[150], + seed=12345) + self.assertEqual([150, 1, 2, 3], rnd.get_shape().as_list()) + rnd = random_ops.random_poisson( + lam=array_ops.ones([1, 2, 3]), + shape=[20, 30], + seed=12345) + self.assertEqual([20, 30, 1, 2, 3], rnd.get_shape().as_list()) + rnd = random_ops.random_poisson( + lam=array_ops.placeholder(dtypes.float32, shape=(2,)), + shape=[12], + seed=12345) + self.assertEqual([12, 2], rnd.get_shape().as_list()) + # Partially known shape. + rnd = random_ops.random_poisson( + lam=array_ops.ones([7, 3]), + shape=array_ops.placeholder(dtypes.int32, shape=(1,)), + seed=12345) + self.assertEqual([None, 7, 3], rnd.get_shape().as_list()) + rnd = random_ops.random_poisson( + lam=array_ops.ones([9, 6]), + shape=array_ops.placeholder(dtypes.int32, shape=(3,)), + seed=12345) + self.assertEqual([None, None, None, 9, 6], rnd.get_shape().as_list()) + # Unknown shape. + rnd = random_ops.random_poisson( + lam=array_ops.placeholder(dtypes.float32), + shape=array_ops.placeholder(dtypes.int32), + seed=12345) + self.assertIs(None, rnd.get_shape().ndims) + rnd = random_ops.random_poisson( + lam=array_ops.placeholder(dtypes.float32), + shape=[50], + seed=12345) + self.assertIs(None, rnd.get_shape().ndims) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index dab15976b58..4937f1a50a8 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -282,6 +282,7 @@ ParseSingleSequenceExample # random_ops RandomGamma +RandomPoisson RandomUniform RandomUniformInt RandomShuffle diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 34b4d361021..5a753ae7a1e 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -71,10 +71,8 @@ def random_normal(shape, mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean") stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") seed1, seed2 = random_seed.get_seed(seed) - rnd = gen_random_ops._random_standard_normal(shape_tensor, - dtype, - seed=seed1, - seed2=seed2) + rnd = gen_random_ops._random_standard_normal( + shape_tensor, dtype, seed=seed1, seed2=seed2) mul = rnd * stddev_tensor value = math_ops.add(mul, mean_tensor, name=name) return value @@ -125,13 +123,14 @@ def parameterized_truncated_normal(shape, minvals_tensor = ops.convert_to_tensor(minvals, dtype=dtype, name="minvals") maxvals_tensor = ops.convert_to_tensor(maxvals, dtype=dtype, name="maxvals") seed1, seed2 = random_seed.get_seed(seed) - rnd = gen_random_ops._parameterized_truncated_normal(shape_tensor, - means_tensor, - stddevs_tensor, - minvals_tensor, - maxvals_tensor, - seed=seed1, - seed2=seed2) + rnd = gen_random_ops._parameterized_truncated_normal( + shape_tensor, + means_tensor, + stddevs_tensor, + minvals_tensor, + maxvals_tensor, + seed=seed1, + seed2=seed2) return rnd @@ -168,10 +167,8 @@ def truncated_normal(shape, mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean") stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") seed1, seed2 = random_seed.get_seed(seed) - rnd = gen_random_ops._truncated_normal(shape_tensor, - dtype, - seed=seed1, - seed2=seed2) + rnd = gen_random_ops._truncated_normal( + shape_tensor, dtype, seed=seed1, seed2=seed2) mul = rnd * stddev_tensor value = math_ops.add(mul, mean_tensor, name=name) return value @@ -232,17 +229,11 @@ def random_uniform(shape, maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") seed1, seed2 = random_seed.get_seed(seed) if dtype.is_integer: - return gen_random_ops._random_uniform_int(shape, - minval, - maxval, - seed=seed1, - seed2=seed2, - name=name) + return gen_random_ops._random_uniform_int( + shape, minval, maxval, seed=seed1, seed2=seed2, name=name) else: - rnd = gen_random_ops._random_uniform(shape, - dtype, - seed=seed1, - seed2=seed2) + rnd = gen_random_ops._random_uniform( + shape, dtype, seed=seed1, seed2=seed2) return math_ops.add(rnd * (maxval - minval), minval, name=name) @@ -275,10 +266,8 @@ def random_shuffle(value, seed=None, name=None): dimension. """ seed1, seed2 = random_seed.get_seed(seed) - return gen_random_ops._random_shuffle(value, - seed=seed1, - seed2=seed2, - name=name) + return gen_random_ops._random_shuffle( + value, seed=seed1, seed2=seed2, name=name) def random_crop(value, size, seed=None, name=None): @@ -349,10 +338,8 @@ def multinomial(logits, num_samples, seed=None, name=None): with ops.name_scope(name, "multinomial", [logits]): logits = ops.convert_to_tensor(logits, name="logits") seed1, seed2 = random_seed.get_seed(seed) - return gen_random_ops.multinomial(logits, - num_samples, - seed=seed1, - seed2=seed2) + return gen_random_ops.multinomial( + logits, num_samples, seed=seed1, seed2=seed2) ops.NotDifferentiable("Multinomial") @@ -426,15 +413,52 @@ def random_gamma(shape, with ops.name_scope(name, "random_gamma", [shape, alpha, beta]): shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32) alpha = ops.convert_to_tensor(alpha, name="alpha", dtype=dtype) - beta = ops.convert_to_tensor(beta if beta is not None else 1, - name="beta", - dtype=dtype) + beta = ops.convert_to_tensor( + beta if beta is not None else 1, name="beta", dtype=dtype) alpha_broadcast = alpha + array_ops.zeros_like(beta) seed1, seed2 = random_seed.get_seed(seed) - return gen_random_ops._random_gamma(shape, - alpha_broadcast, - seed=seed1, - seed2=seed2) / beta + return gen_random_ops._random_gamma( + shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta ops.NotDifferentiable("RandomGamma") + + +def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): + """Draws `shape` samples from each of the given Poisson distribution(s). + + `lam` is the rate parameter describing the distribution(s). + + Example: + + samples = tf.random_poisson([0.5, 1.5], [10]) + # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents + # the samples drawn from each distribution + + samples = tf.random_poisson([12.2, 3.3], [7, 5]) + # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] + # represents the 7x5 samples drawn from each of the two distributions + + Args: + lam: A Tensor or Python value or N-D array of type `dtype`. + `lam` provides the rate parameter(s) describing the poisson + distribution(s) to sample. + shape: A 1-D integer Tensor or Python array. The shape of the output samples + to be drawn per "rate"-parameterized distribution. + dtype: The type of `lam` and the output: `float16`, `float32`, or + `float64`. + seed: A Python integer. Used to create a random seed for the distributions. + See + [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) + for behavior. + name: Optional name for the operation. + + Returns: + samples: a `Tensor` of shape `tf.concat(shape, tf.shape(lam))` with + values of type `dtype`. + """ + with ops.name_scope(name, "random_poisson", [lam, shape]): + lam = ops.convert_to_tensor(lam, name="lam", dtype=dtype) + shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32) + seed1, seed2 = random_seed.get_seed(seed) + return gen_random_ops._random_poisson(shape, lam, seed=seed1, seed2=seed2)