Adds a tf.random.stateless_poisson
sampler for CPU.
PiperOrigin-RevId: 293834809 Change-Id: I719e218b43f8aecbd74d1472f1291748a61979b8
This commit is contained in:
parent
c879a09690
commit
d8dc9415b0
tensorflow
core
api_def/base_api
kernels
ops
python
eager
kernel_tests/random
ops
tools/api/golden
@ -0,0 +1,41 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomPoisson"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "seed"
|
||||
description: <<END
|
||||
2 seeds (shape [2]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lam"
|
||||
description: <<END
|
||||
The rate of the Poisson distribution. Shape must match the rightmost dimensions
|
||||
of `shape`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom random numbers from a Poisson distribution."
|
||||
description: <<END
|
||||
Outputs random values from a Poisson distribution.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `seed`, and `lam`.
|
||||
END
|
||||
}
|
@ -5133,6 +5133,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
":bounds_check",
|
||||
":random_op",
|
||||
":random_poisson_op",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -6465,6 +6466,7 @@ filegroup(
|
||||
"pad_op.h",
|
||||
"pooling_ops_3d.h",
|
||||
"random_op.h",
|
||||
"random_poisson_op.h",
|
||||
"reduction_ops.h",
|
||||
"reduction_ops_common.h",
|
||||
"relu_op.h",
|
||||
@ -6658,6 +6660,7 @@ filegroup(
|
||||
"queue_ops.cc",
|
||||
"random_op.cc",
|
||||
"random_op_cpu.h",
|
||||
"random_poisson_op.cc",
|
||||
"reduction_ops_all.cc",
|
||||
"reduction_ops_any.cc",
|
||||
"reduction_ops_common.cc",
|
||||
|
@ -68,13 +68,6 @@ struct PoissonComputeType {
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T, typename U>
|
||||
struct PoissonFunctor {
|
||||
void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
|
||||
int num_rate, int num_samples,
|
||||
const random::PhiloxRandom& rng, U* samples_flat);
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct PoissonFunctor<CPUDevice, T, U> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat,
|
||||
@ -329,11 +322,12 @@ TF_CALL_half(REGISTER);
|
||||
TF_CALL_float(REGISTER);
|
||||
TF_CALL_double(REGISTER);
|
||||
|
||||
#define REGISTER_V2(RTYPE, OTYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RandomPoissonV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<RTYPE>("R") \
|
||||
.TypeConstraint<OTYPE>("dtype"), \
|
||||
#define REGISTER_V2(RTYPE, OTYPE) \
|
||||
template struct functor::PoissonFunctor<CPUDevice, RTYPE, OTYPE>; \
|
||||
REGISTER_KERNEL_BUILDER(Name("RandomPoissonV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<RTYPE>("R") \
|
||||
.TypeConstraint<OTYPE>("dtype"), \
|
||||
RandomPoissonOp<RTYPE, OTYPE>);
|
||||
|
||||
#define REGISTER_ALL(RTYPE) \
|
||||
|
@ -16,13 +16,20 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Generic helper functor for the Random Poisson Op.
|
||||
template <typename Device, typename T /* rate */, typename U /* output */>
|
||||
struct PoissonFunctor;
|
||||
struct PoissonFunctor {
|
||||
void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
|
||||
int num_rate, int num_samples,
|
||||
const random::PhiloxRandom& rng, U* samples_flat);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/kernels/random_poisson_op.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
@ -162,6 +163,35 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
|
||||
}
|
||||
};
|
||||
|
||||
// Samples from one or more Poisson distributions.
|
||||
template <typename T, typename U>
|
||||
class StatelessRandomPoissonOp : public StatelessRandomOpBase {
|
||||
public:
|
||||
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||
|
||||
void Fill(OpKernelContext* ctx, random::PhiloxRandom random,
|
||||
Tensor* output) override {
|
||||
const Tensor& rate_t = ctx->input(2);
|
||||
|
||||
TensorShape samples_shape = output->shape();
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, rate_t.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Shape passed in must end with broadcasted shape."));
|
||||
|
||||
const int64 num_rate = rate_t.NumElements();
|
||||
const int64 samples_per_rate = samples_shape.num_elements() / num_rate;
|
||||
const auto rate_flat = rate_t.flat<T>().data();
|
||||
auto samples_flat = output->flat<U>().data();
|
||||
|
||||
functor::PoissonFunctor<CPUDevice, T, U>()(
|
||||
ctx, ctx->eigen_device<CPUDevice>(), rate_flat, num_rate,
|
||||
samples_per_rate, random, samples_flat);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomPoissonOp);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class StatelessRandomGammaOp : public StatelessRandomOpBase {
|
||||
public:
|
||||
@ -354,7 +384,7 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase {
|
||||
.HostMemory("minval") \
|
||||
.HostMemory("maxval") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomUniformIntOp<DEVICE##Device, TYPE>);
|
||||
StatelessRandomUniformIntOp<DEVICE##Device, TYPE>)
|
||||
|
||||
#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
|
||||
#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
|
||||
@ -368,6 +398,32 @@ TF_CALL_double(REGISTER_CPU);
|
||||
TF_CALL_int32(REGISTER_INT_CPU);
|
||||
TF_CALL_int64(REGISTER_INT_CPU);
|
||||
|
||||
#define REGISTER_POISSON(RATE_TYPE, OUT_TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomPoisson") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("seed") \
|
||||
.HostMemory("lam") \
|
||||
.TypeConstraint<RATE_TYPE>("Rtype") \
|
||||
.TypeConstraint<OUT_TYPE>("dtype"), \
|
||||
StatelessRandomPoissonOp<RATE_TYPE, OUT_TYPE>)
|
||||
|
||||
#define REGISTER_ALL_POISSON(RATE_TYPE) \
|
||||
REGISTER_POISSON(RATE_TYPE, Eigen::half); \
|
||||
REGISTER_POISSON(RATE_TYPE, float); \
|
||||
REGISTER_POISSON(RATE_TYPE, double); \
|
||||
REGISTER_POISSON(RATE_TYPE, int32); \
|
||||
REGISTER_POISSON(RATE_TYPE, int64)
|
||||
|
||||
TF_CALL_half(REGISTER_ALL_POISSON);
|
||||
TF_CALL_float(REGISTER_ALL_POISSON);
|
||||
TF_CALL_double(REGISTER_ALL_POISSON);
|
||||
TF_CALL_int32(REGISTER_ALL_POISSON);
|
||||
TF_CALL_int64(REGISTER_ALL_POISSON);
|
||||
|
||||
#undef REGISTER_ALL_POISSON
|
||||
#undef REGISTER_POISSON
|
||||
|
||||
#define REGISTER_GAMMA(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
|
@ -105,6 +105,17 @@ REGISTER_OP("StatelessRandomBinomial")
|
||||
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
|
||||
.SetShapeFn(StatelessShape);
|
||||
|
||||
REGISTER_OP("StatelessRandomPoisson")
|
||||
.Input("shape: T")
|
||||
.Input("seed: Tseed")
|
||||
.Input("lam: Rtype")
|
||||
.Output("output: dtype")
|
||||
.Attr("Rtype: {float16, float32, float64, int32, int64}")
|
||||
.Attr("dtype: {float16, float32, float64, int32, int64}")
|
||||
.Attr("T: {int32, int64}")
|
||||
.Attr("Tseed: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn(StatelessShape);
|
||||
|
||||
REGISTER_OP("StatelessRandomGammaV2")
|
||||
.Input("shape: T")
|
||||
.Input("seed: Tseed")
|
||||
|
@ -310,6 +310,7 @@ bool OpGradientDoesntRequireInputIndices(
|
||||
{"StatelessRandomBinomial", {true, {}}},
|
||||
{"StatelessRandomGammaV2", {false, {1}}},
|
||||
{"StatelessRandomNormal", {true, {}}},
|
||||
{"StatelessRandomPoisson", {true, {}}},
|
||||
{"StatelessRandomUniform", {true, {}}},
|
||||
{"StatelessRandomUniformInt", {true, {}}},
|
||||
{"StatelessTruncatedNormal", {true, {}}},
|
||||
@ -765,6 +766,7 @@ bool OpGradientDoesntRequireOutputIndices(
|
||||
{"StatelessMultinomial", {true, {}}},
|
||||
{"StatelessRandomBinomial", {true, {}}},
|
||||
{"StatelessRandomNormal", {true, {}}},
|
||||
{"StatelessRandomPoisson", {true, {}}},
|
||||
{"StatelessRandomUniform", {true, {}}},
|
||||
{"StatelessRandomUniformInt", {true, {}}},
|
||||
{"StatelessTruncatedNormal", {true, {}}},
|
||||
|
@ -133,11 +133,23 @@ class StatelessOpsTest(test.TestCase):
|
||||
for dtype in np.float16, np.float32, np.float64:
|
||||
for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]):
|
||||
kwds = dict(alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype)
|
||||
yield (functools.partial(
|
||||
stateless.stateless_random_gamma,
|
||||
shape=(10,) + tuple(np.shape(alpha)),
|
||||
**kwds),
|
||||
functools.partial(random_ops.random_gamma, shape=(10,), **kwds))
|
||||
yield (
|
||||
functools.partial(stateless.stateless_random_gamma,
|
||||
shape=(10,) + tuple(np.shape(alpha)), **kwds),
|
||||
functools.partial(random_ops.random_gamma, shape=(10,), **kwds))
|
||||
|
||||
def _poisson_cases(self):
|
||||
for lam_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
|
||||
for out_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
|
||||
for lam in ([[5.5, 1., 2.]], [[7.5, 10.5], [3.8, 8.2], [1.25, 9.75]]):
|
||||
kwds = dict(
|
||||
lam=constant_op.constant(lam_dtype(lam), dtype=lam_dtype),
|
||||
dtype=out_dtype)
|
||||
yield (
|
||||
functools.partial(stateless.stateless_random_poisson,
|
||||
shape=(10,) + tuple(np.shape(lam)),
|
||||
**kwds),
|
||||
functools.partial(random_ops.random_poisson, shape=(10,), **kwds))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMatchFloat(self):
|
||||
@ -155,6 +167,10 @@ class StatelessOpsTest(test.TestCase):
|
||||
def testMatchGamma(self):
|
||||
self._test_match(self._gamma_cases())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMatchPoisson(self):
|
||||
self._test_match(self._poisson_cases())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testDeterminismFloat(self):
|
||||
self._test_determinism(
|
||||
@ -173,6 +189,10 @@ class StatelessOpsTest(test.TestCase):
|
||||
def testDeterminismGamma(self):
|
||||
self._test_determinism(self._gamma_cases())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testDeterminismPoisson(self):
|
||||
self._test_determinism(self._poisson_cases())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
ops.NotDifferentiable("StatelessMultinomial")
|
||||
ops.NotDifferentiable("StatelessRandomBinomial")
|
||||
ops.NotDifferentiable("StatelessRandomNormal")
|
||||
ops.NotDifferentiable("StatelessRandomPoisson")
|
||||
ops.NotDifferentiable("StatelessRandomUniform")
|
||||
ops.NotDifferentiable("StatelessRandomUniformInt")
|
||||
ops.NotDifferentiable("StatelessTruncatedNormal")
|
||||
@ -271,6 +272,67 @@ def stateless_random_gamma(shape,
|
||||
return result
|
||||
|
||||
|
||||
@tf_export("random.stateless_poisson")
|
||||
def stateless_random_poisson(shape,
|
||||
seed,
|
||||
lam,
|
||||
dtype=dtypes.int32,
|
||||
name=None):
|
||||
"""Outputs deterministic pseudorandom values from a Poisson distribution.
|
||||
|
||||
The generated values follow a Poisson distribution with specified rate
|
||||
parameter.
|
||||
|
||||
This is a stateless version of `tf.random.poisson`: if run twice with the same
|
||||
seeds, it will produce the same pseudorandom numbers. The output is consistent
|
||||
across multiple runs on the same hardware (and between CPU and GPU), but may
|
||||
change between versions of TensorFlow or on non-CPU/GPU hardware.
|
||||
|
||||
A slight difference exists in the interpretation of the `shape` parameter
|
||||
between `stateless_poisson` and `poisson`: in `poisson`, the `shape` is always
|
||||
prepended to the shape of `rate`; whereas in `stateless_poisson` the shape of
|
||||
`rate` must match the trailing dimensions of `shape`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
samples = tf.random.stateless_poisson([10, 2], seed=[12, 34], lam=[5, 15])
|
||||
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
|
||||
# the samples drawn from each distribution
|
||||
|
||||
samples = tf.random.stateless_poisson([7, 5, 2], seed=[12, 34], lam=[5, 15])
|
||||
# samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
|
||||
# represents the 7x5 samples drawn from each of the two distributions
|
||||
|
||||
rate = tf.constant([[1.], [3.], [5.]])
|
||||
samples = tf.random.stateless_poisson([30, 3, 1], seed=[12, 34], lam=rate)
|
||||
# samples has shape [30, 3, 1], with 30 samples each of 3x1 distributions.
|
||||
```
|
||||
|
||||
Args:
|
||||
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
|
||||
seed: A shape [2] integer Tensor of seeds to the random number generator.
|
||||
lam: Tensor. The rate parameter "lambda" of the Poisson distribution. Shape
|
||||
must match the rightmost dimensions of `shape`.
|
||||
dtype: Dtype of the samples (int or float dtypes are permissible, as samples
|
||||
are discrete). Default: int32.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
samples: A Tensor of the specified shape filled with random Poisson values.
|
||||
For each i, each `samples[..., i]` is an independent draw from the Poisson
|
||||
distribution with rate `lam[i]`.
|
||||
|
||||
"""
|
||||
with ops.name_scope(name, "stateless_random_poisson",
|
||||
[shape, seed, lam]) as name:
|
||||
shape = tensor_util.shape_tensor(shape)
|
||||
result = gen_stateless_random_ops.stateless_random_poisson(
|
||||
shape, seed=seed, lam=lam, dtype=dtype)
|
||||
tensor_util.maybe_set_static_shape(result, shape)
|
||||
return result
|
||||
|
||||
|
||||
@tf_export("random.stateless_normal")
|
||||
def stateless_random_normal(shape,
|
||||
seed,
|
||||
|
@ -92,6 +92,10 @@ tf_module {
|
||||
name: "stateless_normal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_poisson"
|
||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_truncated_normal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
|
@ -4292,6 +4292,10 @@ tf_module {
|
||||
name: "StatelessRandomNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomPoisson"
|
||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniform"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
|
@ -80,6 +80,10 @@ tf_module {
|
||||
name: "stateless_normal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_poisson"
|
||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_truncated_normal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
|
@ -4292,6 +4292,10 @@ tf_module {
|
||||
name: "StatelessRandomNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomPoisson"
|
||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniform"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user