Adds a tf.random.stateless_gamma sampler for CPU. (Stateless analogue to tf.random.gamma.)

PiperOrigin-RevId: 293454705
Change-Id: I103900822dac37989246eaf7b410157e5563b78b
This commit is contained in:
Brian Patton 2020-02-05 14:39:49 -08:00 committed by TensorFlower Gardener
parent f95a6caa8b
commit 31c94250fc
13 changed files with 426 additions and 12 deletions

View File

@ -0,0 +1,41 @@
op {
graph_op_name: "StatelessRandomGammaV2"
visibility: HIDDEN
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
in_arg {
name: "seed"
description: <<END
2 seeds (shape [2]).
END
}
in_arg {
name: "alpha"
description: <<END
The concentration of the gamma distribution. Shape must match the rightmost
dimensions of `shape`.
END
}
out_arg {
name: "output"
description: <<END
Random values with specified shape.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
summary: "Outputs deterministic pseudorandom random numbers from a gamma distribution."
description: <<END
Outputs random values from a gamma distribution.
The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
END
}

View File

@ -167,7 +167,7 @@ class RandomGammaOp : public OpKernel {
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
&samples_shape));
}
const int64 num_samples = samples_shape.num_elements();
const int64 samples_per_alpha = samples_shape.num_elements();
samples_shape.AppendShape(alpha_t.shape());
// Allocate output samples.
@ -199,13 +199,13 @@ class RandomGammaOp : public OpKernel {
num_alphas));
auto samples_flat = samples_t->flat<T>().data();
PhiloxRandom rng = generator_.ReserveRandomOutputs(
num_samples * num_alphas, kReservedSamplesPerOutput);
samples_per_alpha * num_alphas, kReservedSamplesPerOutput);
// We partition work first across alphas then across samples-per-alpha to
// avoid a couple flops which can be done on a per-alpha basis.
auto DoWork = [num_samples, num_alphas, &rng, samples_flat, alpha_flat](
int start_output, int limit_output) {
auto DoWork = [samples_per_alpha, num_alphas, &rng, samples_flat,
alpha_flat](int start_output, int limit_output) {
using Eigen::numext::exp;
using Eigen::numext::log;
using Eigen::numext::pow;
@ -220,7 +220,7 @@ class RandomGammaOp : public OpKernel {
typename Uniform::ResultType uniform_result;
for (int64 output_idx = start_output; output_idx < limit_output;
/* output_idx incremented within inner loop below */) {
int64 alpha_idx = output_idx / num_samples;
int64 alpha_idx = output_idx / samples_per_alpha;
// Instead of +alpha_idx for each sample, we offset the pointer once.
T* const samples_alpha_offset = samples_flat + alpha_idx;
@ -232,8 +232,8 @@ class RandomGammaOp : public OpKernel {
if (alpha == static_cast<double>(1.0)) {
ENABLE_FLOAT_EQUALITY_WARNING
// Sample from an exponential distribution.
for (int64 sample_idx = output_idx % num_samples;
sample_idx < num_samples && output_idx < limit_output;
for (int64 sample_idx = output_idx % samples_per_alpha;
sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) {
// As we want data stable regardless of sharding
// (including eventually on GPU), we skip on a per-sample basis.
@ -259,8 +259,8 @@ class RandomGammaOp : public OpKernel {
const double c = 1.0 / 3 / sqrt(d);
// Compute the rest of the samples for the current alpha value.
for (int64 sample_idx = output_idx % num_samples;
sample_idx < num_samples && output_idx < limit_output;
for (int64 sample_idx = output_idx % samples_per_alpha;
sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) {
// Since each sample may use a variable number of normal/uniform
// samples, and we want data stable regardless of sharding
@ -317,7 +317,7 @@ class RandomGammaOp : public OpKernel {
3 * PhiloxRandom::kElementCost;
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
num_alphas * num_samples, kElementCost, DoWork);
num_alphas * samples_per_alpha, kElementCost, DoWork);
}
private:

View File

@ -22,6 +22,17 @@ limitations under the License.
#include "tensorflow/core/kernels/random_op.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/work_sharder.h"
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
#define DISABLE_FLOAT_EQUALITY_WARNING \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
#else
#define DISABLE_FLOAT_EQUALITY_WARNING
#define ENABLE_FLOAT_EQUALITY_WARNING
#endif
namespace tensorflow {
@ -151,6 +162,162 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
}
};
template <typename Device, typename T>
class StatelessRandomGammaOp : public StatelessRandomOpBase {
public:
using StatelessRandomOpBase::StatelessRandomOpBase;
void Fill(OpKernelContext* ctx, random::PhiloxRandom random,
Tensor* output) override {
const Tensor& alpha_t = ctx->input(2);
TensorShape samples_shape = output->shape();
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, alpha_t.shape()),
errors::InvalidArgument(
"Shape passed in must end with broadcasted shape."));
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
#define UNIFORM(X) \
if (uniform_remaining == 0) { \
uniform_remaining = Uniform::kResultElementCount; \
uniform_result = uniform(&gen); \
} \
uniform_remaining--; \
double X = uniform_result[uniform_remaining]
// Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform
static constexpr int kReservedSamplesPerOutput = 256;
const int64 num_alphas = alpha_t.NumElements();
OP_REQUIRES(ctx, num_alphas > 0,
errors::InvalidArgument(
"Input alpha should have non-zero element count, got: ",
num_alphas));
const int64 samples_per_alpha = samples_shape.num_elements() / num_alphas;
const auto alpha_flat = alpha_t.flat<T>().data();
auto samples_flat = output->flat<T>().data();
// We partition work first across alphas then across samples-per-alpha to
// avoid a couple flops which can be done on a per-alpha basis.
auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat,
alpha_flat](int start_output, int limit_output) {
// Capturing "random" by-value would only make a copy for the _shared_
// lambda. Since we want to let each worker have its own copy, we pass
// "random" by reference and explicitly do a copy assignment.
using Eigen::numext::exp;
using Eigen::numext::log;
using Eigen::numext::pow;
Normal normal;
Uniform uniform;
typename Normal::ResultType norm_result;
typename Uniform::ResultType uniform_result;
for (int64 output_idx = start_output; output_idx < limit_output;
/* output_idx incremented within inner loop below */) {
int64 alpha_idx = output_idx / samples_per_alpha;
// Instead of +alpha_idx for each sample, we offset the pointer once.
T* const samples_alpha_offset = samples_flat + alpha_idx;
// Several calculations can be done on a per-alpha basis.
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
DISABLE_FLOAT_EQUALITY_WARNING
if (alpha == static_cast<double>(1.0)) {
ENABLE_FLOAT_EQUALITY_WARNING
// Sample from an exponential distribution.
for (int64 sample_idx = output_idx % samples_per_alpha;
sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) {
// As we want data stable regardless of sharding
// (including eventually on GPU), we skip on a per-sample basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
int16 uniform_remaining = 0;
UNIFORM(u);
const double res = -log(1.0 - u);
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
} // for (sample_idx)
} else { // if alpha != 1.0
// Transformation-rejection from pairs of uniform and normal random
// variables. http://dl.acm.org/citation.cfm?id=358414
//
// The algorithm has an acceptance rate of ~95% for small alpha (~1),
// and higher accept rates for higher alpha, so runtime is
// O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
//
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
// result by uniform()^(1/alpha)
const bool alpha_less_than_one = alpha < 1;
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
const double c = 1.0 / 3 / sqrt(d);
// Compute the rest of the samples for the current alpha value.
for (int64 sample_idx = output_idx % samples_per_alpha;
sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) {
// Since each sample may use a variable number of normal/uniform
// samples, and we want data stable regardless of sharding
// (including eventually on GPU), we skip on a per-sample basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
int16 norm_remaining = 0;
int16 uniform_remaining = 0;
// Keep trying until we don't reject a sample. In practice, we will
// only reject ~5% at worst, for low alpha near 1.
while (true) {
if (norm_remaining == 0) {
norm_remaining = Normal::kResultElementCount;
norm_result = normal(&gen);
}
norm_remaining--;
const double x = norm_result[norm_remaining];
double v = 1 + c * x;
if (v <= 0) {
continue;
}
v = v * v * v;
UNIFORM(u);
// The first option in the if is a "squeeze" short-circuit to
// dodge the two logs. Magic constant sourced from the paper
// linked above. Upward of .91 of the area covered by the log
// inequality is covered by the squeeze as well (larger coverage
// for smaller values of alpha).
if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
(log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
double res = d * v;
if (alpha_less_than_one) {
UNIFORM(b);
res *= pow(b, 1 / alpha);
}
samples_alpha_offset[sample_idx * num_alphas] =
static_cast<T>(res);
break;
}
} // while: true
} // for: sample_idx
} // if (alpha == 1.0)
} // for: output_idx
}; // DoWork
#undef UNIFORM
// Two calls to log only occur for ~10% of samples reaching the log line.
// 2 x 100 (64-bit cycles per log) x 0.10 = ~20.
// Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
// each = ~60.
// All of this /0.95 due to the rejection possibility = ~85.
static const int kElementCost = 85 + 2 * Normal::kElementCost +
Uniform::kElementCost +
3 * random::PhiloxRandom::kElementCost;
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
num_alphas * samples_per_alpha, kElementCost, DoWork);
}
};
#define REGISTER(DEVICE, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomUniform") \
@ -177,7 +344,7 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
StatelessRandomOp< \
DEVICE##Device, \
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
#define REGISTER_INT(DEVICE, TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \
@ -201,6 +368,22 @@ TF_CALL_double(REGISTER_CPU);
TF_CALL_int32(REGISTER_INT_CPU);
TF_CALL_int64(REGISTER_INT_CPU);
#define REGISTER_GAMMA(TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
.Device(DEVICE_CPU) \
.HostMemory("shape") \
.HostMemory("seed") \
.HostMemory("alpha") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomGammaOp<CPUDevice, TYPE>)
TF_CALL_half(REGISTER_GAMMA);
TF_CALL_bfloat16(REGISTER_GAMMA);
TF_CALL_float(REGISTER_GAMMA);
TF_CALL_double(REGISTER_GAMMA);
#undef REGISTER_GAMMA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_half(REGISTER_GPU);

View File

@ -105,4 +105,14 @@ REGISTER_OP("StatelessRandomBinomial")
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
.SetShapeFn(StatelessShape);
REGISTER_OP("StatelessRandomGammaV2")
.Input("shape: T")
.Input("seed: Tseed")
.Input("alpha: dtype")
.Output("output: dtype")
.Attr("dtype: {float16, float32, float64}")
.Attr("T: {int32, int64}")
.Attr("Tseed: {int32, int64} = DT_INT64")
.SetShapeFn(StatelessShape);
} // namespace tensorflow

View File

@ -308,6 +308,7 @@ bool OpGradientDoesntRequireInputIndices(
{"StackPush", {true, {}}},
{"StatelessMultinomial", {true, {}}},
{"StatelessRandomBinomial", {true, {}}},
{"StatelessRandomGammaV2", {false, {1}}},
{"StatelessRandomNormal", {true, {}}},
{"StatelessRandomUniform", {true, {}}},
{"StatelessRandomUniformInt", {true, {}}},

View File

@ -100,6 +100,7 @@ cuda_py_test(
name = "stateless_random_ops_test",
size = "medium",
srcs = ["stateless_random_ops_test.py"],
shard_count = 2,
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",

View File

@ -129,6 +129,16 @@ class StatelessOpsTest(test.TestCase):
yield (functools.partial(stateless.stateless_multinomial, **kwds),
functools.partial(random_ops.multinomial, **kwds))
def _gamma_cases(self):
for dtype in np.float16, np.float32, np.float64:
for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]):
kwds = dict(alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype)
yield (functools.partial(
stateless.stateless_random_gamma,
shape=(10,) + tuple(np.shape(alpha)),
**kwds),
functools.partial(random_ops.random_gamma, shape=(10,), **kwds))
@test_util.run_deprecated_v1
def testMatchFloat(self):
self._test_match(self._float_cases())
@ -141,6 +151,10 @@ class StatelessOpsTest(test.TestCase):
def testMatchMultinomial(self):
self._test_match(self._multinomial_cases())
@test_util.run_deprecated_v1
def testMatchGamma(self):
self._test_match(self._gamma_cases())
@test_util.run_deprecated_v1
def testDeterminismFloat(self):
self._test_determinism(
@ -155,6 +169,10 @@ class StatelessOpsTest(test.TestCase):
def testDeterminismMultinomial(self):
self._test_determinism(self._multinomial_cases())
@test_util.run_deprecated_v1
def testDeterminismGamma(self):
self._test_determinism(self._gamma_cases())
if __name__ == '__main__':
test.main()

View File

@ -25,7 +25,7 @@ from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import math_ops
def add_leading_unit_dimensions(x, num_dimensions):
def add_leading_unit_dimensions(x, num_dimensions): # pylint: disable=invalid-name
new_shape = array_ops.concat(
[array_ops.ones([num_dimensions], dtype=dtypes.int32),
array_ops.shape(x)], axis=0)
@ -70,3 +70,47 @@ def _RandomGammaGrad(op, grad): # pylint: disable=invalid-name
# The first input is shape; the second input is alpha.
return (None, math_ops.reduce_sum(
grad * partial_a, axis=math_ops.range(num_sample_dimensions)))
@ops.RegisterGradient("StatelessRandomGammaV2")
def _StatelessRandomGammaV2Grad(op, grad): # pylint: disable=invalid-name
"""Returns the gradient of a Gamma sample w.r.t. alpha.
The gradient is computed using implicit differentiation
(Figurnov et al., 2018).
Args:
op: A `StatelessRandomGamma` operation. We assume that the inputs to the
operation are `shape`, `seed` and `alpha` tensors, and the output is the
`sample` tensor.
grad: The incoming gradient `dloss / dsample` of the same shape as
`op.outputs[0]`.
Returns:
A `Tensor` with derivatives `dloss / dalpha`.
References:
Implicit Reparameterization Gradients:
[Figurnov et al., 2018]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
([pdf]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
"""
shape = op.inputs[0]
alpha = op.inputs[2]
sample = op.outputs[0]
with ops.control_dependencies([grad]):
# Note that the shape handling is slightly different for stateless_gamma,
# in particular num_sample_dimensions is different.
num_sample_dimensions = array_ops.shape(shape)[0] - array_ops.rank(alpha)
# Make the parameters alpha broadcastable with samples by appending
# unit dimensions.
alpha_broadcastable = add_leading_unit_dimensions(alpha,
num_sample_dimensions)
partial_a = gen_random_ops.random_gamma_grad(alpha_broadcastable, sample)
# The first two inputs are shape, seed, third input is alpha.
return (None, None,
math_ops.reduce_sum(
grad * partial_a, axis=math_ops.range(num_sample_dimensions)))

View File

@ -18,9 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_stateless_random_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
@ -171,6 +174,103 @@ def stateless_random_binomial(shape,
return result
@tf_export("random.stateless_gamma")
def stateless_random_gamma(shape,
seed,
alpha,
beta=None,
dtype=dtypes.float32,
name=None):
"""Outputs deterministic pseudorandom values from a gamma distribution.
The generated values follow a gamma distribution with specified concentration
(`alpha`) and inverse scale (`beta`) parameters.
This is a stateless version of `tf.random.gamma`: if run twice with the same
seeds, it will produce the same pseudorandom numbers. The output is consistent
across multiple runs on the same hardware (and between CPU and GPU), but may
change between versions of TensorFlow or on non-CPU/GPU hardware.
A slight difference exists in the interpretation of the `shape` parameter
between `stateless_gamma` and `gamma`: in `gamma`, the `shape` is always
prepended to the shape of the broadcast of `alpha` with `beta`; whereas in
`stateless_gamma` the `shape` parameter must always encompass the shapes of
each of `alpha` and `beta` (which must broadcast together to match the
trailing dimensions of `shape`).
Note: Because internal calculations are done using `float64` and casting has
`floor` semantics, we must manually map zero outcomes to the smallest
possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This
means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise
should. This bias can only happen for small values of `alpha`, i.e.,
`alpha << 1` or large values of `beta`, i.e., `beta >> 1`.
The samples are differentiable w.r.t. alpha and beta.
The derivatives are computed using the approach described in
(Figurnov et al., 2018).
Example:
```python
samples = tf.random.stateless_gamma([10, 2], seed=[12, 34], alpha=[0.5, 1.5])
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
# the samples drawn from each distribution
samples = tf.random.stateless_gamma([7, 5, 2], seed=[12, 34], alpha=[.5, 1.5])
# samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
# represents the 7x5 samples drawn from each of the two distributions
alpha = tf.constant([[1.], [3.], [5.]])
beta = tf.constant([[3., 4.]])
samples = tf.random.stateless_gamma(
[30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)
# samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions.
with tf.GradientTape() as tape:
tape.watch([alpha, beta])
loss = tf.reduce_mean(tf.square(tf.random.stateless_gamma(
[30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)))
dloss_dalpha, dloss_dbeta = tape.gradient(loss, [alpha, beta])
# unbiased stochastic derivatives of the loss function
alpha.shape == dloss_dalpha.shape # True
beta.shape == dloss_dbeta.shape # True
```
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
seed: A shape [2] integer Tensor of seeds to the random number generator.
alpha: Tensor. The concentration parameter of the gamma distribution. Must
be broadcastable with `beta`, and broadcastable with the rightmost
dimensions of `shape`.
beta: Tensor. The inverse scale parameter of the gamma distribution. Must be
broadcastable with `alpha` and broadcastable with the rightmost dimensions
of `shape`.
dtype: Floating point dtype of `alpha`, `beta`, and the output.
name: A name for the operation (optional).
Returns:
samples: A Tensor of the specified shape filled with random gamma values.
For each i, each `samples[..., i] is an independent draw from the gamma
distribution with concentration alpha[i] and scale beta[i].
"""
with ops.name_scope(name, "stateless_random_gamma",
[shape, seed, alpha, beta]) as name:
shape = tensor_util.shape_tensor(shape)
alpha = ops.convert_to_tensor(alpha, dtype=dtype, name="alpha")
beta = ops.convert_to_tensor(
beta if beta is not None else 1, name="beta", dtype=dtype)
broadcast_shape = array_ops.broadcast_dynamic_shape(
array_ops.shape(alpha), array_ops.shape(beta))
alpha_broadcast = array_ops.broadcast_to(alpha, broadcast_shape)
result = math_ops.maximum(
np.finfo(alpha.dtype.as_numpy_dtype).tiny,
gen_stateless_random_ops.stateless_random_gamma_v2(
shape, seed=seed, alpha=alpha_broadcast) / beta)
tensor_util.maybe_set_static_shape(result, shape)
return result
@tf_export("random.stateless_normal")
def stateless_random_normal(shape,
seed,

View File

@ -80,6 +80,10 @@ tf_module {
name: "stateless_categorical"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "stateless_gamma"
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'beta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "stateless_multinomial"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "

View File

@ -4284,6 +4284,10 @@ tf_module {
name: "StatelessRandomBinomial"
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomGammaV2"
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomNormal"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "

View File

@ -72,6 +72,10 @@ tf_module {
name: "stateless_categorical"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "stateless_gamma"
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'beta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "stateless_normal"
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "

View File

@ -4284,6 +4284,10 @@ tf_module {
name: "StatelessRandomBinomial"
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomGammaV2"
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomNormal"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "