Add StatelessParameterizedTruncatedNormal sampler.

This sampler supports broadcasting of its input parameters as well as puts the # samples at the left of the output shape, rather than the right.

PiperOrigin-RevId: 317129622
Change-Id: I4b62ad2e89a9637ae8b30b73af4b662ad0caa943
This commit is contained in:
Srinivas Vasudevan 2020-06-18 10:23:23 -07:00 committed by TensorFlower Gardener
parent 89b80c5fb9
commit 455750f362
15 changed files with 913 additions and 71 deletions

View File

@ -0,0 +1,54 @@
op {
graph_op_name: "StatelessParameterizedTruncatedNormal"
visibility: HIDDEN
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
in_arg {
name: "seed"
description: <<END
2 seeds (shape [2]).
END
}
in_arg {
name: "means"
description: <<END
The mean parameter of each batch.
END
}
in_arg {
name: "stddevs"
description: <<END
The standard deviation parameter of each batch. Must be greater than 0.
END
}
in_arg {
name: "minvals"
description: <<END
The minimum cutoff. May be -infinity.
END
}
in_arg {
name: "maxvals"
description: <<END
The maximum cutoff. May be +infinity, and must be more than the minval
for each batch.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
out_arg {
name: "output"
description: <<END
The outputs are truncated normal samples and are a deterministic function of
`shape`, `seed`, `minvals`, `maxvals`, `means` and `stddevs`.
END
}
}

View File

@ -6156,6 +6156,7 @@ tf_kernel_library(
]),
prefix = "parameterized_truncated_normal_op",
deps = [
":stateless_random_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/stateless_random_ops.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/guarded_philox_random.h"
@ -67,10 +68,10 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3);
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
auto DoWork = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
&minvals, &maxvals, &gen, &output,
kStdDevsInsideBoundsToUseRandnSampler](int start_batch,
int limit_batch) {
auto do_work = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
&minvals, &maxvals, &gen, &output,
kStdDevsInsideBoundsToUseRandnSampler](int start_batch,
int limit_batch) {
// Capturing "gen" by-value would only make a copy for the _shared_
// lambda. Since we want to let each worker have its own copy, we pass
// "gen" by reference and explicitly do a copy assignment here.
@ -80,9 +81,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
// The sample from each iteration uses 2 random numbers.
gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) /
4);
typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
using Uniform = random::UniformDistribution<random::PhiloxRandom, T>;
Uniform dist;
typedef random::NormalDistribution<random::PhiloxRandom, T> Normal;
using Normal = random::NormalDistribution<random::PhiloxRandom, T>;
Normal normal_dist;
// Vectorized intermediate calculations for uniform rejection sampling.
@ -112,7 +113,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
Eigen::numext::isfinite(maxval)),
errors::InvalidArgument("Invalid parameters"));
int numIterations = 0;
int num_iterations = 0;
// If possible, make one-sided bound be the lower bound, or make both
// bounds positive. Otherwise, the bounds are on either side of the
@ -160,10 +161,10 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
if (sample >= limit_sample) {
break;
}
numIterations = 0;
num_iterations = 0;
} else {
numIterations++;
if (numIterations > kMaxIterations) {
num_iterations++;
if (num_iterations > kMaxIterations) {
// This should never occur because this sampler should
// (by the selection criteria above) be used if at least 3
// standard deviations of one side of the distribution
@ -201,7 +202,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
const auto u = dist(&gen_copy);
for (int i = 0; i < size; i++) {
auto accept = u[i] <= Eigen::numext::exp(g[i]);
if (accept || numIterations + 1 >= kMaxIterations) {
if (accept || num_iterations + 1 >= kMaxIterations) {
// Accept the sample z.
// If we run out of iterations, just use the current uniform
// sample, but emit a warning.
@ -223,9 +224,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
if (sample >= limit_sample) {
break;
}
numIterations = 0;
num_iterations = 0;
} else {
numIterations++;
num_iterations++;
}
}
}
@ -248,7 +249,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
const T u = rand[i];
i++;
auto accept = (u <= g && z < normMax);
if (accept || numIterations + 1 >= kMaxIterations) {
if (accept || num_iterations + 1 >= kMaxIterations) {
if (!accept) {
LOG(ERROR) << "TruncatedNormal exponential distribution "
<< "rejection sampler exceeds max iterations. "
@ -263,9 +264,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
if (sample >= limit_sample) {
break;
}
numIterations = 0;
num_iterations = 0;
} else {
numIterations++;
num_iterations++;
}
}
}
@ -305,7 +306,297 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
const int64 batchCost =
batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch;
Shard(worker_threads.num_threads, worker_threads.workers, num_batches,
batchCost, DoWork);
batchCost, do_work);
}
};
template <typename T>
struct TruncatedNormalFunctorV2<CPUDevice, T> {
void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
const BCastList<4>& bcast,
typename TTypes<T>::ConstFlat means,
typename TTypes<T>::ConstFlat stddevs,
typename TTypes<T>::ConstFlat minvals,
typename TTypes<T>::ConstFlat maxvals,
const random::PhiloxRandom& gen,
typename TTypes<T>::Flat output) {
// The randn rejection sampling is used when the mean and at least this many
// standard deviations are inside the bounds.
// The uniform proposal samplers become less efficient as the bounds are
// further from the mean, the reverse is true for the randn sampler.
// This number was chosen by empirical benchmarking. If modified, the
// benchmarks in parameterized_truncated_normal_op_test should also be
// changed.
const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3);
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
auto do_work = [num_batches, samples_per_batch, &ctx, &bcast, &means,
&stddevs, &minvals, &maxvals, &gen, &output,
kStdDevsInsideBoundsToUseRandnSampler](int start_output,
int limit_output) {
// Capturing "gen" by-value would only make a copy for the _shared_
// lambda. Since we want to let each worker have its own copy, we pass
// "gen" by reference and explicitly do a copy assignment here.
random::PhiloxRandom gen_copy = gen;
using Uniform = random::UniformDistribution<random::PhiloxRandom, T>;
Uniform dist;
using Normal = random::NormalDistribution<random::PhiloxRandom, T>;
Normal normal_dist;
// Skip takes units of 128 bits. The Uniform::kResultElementCount - 1
// is so rounding doesn't lead to
// us using the same state in different workloads.
// The sample from each iteration uses 2 random numbers.
gen_copy.Skip((start_output * 2 * kMaxIterations +
Uniform::kResultElementCount - 1) /
Uniform::kResultElementCount);
// Vectorized intermediate calculations for uniform rejection sampling.
// We always generate at most 4 samples.
Eigen::array<T, Uniform::kResultElementCount> z;
Eigen::array<T, Uniform::kResultElementCount> g;
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& means_batch_indices = bcast.batch_indices(0);
const auto& stddevs_batch_indices = bcast.batch_indices(1);
const auto& minvals_batch_indices = bcast.batch_indices(2);
const auto& maxvals_batch_indices = bcast.batch_indices(3);
auto output_flat = output.data();
// We partition work across batches and then across samples
// per batch member, to avoid extra work.
for (int64 output_idx = start_output; output_idx < limit_output;
// output_idx is incremented with the inner loops below.
) {
int64 batch_idx = output_idx / samples_per_batch;
// The output layout is [samples_per_batch, num_batches]. Thus
// the output address is sample_idx * num_batches + batch_idx.
// Below, code will index at output_batch_offset[sample_idx *
// num_batches] matching this.
T* const output_batch_offset = output_flat + batch_idx;
// Generate batch counts from BCast, as it has the right indices to loop
// over.
T mean, stddev, minval, maxval;
if (should_bcast) {
mean = means(means_batch_indices[batch_idx]);
stddev = stddevs(stddevs_batch_indices[batch_idx]);
minval = minvals(minvals_batch_indices[batch_idx]);
maxval = maxvals(maxvals_batch_indices[batch_idx]);
} else {
mean = means(batch_idx);
stddev = stddevs(batch_idx);
minval = minvals(batch_idx);
maxval = maxvals(batch_idx);
}
// On GPU, this check will just fill samples with NAN if it fails.
OP_REQUIRES(ctx,
stddev > T(0) && minval < maxval &&
(Eigen::numext::isfinite(minval) ||
Eigen::numext::isfinite(maxval)),
errors::InvalidArgument("Invalid parameters"));
int num_iterations = 0;
// If possible, make one-sided bound be the lower bound, or make both
// bounds positive. Otherwise, the bounds are on either side of the
// mean.
if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) {
// Reverse all calculations. normMin and normMax will be flipped.
std::swap(minval, maxval);
stddev = -stddev;
}
// Calculate normalized samples, then convert them.
const T normMin = (minval - mean) / stddev;
const T normMax = (maxval - mean) / stddev;
// Determine the method to use.
const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
const T cutoff =
T(2) *
Eigen::numext::exp(T(0.5) +
(normMin * (normMin - sqrtFactor)) / T(4)) /
(normMin + sqrtFactor);
const T diff = normMax - normMin;
if (((normMin < -kStdDevsInsideBoundsToUseRandnSampler) &&
(normMax >= T(0.))) ||
((normMax > kStdDevsInsideBoundsToUseRandnSampler) &&
(normMin <= T(0.)))) {
// If the bounds are a least 3 standard deviations from the mean
// on at least one side then we rejection sample by sampling
// from the normal distribution and rejecting samples outside
// the bounds.
// Under this condition the acceptance rate per iteration should
// always be ~ 50%. This sampler is more efficient (and more
// numerically stable when one or both bounds is far from the mean).
for (int64 sample_idx = output_idx % samples_per_batch;
sample_idx < samples_per_batch && output_idx < limit_output;) {
const auto randn_sample = normal_dist(&gen_copy);
const int size = randn_sample.size();
for (int i = 0; i < size; ++i) {
if ((randn_sample[i] >= normMin) &&
(randn_sample[i] <= normMax)) {
output_batch_offset[sample_idx * num_batches] =
randn_sample[i] * stddev + mean;
++sample_idx;
++output_idx;
if (sample_idx >= samples_per_batch ||
output_idx >= limit_output) {
break;
}
num_iterations = 0;
} else {
++num_iterations;
if (num_iterations > kMaxIterations) {
// This should never occur because this sampler should
// (by the selection criteria above) be used if at least 3
// standard deviations of one side of the distribution
// is within the limits (so acceptance probability per
// iterations >~ 1/2 per iteration).
LOG(ERROR) << "TruncatedNormal randn rejection sampler "
<< "exceeded maximum iterations for "
<< "normMin=" << normMin << " normMax=" << normMax
<< " kMaxIterations=" << kMaxIterations;
ctx->SetStatus(errors::Internal(
"TruncatedNormal randn rejection sampler failed to accept"
" a sample."));
return;
}
}
}
}
} else if (diff < cutoff) {
// Sample from a uniform distribution on [normMin, normMax].
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
for (int64 sample_idx = output_idx % samples_per_batch;
sample_idx < samples_per_batch && output_idx < limit_output;) {
const auto rand = dist(&gen_copy);
const int size = rand.size();
// NOTE(ringwalt): These loops seem to only generate packed AVX
// instructions for float32.
for (int i = 0; i < size; i++) {
z[i] = rand[i] * diff + normMin;
g[i] = (plusFactor - z[i] * z[i]) / T(2.0);
}
const auto u = dist(&gen_copy);
for (int i = 0; i < size; i++) {
auto accept = u[i] <= Eigen::numext::exp(g[i]);
if (accept || num_iterations + 1 >= kMaxIterations) {
// Accept the sample z.
// If we run out of iterations, just use the current uniform
// sample, but emit a warning.
// TODO(jjhunt) For small entropies (relative to the bounds),
// this sampler is poor and may take many iterations since
// the proposal distribution is the uniform distribution
// U(lower_bound, upper_bound).
if (!accept) {
LOG(ERROR) << "TruncatedNormal uniform rejection sampler "
<< "exceeded max iterations. Sample may contain "
<< "outliers.";
ctx->SetStatus(errors::Internal(
"TruncatedNormal uniform rejection sampler failed to "
" accept a sample."));
return;
}
output_batch_offset[sample_idx * num_batches] =
z[i] * stddev + mean;
++sample_idx;
++output_idx;
if (sample_idx >= samples_per_batch ||
output_idx >= limit_output) {
break;
}
num_iterations = 0;
} else {
num_iterations++;
}
}
}
} else {
// Sample from an exponential distribution with alpha maximizing
// acceptance probability, offset by normMin from the origin.
// Accept only if less than normMax.
const T alpha =
(normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) /
T(2);
for (int64 sample_idx = output_idx % samples_per_batch;
sample_idx < samples_per_batch && output_idx < limit_output;) {
auto rand = dist(&gen_copy);
const int size = rand.size();
int i = 0;
while (i < size) {
const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
i++;
const T x = normMin < alpha ? alpha - z : normMin - alpha;
const T g = Eigen::numext::exp(-x * x / T(2.0));
const T u = rand[i];
i++;
auto accept = (u <= g && z < normMax);
if (accept || num_iterations + 1 >= kMaxIterations) {
if (!accept) {
LOG(ERROR) << "TruncatedNormal exponential distribution "
<< "rejection sampler exceeds max iterations. "
<< "Sample may contain outliers.";
ctx->SetStatus(errors::Internal(
"TruncatedNormal exponential distribution rejection"
" sampler failed to accept a sample."));
return;
}
output_batch_offset[sample_idx * num_batches] =
z * stddev + mean;
++sample_idx;
++output_idx;
if (sample_idx >= samples_per_batch ||
output_idx >= limit_output) {
break;
}
num_iterations = 0;
} else {
num_iterations++;
}
}
}
}
}
};
// The cost of the initial calculations for the batch.
const int64 batchInitCost =
// normMin, normMax
(Eigen::TensorOpCost::AddCost<T>() +
Eigen::TensorOpCost::MulCost<T>()) *
2
// sqrtFactor
+ Eigen::TensorOpCost::AddCost<T>() +
Eigen::TensorOpCost::MulCost<T>() +
Eigen::internal::functor_traits<
Eigen::internal::scalar_sqrt_op<T>>::Cost
// cutoff
+ Eigen::TensorOpCost::MulCost<T>() * 4 +
Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost
// diff
+ Eigen::TensorOpCost::AddCost<T>();
const int64 uniformSampleCost =
random::PhiloxRandom::kElementCost +
random::UniformDistribution<random::PhiloxRandom, T>::kElementCost;
// The cost of a single uniform sampling round.
const int64 uniformRejectionSamplingCost =
uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() +
Eigen::TensorOpCost::AddCost<T>() +
Eigen::TensorOpCost::MulCost<T>() * 2 +
Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost +
Eigen::internal::functor_traits<
Eigen::internal::scalar_exp_op<T>>::Cost +
Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>();
// Estimate the cost for an entire batch.
// Assume we use uniform sampling, and accept the 2nd sample on average.
const int64 batchCost = batchInitCost + uniformRejectionSamplingCost * 2;
Shard(worker_threads.num_threads, worker_threads.workers, num_elements,
batchCost, do_work);
}
};
@ -436,13 +727,113 @@ class ParameterizedTruncatedNormalOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp);
};
// Samples from a truncated normal distribution, using the given parameters.
template <typename Device, typename T>
class StatelessParameterizedTruncatedNormal : public OpKernel {
// Reshape batches so each batch is this size if possible.
static const int32 kDesiredBatchSize = 100;
public:
explicit StatelessParameterizedTruncatedNormal(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& shape_tensor = ctx->input(0);
const Tensor& seed_tensor = ctx->input(1);
const Tensor& means_tensor = ctx->input(2);
const Tensor& stddevs_tensor = ctx->input(3);
const Tensor& minvals_tensor = ctx->input(4);
const Tensor& maxvals_tensor = ctx->input(5);
OP_REQUIRES(ctx, seed_tensor.dims() == 1 && seed_tensor.dim_size(0) == 2,
errors::InvalidArgument("seed must have shape [2], not ",
seed_tensor.shape().DebugString()));
tensorflow::BCastList<4> bcast(
{means_tensor.shape().dim_sizes(), stddevs_tensor.shape().dim_sizes(),
minvals_tensor.shape().dim_sizes(),
maxvals_tensor.shape().dim_sizes()},
/*fewer_dims_optimization=*/false,
/*return_flattened_batch_indices=*/true);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
"means, stddevs, minvals, maxvals must have compatible "
"batch dimensions: ",
means_tensor.shape().DebugString(), " vs. ",
stddevs_tensor.shape().DebugString(), " vs. ",
minvals_tensor.shape().DebugString(), " vs. ",
maxvals_tensor.shape().DebugString()));
// Let's check that the shape tensor dominates the broadcasted tensor.
TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
errors::InvalidArgument("Input shape should be a vector, got shape: ",
shape_tensor.shape().DebugString()));
TensorShape output_shape;
if (shape_tensor.dtype() == DataType::DT_INT32) {
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
&output_shape));
} else {
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int64>(),
&output_shape));
}
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
errors::InvalidArgument(
"Shape passed in must end with broadcasted shape."));
int64 samples_per_batch = 1;
const int64 num_sample_dims =
(shape_tensor.dim_size(0) - bcast.output_shape().size());
for (int64 i = 0; i < num_sample_dims; ++i) {
samples_per_batch *= output_shape.dim_size(i);
}
int64 num_batches = 1;
for (int64 i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
num_batches *= output_shape.dim_size(i);
}
const int64 num_elements = num_batches * samples_per_batch;
Tensor* samples_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
auto truncFunctor = functor::TruncatedNormalFunctorV2<Device, T>();
// Each worker has the same fudge factor, so use it here.
random::PhiloxRandom::Key key;
random::PhiloxRandom::ResultType counter;
OP_REQUIRES_OK(ctx, GenerateKey(seed_tensor, &key, &counter));
auto philox = random::PhiloxRandom(counter, key);
truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches,
samples_per_batch, num_elements, bcast, means_tensor.flat<T>(),
stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(),
maxvals_tensor.flat<T>(), philox, samples_tensor->flat<T>());
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(StatelessParameterizedTruncatedNormal);
};
} // namespace
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("dtype"), \
ParameterizedTruncatedNormalOp<CPUDevice, TYPE>)
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("dtype"), \
ParameterizedTruncatedNormalOp<CPUDevice, TYPE>) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessParameterizedTruncatedNormal") \
.HostMemory("shape") \
.HostMemory("seed") \
.HostMemory("means") \
.HostMemory("stddevs") \
.HostMemory("minvals") \
.HostMemory("maxvals") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("dtype"), \
StatelessParameterizedTruncatedNormal<CPUDevice, TYPE>)
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
@ -44,6 +45,21 @@ struct TruncatedNormalFunctor {
typename TTypes<T>::Flat output);
};
// This version supports broadcasting of the arguments, as well as puts
// the sample dimension on the left.
template <typename Device, typename T>
struct TruncatedNormalFunctorV2 {
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
const BCastList<4>& bcast,
typename TTypes<T>::ConstFlat means,
typename TTypes<T>::ConstFlat stddevs,
typename TTypes<T>::ConstFlat minvals,
typename TTypes<T>::ConstFlat maxvals,
const random::PhiloxRandom& gen,
typename TTypes<T>::Flat output);
};
} // namespace functor
} // namespace tensorflow

View File

@ -124,6 +124,41 @@ REGISTER_OP("StatelessRandomBinomial")
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
.SetShapeFn(StatelessShape);
REGISTER_OP("StatelessParameterizedTruncatedNormal")
.Input("shape: S")
.Input("seed: Tseed")
.Input("means: dtype")
.Input("stddevs: dtype")
.Input("minvals: dtype")
.Input("maxvals: dtype")
.Output("output: dtype")
.Attr("S: {int32, int64}")
.Attr("Tseed: {int32, int64} = DT_INT64")
.Attr("dtype: {float16, float32, float64}")
.SetShapeFn([](InferenceContext* c) {
// Check seed shape
ShapeHandle seed;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused_dim));
ShapeHandle bcast_means_stddevs;
ShapeHandle bcast_except_maxvals;
ShapeHandle bcast_all;
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, c->input(2), c->input(3), true, &bcast_means_stddevs));
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, c->input(4), bcast_means_stddevs, true, &bcast_except_maxvals));
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, c->input(5), bcast_except_maxvals, true, &bcast_all));
// Set output shape
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
c->set_output(0, out);
return Status::OK();
});
REGISTER_OP("StatelessRandomPoisson")
.Input("shape: T")
.Input("seed: Tseed")

View File

@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 348> a = {{
static std::array<OpIndexInfo, 349> a = {{
{"Acosh"},
{"AllToAll", 1, {0}},
{"ApproximateEqual"},
@ -326,6 +326,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
{"StackPop"},
{"StackPush"},
{"StatelessMultinomial"},
{"StatelessParameterizedTruncatedNormal", 1, {1}},
{"StatelessRandomBinomial"},
{"StatelessRandomGammaV2", 1, {1}},
{"StatelessRandomNormal"},

View File

@ -785,24 +785,6 @@ cuda_py_test(
],
)
cuda_py_test(
name = "parameterized_truncated_normal_op_test",
size = "medium",
srcs = ["parameterized_truncated_normal_op_test.py"],
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:random_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "parsing_ops_test",
size = "medium",

View File

@ -20,6 +20,24 @@ py_library(
],
)
cuda_py_test(
name = "parameterized_truncated_normal_op_test",
size = "medium",
srcs = ["parameterized_truncated_normal_op_test.py"],
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:random_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "random_shuffle_queue_test",
size = "small",

View File

@ -27,11 +27,15 @@ from six.moves import range # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import stateless_random_ops as stateless
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@ -91,13 +95,8 @@ class TruncatedNormalMoments(object):
def calculate_moments(samples, max_moment):
moments = [0.0] * (max_moment + 1)
for sample in samples:
value = 1.0
for k in range(len(moments)):
moments[k] += value
value *= sample
for i in range(len(moments)):
moments[i] /= len(samples)
for k in range(len(moments)):
moments[k] = np.mean(samples**k, axis=0)
return moments
@ -118,16 +117,31 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
# Stop at moment 10 to avoid numerical errors in the theoretical moments.
max_moment = 10
def validateMoments(self, shape, mean, stddev, minval, maxval, seed=1618):
def validateMoments(self,
shape,
mean,
stddev,
minval,
maxval,
use_stateless=False,
seed=1618):
try:
# TruncatedNormalMoments requires scipy.stats.
# Give up early if we are unable to import it.
import scipy.stats # pylint: disable=g-import-not-at-top,unused-variable
random_seed.set_random_seed(seed)
with self.cached_session(use_gpu=True):
samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
minval,
maxval).eval()
if use_stateless:
# Generate a seed that stateless ops can use.
new_seed = random_ops.random_uniform([2],
seed=seed,
minval=0,
maxval=(2**31 - 1),
dtype=np.int32)
samples = stateless.stateless_parameterized_truncated_normal(
shape, new_seed, mean, stddev, minval, maxval).eval()
else:
samples = random_ops.parameterized_truncated_normal(
shape, mean, stddev, minval, maxval).eval()
assert (~np.isnan(samples)).all()
moments = calculate_moments(samples, self.max_moment)
expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval)
@ -144,14 +158,24 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
stddev,
minval,
maxval,
use_stateless=False,
seed=1618):
try:
import scipy.stats # pylint: disable=g-import-not-at-top
random_seed.set_random_seed(seed)
with self.cached_session(use_gpu=True):
samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
minval,
maxval).eval()
if use_stateless:
new_seed = random_ops.random_uniform([2],
seed=seed,
minval=0,
maxval=(2**31 - 1),
dtype=np.int32)
samples = stateless.stateless_parameterized_truncated_normal(
shape, new_seed, mean, stddev, minval, maxval).eval()
else:
samples = random_ops.parameterized_truncated_normal(
shape, mean, stddev, minval, maxval).eval()
assert (~np.isnan(samples)).all()
minval = max(mean - stddev * 10, minval)
maxval = min(mean + stddev * 10, maxval)
@ -169,61 +193,160 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
@test_util.run_deprecated_v1
def testDefaults(self):
self.validateMoments([10**5], 0.0, 1.0, -2.0, 2.0)
self.validateMoments([int(1e5)], 0.0, 1.0, -2.0, 2.0)
self.validateMoments([int(1e5)], 0.0, 1.0, -2.0, 2.0, use_stateless=True)
@test_util.run_deprecated_v1
def testShifted(self):
self.validateMoments([10**5], -1.0, 1.0, -2.0, 2.0)
self.validateMoments([int(1e5)], -1.0, 1.0, -2.0, 2.0)
self.validateMoments([int(1e5)], -1.0, 1.0, -2.0, 2.0, use_stateless=True)
@test_util.run_deprecated_v1
def testRightTail(self):
self.validateMoments([10**5], 0.0, 1.0, 4.0, np.infty)
self.validateMoments([int(1e5)], 0.0, 1.0, 4.0, np.infty)
self.validateMoments([int(1e5)],
0.0,
1.0,
4.0,
np.infty,
use_stateless=True)
@test_util.run_deprecated_v1
def testLeftTail(self):
self.validateMoments([10**5], 0.0, 1.0, -np.infty, -4.0)
self.validateMoments([int(1e5)], 0.0, 1.0, -np.infty, -4.0)
self.validateMoments([int(1e5)],
0.0,
1.0,
-np.infty,
-4.0,
use_stateless=True)
@test_util.run_deprecated_v1
def testLeftTailTwoSidedBounds(self):
self.validateMoments([10**5], 0.0, 1.0, -6.0, -3.0)
self.validateMoments([int(1e5)], 0.0, 1.0, -6.0, -3.0)
self.validateMoments([int(1e5)], 0.0, 1.0, -6.0, -3.0, use_stateless=True)
@test_util.run_deprecated_v1
@test_util.disable_xla("Low probability region")
def testTwoSidedLeftTailShifted(self):
self.validateKolmogorovSmirnov([10**5], 6.0, 1.0, -1.0, 1.0)
self.validateKolmogorovSmirnov([int(1e5)], 6.0, 1.0, -1.0, 1.0)
self.validateKolmogorovSmirnov([int(1e5)],
6.0,
1.0,
-1.0,
1.0,
use_stateless=True)
@test_util.run_deprecated_v1
@test_util.disable_xla("Low probability region")
def testRightTailShifted(self):
self.validateMoments([10**5], -5.0, 1.0, 2.0, np.infty)
self.validateMoments([int(1e5)], -5.0, 1.0, 2.0, np.infty)
self.validateMoments([int(1e5)],
-5.0,
1.0,
2.0,
np.infty,
use_stateless=True)
# Take the normal distribution around the mean, but truncating the left tail
# far from the mean.
@test_util.run_deprecated_v1
def testTruncateOnLeft_entireTailOnRight(self):
self.validateKolmogorovSmirnov([10**5], 10.0, 1.0, 4.0, np.infty)
self.validateKolmogorovSmirnov([int(1e5)], 10.0, 1.0, 4.0, np.infty)
self.validateKolmogorovSmirnov([int(1e5)],
10.0,
1.0,
4.0,
np.infty,
use_stateless=True)
# Take the normal distribution around the mean, but truncating the right tail.
@test_util.run_deprecated_v1
def testTruncateOnRight_entireTailOnLeft(self):
self.validateKolmogorovSmirnov([10**5], -8, 1.0, -np.infty, -4.0)
self.validateKolmogorovSmirnov([int(1e5)], -8, 1.0, -np.infty, -4.0)
self.validateKolmogorovSmirnov([int(1e5)],
-8.,
1.0,
-np.infty,
-4.0,
use_stateless=True)
@test_util.run_deprecated_v1
def testSmallStddev(self):
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
self.validateKolmogorovSmirnov([int(1e5)], 0.0, 0.1, 0.05, 0.10)
self.validateKolmogorovSmirnov([int(1e5)],
0.0,
0.1,
0.05,
0.10,
use_stateless=True)
@test_util.run_deprecated_v1
def testSamplingWithSmallStdDevFarFromBound(self):
sample_op = random_ops.parameterized_truncated_normal(
shape=(int(1e5),), means=0.8, stddevs=0.05, minvals=-1., maxvals=1.)
new_seed = random_ops.random_uniform([2],
seed=1234,
minval=0,
maxval=(2**31 - 1),
dtype=np.int32)
sample_op_stateless = stateless.stateless_parameterized_truncated_normal(
shape=(int(1e5),),
seed=new_seed,
means=0.8,
stddevs=0.05,
minvals=-1.,
maxvals=1.)
with self.session(use_gpu=True) as sess:
samples = sess.run(sample_op)
samples, samples_stateless = sess.run([sample_op, sample_op_stateless])
# 0. is more than 16 standard deviations from the mean, and
# should have a likelihood < 1e-57.
assert (~np.isnan(samples)).all()
no_neg_samples = np.sum(samples < 0.)
self.assertEqual(no_neg_samples, 0.)
assert (~np.isnan(samples_stateless)).all()
self.assertAllGreater(samples, 0.)
self.assertAllGreater(samples_stateless, 0.)
def testStatelessParameterizedTruncatedNormalHasGrads(self):
mean = variables.Variable(0.01)
stddev = variables.Variable(1.)
minval = variables.Variable(-1.)
maxval = variables.Variable(1.)
with self.cached_session(use_gpu=True) as sess:
with backprop.GradientTape(persistent=True) as tape:
samples = stateless.stateless_parameterized_truncated_normal(
[1], [1, 2], mean, stddev, minval, maxval)
sess.run(variables.variables_initializer([mean, stddev, minval, maxval]))
[mean_grad, std_grad], mean_actual_grad, std_actual_grad = sess.run([
tape.gradient(samples, [mean, stddev]),
array_ops.ones_like(mean),
(samples - mean) / stddev])
self.assertAllClose(mean_grad, mean_actual_grad)
self.assertAllClose(std_grad, std_actual_grad[0])
try:
import scipy.stats # pylint:disable=g-import-not-at-top
truncnorm = scipy.stats.truncnorm(a=-1., b=1., loc=0., scale=1.)
samples_np, [minval_grad, maxval_grad] = sess.run([
samples, tape.gradient(samples, [minval, maxval])])
sample_cdf = truncnorm.cdf(samples_np)
# These come from the implicit reparameterization trick.
scipy_maxval_grad = np.exp(
0.5 * (samples_np ** 2 - ((1. - 0.01) / 1.) ** 2) +
np.log(sample_cdf))
scipy_minval_grad = np.exp(
0.5 * (samples_np ** 2 - ((-1. - 0.01) / 1.) ** 2) +
np.log1p(-sample_cdf))
self.assertAllClose(minval_grad, scipy_minval_grad[0], rtol=1e-2)
self.assertAllClose(maxval_grad, scipy_maxval_grad[0], rtol=1e-2)
except ImportError as e:
tf_logging.warn("Cannot test truncated normal op: %s" % str(e))
@test_util.run_deprecated_v1
def testSamplingAtRandnSwitchover(self):
@ -239,18 +362,33 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
epsilon = 0.001
self.validateMoments(
shape=[10**6],
shape=[int(1e6)],
mean=0.,
stddev=1.0,
minval=-epsilon,
maxval=stddev_inside_bounds_before_using_randn - epsilon)
self.validateMoments(
shape=[10**6],
shape=[int(1e6)],
mean=0.,
stddev=1.0,
minval=-epsilon,
maxval=stddev_inside_bounds_before_using_randn + epsilon)
self.validateMoments(
shape=[int(1e6)],
mean=0.,
stddev=1.0,
minval=-epsilon,
maxval=stddev_inside_bounds_before_using_randn - epsilon,
use_stateless=True)
self.validateMoments(
shape=[int(1e6)],
mean=0.,
stddev=1.0,
minval=-epsilon,
maxval=stddev_inside_bounds_before_using_randn + epsilon,
use_stateless=True)
# Benchmarking code
def parameterized_vs_naive(shape, num_iters, use_gpu=False):

