Adds a tf.random.stateless_gamma
sampler for CPU. (Stateless analogue to tf.random.gamma
.)
PiperOrigin-RevId: 293454705 Change-Id: I103900822dac37989246eaf7b410157e5563b78b
This commit is contained in:
parent
f95a6caa8b
commit
31c94250fc
tensorflow
core
api_def/base_api
kernels
ops
python
eager
kernel_tests/random
ops
tools/api/golden
@ -0,0 +1,41 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomGammaV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "seed"
|
||||
description: <<END
|
||||
2 seeds (shape [2]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alpha"
|
||||
description: <<END
|
||||
The concentration of the gamma distribution. Shape must match the rightmost
|
||||
dimensions of `shape`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom random numbers from a gamma distribution."
|
||||
description: <<END
|
||||
Outputs random values from a gamma distribution.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
|
||||
END
|
||||
}
|
@ -167,7 +167,7 @@ class RandomGammaOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
|
||||
&samples_shape));
|
||||
}
|
||||
const int64 num_samples = samples_shape.num_elements();
|
||||
const int64 samples_per_alpha = samples_shape.num_elements();
|
||||
|
||||
samples_shape.AppendShape(alpha_t.shape());
|
||||
// Allocate output samples.
|
||||
@ -199,13 +199,13 @@ class RandomGammaOp : public OpKernel {
|
||||
num_alphas));
|
||||
auto samples_flat = samples_t->flat<T>().data();
|
||||
PhiloxRandom rng = generator_.ReserveRandomOutputs(
|
||||
num_samples * num_alphas, kReservedSamplesPerOutput);
|
||||
samples_per_alpha * num_alphas, kReservedSamplesPerOutput);
|
||||
|
||||
// We partition work first across alphas then across samples-per-alpha to
|
||||
// avoid a couple flops which can be done on a per-alpha basis.
|
||||
|
||||
auto DoWork = [num_samples, num_alphas, &rng, samples_flat, alpha_flat](
|
||||
int start_output, int limit_output) {
|
||||
auto DoWork = [samples_per_alpha, num_alphas, &rng, samples_flat,
|
||||
alpha_flat](int start_output, int limit_output) {
|
||||
using Eigen::numext::exp;
|
||||
using Eigen::numext::log;
|
||||
using Eigen::numext::pow;
|
||||
@ -220,7 +220,7 @@ class RandomGammaOp : public OpKernel {
|
||||
typename Uniform::ResultType uniform_result;
|
||||
for (int64 output_idx = start_output; output_idx < limit_output;
|
||||
/* output_idx incremented within inner loop below */) {
|
||||
int64 alpha_idx = output_idx / num_samples;
|
||||
int64 alpha_idx = output_idx / samples_per_alpha;
|
||||
|
||||
// Instead of +alpha_idx for each sample, we offset the pointer once.
|
||||
T* const samples_alpha_offset = samples_flat + alpha_idx;
|
||||
@ -232,8 +232,8 @@ class RandomGammaOp : public OpKernel {
|
||||
if (alpha == static_cast<double>(1.0)) {
|
||||
ENABLE_FLOAT_EQUALITY_WARNING
|
||||
// Sample from an exponential distribution.
|
||||
for (int64 sample_idx = output_idx % num_samples;
|
||||
sample_idx < num_samples && output_idx < limit_output;
|
||||
for (int64 sample_idx = output_idx % samples_per_alpha;
|
||||
sample_idx < samples_per_alpha && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// As we want data stable regardless of sharding
|
||||
// (including eventually on GPU), we skip on a per-sample basis.
|
||||
@ -259,8 +259,8 @@ class RandomGammaOp : public OpKernel {
|
||||
const double c = 1.0 / 3 / sqrt(d);
|
||||
|
||||
// Compute the rest of the samples for the current alpha value.
|
||||
for (int64 sample_idx = output_idx % num_samples;
|
||||
sample_idx < num_samples && output_idx < limit_output;
|
||||
for (int64 sample_idx = output_idx % samples_per_alpha;
|
||||
sample_idx < samples_per_alpha && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// Since each sample may use a variable number of normal/uniform
|
||||
// samples, and we want data stable regardless of sharding
|
||||
@ -317,7 +317,7 @@ class RandomGammaOp : public OpKernel {
|
||||
3 * PhiloxRandom::kElementCost;
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||
num_alphas * num_samples, kElementCost, DoWork);
|
||||
num_alphas * samples_per_alpha, kElementCost, DoWork);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -22,6 +22,17 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||
#else
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -151,6 +162,162 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class StatelessRandomGammaOp : public StatelessRandomOpBase {
|
||||
public:
|
||||
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||
|
||||
void Fill(OpKernelContext* ctx, random::PhiloxRandom random,
|
||||
Tensor* output) override {
|
||||
const Tensor& alpha_t = ctx->input(2);
|
||||
|
||||
TensorShape samples_shape = output->shape();
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, alpha_t.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Shape passed in must end with broadcasted shape."));
|
||||
|
||||
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
|
||||
#define UNIFORM(X) \
|
||||
if (uniform_remaining == 0) { \
|
||||
uniform_remaining = Uniform::kResultElementCount; \
|
||||
uniform_result = uniform(&gen); \
|
||||
} \
|
||||
uniform_remaining--; \
|
||||
double X = uniform_result[uniform_remaining]
|
||||
|
||||
// Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform
|
||||
static constexpr int kReservedSamplesPerOutput = 256;
|
||||
|
||||
const int64 num_alphas = alpha_t.NumElements();
|
||||
OP_REQUIRES(ctx, num_alphas > 0,
|
||||
errors::InvalidArgument(
|
||||
"Input alpha should have non-zero element count, got: ",
|
||||
num_alphas));
|
||||
const int64 samples_per_alpha = samples_shape.num_elements() / num_alphas;
|
||||
const auto alpha_flat = alpha_t.flat<T>().data();
|
||||
auto samples_flat = output->flat<T>().data();
|
||||
|
||||
// We partition work first across alphas then across samples-per-alpha to
|
||||
// avoid a couple flops which can be done on a per-alpha basis.
|
||||
|
||||
auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat,
|
||||
alpha_flat](int start_output, int limit_output) {
|
||||
// Capturing "random" by-value would only make a copy for the _shared_
|
||||
// lambda. Since we want to let each worker have its own copy, we pass
|
||||
// "random" by reference and explicitly do a copy assignment.
|
||||
|
||||
using Eigen::numext::exp;
|
||||
using Eigen::numext::log;
|
||||
using Eigen::numext::pow;
|
||||
|
||||
Normal normal;
|
||||
Uniform uniform;
|
||||
typename Normal::ResultType norm_result;
|
||||
typename Uniform::ResultType uniform_result;
|
||||
for (int64 output_idx = start_output; output_idx < limit_output;
|
||||
/* output_idx incremented within inner loop below */) {
|
||||
int64 alpha_idx = output_idx / samples_per_alpha;
|
||||
|
||||
// Instead of +alpha_idx for each sample, we offset the pointer once.
|
||||
T* const samples_alpha_offset = samples_flat + alpha_idx;
|
||||
|
||||
// Several calculations can be done on a per-alpha basis.
|
||||
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
|
||||
|
||||
DISABLE_FLOAT_EQUALITY_WARNING
|
||||
if (alpha == static_cast<double>(1.0)) {
|
||||
ENABLE_FLOAT_EQUALITY_WARNING
|
||||
// Sample from an exponential distribution.
|
||||
for (int64 sample_idx = output_idx % samples_per_alpha;
|
||||
sample_idx < samples_per_alpha && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// As we want data stable regardless of sharding
|
||||
// (including eventually on GPU), we skip on a per-sample basis.
|
||||
random::PhiloxRandom gen = random;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
int16 uniform_remaining = 0;
|
||||
UNIFORM(u);
|
||||
const double res = -log(1.0 - u);
|
||||
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
||||
} // for (sample_idx)
|
||||
} else { // if alpha != 1.0
|
||||
// Transformation-rejection from pairs of uniform and normal random
|
||||
// variables. http://dl.acm.org/citation.cfm?id=358414
|
||||
//
|
||||
// The algorithm has an acceptance rate of ~95% for small alpha (~1),
|
||||
// and higher accept rates for higher alpha, so runtime is
|
||||
// O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
|
||||
//
|
||||
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
|
||||
// result by uniform()^(1/alpha)
|
||||
const bool alpha_less_than_one = alpha < 1;
|
||||
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
|
||||
const double c = 1.0 / 3 / sqrt(d);
|
||||
|
||||
// Compute the rest of the samples for the current alpha value.
|
||||
for (int64 sample_idx = output_idx % samples_per_alpha;
|
||||
sample_idx < samples_per_alpha && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// Since each sample may use a variable number of normal/uniform
|
||||
// samples, and we want data stable regardless of sharding
|
||||
// (including eventually on GPU), we skip on a per-sample basis.
|
||||
random::PhiloxRandom gen = random;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
int16 norm_remaining = 0;
|
||||
int16 uniform_remaining = 0;
|
||||
|
||||
// Keep trying until we don't reject a sample. In practice, we will
|
||||
// only reject ~5% at worst, for low alpha near 1.
|
||||
while (true) {
|
||||
if (norm_remaining == 0) {
|
||||
norm_remaining = Normal::kResultElementCount;
|
||||
norm_result = normal(&gen);
|
||||
}
|
||||
norm_remaining--;
|
||||
const double x = norm_result[norm_remaining];
|
||||
double v = 1 + c * x;
|
||||
if (v <= 0) {
|
||||
continue;
|
||||
}
|
||||
v = v * v * v;
|
||||
UNIFORM(u);
|
||||
// The first option in the if is a "squeeze" short-circuit to
|
||||
// dodge the two logs. Magic constant sourced from the paper
|
||||
// linked above. Upward of .91 of the area covered by the log
|
||||
// inequality is covered by the squeeze as well (larger coverage
|
||||
// for smaller values of alpha).
|
||||
if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
|
||||
(log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
|
||||
double res = d * v;
|
||||
if (alpha_less_than_one) {
|
||||
UNIFORM(b);
|
||||
res *= pow(b, 1 / alpha);
|
||||
}
|
||||
samples_alpha_offset[sample_idx * num_alphas] =
|
||||
static_cast<T>(res);
|
||||
break;
|
||||
}
|
||||
} // while: true
|
||||
} // for: sample_idx
|
||||
} // if (alpha == 1.0)
|
||||
} // for: output_idx
|
||||
}; // DoWork
|
||||
#undef UNIFORM
|
||||
// Two calls to log only occur for ~10% of samples reaching the log line.
|
||||
// 2 x 100 (64-bit cycles per log) x 0.10 = ~20.
|
||||
// Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
|
||||
// each = ~60.
|
||||
// All of this /0.95 due to the rejection possibility = ~85.
|
||||
static const int kElementCost = 85 + 2 * Normal::kElementCost +
|
||||
Uniform::kElementCost +
|
||||
3 * random::PhiloxRandom::kElementCost;
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||
num_alphas * samples_per_alpha, kElementCost, DoWork);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER(DEVICE, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessRandomUniform") \
|
||||
@ -177,7 +344,7 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
|
||||
StatelessRandomOp< \
|
||||
DEVICE##Device, \
|
||||
random::TruncatedNormalDistribution< \
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
|
||||
|
||||
#define REGISTER_INT(DEVICE, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \
|
||||
@ -201,6 +368,22 @@ TF_CALL_double(REGISTER_CPU);
|
||||
TF_CALL_int32(REGISTER_INT_CPU);
|
||||
TF_CALL_int64(REGISTER_INT_CPU);
|
||||
|
||||
#define REGISTER_GAMMA(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("seed") \
|
||||
.HostMemory("alpha") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomGammaOp<CPUDevice, TYPE>)
|
||||
|
||||
TF_CALL_half(REGISTER_GAMMA);
|
||||
TF_CALL_bfloat16(REGISTER_GAMMA);
|
||||
TF_CALL_float(REGISTER_GAMMA);
|
||||
TF_CALL_double(REGISTER_GAMMA);
|
||||
|
||||
#undef REGISTER_GAMMA
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
TF_CALL_half(REGISTER_GPU);
|
||||
|
@ -105,4 +105,14 @@ REGISTER_OP("StatelessRandomBinomial")
|
||||
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
|
||||
.SetShapeFn(StatelessShape);
|
||||
|
||||
REGISTER_OP("StatelessRandomGammaV2")
|
||||
.Input("shape: T")
|
||||
.Input("seed: Tseed")
|
||||
.Input("alpha: dtype")
|
||||
.Output("output: dtype")
|
||||
.Attr("dtype: {float16, float32, float64}")
|
||||
.Attr("T: {int32, int64}")
|
||||
.Attr("Tseed: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn(StatelessShape);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -308,6 +308,7 @@ bool OpGradientDoesntRequireInputIndices(
|
||||
{"StackPush", {true, {}}},
|
||||
{"StatelessMultinomial", {true, {}}},
|
||||
{"StatelessRandomBinomial", {true, {}}},
|
||||
{"StatelessRandomGammaV2", {false, {1}}},
|
||||
{"StatelessRandomNormal", {true, {}}},
|
||||
{"StatelessRandomUniform", {true, {}}},
|
||||
{"StatelessRandomUniformInt", {true, {}}},
|
||||
|
@ -100,6 +100,7 @@ cuda_py_test(
|
||||
name = "stateless_random_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["stateless_random_ops_test.py"],
|
||||
shard_count = 2,
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -129,6 +129,16 @@ class StatelessOpsTest(test.TestCase):
|
||||
yield (functools.partial(stateless.stateless_multinomial, **kwds),
|
||||
functools.partial(random_ops.multinomial, **kwds))
|
||||
|
||||
def _gamma_cases(self):
|
||||
for dtype in np.float16, np.float32, np.float64:
|
||||
for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]):
|
||||
kwds = dict(alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype)
|
||||
yield (functools.partial(
|
||||
stateless.stateless_random_gamma,
|
||||
shape=(10,) + tuple(np.shape(alpha)),
|
||||
**kwds),
|
||||
functools.partial(random_ops.random_gamma, shape=(10,), **kwds))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMatchFloat(self):
|
||||
self._test_match(self._float_cases())
|
||||
@ -141,6 +151,10 @@ class StatelessOpsTest(test.TestCase):
|
||||
def testMatchMultinomial(self):
|
||||
self._test_match(self._multinomial_cases())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMatchGamma(self):
|
||||
self._test_match(self._gamma_cases())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testDeterminismFloat(self):
|
||||
self._test_determinism(
|
||||
@ -155,6 +169,10 @@ class StatelessOpsTest(test.TestCase):
|
||||
def testDeterminismMultinomial(self):
|
||||
self._test_determinism(self._multinomial_cases())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testDeterminismGamma(self):
|
||||
self._test_determinism(self._gamma_cases())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -25,7 +25,7 @@ from tensorflow.python.ops import gen_random_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def add_leading_unit_dimensions(x, num_dimensions):
|
||||
def add_leading_unit_dimensions(x, num_dimensions): # pylint: disable=invalid-name
|
||||
new_shape = array_ops.concat(
|
||||
[array_ops.ones([num_dimensions], dtype=dtypes.int32),
|
||||
array_ops.shape(x)], axis=0)
|
||||
@ -70,3 +70,47 @@ def _RandomGammaGrad(op, grad): # pylint: disable=invalid-name
|
||||
# The first input is shape; the second input is alpha.
|
||||
return (None, math_ops.reduce_sum(
|
||||
grad * partial_a, axis=math_ops.range(num_sample_dimensions)))
|
||||
|
||||
|
||||
@ops.RegisterGradient("StatelessRandomGammaV2")
|
||||
def _StatelessRandomGammaV2Grad(op, grad): # pylint: disable=invalid-name
|
||||
"""Returns the gradient of a Gamma sample w.r.t. alpha.
|
||||
|
||||
The gradient is computed using implicit differentiation
|
||||
(Figurnov et al., 2018).
|
||||
|
||||
Args:
|
||||
op: A `StatelessRandomGamma` operation. We assume that the inputs to the
|
||||
operation are `shape`, `seed` and `alpha` tensors, and the output is the
|
||||
`sample` tensor.
|
||||
grad: The incoming gradient `dloss / dsample` of the same shape as
|
||||
`op.outputs[0]`.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with derivatives `dloss / dalpha`.
|
||||
|
||||
References:
|
||||
Implicit Reparameterization Gradients:
|
||||
[Figurnov et al., 2018]
|
||||
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
|
||||
([pdf]
|
||||
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
|
||||
"""
|
||||
shape = op.inputs[0]
|
||||
alpha = op.inputs[2]
|
||||
sample = op.outputs[0]
|
||||
|
||||
with ops.control_dependencies([grad]):
|
||||
# Note that the shape handling is slightly different for stateless_gamma,
|
||||
# in particular num_sample_dimensions is different.
|
||||
num_sample_dimensions = array_ops.shape(shape)[0] - array_ops.rank(alpha)
|
||||
# Make the parameters alpha broadcastable with samples by appending
|
||||
# unit dimensions.
|
||||
alpha_broadcastable = add_leading_unit_dimensions(alpha,
|
||||
num_sample_dimensions)
|
||||
partial_a = gen_random_ops.random_gamma_grad(alpha_broadcastable, sample)
|
||||
|
||||
# The first two inputs are shape, seed, third input is alpha.
|
||||
return (None, None,
|
||||
math_ops.reduce_sum(
|
||||
grad * partial_a, axis=math_ops.range(num_sample_dimensions)))
|
||||
|
@ -18,9 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -171,6 +174,103 @@ def stateless_random_binomial(shape,
|
||||
return result
|
||||
|
||||
|
||||
@tf_export("random.stateless_gamma")
|
||||
def stateless_random_gamma(shape,
|
||||
seed,
|
||||
alpha,
|
||||
beta=None,
|
||||
dtype=dtypes.float32,
|
||||
name=None):
|
||||
"""Outputs deterministic pseudorandom values from a gamma distribution.
|
||||
|
||||
The generated values follow a gamma distribution with specified concentration
|
||||
(`alpha`) and inverse scale (`beta`) parameters.
|
||||
|
||||
This is a stateless version of `tf.random.gamma`: if run twice with the same
|
||||
seeds, it will produce the same pseudorandom numbers. The output is consistent
|
||||
across multiple runs on the same hardware (and between CPU and GPU), but may
|
||||
change between versions of TensorFlow or on non-CPU/GPU hardware.
|
||||
|
||||
A slight difference exists in the interpretation of the `shape` parameter
|
||||
between `stateless_gamma` and `gamma`: in `gamma`, the `shape` is always
|
||||
prepended to the shape of the broadcast of `alpha` with `beta`; whereas in
|
||||
`stateless_gamma` the `shape` parameter must always encompass the shapes of
|
||||
each of `alpha` and `beta` (which must broadcast together to match the
|
||||
trailing dimensions of `shape`).
|
||||
|
||||
Note: Because internal calculations are done using `float64` and casting has
|
||||
`floor` semantics, we must manually map zero outcomes to the smallest
|
||||
possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This
|
||||
means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise
|
||||
should. This bias can only happen for small values of `alpha`, i.e.,
|
||||
`alpha << 1` or large values of `beta`, i.e., `beta >> 1`.
|
||||
|
||||
The samples are differentiable w.r.t. alpha and beta.
|
||||
The derivatives are computed using the approach described in
|
||||
(Figurnov et al., 2018).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
samples = tf.random.stateless_gamma([10, 2], seed=[12, 34], alpha=[0.5, 1.5])
|
||||
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
|
||||
# the samples drawn from each distribution
|
||||
|
||||
samples = tf.random.stateless_gamma([7, 5, 2], seed=[12, 34], alpha=[.5, 1.5])
|
||||
# samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
|
||||
# represents the 7x5 samples drawn from each of the two distributions
|
||||
|
||||
alpha = tf.constant([[1.], [3.], [5.]])
|
||||
beta = tf.constant([[3., 4.]])
|
||||
samples = tf.random.stateless_gamma(
|
||||
[30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)
|
||||
# samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions.
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
tape.watch([alpha, beta])
|
||||
loss = tf.reduce_mean(tf.square(tf.random.stateless_gamma(
|
||||
[30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)))
|
||||
dloss_dalpha, dloss_dbeta = tape.gradient(loss, [alpha, beta])
|
||||
# unbiased stochastic derivatives of the loss function
|
||||
alpha.shape == dloss_dalpha.shape # True
|
||||
beta.shape == dloss_dbeta.shape # True
|
||||
```
|
||||
|
||||
Args:
|
||||
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
|
||||
seed: A shape [2] integer Tensor of seeds to the random number generator.
|
||||
alpha: Tensor. The concentration parameter of the gamma distribution. Must
|
||||
be broadcastable with `beta`, and broadcastable with the rightmost
|
||||
dimensions of `shape`.
|
||||
beta: Tensor. The inverse scale parameter of the gamma distribution. Must be
|
||||
broadcastable with `alpha` and broadcastable with the rightmost dimensions
|
||||
of `shape`.
|
||||
dtype: Floating point dtype of `alpha`, `beta`, and the output.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
samples: A Tensor of the specified shape filled with random gamma values.
|
||||
For each i, each `samples[..., i] is an independent draw from the gamma
|
||||
distribution with concentration alpha[i] and scale beta[i].
|
||||
|
||||
"""
|
||||
with ops.name_scope(name, "stateless_random_gamma",
|
||||
[shape, seed, alpha, beta]) as name:
|
||||
shape = tensor_util.shape_tensor(shape)
|
||||
alpha = ops.convert_to_tensor(alpha, dtype=dtype, name="alpha")
|
||||
beta = ops.convert_to_tensor(
|
||||
beta if beta is not None else 1, name="beta", dtype=dtype)
|
||||
broadcast_shape = array_ops.broadcast_dynamic_shape(
|
||||
array_ops.shape(alpha), array_ops.shape(beta))
|
||||
alpha_broadcast = array_ops.broadcast_to(alpha, broadcast_shape)
|
||||
result = math_ops.maximum(
|
||||
np.finfo(alpha.dtype.as_numpy_dtype).tiny,
|
||||
gen_stateless_random_ops.stateless_random_gamma_v2(
|
||||
shape, seed=seed, alpha=alpha_broadcast) / beta)
|
||||
tensor_util.maybe_set_static_shape(result, shape)
|
||||
return result
|
||||
|
||||
|
||||
@tf_export("random.stateless_normal")
|
||||
def stateless_random_normal(shape,
|
||||
seed,
|
||||
|
@ -80,6 +80,10 @@ tf_module {
|
||||
name: "stateless_categorical"
|
||||
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_gamma"
|
||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'beta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_multinomial"
|
||||
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
|
@ -4284,6 +4284,10 @@ tf_module {
|
||||
name: "StatelessRandomBinomial"
|
||||
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomGammaV2"
|
||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
|
@ -72,6 +72,10 @@ tf_module {
|
||||
name: "stateless_categorical"
|
||||
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_gamma"
|
||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'beta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_normal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
|
@ -4284,6 +4284,10 @@ tf_module {
|
||||
name: "StatelessRandomBinomial"
|
||||
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomGammaV2"
|
||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user