Add tf.random_poisson(shape, lam) to tf core.
Fixes #6798 Change: 146861107
This commit is contained in:
parent
f6e6984833
commit
8b07605e45
tensorflow
@ -170,6 +170,56 @@ class PoissonTest(test.TestCase):
|
||||
self.assertEqual((6,), poisson.mode().get_shape())
|
||||
self.assertAllClose(lam_v, poisson.mode().eval())
|
||||
|
||||
def testPoissonSample(self):
|
||||
with self.test_session():
|
||||
lam_v = 4.0
|
||||
lam = constant_op.constant(lam_v)
|
||||
# Choosing `n >= (k/rtol)**2, roughly ensures our sample mean should be
|
||||
# within `k` std. deviations of actual up to rtol precision.
|
||||
n = int(100e3)
|
||||
poisson = poisson_lib.Poisson(rate=lam)
|
||||
samples = poisson.sample(n, seed=123456)
|
||||
sample_values = samples.eval()
|
||||
self.assertEqual(samples.get_shape(), (n,))
|
||||
self.assertEqual(sample_values.shape, (n,))
|
||||
self.assertAllClose(
|
||||
sample_values.mean(), stats.poisson.mean(lam_v), rtol=.01)
|
||||
self.assertAllClose(
|
||||
sample_values.var(), stats.poisson.var(lam_v), rtol=.01)
|
||||
|
||||
def testPoissonSampleMultidimensionalMean(self):
|
||||
with self.test_session():
|
||||
lam_v = np.array([np.arange(1, 51, dtype=np.float32)]) # 1 x 50
|
||||
poisson = poisson_lib.Poisson(rate=lam_v)
|
||||
# Choosing `n >= (k/rtol)**2, roughly ensures our sample mean should be
|
||||
# within `k` std. deviations of actual up to rtol precision.
|
||||
n = int(100e3)
|
||||
samples = poisson.sample(n, seed=123456)
|
||||
sample_values = samples.eval()
|
||||
self.assertEqual(samples.get_shape(), (n, 1, 50))
|
||||
self.assertEqual(sample_values.shape, (n, 1, 50))
|
||||
self.assertAllClose(
|
||||
sample_values.mean(axis=0),
|
||||
stats.poisson.mean(lam_v),
|
||||
rtol=.01,
|
||||
atol=0)
|
||||
|
||||
def testPoissonSampleMultidimensionalVariance(self):
|
||||
with self.test_session():
|
||||
lam_v = np.array([np.arange(5, 15, dtype=np.float32)]) # 1 x 10
|
||||
poisson = poisson_lib.Poisson(rate=lam_v)
|
||||
# Choosing `n >= 2 * lam * (k/rtol)**2, roughly ensures our sample
|
||||
# variance should be within `k` std. deviations of actual up to rtol
|
||||
# precision.
|
||||
n = int(300e3)
|
||||
samples = poisson.sample(n, seed=123456)
|
||||
sample_values = samples.eval()
|
||||
self.assertEqual(samples.get_shape(), (n, 1, 10))
|
||||
self.assertEqual(sample_values.shape, (n, 1, 10))
|
||||
|
||||
self.assertAllClose(
|
||||
sample_values.var(axis=0), stats.poisson.var(lam_v), rtol=.03, atol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
from tensorflow.python.ops import random_ops
|
||||
|
||||
__all__ = [
|
||||
"Poisson",
|
||||
@ -148,6 +148,10 @@ class Poisson(distribution.Distribution):
|
||||
def _mode(self):
|
||||
return math_ops.floor(self.rate)
|
||||
|
||||
def _sample_n(self, n, seed=None):
|
||||
return random_ops.random_poisson(
|
||||
self.rate, [n], dtype=self.dtype, seed=seed)
|
||||
|
||||
def _assert_valid_sample(self, x, check_integer=True):
|
||||
if not self.validate_args:
|
||||
return x
|
||||
|
@ -635,6 +635,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:parameterized_truncated_normal_op",
|
||||
"//tensorflow/core/kernels:parsing",
|
||||
"//tensorflow/core/kernels:random_ops",
|
||||
"//tensorflow/core/kernels:random_poisson_op",
|
||||
"//tensorflow/core/kernels:remote_fused_graph_ops",
|
||||
"//tensorflow/core/kernels:required",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
|
@ -219,6 +219,16 @@ Node* RandomGamma(Graph* g, Node* shape, Node* alpha) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* RandomPoisson(Graph* g, Node* shape, Node* lam) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson")
|
||||
.Input(shape)
|
||||
.Input(lam)
|
||||
.Attr("seed", 0)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* Unary(Graph* g, const string& func, Node* input, int index) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
|
||||
|
@ -113,6 +113,10 @@ Node* RandomGaussian(Graph* g, Node* input, DataType dtype);
|
||||
// Output dtype determined by alpha.
|
||||
Node* RandomGamma(Graph* g, Node* shape, Node* alpha);
|
||||
|
||||
// Generates random poisson distribution with the given shape and lam[s].
|
||||
// Output dtype determined by lam.
|
||||
Node* RandomPoisson(Graph* g, Node* shape, Node* lam);
|
||||
|
||||
// Generates random parameters from the truncated standard normal distribution
|
||||
// of the nput shape
|
||||
Node* TruncatedNormal(Graph* g, Node* input, DataType dtype);
|
||||
|
@ -3361,6 +3361,33 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "random_poisson_op",
|
||||
prefix = "random_poisson_op",
|
||||
deps = [
|
||||
":random_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:random_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "random_poisson_op_test",
|
||||
size = "small",
|
||||
srcs = ["random_poisson_op_test.cc"],
|
||||
deps = [
|
||||
":ops_util",
|
||||
":random_poisson_op",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "word2vec_kernels",
|
||||
prefix = "word2vec_kernels",
|
||||
|
@ -541,7 +541,7 @@ TF_CALL_int64(REGISTER_INT);
|
||||
PhiloxRandomOp< \
|
||||
GPUDevice, \
|
||||
random::TruncatedNormalDistribution< \
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
|
||||
|
||||
#define REGISTER_INT(IntType) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
|
||||
|
357
tensorflow/core/kernels/random_poisson_op.cc
Normal file
357
tensorflow/core/kernels/random_poisson_op.cc
Normal file
@ -0,0 +1,357 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/random_ops.cc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/random_poisson_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||
#else
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||
#endif
|
||||
|
||||
#define UNIFORM(X) \
|
||||
if (uniform_remaining == 0) { \
|
||||
uniform_remaining = Uniform::kResultElementCount; \
|
||||
uniform_result = uniform(&gen); \
|
||||
} \
|
||||
uniform_remaining--; \
|
||||
CT X = uniform_result[uniform_remaining]
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static constexpr int kReservedSamplesPerOutput = 256;
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
// We will compute half-precision Poisson samples with float precision
|
||||
// intermediate calculations.
|
||||
template <typename T>
|
||||
struct PoissonComputeType {
|
||||
typedef T ComputeType;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PoissonComputeType<Eigen::half> {
|
||||
typedef float ComputeType;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct PoissonFunctor {
|
||||
void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
|
||||
int num_rate, int num_samples,
|
||||
const random::PhiloxRandom& rng, T* samples_flat);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct PoissonFunctor<CPUDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat,
|
||||
int num_rate, int num_samples,
|
||||
const random::PhiloxRandom& rng, T* samples_flat) {
|
||||
// Two different algorithms are employed, depending on the size of
|
||||
// rate.
|
||||
// If rate < 10, we use an algorithm attributed to Knuth:
|
||||
// Seminumerical Algorithms. Art of Computer Programming, Volume 2.
|
||||
//
|
||||
// This algorithm runs in O(rate) time, and will require O(rate)
|
||||
// uniform
|
||||
// variates.
|
||||
//
|
||||
// If rate >= 10 we use a transformation-rejection algorithm from
|
||||
// pairs
|
||||
// of uniform random variables due to Hormann.
|
||||
// http://www.sciencedirect.com/science/article/pii/0167668793909974
|
||||
//
|
||||
// The algorithm has an acceptance rate of ~89% for the smallest rate
|
||||
// (~10),
|
||||
// and higher accept rates for higher rate, so runtime is
|
||||
// O(NumRate * NumSamples * k) with k ~ 1 / 0.89.
|
||||
//
|
||||
// We partition work first across rates then across
|
||||
// samples-per-rate to
|
||||
// avoid a couple flops which can be done on a per-rate basis.
|
||||
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, CT> Uniform;
|
||||
|
||||
auto DoWork = [num_samples, num_rate, &rng, samples_flat, rate_flat](
|
||||
int start_output, int limit_output) {
|
||||
// Capturing "rng" by value would only make a copy for the _shared_
|
||||
// lambda. Since we want to let each worker have its own copy, we pass
|
||||
// "rng" by reference and explicitly do a copy assignment.
|
||||
|
||||
Uniform uniform;
|
||||
typename Uniform::ResultType uniform_result;
|
||||
for (int64 output_idx = start_output; output_idx < limit_output;
|
||||
/* output_idx incremented within inner loop below */) {
|
||||
const int64 rate_idx = output_idx / num_samples;
|
||||
|
||||
// Several calculations can be done on a per-rate basis.
|
||||
const CT rate = CT(rate_flat[rate_idx]);
|
||||
|
||||
auto samples_rate_output = samples_flat + rate_idx;
|
||||
|
||||
if (rate < CT(10)) {
|
||||
// Knuth's algorithm for generating Poisson random variates.
|
||||
// Given a Poisson process, the time between events is exponentially
|
||||
// distributed. If we have a Poisson process with rate lambda, then,
|
||||
// the time between events is distributed Exp(lambda). If X ~
|
||||
// Uniform(0, 1), then Y ~ Exp(lambda), where Y = -log(X) / lambda.
|
||||
// Thus to simulate a Poisson draw, we can draw X_i ~ Exp(lambda),
|
||||
// and N ~ Poisson(lambda), where N is the least number such that
|
||||
// \sum_i^N X_i > 1.
|
||||
const CT exp_neg_rate = Eigen::numext::exp(-rate);
|
||||
|
||||
// Compute the rest of the samples for the current rate value.
|
||||
for (int64 sample_idx = output_idx % num_samples;
|
||||
sample_idx < num_samples && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
random::PhiloxRandom gen = rng;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
int16 uniform_remaining = 0;
|
||||
|
||||
CT prod = 1;
|
||||
CT x = 0;
|
||||
|
||||
// Keep trying until we surpass e^(-rate). This will take
|
||||
// expected time proportional to rate.
|
||||
while (true) {
|
||||
UNIFORM(u);
|
||||
prod = prod * u;
|
||||
if (prod <= exp_neg_rate) {
|
||||
samples_rate_output[sample_idx * num_rate] = T(x);
|
||||
break;
|
||||
}
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// Transformed rejection due to Hormann.
|
||||
//
|
||||
// Given a CDF F(x), and G(x), a dominating distribution chosen such
|
||||
// that it is close to the inverse CDF F^-1(x), compute the following
|
||||
// steps:
|
||||
//
|
||||
// 1) Generate U and V, two independent random variates. Set U = U - 0.5
|
||||
// (this step isn't strictly necessary, but is done to make some
|
||||
// calculations symmetric and convenient. Henceforth, G is defined on
|
||||
// [-0.5, 0.5]).
|
||||
//
|
||||
// 2) If V <= alpha * F'(G(U)) * G'(U), return floor(G(U)), else return
|
||||
// to step 1. alpha is the acceptance probability of the rejection
|
||||
// algorithm.
|
||||
//
|
||||
// For more details on transformed rejection, see:
|
||||
// http://citeseer.ist.psu.edu/viewdoc/citations;jsessionid=1BEB35946CC807879F55D42512E5490C?doi=10.1.1.48.3054.
|
||||
//
|
||||
// The dominating distribution in this case:
|
||||
//
|
||||
// G(u) = (2 * a / (2 - |u|) + b) * u + c
|
||||
|
||||
using Eigen::numext::log;
|
||||
const CT log_rate = log(rate);
|
||||
|
||||
// Constants used to define the dominating distribution. Names taken
|
||||
// from Hormann's paper. Constants were chosen to define the tightest
|
||||
// G(u) for the inverse Poisson CDF.
|
||||
const CT b = CT(0.931) + CT(2.53) * Eigen::numext::sqrt(rate);
|
||||
const CT a = CT(-0.059) + CT(0.02483) * b;
|
||||
|
||||
// This is the inverse acceptance rate. At a minimum (when rate = 10),
|
||||
// this corresponds to ~75% acceptance. As the rate becomes larger, this
|
||||
// approaches ~89%.
|
||||
const CT inv_alpha = CT(1.1239) + CT(1.1328) / (b - CT(3.4));
|
||||
|
||||
// Compute the rest of the samples for the current rate value.
|
||||
for (int64 sample_idx = output_idx % num_samples;
|
||||
sample_idx < num_samples && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
random::PhiloxRandom gen = rng;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
int16 uniform_remaining = 0;
|
||||
|
||||
while (true) {
|
||||
UNIFORM(u);
|
||||
u -= CT(0.5);
|
||||
UNIFORM(v);
|
||||
|
||||
CT u_shifted = CT(0.5) - Eigen::numext::abs(u);
|
||||
CT k = Eigen::numext::floor((CT(2) * a / u_shifted + b) * u + rate +
|
||||
CT(0.43));
|
||||
|
||||
// When alpha * f(G(U)) * G'(U) is close to 1, it is possible to
|
||||
// find a rectangle (-u_r, u_r) x (0, v_r) under the curve, such
|
||||
// that if v <= v_r and |u| <= u_r, then we can accept.
|
||||
// Here v_r = 0.9227 - 3.6224 / (b - 2) and u_r = 0.43.
|
||||
if (u_shifted >= CT(0.07) &&
|
||||
v <= CT(0.9277) - CT(3.6224) / (b - CT(2))) {
|
||||
samples_rate_output[sample_idx * num_rate] = T(k);
|
||||
break;
|
||||
}
|
||||
|
||||
if (k < 0 || (u_shifted < CT(0.013) && v > u_shifted)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The expression below is equivalent to the computation of step 2)
|
||||
// in transformed rejection (v <= alpha * F'(G(u)) * G'(u)).
|
||||
CT s = log(v * inv_alpha / (a / (u_shifted * u_shifted) + b));
|
||||
CT t = -rate + k * log_rate - Eigen::numext::lgamma(k + 1);
|
||||
if (s <= t) {
|
||||
samples_rate_output[sample_idx * num_rate] = T(k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// This will depend on rate.
|
||||
// For rate < 10, on average, O(rate) calls to uniform are
|
||||
// needed, with that
|
||||
// many multiplies. ~10 uniform calls on average with ~25 cost op calls.
|
||||
//
|
||||
// Very roughly, for rate >= 10, the single call to log + call to
|
||||
// lgamma
|
||||
// occur for ~60 percent of samples.
|
||||
// 2 x 100 (64-bit cycles per log) * 0.62 = ~124
|
||||
// Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each:
|
||||
// 40 * .62 = ~25.
|
||||
//
|
||||
// Finally, there are several other ops that are done every loop along with
|
||||
// 2 uniform generations along with 5 other ops at 3-6 cycles each.
|
||||
// ~15 / .89 = ~16
|
||||
//
|
||||
// In total this should be ~165 + 2 * Uniform::kElementCost.
|
||||
// We assume that half the tensor has rate < 10, so on average 6
|
||||
// uniform's
|
||||
// will be needed. We will upper bound the other op cost by the one for
|
||||
// rate > 10.
|
||||
static const int kElementCost = 165 + 6 * Uniform::kElementCost +
|
||||
6 * random::PhiloxRandom::kElementCost;
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||
num_rate * num_samples, kElementCost, DoWork);
|
||||
}
|
||||
|
||||
private:
|
||||
typedef typename PoissonComputeType<T>::ComputeType CT;
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
namespace {
|
||||
|
||||
// Samples from one or more Poisson distributions.
|
||||
template <typename T>
|
||||
class RandomPoissonOp : public OpKernel {
|
||||
public:
|
||||
explicit RandomPoissonOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, generator_.Init(context));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& shape_t = ctx->input(0);
|
||||
const Tensor& rate_t = ctx->input(1);
|
||||
|
||||
OP_REQUIRES(ctx,
|
||||
TensorShapeUtils::IsVector(shape_t.shape()) &&
|
||||
(shape_t.dtype() == DataType::DT_INT32 ||
|
||||
shape_t.dtype() == DataType::DT_INT64),
|
||||
errors::InvalidArgument(
|
||||
"shape must be a vector of {int32,int64}, got shape: ",
|
||||
shape_t.DebugString()));
|
||||
TensorShape samples_shape;
|
||||
if (shape_t.dtype() == DataType::DT_INT32) {
|
||||
auto vec = shape_t.flat<int32>();
|
||||
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
|
||||
&samples_shape));
|
||||
} else if (shape_t.dtype() == DataType::DT_INT64) {
|
||||
auto vec = shape_t.flat<int64>();
|
||||
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
|
||||
&samples_shape));
|
||||
}
|
||||
const int64 num_samples = samples_shape.num_elements();
|
||||
OP_REQUIRES(ctx, num_samples > 0,
|
||||
errors::InvalidArgument(
|
||||
"Input shape should have non-zero element count, got: ",
|
||||
num_samples));
|
||||
|
||||
samples_shape.AppendShape(rate_t.shape());
|
||||
// Allocate output samples.
|
||||
Tensor* samples_t = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
|
||||
|
||||
const auto rate_flat = rate_t.flat<T>().data();
|
||||
const int64 num_rate = rate_t.NumElements();
|
||||
OP_REQUIRES(
|
||||
ctx, num_rate > 0,
|
||||
errors::InvalidArgument(
|
||||
"Input rate should have non-zero element count, got: ", num_rate));
|
||||
auto samples_flat = samples_t->flat<T>().data();
|
||||
random::PhiloxRandom rng = generator_.ReserveRandomOutputs(
|
||||
num_samples * num_rate, kReservedSamplesPerOutput);
|
||||
|
||||
functor::PoissonFunctor<CPUDevice, T>()(ctx, ctx->eigen_device<CPUDevice>(),
|
||||
rate_flat, num_rate, num_samples,
|
||||
rng, samples_flat);
|
||||
}
|
||||
|
||||
private:
|
||||
GuardedPhiloxRandom generator_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RandomPoissonOp);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
#undef UNIFORM
|
||||
|
||||
#define REGISTER(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RandomPoisson").Device(DEVICE_CPU).TypeConstraint<TYPE>("dtype"), \
|
||||
RandomPoissonOp<TYPE>);
|
||||
|
||||
TF_CALL_half(REGISTER);
|
||||
TF_CALL_float(REGISTER);
|
||||
TF_CALL_double(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
} // end namespace tensorflow
|
31
tensorflow/core/kernels/random_poisson_op.h
Normal file
31
tensorflow/core/kernels/random_poisson_op.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
|
||||
#define TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Generic helper functor for the Random Poisson Op.
|
||||
template <typename Device, typename T>
|
||||
struct PoissonFunctor;
|
||||
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
|
82
tensorflow/core/kernels/random_poisson_op_test.cc
Normal file
82
tensorflow/core/kernels/random_poisson_op_test.cc
Normal file
@ -0,0 +1,82 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Tensor VecShape(int64 v) {
|
||||
if (v >= std::numeric_limits<int32>::max()) {
|
||||
Tensor shape(DT_INT64, TensorShape({1}));
|
||||
shape.vec<int64>()(0) = v;
|
||||
return shape;
|
||||
} else {
|
||||
Tensor shape(DT_INT32, TensorShape({1}));
|
||||
shape.vec<int32>()(0) = v;
|
||||
return shape;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor VecLam32(int64 n, int magnitude) {
|
||||
std::mt19937 gen(0x12345);
|
||||
std::uniform_real_distribution<float> dist(0.0, 1.0);
|
||||
Tensor lams(DT_FLOAT, TensorShape({n}));
|
||||
for (int i = 0; i < n; i++) {
|
||||
// Generate in range (magnitude, 2 * magnitude)
|
||||
lams.vec<float>()(i) = magnitude * (1 + dist(gen));
|
||||
}
|
||||
return lams;
|
||||
}
|
||||
|
||||
Tensor VecLam64(int64 n, int magnitude) {
|
||||
std::mt19937 gen(0x12345);
|
||||
std::uniform_real_distribution<double> dist(0.0, 1.0);
|
||||
Tensor lams(DT_DOUBLE, TensorShape({n}));
|
||||
for (int i = 0; i < n; i++) {
|
||||
// Generate in range (magnitude, 2 * magnitude)
|
||||
lams.vec<double>()(i) = magnitude * (1 + dist(gen));
|
||||
}
|
||||
return lams;
|
||||
}
|
||||
|
||||
#define BM_Poisson(DEVICE, BITS, MAGNITUDE) \
|
||||
static void BM_##DEVICE##_RandomPoisson_lam_##MAGNITUDE##_##BITS( \
|
||||
int iters, int nsamp, int nlam) { \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * nsamp * nlam); \
|
||||
Graph* g = new Graph(OpRegistry::Global()); \
|
||||
test::graph::RandomPoisson( \
|
||||
g, test::graph::Constant(g, VecShape(nsamp)), \
|
||||
test::graph::Constant(g, VecLam##BITS(nlam, MAGNITUDE))); \
|
||||
test::Benchmark(#DEVICE, g).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_##DEVICE##_RandomPoisson_lam_##MAGNITUDE##_##BITS) \
|
||||
->RangePair(1, 64, 2, 50);
|
||||
|
||||
BM_Poisson(cpu, 32, 1);
|
||||
BM_Poisson(cpu, 32, 8);
|
||||
BM_Poisson(cpu, 32, 32);
|
||||
|
||||
BM_Poisson(cpu, 64, 1);
|
||||
BM_Poisson(cpu, 64, 8);
|
||||
BM_Poisson(cpu, 64, 32);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -276,4 +276,48 @@ output: A tensor with shape `shape + shape(alpha)`. Each slice
|
||||
`alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("RandomPoisson")
|
||||
.SetIsStateful()
|
||||
.Input("shape: S")
|
||||
.Input("rate: dtype")
|
||||
.Output("output: dtype")
|
||||
.Attr("seed: int = 0")
|
||||
.Attr("seed2: int = 0")
|
||||
.Attr("S: {int32, int64}")
|
||||
.Attr("dtype: {half, float, double}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Outputs random values from the Poisson distribution(s) described by rate.
|
||||
|
||||
This op uses two algorithms, depending on rate. If rate >= 10, then
|
||||
the algorithm by Hormann is used to acquire samples via
|
||||
transformation-rejection.
|
||||
See http://www.sciencedirect.com/science/article/pii/0167668793909974.
|
||||
|
||||
Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
|
||||
random variables.
|
||||
See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
|
||||
Programming, Volume 2. Addison Wesley
|
||||
|
||||
shape: 1-D integer tensor. Shape of independent samples to draw from each
|
||||
distribution described by the shape parameters given in rate.
|
||||
rate: A tensor in which each scalar is a "rate" parameter describing the
|
||||
associated poisson distribution.
|
||||
seed: If either `seed` or `seed2` are set to be non-zero, the random number
|
||||
generator is seeded by the given seed. Otherwise, it is seeded by a
|
||||
random seed.
|
||||
seed2: A second seed to avoid seed collision.
|
||||
|
||||
output: A tensor with shape `shape + shape(rate)`. Each slice
|
||||
`[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
|
||||
`rate[i0, i1, ...iN]`. The dtype of the output matches the dtype of
|
||||
rate.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -53,4 +53,19 @@ TEST(RandomOpsTest, RandomGamma_ShapeFn) {
|
||||
INFER_OK(op, "[3];[]", "[1,2,3]");
|
||||
}
|
||||
|
||||
TEST(RandomOpsTest, RandomPoisson_ShapeFn) {
|
||||
ShapeInferenceTestOp op("RandomPoisson");
|
||||
op.input_tensors.resize(2);
|
||||
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "?;[3]", "?");
|
||||
INFER_OK(op, "[1];?", "?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,2];[3,4]");
|
||||
Tensor shape = test::AsTensor<int32>({1, 2, 3});
|
||||
op.input_tensors[0] = &shape;
|
||||
INFER_OK(op, "[3];[4,?]", "[1,2,3,d1_0,d1_1]");
|
||||
INFER_OK(op, "[3];[4,5]", "[1,2,3,d1_0,d1_1]");
|
||||
INFER_OK(op, "[3];[]", "[1,2,3]");
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -96,6 +96,7 @@ print(sess.run(var))
|
||||
@@random_crop
|
||||
@@multinomial
|
||||
@@random_gamma
|
||||
@@random_poisson
|
||||
@@set_random_seed
|
||||
"""
|
||||
|
||||
|
@ -2112,6 +2112,21 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "random_poisson_test",
|
||||
size = "medium",
|
||||
srcs = ["random_poisson_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "rnn_test",
|
||||
size = "medium",
|
||||
|
178
tensorflow/python/kernel_tests/random_poisson_test.py
Normal file
178
tensorflow/python/kernel_tests/random_poisson_test.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.random_ops.random_poisson."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
|
||||
class RandomPoissonTest(test.TestCase):
|
||||
"""This is a large test due to the moments computation taking some time."""
|
||||
|
||||
def _Sampler(self, num, lam, dtype, use_gpu, seed=None):
|
||||
|
||||
def func():
|
||||
with self.test_session(use_gpu=use_gpu, graph=ops.Graph()) as sess:
|
||||
rng = random_ops.random_poisson(lam, [num], dtype=dtype, seed=seed)
|
||||
ret = np.empty([10, num])
|
||||
for i in xrange(10):
|
||||
ret[i, :] = sess.run(rng)
|
||||
return ret
|
||||
|
||||
return func
|
||||
|
||||
# TODO(srvasude): Factor this out along with the corresponding moment testing
|
||||
# method in random_gamma_test into a single library.
|
||||
def testMoments(self):
|
||||
try:
|
||||
from scipy import stats # pylint: disable=g-import-not-at-top
|
||||
except ImportError as e:
|
||||
tf_logging.warn("Cannot test moments: %s", e)
|
||||
return
|
||||
# The moments test is a z-value test. This is the largest z-value
|
||||
# we want to tolerate. Since the z-test approximates a unit normal
|
||||
# distribution, it should almost definitely never exceed 6.
|
||||
z_limit = 6.0
|
||||
for dt in dtypes.float16, dtypes.float32, dtypes.float64:
|
||||
# Test when lam < 10 and when lam >= 10
|
||||
for stride in 0, 4, 10:
|
||||
for lam in (3., 20):
|
||||
max_moment = 5
|
||||
sampler = self._Sampler(10000, lam, dt, use_gpu=False, seed=12345)
|
||||
moments = [0] * (max_moment + 1)
|
||||
moments_sample_count = [0] * (max_moment + 1)
|
||||
x = np.array(sampler().flat) # sampler does 10x samples
|
||||
for k in range(len(x)):
|
||||
moment = 1.
|
||||
for i in range(max_moment + 1):
|
||||
index = k + i * stride
|
||||
if index >= len(x):
|
||||
break
|
||||
moments[i] += moment
|
||||
moments_sample_count[i] += 1
|
||||
moment *= x[index]
|
||||
for i in range(max_moment + 1):
|
||||
moments[i] /= moments_sample_count[i]
|
||||
for i in range(1, max_moment + 1):
|
||||
g = stats.poisson(lam)
|
||||
if stride == 0:
|
||||
moments_i_mean = g.moment(i)
|
||||
moments_i_squared = g.moment(2 * i)
|
||||
else:
|
||||
moments_i_mean = pow(g.moment(1), i)
|
||||
moments_i_squared = pow(g.moment(2), i)
|
||||
moments_i_var = (
|
||||
moments_i_squared - moments_i_mean * moments_i_mean)
|
||||
# Assume every operation has a small numerical error.
|
||||
# It takes i multiplications to calculate one i-th moment.
|
||||
error_per_moment = i * 1e-6
|
||||
total_variance = (
|
||||
moments_i_var / moments_sample_count[i] + error_per_moment)
|
||||
if not total_variance:
|
||||
total_variance = 1e-10
|
||||
# z_test is approximately a unit normal distribution.
|
||||
z_test = abs(
|
||||
(moments[i] - moments_i_mean) / np.sqrt(total_variance))
|
||||
self.assertLess(z_test, z_limit)
|
||||
|
||||
# Checks that the CPU and GPU implementation returns the same results,
|
||||
# given the same random seed
|
||||
def testCPUGPUMatch(self):
|
||||
for dt in dtypes.float16, dtypes.float32, dtypes.float64:
|
||||
results = {}
|
||||
for use_gpu in [False, True]:
|
||||
sampler = self._Sampler(1000, 1.0, dt, use_gpu=use_gpu, seed=12345)
|
||||
results[use_gpu] = sampler()
|
||||
if dt == dtypes.float16:
|
||||
self.assertAllClose(results[False], results[True], rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
self.assertAllClose(results[False], results[True], rtol=1e-6, atol=1e-6)
|
||||
|
||||
def testSeed(self):
|
||||
for dt in dtypes.float16, dtypes.float32, dtypes.float64:
|
||||
sx = self._Sampler(1000, 1.0, dt, use_gpu=True, seed=345)
|
||||
sy = self._Sampler(1000, 1.0, dt, use_gpu=True, seed=345)
|
||||
self.assertAllEqual(sx(), sy())
|
||||
|
||||
def testNoCSE(self):
|
||||
"""CSE = constant subexpression eliminator.
|
||||
|
||||
SetIsStateful() should prevent two identical random ops from getting
|
||||
merged.
|
||||
"""
|
||||
for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
|
||||
with self.test_session(use_gpu=True):
|
||||
rnd1 = random_ops.random_poisson(2.0, [24], dtype=dtype)
|
||||
rnd2 = random_ops.random_poisson(2.0, [24], dtype=dtype)
|
||||
diff = rnd2 - rnd1
|
||||
# Since these are all positive integers, the norm will
|
||||
# be at least 1 if they are different.
|
||||
self.assertGreaterEqual(np.linalg.norm(diff.eval()), 1)
|
||||
|
||||
def testShape(self):
|
||||
# Fully known shape.
|
||||
rnd = random_ops.random_poisson(2.0, [150], seed=12345)
|
||||
self.assertEqual([150], rnd.get_shape().as_list())
|
||||
rnd = random_ops.random_poisson(
|
||||
lam=array_ops.ones([1, 2, 3]),
|
||||
shape=[150],
|
||||
seed=12345)
|
||||
self.assertEqual([150, 1, 2, 3], rnd.get_shape().as_list())
|
||||
rnd = random_ops.random_poisson(
|
||||
lam=array_ops.ones([1, 2, 3]),
|
||||
shape=[20, 30],
|
||||
seed=12345)
|
||||
self.assertEqual([20, 30, 1, 2, 3], rnd.get_shape().as_list())
|
||||
rnd = random_ops.random_poisson(
|
||||
lam=array_ops.placeholder(dtypes.float32, shape=(2,)),
|
||||
shape=[12],
|
||||
seed=12345)
|
||||
self.assertEqual([12, 2], rnd.get_shape().as_list())
|
||||
# Partially known shape.
|
||||
rnd = random_ops.random_poisson(
|
||||
lam=array_ops.ones([7, 3]),
|
||||
shape=array_ops.placeholder(dtypes.int32, shape=(1,)),
|
||||
seed=12345)
|
||||
self.assertEqual([None, 7, 3], rnd.get_shape().as_list())
|
||||
rnd = random_ops.random_poisson(
|
||||
lam=array_ops.ones([9, 6]),
|
||||
shape=array_ops.placeholder(dtypes.int32, shape=(3,)),
|
||||
seed=12345)
|
||||
self.assertEqual([None, None, None, 9, 6], rnd.get_shape().as_list())
|
||||
# Unknown shape.
|
||||
rnd = random_ops.random_poisson(
|
||||
lam=array_ops.placeholder(dtypes.float32),
|
||||
shape=array_ops.placeholder(dtypes.int32),
|
||||
seed=12345)
|
||||
self.assertIs(None, rnd.get_shape().ndims)
|
||||
rnd = random_ops.random_poisson(
|
||||
lam=array_ops.placeholder(dtypes.float32),
|
||||
shape=[50],
|
||||
seed=12345)
|
||||
self.assertIs(None, rnd.get_shape().ndims)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -282,6 +282,7 @@ ParseSingleSequenceExample
|
||||
|
||||
# random_ops
|
||||
RandomGamma
|
||||
RandomPoisson
|
||||
RandomUniform
|
||||
RandomUniformInt
|
||||
RandomShuffle
|
||||
|
@ -71,10 +71,8 @@ def random_normal(shape,
|
||||
mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||
stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
rnd = gen_random_ops._random_standard_normal(shape_tensor,
|
||||
dtype,
|
||||
seed=seed1,
|
||||
seed2=seed2)
|
||||
rnd = gen_random_ops._random_standard_normal(
|
||||
shape_tensor, dtype, seed=seed1, seed2=seed2)
|
||||
mul = rnd * stddev_tensor
|
||||
value = math_ops.add(mul, mean_tensor, name=name)
|
||||
return value
|
||||
@ -125,13 +123,14 @@ def parameterized_truncated_normal(shape,
|
||||
minvals_tensor = ops.convert_to_tensor(minvals, dtype=dtype, name="minvals")
|
||||
maxvals_tensor = ops.convert_to_tensor(maxvals, dtype=dtype, name="maxvals")
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
rnd = gen_random_ops._parameterized_truncated_normal(shape_tensor,
|
||||
means_tensor,
|
||||
stddevs_tensor,
|
||||
minvals_tensor,
|
||||
maxvals_tensor,
|
||||
seed=seed1,
|
||||
seed2=seed2)
|
||||
rnd = gen_random_ops._parameterized_truncated_normal(
|
||||
shape_tensor,
|
||||
means_tensor,
|
||||
stddevs_tensor,
|
||||
minvals_tensor,
|
||||
maxvals_tensor,
|
||||
seed=seed1,
|
||||
seed2=seed2)
|
||||
return rnd
|
||||
|
||||
|
||||
@ -168,10 +167,8 @@ def truncated_normal(shape,
|
||||
mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||
stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
rnd = gen_random_ops._truncated_normal(shape_tensor,
|
||||
dtype,
|
||||
seed=seed1,
|
||||
seed2=seed2)
|
||||
rnd = gen_random_ops._truncated_normal(
|
||||
shape_tensor, dtype, seed=seed1, seed2=seed2)
|
||||
mul = rnd * stddev_tensor
|
||||
value = math_ops.add(mul, mean_tensor, name=name)
|
||||
return value
|
||||
@ -232,17 +229,11 @@ def random_uniform(shape,
|
||||
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
if dtype.is_integer:
|
||||
return gen_random_ops._random_uniform_int(shape,
|
||||
minval,
|
||||
maxval,
|
||||
seed=seed1,
|
||||
seed2=seed2,
|
||||
name=name)
|
||||
return gen_random_ops._random_uniform_int(
|
||||
shape, minval, maxval, seed=seed1, seed2=seed2, name=name)
|
||||
else:
|
||||
rnd = gen_random_ops._random_uniform(shape,
|
||||
dtype,
|
||||
seed=seed1,
|
||||
seed2=seed2)
|
||||
rnd = gen_random_ops._random_uniform(
|
||||
shape, dtype, seed=seed1, seed2=seed2)
|
||||
return math_ops.add(rnd * (maxval - minval), minval, name=name)
|
||||
|
||||
|
||||
@ -275,10 +266,8 @@ def random_shuffle(value, seed=None, name=None):
|
||||
dimension.
|
||||
"""
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
return gen_random_ops._random_shuffle(value,
|
||||
seed=seed1,
|
||||
seed2=seed2,
|
||||
name=name)
|
||||
return gen_random_ops._random_shuffle(
|
||||
value, seed=seed1, seed2=seed2, name=name)
|
||||
|
||||
|
||||
def random_crop(value, size, seed=None, name=None):
|
||||
@ -349,10 +338,8 @@ def multinomial(logits, num_samples, seed=None, name=None):
|
||||
with ops.name_scope(name, "multinomial", [logits]):
|
||||
logits = ops.convert_to_tensor(logits, name="logits")
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
return gen_random_ops.multinomial(logits,
|
||||
num_samples,
|
||||
seed=seed1,
|
||||
seed2=seed2)
|
||||
return gen_random_ops.multinomial(
|
||||
logits, num_samples, seed=seed1, seed2=seed2)
|
||||
|
||||
|
||||
ops.NotDifferentiable("Multinomial")
|
||||
@ -426,15 +413,52 @@ def random_gamma(shape,
|
||||
with ops.name_scope(name, "random_gamma", [shape, alpha, beta]):
|
||||
shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32)
|
||||
alpha = ops.convert_to_tensor(alpha, name="alpha", dtype=dtype)
|
||||
beta = ops.convert_to_tensor(beta if beta is not None else 1,
|
||||
name="beta",
|
||||
dtype=dtype)
|
||||
beta = ops.convert_to_tensor(
|
||||
beta if beta is not None else 1, name="beta", dtype=dtype)
|
||||
alpha_broadcast = alpha + array_ops.zeros_like(beta)
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
return gen_random_ops._random_gamma(shape,
|
||||
alpha_broadcast,
|
||||
seed=seed1,
|
||||
seed2=seed2) / beta
|
||||
return gen_random_ops._random_gamma(
|
||||
shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta
|
||||
|
||||
|
||||
ops.NotDifferentiable("RandomGamma")
|
||||
|
||||
|
||||
def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
|
||||
"""Draws `shape` samples from each of the given Poisson distribution(s).
|
||||
|
||||
`lam` is the rate parameter describing the distribution(s).
|
||||
|
||||
Example:
|
||||
|
||||
samples = tf.random_poisson([0.5, 1.5], [10])
|
||||
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
|
||||
# the samples drawn from each distribution
|
||||
|
||||
samples = tf.random_poisson([12.2, 3.3], [7, 5])
|
||||
# samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
|
||||
# represents the 7x5 samples drawn from each of the two distributions
|
||||
|
||||
Args:
|
||||
lam: A Tensor or Python value or N-D array of type `dtype`.
|
||||
`lam` provides the rate parameter(s) describing the poisson
|
||||
distribution(s) to sample.
|
||||
shape: A 1-D integer Tensor or Python array. The shape of the output samples
|
||||
to be drawn per "rate"-parameterized distribution.
|
||||
dtype: The type of `lam` and the output: `float16`, `float32`, or
|
||||
`float64`.
|
||||
seed: A Python integer. Used to create a random seed for the distributions.
|
||||
See
|
||||
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
|
||||
for behavior.
|
||||
name: Optional name for the operation.
|
||||
|
||||
Returns:
|
||||
samples: a `Tensor` of shape `tf.concat(shape, tf.shape(lam))` with
|
||||
values of type `dtype`.
|
||||
"""
|
||||
with ops.name_scope(name, "random_poisson", [lam, shape]):
|
||||
lam = ops.convert_to_tensor(lam, name="lam", dtype=dtype)
|
||||
shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32)
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
return gen_random_ops._random_poisson(shape, lam, seed=seed1, seed2=seed2)
|
||||
|
Loading…
Reference in New Issue
Block a user