View File

@ -18,9 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import math_ops
@ -114,3 +119,118 @@ def _StatelessRandomGammaV2Grad(op, grad): # pylint: disable=invalid-name
return (None, None,
math_ops.reduce_sum(
grad * partial_a, axis=math_ops.range(num_sample_dimensions)))
def _Ndtr(x):
"""Normal distribution function."""
half_sqrt_2 = constant_op.constant(
0.5 * np.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
w = x * half_sqrt_2
z = math_ops.abs(w)
y = array_ops.where(
z < half_sqrt_2,
1. + math_ops.erf(w),
array_ops.where(
w > 0., 2. - math_ops.erfc(z), math_ops.erfc(z)))
return 0.5 * y
@ops.RegisterGradient("StatelessParameterizedTruncatedNormal")
def _StatelessParameterizedTruncatedNormalGrad(op, grad): # pylint: disable=invalid-name
"""Returns the gradient of a TruncatedNormal sample w.r.t. parameters.
The gradient is computed using implicit differentiation
(Figurnov et al., 2018).
Args:
op: A `StatelessParameterizedTruncatedNormal` operation. We assume that the
inputs to the operation are `shape`, `seed`, `mean`, `stddev`, `minval`,
and `maxval` tensors, and the output is the `sample` tensor.
grad: The incoming gradient `dloss / dsample` of the same shape as
`op.outputs[0]`.
Returns:
A list of `Tensor` with derivates with respect to each parameter.
References:
Implicit Reparameterization Gradients:
[Figurnov et al., 2018]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
([pdf]
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
"""
shape = op.inputs[0]
mean = op.inputs[2]
stddev = op.inputs[3]
minval = op.inputs[4]
maxval = op.inputs[5]
sample = op.outputs[0]
with ops.control_dependencies([grad]):
minval_std = (minval - mean) / stddev
maxval_std = (maxval - mean) / stddev
sample_std = (sample - mean) / stddev
cdf_sample = (_Ndtr(sample_std) - _Ndtr(minval_std)) / (
_Ndtr(maxval_std) - _Ndtr(minval_std))
# Clip to avoid zero argument for log_cdf expression
tiny = np.finfo(mean.dtype.as_numpy_dtype).tiny
eps = np.finfo(mean.dtype.as_numpy_dtype).eps
cdf_sample = clip_ops.clip_by_value(cdf_sample, tiny, 1 - eps)
dmaxval = math_ops.exp(0.5 * (sample_std ** 2 - maxval_std ** 2) +
math_ops.log(cdf_sample))
dminval = math_ops.exp(0.5 * (sample_std ** 2 - minval_std ** 2) +
math_ops.log1p(-cdf_sample))
dmean = array_ops.ones_like(sample_std)
dstddev = sample_std
# Reduce over extra dimensions caused by `shape`. We need to get the
# difference in rank from shape vs. the broadcasted rank.
mean_shape = array_ops.shape(mean)
stddev_shape = array_ops.shape(stddev)
minval_shape = array_ops.shape(minval)
maxval_shape = array_ops.shape(maxval)
broadcast_shape = array_ops.broadcast_dynamic_shape(
mean_shape, stddev_shape)
broadcast_shape = array_ops.broadcast_dynamic_shape(
minval_shape, broadcast_shape)
broadcast_shape = array_ops.broadcast_dynamic_shape(
maxval_shape, broadcast_shape)
extra_dims = math_ops.range(
array_ops.size(shape) - array_ops.size(broadcast_shape))
grad_mean = math_ops.reduce_sum(grad * dmean, axis=extra_dims)
grad_stddev = math_ops.reduce_sum(grad * dstddev, axis=extra_dims)
grad_minval = math_ops.reduce_sum(grad * dminval, axis=extra_dims)
grad_maxval = math_ops.reduce_sum(grad * dmaxval, axis=extra_dims)
_, rmean = gen_array_ops.broadcast_gradient_args(
broadcast_shape, mean_shape)
_, rstddev = gen_array_ops.broadcast_gradient_args(
broadcast_shape, stddev_shape)
_, rminval = gen_array_ops.broadcast_gradient_args(
broadcast_shape, minval_shape)
_, rmaxval = gen_array_ops.broadcast_gradient_args(
broadcast_shape, maxval_shape)
grad_mean = array_ops.reshape(
math_ops.reduce_sum(grad_mean, axis=rmean, keepdims=True), mean_shape)
grad_stddev = array_ops.reshape(
math_ops.reduce_sum(grad_stddev, axis=rstddev, keepdims=True),
stddev_shape)
grad_minval = array_ops.reshape(
math_ops.reduce_sum(grad_minval, axis=rminval, keepdims=True),
minval_shape)
grad_maxval = array_ops.reshape(
math_ops.reduce_sum(grad_maxval, axis=rmaxval, keepdims=True),
maxval_shape)
# The first two inputs are shape.
return (None, None, grad_mean, grad_stddev, grad_minval, grad_maxval)

View File

@ -618,3 +618,73 @@ def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed):
logits = ops.convert_to_tensor(logits, name="logits")
return gen_stateless_random_ops.stateless_multinomial(
logits, num_samples, seed, output_dtype=dtype)
@dispatch.add_dispatch_support
@tf_export("random.stateless_parameterized_truncated_normal")
def stateless_parameterized_truncated_normal(shape,
seed,
means=0.0,
stddevs=1.0,
minvals=-2.0,
maxvals=2.0,
name=None):
"""Outputs random values from a truncated normal distribution.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 standard
deviations from the mean are dropped and re-picked.
Examples:
Sample from a Truncated normal, with deferring shape parameters that
broadcast.
>>> means = 0.
>>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3]))
>>> minvals = [-1., -2., -1000.]
>>> maxvals = [[10000.], [1.]]
>>> y = tf.random.stateless_parameterized_truncated_normal(
... shape=[10, 2, 3], seed=[7, 17],
... means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals)
>>> y.shape
TensorShape([10, 2, 3])
Args:
shape: A 1-D integer `Tensor` or Python array. The shape of the output
tensor.
seed: A shape [2] Tensor, the seed to the random number generator. Must have
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
means: A `Tensor` or Python value of type `dtype`. The mean of the truncated
normal distribution. This must broadcast with `stddevs`, `minvals` and
`maxvals`, and the broadcasted shape must be dominated by `shape`.
stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation
of the truncated normal distribution. This must broadcast with `means`,
`minvals` and `maxvals`, and the broadcasted shape must be dominated by
`shape`.
minvals: A `Tensor` or Python value of type `dtype`. The minimum value of
the truncated normal distribution. This must broadcast with `means`,
`stddevs` and `maxvals`, and the broadcasted shape must be dominated by
`shape`.
maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of
the truncated normal distribution. This must broadcast with `means`,
`stddevs` and `minvals`, and the broadcasted shape must be dominated by
`shape`.
name: A name for the operation (optional).
Returns:
A tensor of the specified shape filled with random truncated normal values.
"""
with ops.name_scope(name, "stateless_parameterized_truncated_normal",
[shape, means, stddevs, minvals, maxvals]) as name:
shape_tensor = tensor_util.shape_tensor(shape)
means_tensor = ops.convert_to_tensor(means, name="means")
stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs")
minvals_tensor = ops.convert_to_tensor(minvals, name="minvals")
maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals")
rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal(
shape_tensor, seed, means_tensor, stddevs_tensor, minvals_tensor,
maxvals_tensor)
tensor_util.maybe_set_static_shape(rnd, shape)
return rnd

View File

@ -92,6 +92,10 @@ tf_module {
name: "stateless_normal"
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "stateless_parameterized_truncated_normal"
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'-2.0\', \'2.0\', \'None\'], "
}
member_method {
name: "stateless_poisson"
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "

View File

@ -4492,6 +4492,10 @@ tf_module {
name: "StatelessMultinomial"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "StatelessParameterizedTruncatedNormal"
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomBinomial"
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "

View File

@ -80,6 +80,10 @@ tf_module {
name: "stateless_normal"
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "stateless_parameterized_truncated_normal"
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'-2.0\', \'2.0\', \'None\'], "
}
member_method {
name: "stateless_poisson"
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "

View File

@ -4492,6 +4492,10 @@ tf_module {
name: "StatelessMultinomial"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "StatelessParameterizedTruncatedNormal"
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomBinomial"
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "