diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index ffdd415e637..da59e1fa3ad 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1497,6 +1497,7 @@ cc_library( "//tensorflow/core/kernels:ragged_ops", "//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:stateful_random_ops", + "//tensorflow/core/kernels:random_binomial_op", "//tensorflow/core/kernels:random_poisson_op", "//tensorflow/core/kernels:remote_fused_graph_ops", "//tensorflow/core/kernels:required", diff --git a/tensorflow/core/api_def/base_api/api_def_StatefulRandomBinomial.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatefulRandomBinomial.pbtxt new file mode 100644 index 00000000000..752c2ba48bc --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatefulRandomBinomial.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "StatefulRandomBinomial" +} diff --git a/tensorflow/core/api_def/python_api/api_def_StatefulRandomBinomial.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatefulRandomBinomial.pbtxt new file mode 100644 index 00000000000..cb371d5674f --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_StatefulRandomBinomial.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "StatefulRandomBinomial" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 1785ba0973c..1188251f085 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5222,6 +5222,38 @@ tf_cuda_cc_test( ], ) +tf_kernel_library( + name = "random_binomial_op", + prefix = "random_binomial_op", + deps = [ + ":cwise_op", + ":random_ops", + ":resource_variable_ops", + ":stateful_random_ops", + ":training_op_helpers", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:random_ops_op_lib", + ], +) + +tf_cuda_cc_test( + name = "random_binomial_op_test", + size = "small", + srcs = ["random_binomial_op_test.cc"], + deps = [ + ":ops_util", + ":random_binomial_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "random_poisson_op", prefix = "random_poisson_op", diff --git a/tensorflow/core/kernels/random_binomial_op.cc b/tensorflow/core/kernels/random_binomial_op.cc new file mode 100644 index 00000000000..6ed36605530 --- /dev/null +++ b/tensorflow/core/kernels/random_binomial_op.cc @@ -0,0 +1,447 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/random_ops.cc. +// NOTE: If the algorithm is changed, please run the test +// .../python/kernel_tests/random:random_binomial_test +// commenting out the "tf.set_random_seed(seed)" lines, and using the +// "--runs-per-test=1000" flag. This tests the statistical correctness of the +// op results. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/random_binomial_op.h" + +#include <algorithm> +#include <cmath> +#include <memory> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" +#include "tensorflow/core/kernels/training_op_helpers.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow/core/util/work_sharder.h" + +#define UNIFORM(X) \ + if (uniform_remaining == 0) { \ + uniform_remaining = Uniform::kResultElementCount; \ + uniform_result = uniform(gen); \ + } \ + uniform_remaining--; \ + double X = uniform_result[uniform_remaining] + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace { + +typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform; + +// Binomial inversion. Given prob, sum geometric random variables until they +// exceed count. The number of random variables used is binomially distributed. +// This is also known as binomial inversion, as this is equivalent to inverting +// the Binomial CDF. +double binomial_inversion(double count, double prob, + random::PhiloxRandom* gen) { + using Eigen::numext::ceil; + using Eigen::numext::log; + using Eigen::numext::log1p; + + double geom_sum = 0; + int num_geom = 0; + + Uniform uniform; + typename Uniform::ResultType uniform_result; + int16 uniform_remaining = 0; + + while (true) { + UNIFORM(u); + double geom = ceil(log(u) / log1p(-prob)); + geom_sum += geom; + if (geom_sum > count) { + break; + } + ++num_geom; + } + return num_geom; +} + +double stirling_approx_tail(double k) { + static double kTailValues[] = {0.0810614667953272, 0.0413406959554092, + 0.0276779256849983, 0.02079067210376509, + 0.0166446911898211, 0.0138761288230707, + 0.0118967099458917, 0.0104112652619720, + 0.00925546218271273, 0.00833056343336287}; + if (k <= 9) { + return kTailValues[static_cast<int>(k)]; + } + double kp1sq = (k + 1) * (k + 1); + return (1 / 12 - (1 / 360 + 1 / 1260 / kp1sq) / kp1sq) / (k + 1); +} + +// We use a transformation-rejection algorithm from +// pairs of uniform random variables due to Hormann. +// https://www.tandfonline.com/doi/abs/10.1080/00949659308811496 +double btrs(double count, double prob, random::PhiloxRandom* gen) { + using Eigen::numext::abs; + using Eigen::numext::floor; + using Eigen::numext::log; + using Eigen::numext::log1p; + using Eigen::numext::sqrt; + + // This is spq in the paper. + const double stddev = sqrt(count * prob * (1 - prob)); + + // Other coefficients for Transformed Rejection sampling. + const double b = 1.15 + 2.53 * stddev; + const double a = -0.0873 + 0.0248 * b + 0.01 * prob; + const double c = count * prob + 0.5; + const double v_r = 0.92 - 4.2 / b; + const double r = prob / (1 - prob); + + Uniform uniform; + typename Uniform::ResultType uniform_result; + int16 uniform_remaining = 0; + + while (true) { + UNIFORM(u); + UNIFORM(v); + u = u - 0.5; + double us = 0.5 - abs(u); + double k = floor((2 * a / us + b) * u + c); + + // Region for which the box is tight, and we + // can return our calculated value This should happen + // 0.86 * v_r times. In the limit as n * p is large, + // the acceptance rate converges to ~79% (and in the lower + // regime it is ~24%). + if (us >= 0.07 && v <= v_r) { + return k; + } + // Reject non-sensical answers. + if (k < 0 || k > count) { + continue; + } + + double alpha = (2.83 + 5.1 / b) * stddev; + double m = floor((count + 1) * prob); + // This deviates from Hormann's BRTS algorithm, as there is a log missing. + // For all (u, v) pairs outside of the bounding box, this calculates the + // transformed-reject ratio. + v = log(v * alpha / (a / (us * us) + b)); + double upperbound = + ((m + 0.5) * log((m + 1) / (r * (count - m + 1))) + + (count + 1) * log((count - m + 1) / (count - k + 1)) + + (k + 0.5) * log(r * (count - k + 1) / (k + 1)) + + stirling_approx_tail(m) + stirling_approx_tail(count - m) - + stirling_approx_tail(k) - stirling_approx_tail(count - k)); + if (v <= upperbound) { + return k; + } + } +} + +} // namespace + +namespace functor { + +template <typename T, typename U> +struct RandomBinomialFunctor<CPUDevice, T, U> { + void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches, + int64 samples_per_batch, int64 num_elements, + typename TTypes<T>::ConstFlat counts, + typename TTypes<T>::ConstFlat probs, + const random::PhiloxRandom& gen, + typename TTypes<U>::Flat output) { + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + + auto DoWork = [samples_per_batch, num_elements, &counts, &probs, &gen, + &output](int start_batch, int limit_batch) { + // Capturing "gen" by-value would only make a copy for the _shared_ + // lambda. Since we want to let each worker have its own copy, we pass + // "gen" by reference and explicitly do a copy assignment here. + random::PhiloxRandom gen_copy = gen; + // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to + // us using the same state in different batches. + // The sample from each iteration uses 2 random numbers. + gen_copy.Skip(start_batch * 2 * 3 * (samples_per_batch + 3) / 4); + + // Vectorized intermediate calculations for uniform rejection sampling. + // We always generate at most 4 samples. + Eigen::array<T, 4> z; + Eigen::array<T, 4> g; + + for (int64 b = start_batch; b < limit_batch; ++b) { + // We are passed a flat array for each of the parameter tensors. + // The input is either a scalar broadcasted to all batches or a vector + // with length num_batches, but the scalar becomes an array of length 1. + T count = counts((counts.dimension(0) == 1) ? 0 : b); + T prob = probs((probs.dimension(0) == 1) ? 0 : b); + + // The last batch can be short, if we adjusted num_batches and + // samples_per_batch. + const int64 limit_sample = + std::min((b + 1) * samples_per_batch, num_elements); + int64 sample = b * samples_per_batch; + + // Calculate normalized samples, then convert them. + // Determine the method to use. + double dcount = static_cast<double>(count); + if (prob <= T(0.5)) { + double dp = static_cast<double>(prob); + if (count * prob >= T(10)) { + while (sample < limit_sample) { + output(sample) = static_cast<U>(btrs(dcount, dp, &gen_copy)); + sample++; + } + } else { + while (sample < limit_sample) { + output(sample) = + static_cast<U>(binomial_inversion(dcount, dp, &gen_copy)); + sample++; + } + } + } else { + T q = T(1) - prob; + double dcount = static_cast<double>(count); + double dq = static_cast<double>(q); + if (count * q >= T(10)) { + while (sample < limit_sample) { + output(sample) = + static_cast<U>(dcount - btrs(dcount, dq, &gen_copy)); + sample++; + } + } else { + while (sample < limit_sample) { + output(sample) = static_cast<U>( + dcount - binomial_inversion(dcount, dq, &gen_copy)); + sample++; + } + } + } + } + }; + + const int64 batch_init_cost = + // normMin, normMax + (Eigen::TensorOpCost::AddCost<T>() + + Eigen::TensorOpCost::MulCost<T>()) * + 2 + // sqrtFactor + + Eigen::TensorOpCost::AddCost<T>() + + Eigen::TensorOpCost::MulCost<T>() + + Eigen::internal::functor_traits< + Eigen::internal::scalar_sqrt_op<T>>::Cost + // cutoff + + Eigen::TensorOpCost::MulCost<T>() * 4 + + Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + // diff + + Eigen::TensorOpCost::AddCost<T>(); + // This will depend on count * p (or count * q). + // For n * p < 10, on average, O(n * p) calls to uniform are + // needed, with that + // many multiplies. ~10 uniform calls on average with ~200 cost op calls. + // + // Very roughly, for rate >= 10, the four calls to log + // occur for ~72 percent of samples. + // 4 x 100 (64-bit cycles per log) * 0.72 = ~288 + // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each: + // 40 * .72 = ~25. + // + // Finally, there are several other ops that are done every loop along with + // 2 uniform generations along with 5 other ops at 3-6 cycles each. + // ~15 / .89 = ~16 + // + // In total this should be ~529 + 2 * Uniform::kElementCost. + // We assume that half the tensor has rate < 10, so on average 6 + // uniform's + // will be needed. We will upper bound the other op cost by the one for + // rate > 10. + static const int kElementCost = 529 + 6 * Uniform::kElementCost + + 6 * random::PhiloxRandom::kElementCost; + // Assume we use uniform sampling, and accept the 2nd sample on average. + const int64 batch_cost = batch_init_cost + kElementCost * samples_per_batch; + Shard(worker_threads.num_threads, worker_threads.workers, num_batches, + batch_cost, DoWork); + } +}; + +} // namespace functor + +namespace { + +// Samples from a binomial distribution, using the given parameters. +template <typename Device, typename T, typename U> +class RandomBinomialOp : public OpKernel { + // Reshape batches so each batch is this size if possible. + static const int32 kDesiredBatchSize = 100; + + public: + explicit RandomBinomialOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& alg_tensor = ctx->input(1); + const Tensor& shape_tensor = ctx->input(2); + const Tensor& counts_tensor = ctx->input(3); + const Tensor& probs_tensor = ctx->input(4); + + OP_REQUIRES(ctx, alg_tensor.dims() == 0, + errors::InvalidArgument("algorithm must be of shape [], not ", + alg_tensor.shape().DebugString())); + Algorithm alg = alg_tensor.flat<Algorithm>()(0); + + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), + errors::InvalidArgument("Input shape should be a vector, got shape: ", + shape_tensor.shape().DebugString())); + int32 num_batches = shape_tensor.flat<int32>()(0); + + int32 samples_per_batch = 1; + const int32 num_dims = shape_tensor.dim_size(0); + for (int32 i = 1; i < num_dims; i++) { + samples_per_batch *= shape_tensor.flat<int32>()(i); + } + const int32 num_elements = num_batches * samples_per_batch; + + // Allocate the output before fudging num_batches and samples_per_batch. + auto shape_vec = shape_tensor.flat<int32>(); + TensorShape tensor_shape; + OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( + shape_vec.data(), shape_vec.size(), &tensor_shape)); + Tensor* samples_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor)); + + // Parameters must be 0-d or 1-d. + OP_REQUIRES(ctx, counts_tensor.dims() <= 1, + errors::InvalidArgument( + "Input counts should be a scalar or vector, got shape: ", + counts_tensor.shape().DebugString())); + OP_REQUIRES(ctx, probs_tensor.dims() <= 1, + errors::InvalidArgument( + "Input probs should be a scalar or vector, got shape: ", + probs_tensor.shape().DebugString())); + + if ((counts_tensor.dims() == 0 || counts_tensor.dim_size(0) == 1) && + (probs_tensor.dims() == 0 || probs_tensor.dim_size(0) == 1)) { + // All batches have the same parameters, so we can update the batch size + // to a reasonable value to improve parallelism (ensure enough batches, + // and no very small batches which have high overhead). + int32 size = num_batches * samples_per_batch; + int32 adjusted_samples = kDesiredBatchSize; + // Ensure adjusted_batches * adjusted_samples >= size. + int32 adjusted_batches = Eigen::divup(size, adjusted_samples); + num_batches = adjusted_batches; + samples_per_batch = adjusted_samples; + } else { + // Parameters must be broadcastable to the shape [num_batches]. + OP_REQUIRES( + ctx, + TensorShapeUtils::IsScalar(counts_tensor.shape()) || + counts_tensor.dim_size(0) == 1 || + counts_tensor.dim_size(0) == num_batches, + errors::InvalidArgument( + "Input counts should have length 1 or shape[0], got shape: ", + counts_tensor.shape().DebugString())); + OP_REQUIRES( + ctx, + TensorShapeUtils::IsScalar(probs_tensor.shape()) || + probs_tensor.dim_size(0) == 1 || + probs_tensor.dim_size(0) == num_batches, + errors::InvalidArgument( + "Input probs should have length 1 or shape[0], got shape: ", + probs_tensor.shape().DebugString())); + } + Var* var = nullptr; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var)); + + ScopedUnlockUnrefVar var_guard(var); + Tensor* var_tensor = var->tensor(); + OP_REQUIRES( + ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE, + errors::InvalidArgument("dtype of RNG state variable must be ", + DataTypeString(STATE_ELEMENT_DTYPE), ", not ", + DataTypeString(var_tensor->dtype()))); + OP_REQUIRES(ctx, var_tensor->dims() == 1, + errors::InvalidArgument( + "RNG state must have one and only one dimension, not ", + var_tensor->dims())); + auto var_tensor_flat = var_tensor->flat<StateElementType>(); + OP_REQUIRES(ctx, alg == RNG_ALG_PHILOX, + errors::InvalidArgument("Unsupported algorithm id: ", alg)); + static_assert(std::is_same<StateElementType, int64>::value, + "StateElementType must be int64"); + static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value, + "PhiloxRandom::ResultElementType must be uint32"); + OP_REQUIRES(ctx, var_tensor_flat.size() >= PHILOX_MIN_STATE_SIZE, + errors::InvalidArgument( + "For Philox algorithm, the size of state must be at least ", + PHILOX_MIN_STATE_SIZE, "; got ", var_tensor_flat.size())); + + // Each worker has the fudge factor for samples_per_batch, so use it here. + OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>( + ctx, var_tensor, var->copy_on_read_mode.load())); + auto var_data = var_tensor_flat.data(); + auto philox = GetPhiloxRandomFromMem(var_data); + UpdateMemWithPhiloxRandom( + philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data); + var_guard.Release(); + + auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>(); + binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches, + samples_per_batch, num_elements, counts_tensor.flat<T>(), + probs_tensor.flat<T>(), philox, samples_tensor->flat<U>()); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomBinomialOp); +}; + +} // namespace + +#define REGISTER(RTYPE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatefulRandomBinomial") \ + .Device(DEVICE_CPU) \ + .HostMemory("resource") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .HostMemory("counts") \ + .HostMemory("probs") \ + .TypeConstraint<RTYPE>("dtype") \ + .TypeConstraint<TYPE>("T"), \ + RandomBinomialOp<CPUDevice, TYPE, RTYPE>) + +#define REGISTER_ALL(RTYPE) \ + REGISTER(RTYPE, Eigen::half); \ + REGISTER(RTYPE, float); \ + REGISTER(RTYPE, double); + +REGISTER_ALL(Eigen::half); +REGISTER_ALL(float); +REGISTER_ALL(double); +REGISTER_ALL(int32); +REGISTER_ALL(int64); + +#undef REGISTER +#undef REGISTER_ALL + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/random_binomial_op.h b/tensorflow/core/kernels/random_binomial_op.h new file mode 100644 index 00000000000..05c489da83a --- /dev/null +++ b/tensorflow/core/kernels/random_binomial_op.h @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace functor { + +// Sample a binomial random variable, with probs and counts for each batch. +// Uses binomial inversion and a transformed rejection sampling method as +// described in +// https://pdfs.semanticscholar.org/471b/c2726e25bbf8801ef781630a2c13f654268e.pdf. +// Two different algorithms are employed, depending on the size of +// counts * probs (or counts * (1 - probs) if probs > 0.5. +// If counts * probs < 10, we simply sum up Geometric random variables until +// they exceed count, and the number we used is binomially distributed. +// In expectation, this will take O(counts * probs) time, and requiring in +// expectation the same number of random variates. +// This can be much cheaper than summing bernoulli random variates, as we +// will always need O(counts) bernoulli random variates (so this requires fewer +// uniform r.v.s as well as can be faster). +// +// If counts * probs > 10, we use a transformed-rejection algorithm based on +// pairs of uniform random variates due to Hormann. +// https://pdfs.semanticscholar.org/471b/c2726e25bbf8801ef781630a2c13f654268e.pdf +// This algorithm has higher acceptance rates for counts * probs large, as the +// proposal distribution becomes quite tight, requiring approximately two +// uniform random variates as counts * probs becomes large. +template <typename Device, typename T, typename U> +struct RandomBinomialFunctor { + void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches, + int64 samples_per_batch, int64 num_elements, + typename TTypes<T>::ConstFlat counts, + typename TTypes<T>::ConstFlat probs, + const random::PhiloxRandom& gen, + typename TTypes<U>::Flat output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ diff --git a/tensorflow/core/kernels/random_binomial_op_test.cc b/tensorflow/core/kernels/random_binomial_op_test.cc new file mode 100644 index 00000000000..9f8f47ef853 --- /dev/null +++ b/tensorflow/core/kernels/random_binomial_op_test.cc @@ -0,0 +1,107 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <functional> +#include <memory> +#include <vector> + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +static Graph* RandomBinomialGraph(double count, double prob, int num_batches, + int samples_per_batch) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor shape_t(DT_INT32, TensorShape({2})); + shape_t.flat<int32>().setValues({num_batches, samples_per_batch}); + + Tensor counts_t(DT_FLOAT, TensorShape({num_batches})); + counts_t.flat<float>().setConstant(count); + Tensor probs_t(DT_FLOAT, TensorShape({num_batches})); + probs_t.flat<float>().setConstant(prob); + + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("randombinomial"), "RandomBinomial") + .Input(test::graph::Constant(g, shape_t)) + .Input(test::graph::Constant(g, counts_t)) + .Input(test::graph::Constant(g, probs_t)) + .Attr("dtype", DT_FLOAT) + .Finalize(g, &ret)); + return g; +} + +static Graph* RandomBinomialInv(int num_batches, int samples_per_batch) { + // Because counts * probs < 10, we are guaranteed to use inversion. + return RandomBinomialGraph(10., 0.3, num_batches, samples_per_batch); +} + +static Graph* RandomBinomialRej(int num_batches, int samples_per_batch) { + // Because counts * probs > 10, we are guaranteed to use rejection. + return RandomBinomialGraph(100., 0.3, num_batches, samples_per_batch); +} + +static Graph* RandomBinomialInvComplement(int num_batches, + int samples_per_batch) { + // Because counts * (1 - probs) < 10, we are guaranteed to use inversion. + return RandomBinomialGraph(10., 0.8, num_batches, samples_per_batch); +} + +static Graph* RandomBinomialRejComplement(int num_batches, + int samples_per_batch) { + // Because counts * (1 - probs) > 10, we are guaranteed to use inversion. + return RandomBinomialGraph(100., 0.2, num_batches, samples_per_batch); +} + +#define BM_RandomBinomialInv(DEVICE, B, S) \ + static void BM_RandomBinomialInv_##DEVICE##_##B##_##S(int iters) { \ + test::Benchmark(#DEVICE, RandomBinomialInv(B, S)).Run(iters); \ + testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ + } \ + BENCHMARK(BM_RandomBinomialInv_##DEVICE##_##B##_##S); + +#define BM_RandomBinomialRej(DEVICE, B, S) \ + static void BM_RandomBinomialRej_##DEVICE##_##B##_##S(int iters) { \ + test::Benchmark(#DEVICE, RandomBinomialRej(B, S)).Run(iters); \ + testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ + } \ + BENCHMARK(BM_RandomBinomialRej_##DEVICE##_##B##_##S); + +#define BM_RandomBinomialInvComplement(DEVICE, B, S) \ + static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S(int iters) { \ + test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S)).Run(iters); \ + testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ + } \ + BENCHMARK(BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S); + +#define BM_RandomBinomialRejComplement(DEVICE, B, S) \ + static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S(int iters) { \ + test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S)).Run(iters); \ + testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ + } \ + BENCHMARK(BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S); + +BM_RandomBinomialInv(cpu, 1000, 1000); +BM_RandomBinomialRej(cpu, 1000, 1000); +BM_RandomBinomialInvComplement(cpu, 1000, 1000); +BM_RandomBinomialRejComplement(cpu, 1000, 1000); +BM_RandomBinomialInv(gpu, 1000, 1000); +BM_RandomBinomialRej(gpu, 1000, 1000); +BM_RandomBinomialInvComplement(gpu, 1000, 1000); +BM_RandomBinomialRejComplement(gpu, 1000, 1000); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/stateful_random_ops.cc b/tensorflow/core/ops/stateful_random_ops.cc index c351391580c..80e766cd617 100644 --- a/tensorflow/core/ops/stateful_random_ops.cc +++ b/tensorflow/core/ops/stateful_random_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -82,6 +83,29 @@ REGISTER_OP("NonDeterministicInts") return Status::OK(); }); +REGISTER_OP("StatefulRandomBinomial") + .Input("resource: resource") + .Input("algorithm: int64") + .Input("shape: S") + .Input("counts: T") + .Input("probs: T") + .Output("output: dtype") + .Attr("S: {int32, int64}") + .Attr("T: {half, float, double, int32, int64} = DT_DOUBLE") + .Attr("dtype: {half, float, double, int32, int64} = DT_INT64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + using shape_inference::ShapeHandle; + + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); + + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out)); + c->set_output(0, out); + return Status::OK(); + }); + // Register the depracated 'StatefulStandardNormal' op. This op is a short-lived // version where the 'resource' variable also contains the algorithm tag. // It is deprecated in favor of 'StatefulStandardNormalV2'. diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD index 8452982a447..f6afae97791 100644 --- a/tensorflow/python/kernel_tests/random/BUILD +++ b/tensorflow/python/kernel_tests/random/BUILD @@ -155,6 +155,23 @@ cuda_py_test( xla_enable_strict_auto_jit = True, ) +cuda_py_test( + name = "random_binomial_test", + size = "medium", + srcs = ["random_binomial_test.py"], + additional_deps = [ + ":util", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:stateful_random_ops", + ], + xla_enable_strict_auto_jit = True, +) + cuda_py_test( name = "random_poisson_test", size = "medium", @@ -169,5 +186,4 @@ cuda_py_test( "//tensorflow/python:platform", "//tensorflow/python:random_ops", ], - xla_enable_strict_auto_jit = True, ) diff --git a/tensorflow/python/kernel_tests/random/random_binomial_test.py b/tensorflow/python/kernel_tests/random/random_binomial_test.py new file mode 100644 index 00000000000..7214d7ef3c9 --- /dev/null +++ b/tensorflow/python/kernel_tests/random/random_binomial_test.py @@ -0,0 +1,120 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.stateful_random_ops.binomial.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.kernel_tests.random import util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import stateful_random_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + +# All supported dtypes for binomial(). +_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64, + dtypes.int32, dtypes.int64) + + +class RandomBinomialTest(test.TestCase): + """This is a large test due to the moments computation taking some time.""" + + def _Sampler(self, num, counts, probs, dtype, seed=None): + + def func(): + rng = stateful_random_ops.Generator(seed=seed).binomial( + shape=[10 * num], counts=counts, probs=probs, dtype=dtype) + ret = array_ops.reshape(rng, [10, num]) + ret = self.evaluate(ret) + return ret + + return func + + @test_util.run_v2_only + def testMoments(self): + try: + from scipy import stats # pylint: disable=g-import-not-at-top + except ImportError as e: + tf_logging.warn("Cannot test moments: %s", e) + return + # The moments test is a z-value test. This is the largest z-value + # we want to tolerate. Since the z-test approximates a unit normal + # distribution, it should almost definitely never exceed 6. + z_limit = 6.0 + for dt in _SUPPORTED_DTYPES: + # Test when n * p > 10, and n * p < 10 + for stride in 0, 4, 10: + for counts in (1., 10., 22., 50.): + for prob in (0.1, 0.5, 0.8): + sampler = self._Sampler(int(1e5), counts, prob, dt, seed=12345) + z_scores = util.test_moment_matching( + # Use float64 samples. + sampler().astype(np.float64), + number_moments=6, + dist=stats.binom(counts, prob), + stride=stride, + ) + self.assertAllLess(z_scores, z_limit) + + @test_util.run_v2_only + def testSeed(self): + for dt in dtypes.float16, dtypes.float32, dtypes.float64: + sx = self._Sampler(1000, counts=10., probs=0.4, dtype=dt, seed=345) + sy = self._Sampler(1000, counts=10., probs=0.4, dtype=dt, seed=345) + self.assertAllEqual(sx(), sy()) + + def testZeroShape(self): + rnd = stateful_random_ops.Generator(seed=12345).binomial([0], [], []) + self.assertEqual([0], rnd.shape.as_list()) + + def testShape(self): + rng = stateful_random_ops.Generator(seed=12345) + # Scalar parameters. + rnd = rng.binomial(shape=[10], counts=np.float32(2.), probs=np.float32(0.5)) + self.assertEqual([10], rnd.shape.as_list()) + + # Vector parameters. + rnd = rng.binomial( + shape=[10], + counts=array_ops.ones([10], dtype=np.float32), + probs=0.3 * array_ops.ones([10], dtype=np.float32)) + self.assertEqual([10], rnd.shape.as_list()) + rnd = rng.binomial( + shape=[2, 5], + counts=array_ops.ones([2], dtype=np.float32), + probs=0.4 * array_ops.ones([2], dtype=np.float32)) + self.assertEqual([2, 5], rnd.shape.as_list()) + + # Scalar counts, vector probs. + rnd = rng.binomial( + shape=[10], + counts=np.float32(5.), + probs=0.8 * array_ops.ones([10], dtype=np.float32)) + self.assertEqual([10], rnd.shape.as_list()) + + # Vector counts, scalar probs. + rnd = rng.binomial( + shape=[10], + counts=array_ops.ones([10], dtype=np.float32), + probs=np.float32(0.9)) + self.assertEqual([10], rnd.shape.as_list()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/stateful_random_ops.py b/tensorflow/python/ops/stateful_random_ops.py index ca92fe006f2..9f8884224e9 100644 --- a/tensorflow/python/ops/stateful_random_ops.py +++ b/tensorflow/python/ops/stateful_random_ops.py @@ -368,6 +368,51 @@ class Generator(tracking.AutoTrackable): self.state.handle, self.algorithm, shape=shape, dtype=dtype, name=name) + def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None): + """Outputs random values from a binomial distribution. + + The generated values follow a binomial distribution with specified count and + probability of success parameters. + + Example: + + ```python + counts = [10., 20.] + # Probability of success. + probs = [0.8, 0.9] + + rng = tf.random.experimental.Generator(seed=234) + binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs) + ``` + + + Args: + shape: A 1-D integer Tensor or Python array. The shape of the output + tensor. + counts: A 0/1-D Tensor or Python value`. The counts of the binomial + distribution. + probs: A 0/1-D Tensor or Python value`. The probability of success for the + binomial distribution. + dtype: The type of the output. Default: tf.int32 + name: A name for the operation (optional). + + Returns: + A tensor of the specified shape filled with random binomial values. + """ + dtype = dtypes.as_dtype(dtype) + with ops.name_scope(name, "binomial", [shape, counts, probs]) as name: + counts = ops.convert_to_tensor(counts, name="counts") + probs = ops.convert_to_tensor(probs, name="probs") + shape_tensor = _shape_tensor(shape) + return gen_stateful_random_ops.stateful_random_binomial( + self.state.handle, + self.algorithm, + shape=shape_tensor, + counts=counts, + probs=probs, + dtype=dtype, + name=name) + # TODO(wangpeng): implement other distributions def _make_int64_keys(self, shape=()): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.-generator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.-generator.pbtxt index 012e4a87079..98b3e8220ab 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.-generator.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.random.experimental.-generator.pbtxt @@ -16,6 +16,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'copy_from\', \'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } + member_method { + name: "binomial" + argspec: "args=[\'self\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], " + } member_method { name: "make_seeds" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index b3da1e0af23..3902aa5fe25 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -3708,6 +3708,10 @@ tf_module { name: "StatefulPartitionedCall" argspec: "args=[\'args\', \'Tout\', \'f\', \'config\', \'config_proto\', \'executor_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], " } + member_method { + name: "StatefulRandomBinomial" + argspec: "args=[\'resource\', \'algorithm\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], " + } member_method { name: "StatefulStandardNormal" argspec: "args=[\'resource\', \'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.-generator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.-generator.pbtxt index 012e4a87079..98b3e8220ab 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.-generator.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.random.experimental.-generator.pbtxt @@ -16,6 +16,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'copy_from\', \'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } + member_method { + name: "binomial" + argspec: "args=[\'self\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], " + } member_method { name: "make_seeds" argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index b3da1e0af23..3902aa5fe25 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -3708,6 +3708,10 @@ tf_module { name: "StatefulPartitionedCall" argspec: "args=[\'args\', \'Tout\', \'f\', \'config\', \'config_proto\', \'executor_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], " } + member_method { + name: "StatefulRandomBinomial" + argspec: "args=[\'resource\', \'algorithm\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], " + } member_method { name: "StatefulStandardNormal" argspec: "args=[\'resource\', \'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "