Add StatelessParameterizedTruncatedNormal sampler.
This sampler supports broadcasting of its input parameters as well as puts the # samples at the left of the output shape, rather than the right. PiperOrigin-RevId: 317129622 Change-Id: I4b62ad2e89a9637ae8b30b73af4b662ad0caa943
This commit is contained in:
parent
89b80c5fb9
commit
455750f362
@ -0,0 +1,54 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatelessParameterizedTruncatedNormal"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "shape"
|
||||||
|
description: <<END
|
||||||
|
The shape of the output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "seed"
|
||||||
|
description: <<END
|
||||||
|
2 seeds (shape [2]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "means"
|
||||||
|
description: <<END
|
||||||
|
The mean parameter of each batch.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "stddevs"
|
||||||
|
description: <<END
|
||||||
|
The standard deviation parameter of each batch. Must be greater than 0.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "minvals"
|
||||||
|
description: <<END
|
||||||
|
The minimum cutoff. May be -infinity.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "maxvals"
|
||||||
|
description: <<END
|
||||||
|
The maximum cutoff. May be +infinity, and must be more than the minval
|
||||||
|
for each batch.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The type of the output.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
The outputs are truncated normal samples and are a deterministic function of
|
||||||
|
`shape`, `seed`, `minvals`, `maxvals`, `means` and `stddevs`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
}
|
@ -6156,6 +6156,7 @@ tf_kernel_library(
|
|||||||
]),
|
]),
|
||||||
prefix = "parameterized_truncated_normal_op",
|
prefix = "parameterized_truncated_normal_op",
|
||||||
deps = [
|
deps = [
|
||||||
|
":stateless_random_ops",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/stateless_random_ops.h"
|
||||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/guarded_philox_random.h"
|
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||||
@ -67,10 +68,10 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3);
|
const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3);
|
||||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||||
|
|
||||||
auto DoWork = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
|
auto do_work = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
|
||||||
&minvals, &maxvals, &gen, &output,
|
&minvals, &maxvals, &gen, &output,
|
||||||
kStdDevsInsideBoundsToUseRandnSampler](int start_batch,
|
kStdDevsInsideBoundsToUseRandnSampler](int start_batch,
|
||||||
int limit_batch) {
|
int limit_batch) {
|
||||||
// Capturing "gen" by-value would only make a copy for the _shared_
|
// Capturing "gen" by-value would only make a copy for the _shared_
|
||||||
// lambda. Since we want to let each worker have its own copy, we pass
|
// lambda. Since we want to let each worker have its own copy, we pass
|
||||||
// "gen" by reference and explicitly do a copy assignment here.
|
// "gen" by reference and explicitly do a copy assignment here.
|
||||||
@ -80,9 +81,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
// The sample from each iteration uses 2 random numbers.
|
// The sample from each iteration uses 2 random numbers.
|
||||||
gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) /
|
gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) /
|
||||||
4);
|
4);
|
||||||
typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
|
using Uniform = random::UniformDistribution<random::PhiloxRandom, T>;
|
||||||
Uniform dist;
|
Uniform dist;
|
||||||
typedef random::NormalDistribution<random::PhiloxRandom, T> Normal;
|
using Normal = random::NormalDistribution<random::PhiloxRandom, T>;
|
||||||
Normal normal_dist;
|
Normal normal_dist;
|
||||||
|
|
||||||
// Vectorized intermediate calculations for uniform rejection sampling.
|
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||||
@ -112,7 +113,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
Eigen::numext::isfinite(maxval)),
|
Eigen::numext::isfinite(maxval)),
|
||||||
errors::InvalidArgument("Invalid parameters"));
|
errors::InvalidArgument("Invalid parameters"));
|
||||||
|
|
||||||
int numIterations = 0;
|
int num_iterations = 0;
|
||||||
|
|
||||||
// If possible, make one-sided bound be the lower bound, or make both
|
// If possible, make one-sided bound be the lower bound, or make both
|
||||||
// bounds positive. Otherwise, the bounds are on either side of the
|
// bounds positive. Otherwise, the bounds are on either side of the
|
||||||
@ -160,10 +161,10 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
if (sample >= limit_sample) {
|
if (sample >= limit_sample) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
numIterations = 0;
|
num_iterations = 0;
|
||||||
} else {
|
} else {
|
||||||
numIterations++;
|
num_iterations++;
|
||||||
if (numIterations > kMaxIterations) {
|
if (num_iterations > kMaxIterations) {
|
||||||
// This should never occur because this sampler should
|
// This should never occur because this sampler should
|
||||||
// (by the selection criteria above) be used if at least 3
|
// (by the selection criteria above) be used if at least 3
|
||||||
// standard deviations of one side of the distribution
|
// standard deviations of one side of the distribution
|
||||||
@ -201,7 +202,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
const auto u = dist(&gen_copy);
|
const auto u = dist(&gen_copy);
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
auto accept = u[i] <= Eigen::numext::exp(g[i]);
|
auto accept = u[i] <= Eigen::numext::exp(g[i]);
|
||||||
if (accept || numIterations + 1 >= kMaxIterations) {
|
if (accept || num_iterations + 1 >= kMaxIterations) {
|
||||||
// Accept the sample z.
|
// Accept the sample z.
|
||||||
// If we run out of iterations, just use the current uniform
|
// If we run out of iterations, just use the current uniform
|
||||||
// sample, but emit a warning.
|
// sample, but emit a warning.
|
||||||
@ -223,9 +224,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
if (sample >= limit_sample) {
|
if (sample >= limit_sample) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
numIterations = 0;
|
num_iterations = 0;
|
||||||
} else {
|
} else {
|
||||||
numIterations++;
|
num_iterations++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -248,7 +249,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
const T u = rand[i];
|
const T u = rand[i];
|
||||||
i++;
|
i++;
|
||||||
auto accept = (u <= g && z < normMax);
|
auto accept = (u <= g && z < normMax);
|
||||||
if (accept || numIterations + 1 >= kMaxIterations) {
|
if (accept || num_iterations + 1 >= kMaxIterations) {
|
||||||
if (!accept) {
|
if (!accept) {
|
||||||
LOG(ERROR) << "TruncatedNormal exponential distribution "
|
LOG(ERROR) << "TruncatedNormal exponential distribution "
|
||||||
<< "rejection sampler exceeds max iterations. "
|
<< "rejection sampler exceeds max iterations. "
|
||||||
@ -263,9 +264,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
if (sample >= limit_sample) {
|
if (sample >= limit_sample) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
numIterations = 0;
|
num_iterations = 0;
|
||||||
} else {
|
} else {
|
||||||
numIterations++;
|
num_iterations++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -305,7 +306,297 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
const int64 batchCost =
|
const int64 batchCost =
|
||||||
batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch;
|
batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch;
|
||||||
Shard(worker_threads.num_threads, worker_threads.workers, num_batches,
|
Shard(worker_threads.num_threads, worker_threads.workers, num_batches,
|
||||||
batchCost, DoWork);
|
batchCost, do_work);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct TruncatedNormalFunctorV2<CPUDevice, T> {
|
||||||
|
void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
|
||||||
|
int64 samples_per_batch, int64 num_elements,
|
||||||
|
const BCastList<4>& bcast,
|
||||||
|
typename TTypes<T>::ConstFlat means,
|
||||||
|
typename TTypes<T>::ConstFlat stddevs,
|
||||||
|
typename TTypes<T>::ConstFlat minvals,
|
||||||
|
typename TTypes<T>::ConstFlat maxvals,
|
||||||
|
const random::PhiloxRandom& gen,
|
||||||
|
typename TTypes<T>::Flat output) {
|
||||||
|
// The randn rejection sampling is used when the mean and at least this many
|
||||||
|
// standard deviations are inside the bounds.
|
||||||
|
// The uniform proposal samplers become less efficient as the bounds are
|
||||||
|
// further from the mean, the reverse is true for the randn sampler.
|
||||||
|
// This number was chosen by empirical benchmarking. If modified, the
|
||||||
|
// benchmarks in parameterized_truncated_normal_op_test should also be
|
||||||
|
// changed.
|
||||||
|
const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3);
|
||||||
|
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||||
|
|
||||||
|
auto do_work = [num_batches, samples_per_batch, &ctx, &bcast, &means,
|
||||||
|
&stddevs, &minvals, &maxvals, &gen, &output,
|
||||||
|
kStdDevsInsideBoundsToUseRandnSampler](int start_output,
|
||||||
|
int limit_output) {
|
||||||
|
// Capturing "gen" by-value would only make a copy for the _shared_
|
||||||
|
// lambda. Since we want to let each worker have its own copy, we pass
|
||||||
|
// "gen" by reference and explicitly do a copy assignment here.
|
||||||
|
random::PhiloxRandom gen_copy = gen;
|
||||||
|
using Uniform = random::UniformDistribution<random::PhiloxRandom, T>;
|
||||||
|
Uniform dist;
|
||||||
|
using Normal = random::NormalDistribution<random::PhiloxRandom, T>;
|
||||||
|
Normal normal_dist;
|
||||||
|
// Skip takes units of 128 bits. The Uniform::kResultElementCount - 1
|
||||||
|
// is so rounding doesn't lead to
|
||||||
|
// us using the same state in different workloads.
|
||||||
|
// The sample from each iteration uses 2 random numbers.
|
||||||
|
gen_copy.Skip((start_output * 2 * kMaxIterations +
|
||||||
|
Uniform::kResultElementCount - 1) /
|
||||||
|
Uniform::kResultElementCount);
|
||||||
|
|
||||||
|
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||||
|
// We always generate at most 4 samples.
|
||||||
|
Eigen::array<T, Uniform::kResultElementCount> z;
|
||||||
|
Eigen::array<T, Uniform::kResultElementCount> g;
|
||||||
|
|
||||||
|
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||||
|
const auto& means_batch_indices = bcast.batch_indices(0);
|
||||||
|
const auto& stddevs_batch_indices = bcast.batch_indices(1);
|
||||||
|
const auto& minvals_batch_indices = bcast.batch_indices(2);
|
||||||
|
const auto& maxvals_batch_indices = bcast.batch_indices(3);
|
||||||
|
auto output_flat = output.data();
|
||||||
|
|
||||||
|
// We partition work across batches and then across samples
|
||||||
|
// per batch member, to avoid extra work.
|
||||||
|
for (int64 output_idx = start_output; output_idx < limit_output;
|
||||||
|
// output_idx is incremented with the inner loops below.
|
||||||
|
) {
|
||||||
|
int64 batch_idx = output_idx / samples_per_batch;
|
||||||
|
// The output layout is [samples_per_batch, num_batches]. Thus
|
||||||
|
// the output address is sample_idx * num_batches + batch_idx.
|
||||||
|
// Below, code will index at output_batch_offset[sample_idx *
|
||||||
|
// num_batches] matching this.
|
||||||
|
T* const output_batch_offset = output_flat + batch_idx;
|
||||||
|
// Generate batch counts from BCast, as it has the right indices to loop
|
||||||
|
// over.
|
||||||
|
T mean, stddev, minval, maxval;
|
||||||
|
if (should_bcast) {
|
||||||
|
mean = means(means_batch_indices[batch_idx]);
|
||||||
|
stddev = stddevs(stddevs_batch_indices[batch_idx]);
|
||||||
|
minval = minvals(minvals_batch_indices[batch_idx]);
|
||||||
|
maxval = maxvals(maxvals_batch_indices[batch_idx]);
|
||||||
|
} else {
|
||||||
|
mean = means(batch_idx);
|
||||||
|
stddev = stddevs(batch_idx);
|
||||||
|
minval = minvals(batch_idx);
|
||||||
|
maxval = maxvals(batch_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// On GPU, this check will just fill samples with NAN if it fails.
|
||||||
|
OP_REQUIRES(ctx,
|
||||||
|
stddev > T(0) && minval < maxval &&
|
||||||
|
(Eigen::numext::isfinite(minval) ||
|
||||||
|
Eigen::numext::isfinite(maxval)),
|
||||||
|
errors::InvalidArgument("Invalid parameters"));
|
||||||
|
|
||||||
|
int num_iterations = 0;
|
||||||
|
|
||||||
|
// If possible, make one-sided bound be the lower bound, or make both
|
||||||
|
// bounds positive. Otherwise, the bounds are on either side of the
|
||||||
|
// mean.
|
||||||
|
if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) {
|
||||||
|
// Reverse all calculations. normMin and normMax will be flipped.
|
||||||
|
std::swap(minval, maxval);
|
||||||
|
stddev = -stddev;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate normalized samples, then convert them.
|
||||||
|
const T normMin = (minval - mean) / stddev;
|
||||||
|
const T normMax = (maxval - mean) / stddev;
|
||||||
|
|
||||||
|
// Determine the method to use.
|
||||||
|
const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
|
||||||
|
const T cutoff =
|
||||||
|
T(2) *
|
||||||
|
Eigen::numext::exp(T(0.5) +
|
||||||
|
(normMin * (normMin - sqrtFactor)) / T(4)) /
|
||||||
|
(normMin + sqrtFactor);
|
||||||
|
const T diff = normMax - normMin;
|
||||||
|
|
||||||
|
if (((normMin < -kStdDevsInsideBoundsToUseRandnSampler) &&
|
||||||
|
(normMax >= T(0.))) ||
|
||||||
|
((normMax > kStdDevsInsideBoundsToUseRandnSampler) &&
|
||||||
|
(normMin <= T(0.)))) {
|
||||||
|
// If the bounds are a least 3 standard deviations from the mean
|
||||||
|
// on at least one side then we rejection sample by sampling
|
||||||
|
// from the normal distribution and rejecting samples outside
|
||||||
|
// the bounds.
|
||||||
|
// Under this condition the acceptance rate per iteration should
|
||||||
|
// always be ~ 50%. This sampler is more efficient (and more
|
||||||
|
// numerically stable when one or both bounds is far from the mean).
|
||||||
|
for (int64 sample_idx = output_idx % samples_per_batch;
|
||||||
|
sample_idx < samples_per_batch && output_idx < limit_output;) {
|
||||||
|
const auto randn_sample = normal_dist(&gen_copy);
|
||||||
|
const int size = randn_sample.size();
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
if ((randn_sample[i] >= normMin) &&
|
||||||
|
(randn_sample[i] <= normMax)) {
|
||||||
|
output_batch_offset[sample_idx * num_batches] =
|
||||||
|
randn_sample[i] * stddev + mean;
|
||||||
|
++sample_idx;
|
||||||
|
++output_idx;
|
||||||
|
if (sample_idx >= samples_per_batch ||
|
||||||
|
output_idx >= limit_output) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
num_iterations = 0;
|
||||||
|
} else {
|
||||||
|
++num_iterations;
|
||||||
|
if (num_iterations > kMaxIterations) {
|
||||||
|
// This should never occur because this sampler should
|
||||||
|
// (by the selection criteria above) be used if at least 3
|
||||||
|
// standard deviations of one side of the distribution
|
||||||
|
// is within the limits (so acceptance probability per
|
||||||
|
// iterations >~ 1/2 per iteration).
|
||||||
|
LOG(ERROR) << "TruncatedNormal randn rejection sampler "
|
||||||
|
<< "exceeded maximum iterations for "
|
||||||
|
<< "normMin=" << normMin << " normMax=" << normMax
|
||||||
|
<< " kMaxIterations=" << kMaxIterations;
|
||||||
|
ctx->SetStatus(errors::Internal(
|
||||||
|
"TruncatedNormal randn rejection sampler failed to accept"
|
||||||
|
" a sample."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (diff < cutoff) {
|
||||||
|
// Sample from a uniform distribution on [normMin, normMax].
|
||||||
|
|
||||||
|
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
|
||||||
|
|
||||||
|
for (int64 sample_idx = output_idx % samples_per_batch;
|
||||||
|
sample_idx < samples_per_batch && output_idx < limit_output;) {
|
||||||
|
const auto rand = dist(&gen_copy);
|
||||||
|
const int size = rand.size();
|
||||||
|
// NOTE(ringwalt): These loops seem to only generate packed AVX
|
||||||
|
// instructions for float32.
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
z[i] = rand[i] * diff + normMin;
|
||||||
|
g[i] = (plusFactor - z[i] * z[i]) / T(2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto u = dist(&gen_copy);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
auto accept = u[i] <= Eigen::numext::exp(g[i]);
|
||||||
|
if (accept || num_iterations + 1 >= kMaxIterations) {
|
||||||
|
// Accept the sample z.
|
||||||
|
// If we run out of iterations, just use the current uniform
|
||||||
|
// sample, but emit a warning.
|
||||||
|
// TODO(jjhunt) For small entropies (relative to the bounds),
|
||||||
|
// this sampler is poor and may take many iterations since
|
||||||
|
// the proposal distribution is the uniform distribution
|
||||||
|
// U(lower_bound, upper_bound).
|
||||||
|
if (!accept) {
|
||||||
|
LOG(ERROR) << "TruncatedNormal uniform rejection sampler "
|
||||||
|
<< "exceeded max iterations. Sample may contain "
|
||||||
|
<< "outliers.";
|
||||||
|
ctx->SetStatus(errors::Internal(
|
||||||
|
"TruncatedNormal uniform rejection sampler failed to "
|
||||||
|
" accept a sample."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
output_batch_offset[sample_idx * num_batches] =
|
||||||
|
z[i] * stddev + mean;
|
||||||
|
++sample_idx;
|
||||||
|
++output_idx;
|
||||||
|
if (sample_idx >= samples_per_batch ||
|
||||||
|
output_idx >= limit_output) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
num_iterations = 0;
|
||||||
|
} else {
|
||||||
|
num_iterations++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Sample from an exponential distribution with alpha maximizing
|
||||||
|
// acceptance probability, offset by normMin from the origin.
|
||||||
|
// Accept only if less than normMax.
|
||||||
|
const T alpha =
|
||||||
|
(normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) /
|
||||||
|
T(2);
|
||||||
|
for (int64 sample_idx = output_idx % samples_per_batch;
|
||||||
|
sample_idx < samples_per_batch && output_idx < limit_output;) {
|
||||||
|
auto rand = dist(&gen_copy);
|
||||||
|
const int size = rand.size();
|
||||||
|
int i = 0;
|
||||||
|
while (i < size) {
|
||||||
|
const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
|
||||||
|
i++;
|
||||||
|
const T x = normMin < alpha ? alpha - z : normMin - alpha;
|
||||||
|
const T g = Eigen::numext::exp(-x * x / T(2.0));
|
||||||
|
const T u = rand[i];
|
||||||
|
i++;
|
||||||
|
auto accept = (u <= g && z < normMax);
|
||||||
|
if (accept || num_iterations + 1 >= kMaxIterations) {
|
||||||
|
if (!accept) {
|
||||||
|
LOG(ERROR) << "TruncatedNormal exponential distribution "
|
||||||
|
<< "rejection sampler exceeds max iterations. "
|
||||||
|
<< "Sample may contain outliers.";
|
||||||
|
ctx->SetStatus(errors::Internal(
|
||||||
|
"TruncatedNormal exponential distribution rejection"
|
||||||
|
" sampler failed to accept a sample."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
output_batch_offset[sample_idx * num_batches] =
|
||||||
|
z * stddev + mean;
|
||||||
|
++sample_idx;
|
||||||
|
++output_idx;
|
||||||
|
if (sample_idx >= samples_per_batch ||
|
||||||
|
output_idx >= limit_output) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
num_iterations = 0;
|
||||||
|
} else {
|
||||||
|
num_iterations++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// The cost of the initial calculations for the batch.
|
||||||
|
const int64 batchInitCost =
|
||||||
|
// normMin, normMax
|
||||||
|
(Eigen::TensorOpCost::AddCost<T>() +
|
||||||
|
Eigen::TensorOpCost::MulCost<T>()) *
|
||||||
|
2
|
||||||
|
// sqrtFactor
|
||||||
|
+ Eigen::TensorOpCost::AddCost<T>() +
|
||||||
|
Eigen::TensorOpCost::MulCost<T>() +
|
||||||
|
Eigen::internal::functor_traits<
|
||||||
|
Eigen::internal::scalar_sqrt_op<T>>::Cost
|
||||||
|
// cutoff
|
||||||
|
+ Eigen::TensorOpCost::MulCost<T>() * 4 +
|
||||||
|
Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost
|
||||||
|
// diff
|
||||||
|
+ Eigen::TensorOpCost::AddCost<T>();
|
||||||
|
const int64 uniformSampleCost =
|
||||||
|
random::PhiloxRandom::kElementCost +
|
||||||
|
random::UniformDistribution<random::PhiloxRandom, T>::kElementCost;
|
||||||
|
// The cost of a single uniform sampling round.
|
||||||
|
const int64 uniformRejectionSamplingCost =
|
||||||
|
uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() +
|
||||||
|
Eigen::TensorOpCost::AddCost<T>() +
|
||||||
|
Eigen::TensorOpCost::MulCost<T>() * 2 +
|
||||||
|
Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost +
|
||||||
|
Eigen::internal::functor_traits<
|
||||||
|
Eigen::internal::scalar_exp_op<T>>::Cost +
|
||||||
|
Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>();
|
||||||
|
// Estimate the cost for an entire batch.
|
||||||
|
// Assume we use uniform sampling, and accept the 2nd sample on average.
|
||||||
|
const int64 batchCost = batchInitCost + uniformRejectionSamplingCost * 2;
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers, num_elements,
|
||||||
|
batchCost, do_work);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -436,13 +727,113 @@ class ParameterizedTruncatedNormalOp : public OpKernel {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Samples from a truncated normal distribution, using the given parameters.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class StatelessParameterizedTruncatedNormal : public OpKernel {
|
||||||
|
// Reshape batches so each batch is this size if possible.
|
||||||
|
static const int32 kDesiredBatchSize = 100;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit StatelessParameterizedTruncatedNormal(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const Tensor& shape_tensor = ctx->input(0);
|
||||||
|
const Tensor& seed_tensor = ctx->input(1);
|
||||||
|
const Tensor& means_tensor = ctx->input(2);
|
||||||
|
const Tensor& stddevs_tensor = ctx->input(3);
|
||||||
|
const Tensor& minvals_tensor = ctx->input(4);
|
||||||
|
const Tensor& maxvals_tensor = ctx->input(5);
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, seed_tensor.dims() == 1 && seed_tensor.dim_size(0) == 2,
|
||||||
|
errors::InvalidArgument("seed must have shape [2], not ",
|
||||||
|
seed_tensor.shape().DebugString()));
|
||||||
|
|
||||||
|
tensorflow::BCastList<4> bcast(
|
||||||
|
{means_tensor.shape().dim_sizes(), stddevs_tensor.shape().dim_sizes(),
|
||||||
|
minvals_tensor.shape().dim_sizes(),
|
||||||
|
maxvals_tensor.shape().dim_sizes()},
|
||||||
|
/*fewer_dims_optimization=*/false,
|
||||||
|
/*return_flattened_batch_indices=*/true);
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, bcast.IsValid(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"means, stddevs, minvals, maxvals must have compatible "
|
||||||
|
"batch dimensions: ",
|
||||||
|
means_tensor.shape().DebugString(), " vs. ",
|
||||||
|
stddevs_tensor.shape().DebugString(), " vs. ",
|
||||||
|
minvals_tensor.shape().DebugString(), " vs. ",
|
||||||
|
maxvals_tensor.shape().DebugString()));
|
||||||
|
|
||||||
|
// Let's check that the shape tensor dominates the broadcasted tensor.
|
||||||
|
TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
|
||||||
|
errors::InvalidArgument("Input shape should be a vector, got shape: ",
|
||||||
|
shape_tensor.shape().DebugString()));
|
||||||
|
TensorShape output_shape;
|
||||||
|
if (shape_tensor.dtype() == DataType::DT_INT32) {
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
|
||||||
|
&output_shape));
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int64>(),
|
||||||
|
&output_shape));
|
||||||
|
}
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Shape passed in must end with broadcasted shape."));
|
||||||
|
|
||||||
|
int64 samples_per_batch = 1;
|
||||||
|
const int64 num_sample_dims =
|
||||||
|
(shape_tensor.dim_size(0) - bcast.output_shape().size());
|
||||||
|
for (int64 i = 0; i < num_sample_dims; ++i) {
|
||||||
|
samples_per_batch *= output_shape.dim_size(i);
|
||||||
|
}
|
||||||
|
int64 num_batches = 1;
|
||||||
|
for (int64 i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
|
||||||
|
num_batches *= output_shape.dim_size(i);
|
||||||
|
}
|
||||||
|
const int64 num_elements = num_batches * samples_per_batch;
|
||||||
|
|
||||||
|
Tensor* samples_tensor;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
|
||||||
|
|
||||||
|
auto truncFunctor = functor::TruncatedNormalFunctorV2<Device, T>();
|
||||||
|
// Each worker has the same fudge factor, so use it here.
|
||||||
|
random::PhiloxRandom::Key key;
|
||||||
|
random::PhiloxRandom::ResultType counter;
|
||||||
|
OP_REQUIRES_OK(ctx, GenerateKey(seed_tensor, &key, &counter));
|
||||||
|
|
||||||
|
auto philox = random::PhiloxRandom(counter, key);
|
||||||
|
|
||||||
|
truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches,
|
||||||
|
samples_per_batch, num_elements, bcast, means_tensor.flat<T>(),
|
||||||
|
stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(),
|
||||||
|
maxvals_tensor.flat<T>(), philox, samples_tensor->flat<T>());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(StatelessParameterizedTruncatedNormal);
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#define REGISTER(TYPE) \
|
#define REGISTER(TYPE) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
|
REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<TYPE>("dtype"), \
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
ParameterizedTruncatedNormalOp<CPUDevice, TYPE>)
|
ParameterizedTruncatedNormalOp<CPUDevice, TYPE>) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("StatelessParameterizedTruncatedNormal") \
|
||||||
|
.HostMemory("shape") \
|
||||||
|
.HostMemory("seed") \
|
||||||
|
.HostMemory("means") \
|
||||||
|
.HostMemory("stddevs") \
|
||||||
|
.HostMemory("minvals") \
|
||||||
|
.HostMemory("maxvals") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
StatelessParameterizedTruncatedNormal<CPUDevice, TYPE>)
|
||||||
|
|
||||||
TF_CALL_half(REGISTER);
|
TF_CALL_half(REGISTER);
|
||||||
TF_CALL_float(REGISTER);
|
TF_CALL_float(REGISTER);
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
|
#include "tensorflow/core/util/bcast.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -44,6 +45,21 @@ struct TruncatedNormalFunctor {
|
|||||||
typename TTypes<T>::Flat output);
|
typename TTypes<T>::Flat output);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// This version supports broadcasting of the arguments, as well as puts
|
||||||
|
// the sample dimension on the left.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct TruncatedNormalFunctorV2 {
|
||||||
|
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
|
||||||
|
int64 samples_per_batch, int64 num_elements,
|
||||||
|
const BCastList<4>& bcast,
|
||||||
|
typename TTypes<T>::ConstFlat means,
|
||||||
|
typename TTypes<T>::ConstFlat stddevs,
|
||||||
|
typename TTypes<T>::ConstFlat minvals,
|
||||||
|
typename TTypes<T>::ConstFlat maxvals,
|
||||||
|
const random::PhiloxRandom& gen,
|
||||||
|
typename TTypes<T>::Flat output);
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -124,6 +124,41 @@ REGISTER_OP("StatelessRandomBinomial")
|
|||||||
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
|
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
|
||||||
.SetShapeFn(StatelessShape);
|
.SetShapeFn(StatelessShape);
|
||||||
|
|
||||||
|
REGISTER_OP("StatelessParameterizedTruncatedNormal")
|
||||||
|
.Input("shape: S")
|
||||||
|
.Input("seed: Tseed")
|
||||||
|
.Input("means: dtype")
|
||||||
|
.Input("stddevs: dtype")
|
||||||
|
.Input("minvals: dtype")
|
||||||
|
.Input("maxvals: dtype")
|
||||||
|
.Output("output: dtype")
|
||||||
|
.Attr("S: {int32, int64}")
|
||||||
|
.Attr("Tseed: {int32, int64} = DT_INT64")
|
||||||
|
.Attr("dtype: {float16, float32, float64}")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
// Check seed shape
|
||||||
|
ShapeHandle seed;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
|
||||||
|
DimensionHandle unused_dim;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused_dim));
|
||||||
|
|
||||||
|
ShapeHandle bcast_means_stddevs;
|
||||||
|
ShapeHandle bcast_except_maxvals;
|
||||||
|
ShapeHandle bcast_all;
|
||||||
|
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||||
|
c, c->input(2), c->input(3), true, &bcast_means_stddevs));
|
||||||
|
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||||
|
c, c->input(4), bcast_means_stddevs, true, &bcast_except_maxvals));
|
||||||
|
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||||
|
c, c->input(5), bcast_except_maxvals, true, &bcast_all));
|
||||||
|
|
||||||
|
// Set output shape
|
||||||
|
ShapeHandle out;
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
||||||
|
c->set_output(0, out);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("StatelessRandomPoisson")
|
REGISTER_OP("StatelessRandomPoisson")
|
||||||
.Input("shape: T")
|
.Input("shape: T")
|
||||||
.Input("seed: Tseed")
|
.Input("seed: Tseed")
|
||||||
|
@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
|
|||||||
|
|
||||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||||
const tensorflow::string &op_name) {
|
const tensorflow::string &op_name) {
|
||||||
static std::array<OpIndexInfo, 348> a = {{
|
static std::array<OpIndexInfo, 349> a = {{
|
||||||
{"Acosh"},
|
{"Acosh"},
|
||||||
{"AllToAll", 1, {0}},
|
{"AllToAll", 1, {0}},
|
||||||
{"ApproximateEqual"},
|
{"ApproximateEqual"},
|
||||||
@ -326,6 +326,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
|||||||
{"StackPop"},
|
{"StackPop"},
|
||||||
{"StackPush"},
|
{"StackPush"},
|
||||||
{"StatelessMultinomial"},
|
{"StatelessMultinomial"},
|
||||||
|
{"StatelessParameterizedTruncatedNormal", 1, {1}},
|
||||||
{"StatelessRandomBinomial"},
|
{"StatelessRandomBinomial"},
|
||||||
{"StatelessRandomGammaV2", 1, {1}},
|
{"StatelessRandomGammaV2", 1, {1}},
|
||||||
{"StatelessRandomNormal"},
|
{"StatelessRandomNormal"},
|
||||||
|
@ -785,24 +785,6 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
|
||||||
name = "parameterized_truncated_normal_op_test",
|
|
||||||
size = "medium",
|
|
||||||
srcs = ["parameterized_truncated_normal_op_test.py"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:protos_all_py",
|
|
||||||
"//tensorflow/python:client",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:control_flow_ops",
|
|
||||||
"//tensorflow/python:framework",
|
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
|
||||||
"//tensorflow/python:platform",
|
|
||||||
"//tensorflow/python:random_ops",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
"@absl_py//absl/testing:parameterized",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "parsing_ops_test",
|
name = "parsing_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
@ -20,6 +20,24 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "parameterized_truncated_normal_op_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["parameterized_truncated_normal_op_test.py"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/python:client",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:control_flow_ops",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python:random_ops",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "random_shuffle_queue_test",
|
name = "random_shuffle_queue_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -27,11 +27,15 @@ from six.moves import range # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import stateless_random_ops as stateless
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
@ -91,13 +95,8 @@ class TruncatedNormalMoments(object):
|
|||||||
|
|
||||||
def calculate_moments(samples, max_moment):
|
def calculate_moments(samples, max_moment):
|
||||||
moments = [0.0] * (max_moment + 1)
|
moments = [0.0] * (max_moment + 1)
|
||||||
for sample in samples:
|
for k in range(len(moments)):
|
||||||
value = 1.0
|
moments[k] = np.mean(samples**k, axis=0)
|
||||||
for k in range(len(moments)):
|
|
||||||
moments[k] += value
|
|
||||||
value *= sample
|
|
||||||
for i in range(len(moments)):
|
|
||||||
moments[i] /= len(samples)
|
|
||||||
return moments
|
return moments
|
||||||
|
|
||||||
|
|
||||||
@ -118,16 +117,31 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
|
|||||||
# Stop at moment 10 to avoid numerical errors in the theoretical moments.
|
# Stop at moment 10 to avoid numerical errors in the theoretical moments.
|
||||||
max_moment = 10
|
max_moment = 10
|
||||||
|
|
||||||
def validateMoments(self, shape, mean, stddev, minval, maxval, seed=1618):
|
def validateMoments(self,
|
||||||
|
shape,
|
||||||
|
mean,
|
||||||
|
stddev,
|
||||||
|
minval,
|
||||||
|
maxval,
|
||||||
|
use_stateless=False,
|
||||||
|
seed=1618):
|
||||||
try:
|
try:
|
||||||
# TruncatedNormalMoments requires scipy.stats.
|
# TruncatedNormalMoments requires scipy.stats.
|
||||||
# Give up early if we are unable to import it.
|
# Give up early if we are unable to import it.
|
||||||
import scipy.stats # pylint: disable=g-import-not-at-top,unused-variable
|
|
||||||
random_seed.set_random_seed(seed)
|
random_seed.set_random_seed(seed)
|
||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
|
if use_stateless:
|
||||||
minval,
|
# Generate a seed that stateless ops can use.
|
||||||
maxval).eval()
|
new_seed = random_ops.random_uniform([2],
|
||||||
|
seed=seed,
|
||||||
|
minval=0,
|
||||||
|
maxval=(2**31 - 1),
|
||||||
|
dtype=np.int32)
|
||||||
|
samples = stateless.stateless_parameterized_truncated_normal(
|
||||||
|
shape, new_seed, mean, stddev, minval, maxval).eval()
|
||||||
|
else:
|
||||||
|
samples = random_ops.parameterized_truncated_normal(
|
||||||
|
shape, mean, stddev, minval, maxval).eval()
|
||||||
assert (~np.isnan(samples)).all()
|
assert (~np.isnan(samples)).all()
|
||||||
moments = calculate_moments(samples, self.max_moment)
|
moments = calculate_moments(samples, self.max_moment)
|
||||||
expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval)
|
expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval)
|
||||||
@ -144,14 +158,24 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
|
|||||||
stddev,
|
stddev,
|
||||||
minval,
|
minval,
|
||||||
maxval,
|
maxval,
|
||||||
|
use_stateless=False,
|
||||||
seed=1618):
|
seed=1618):
|
||||||
try:
|
try:
|
||||||
import scipy.stats # pylint: disable=g-import-not-at-top
|
import scipy.stats # pylint: disable=g-import-not-at-top
|
||||||
random_seed.set_random_seed(seed)
|
random_seed.set_random_seed(seed)
|
||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
|
if use_stateless:
|
||||||
minval,
|
new_seed = random_ops.random_uniform([2],
|
||||||
maxval).eval()
|
seed=seed,
|
||||||
|
minval=0,
|
||||||
|
maxval=(2**31 - 1),
|
||||||
|
dtype=np.int32)
|
||||||
|
samples = stateless.stateless_parameterized_truncated_normal(
|
||||||
|
shape, new_seed, mean, stddev, minval, maxval).eval()
|
||||||
|
else:
|
||||||
|
samples = random_ops.parameterized_truncated_normal(
|
||||||
|
shape, mean, stddev, minval, maxval).eval()
|
||||||
|
|
||||||
assert (~np.isnan(samples)).all()
|
assert (~np.isnan(samples)).all()
|
||||||
minval = max(mean - stddev * 10, minval)
|
minval = max(mean - stddev * 10, minval)
|
||||||
maxval = min(mean + stddev * 10, maxval)
|
maxval = min(mean + stddev * 10, maxval)
|
||||||
@ -169,61 +193,160 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testDefaults(self):
|
def testDefaults(self):
|
||||||
self.validateMoments([10**5], 0.0, 1.0, -2.0, 2.0)
|
self.validateMoments([int(1e5)], 0.0, 1.0, -2.0, 2.0)
|
||||||
|
self.validateMoments([int(1e5)], 0.0, 1.0, -2.0, 2.0, use_stateless=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testShifted(self):
|
def testShifted(self):
|
||||||
self.validateMoments([10**5], -1.0, 1.0, -2.0, 2.0)
|
self.validateMoments([int(1e5)], -1.0, 1.0, -2.0, 2.0)
|
||||||
|
self.validateMoments([int(1e5)], -1.0, 1.0, -2.0, 2.0, use_stateless=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testRightTail(self):
|
def testRightTail(self):
|
||||||
self.validateMoments([10**5], 0.0, 1.0, 4.0, np.infty)
|
self.validateMoments([int(1e5)], 0.0, 1.0, 4.0, np.infty)
|
||||||
|
self.validateMoments([int(1e5)],
|
||||||
|
0.0,
|
||||||
|
1.0,
|
||||||
|
4.0,
|
||||||
|
np.infty,
|
||||||
|
use_stateless=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testLeftTail(self):
|
def testLeftTail(self):
|
||||||
self.validateMoments([10**5], 0.0, 1.0, -np.infty, -4.0)
|
self.validateMoments([int(1e5)], 0.0, 1.0, -np.infty, -4.0)
|
||||||
|
self.validateMoments([int(1e5)],
|
||||||
|
0.0,
|
||||||
|
1.0,
|
||||||
|
-np.infty,
|
||||||
|
-4.0,
|
||||||
|
use_stateless=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testLeftTailTwoSidedBounds(self):
|
def testLeftTailTwoSidedBounds(self):
|
||||||
self.validateMoments([10**5], 0.0, 1.0, -6.0, -3.0)
|
self.validateMoments([int(1e5)], 0.0, 1.0, -6.0, -3.0)
|
||||||
|
self.validateMoments([int(1e5)], 0.0, 1.0, -6.0, -3.0, use_stateless=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla("Low probability region")
|
@test_util.disable_xla("Low probability region")
|
||||||
def testTwoSidedLeftTailShifted(self):
|
def testTwoSidedLeftTailShifted(self):
|
||||||
self.validateKolmogorovSmirnov([10**5], 6.0, 1.0, -1.0, 1.0)
|
self.validateKolmogorovSmirnov([int(1e5)], 6.0, 1.0, -1.0, 1.0)
|
||||||
|
self.validateKolmogorovSmirnov([int(1e5)],
|
||||||
|
6.0,
|
||||||
|
1.0,
|
||||||
|
-1.0,
|
||||||
|
1.0,
|
||||||
|
use_stateless=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla("Low probability region")
|
@test_util.disable_xla("Low probability region")
|
||||||
def testRightTailShifted(self):
|
def testRightTailShifted(self):
|
||||||
self.validateMoments([10**5], -5.0, 1.0, 2.0, np.infty)
|
self.validateMoments([int(1e5)], -5.0, 1.0, 2.0, np.infty)
|
||||||
|
self.validateMoments([int(1e5)],
|
||||||
|
-5.0,
|
||||||
|
1.0,
|
||||||
|
2.0,
|
||||||
|
np.infty,
|
||||||
|
use_stateless=True)
|
||||||
|
|
||||||
# Take the normal distribution around the mean, but truncating the left tail
|
# Take the normal distribution around the mean, but truncating the left tail
|
||||||
# far from the mean.
|
# far from the mean.
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testTruncateOnLeft_entireTailOnRight(self):
|
def testTruncateOnLeft_entireTailOnRight(self):
|
||||||
self.validateKolmogorovSmirnov([10**5], 10.0, 1.0, 4.0, np.infty)
|
self.validateKolmogorovSmirnov([int(1e5)], 10.0, 1.0, 4.0, np.infty)
|
||||||
|
self.validateKolmogorovSmirnov([int(1e5)],
|
||||||
|
10.0,
|
||||||
|
1.0,
|
||||||
|
4.0,
|
||||||
|
np.infty,
|
||||||
|
use_stateless=True)
|
||||||
|
|
||||||
# Take the normal distribution around the mean, but truncating the right tail.
|
# Take the normal distribution around the mean, but truncating the right tail.
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testTruncateOnRight_entireTailOnLeft(self):
|
def testTruncateOnRight_entireTailOnLeft(self):
|
||||||
self.validateKolmogorovSmirnov([10**5], -8, 1.0, -np.infty, -4.0)
|
self.validateKolmogorovSmirnov([int(1e5)], -8, 1.0, -np.infty, -4.0)
|
||||||
|
self.validateKolmogorovSmirnov([int(1e5)],
|
||||||
|
-8.,
|
||||||
|
1.0,
|
||||||
|
-np.infty,
|
||||||
|
-4.0,
|
||||||
|
use_stateless=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testSmallStddev(self):
|
def testSmallStddev(self):
|
||||||
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
|
self.validateKolmogorovSmirnov([int(1e5)], 0.0, 0.1, 0.05, 0.10)
|
||||||
|
self.validateKolmogorovSmirnov([int(1e5)],
|
||||||
|
0.0,
|
||||||
|
0.1,
|
||||||
|
0.05,
|
||||||
|
0.10,
|
||||||
|
use_stateless=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testSamplingWithSmallStdDevFarFromBound(self):
|
def testSamplingWithSmallStdDevFarFromBound(self):
|
||||||
sample_op = random_ops.parameterized_truncated_normal(
|
sample_op = random_ops.parameterized_truncated_normal(
|
||||||
shape=(int(1e5),), means=0.8, stddevs=0.05, minvals=-1., maxvals=1.)
|
shape=(int(1e5),), means=0.8, stddevs=0.05, minvals=-1., maxvals=1.)
|
||||||
|
new_seed = random_ops.random_uniform([2],
|
||||||
|
seed=1234,
|
||||||
|
minval=0,
|
||||||
|
maxval=(2**31 - 1),
|
||||||
|
dtype=np.int32)
|
||||||
|
sample_op_stateless = stateless.stateless_parameterized_truncated_normal(
|
||||||
|
shape=(int(1e5),),
|
||||||
|
seed=new_seed,
|
||||||
|
means=0.8,
|
||||||
|
stddevs=0.05,
|
||||||
|
minvals=-1.,
|
||||||
|
maxvals=1.)
|
||||||
|
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
samples = sess.run(sample_op)
|
samples, samples_stateless = sess.run([sample_op, sample_op_stateless])
|
||||||
# 0. is more than 16 standard deviations from the mean, and
|
# 0. is more than 16 standard deviations from the mean, and
|
||||||
# should have a likelihood < 1e-57.
|
# should have a likelihood < 1e-57.
|
||||||
assert (~np.isnan(samples)).all()
|
assert (~np.isnan(samples)).all()
|
||||||
no_neg_samples = np.sum(samples < 0.)
|
assert (~np.isnan(samples_stateless)).all()
|
||||||
self.assertEqual(no_neg_samples, 0.)
|
self.assertAllGreater(samples, 0.)
|
||||||
|
self.assertAllGreater(samples_stateless, 0.)
|
||||||
|
|
||||||
|
def testStatelessParameterizedTruncatedNormalHasGrads(self):
|
||||||
|
mean = variables.Variable(0.01)
|
||||||
|
stddev = variables.Variable(1.)
|
||||||
|
minval = variables.Variable(-1.)
|
||||||
|
maxval = variables.Variable(1.)
|
||||||
|
|
||||||
|
with self.cached_session(use_gpu=True) as sess:
|
||||||
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
|
samples = stateless.stateless_parameterized_truncated_normal(
|
||||||
|
[1], [1, 2], mean, stddev, minval, maxval)
|
||||||
|
|
||||||
|
sess.run(variables.variables_initializer([mean, stddev, minval, maxval]))
|
||||||
|
[mean_grad, std_grad], mean_actual_grad, std_actual_grad = sess.run([
|
||||||
|
tape.gradient(samples, [mean, stddev]),
|
||||||
|
array_ops.ones_like(mean),
|
||||||
|
(samples - mean) / stddev])
|
||||||
|
self.assertAllClose(mean_grad, mean_actual_grad)
|
||||||
|
self.assertAllClose(std_grad, std_actual_grad[0])
|
||||||
|
|
||||||
|
try:
|
||||||
|
import scipy.stats # pylint:disable=g-import-not-at-top
|
||||||
|
truncnorm = scipy.stats.truncnorm(a=-1., b=1., loc=0., scale=1.)
|
||||||
|
samples_np, [minval_grad, maxval_grad] = sess.run([
|
||||||
|
samples, tape.gradient(samples, [minval, maxval])])
|
||||||
|
|
||||||
|
sample_cdf = truncnorm.cdf(samples_np)
|
||||||
|
# These come from the implicit reparameterization trick.
|
||||||
|
scipy_maxval_grad = np.exp(
|
||||||
|
0.5 * (samples_np ** 2 - ((1. - 0.01) / 1.) ** 2) +
|
||||||
|
np.log(sample_cdf))
|
||||||
|
|
||||||
|
scipy_minval_grad = np.exp(
|
||||||
|
0.5 * (samples_np ** 2 - ((-1. - 0.01) / 1.) ** 2) +
|
||||||
|
np.log1p(-sample_cdf))
|
||||||
|
|
||||||
|
self.assertAllClose(minval_grad, scipy_minval_grad[0], rtol=1e-2)
|
||||||
|
self.assertAllClose(maxval_grad, scipy_maxval_grad[0], rtol=1e-2)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
tf_logging.warn("Cannot test truncated normal op: %s" % str(e))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testSamplingAtRandnSwitchover(self):
|
def testSamplingAtRandnSwitchover(self):
|
||||||
@ -239,18 +362,33 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
|
|||||||
|
|
||||||
epsilon = 0.001
|
epsilon = 0.001
|
||||||
self.validateMoments(
|
self.validateMoments(
|
||||||
shape=[10**6],
|
shape=[int(1e6)],
|
||||||
mean=0.,
|
mean=0.,
|
||||||
stddev=1.0,
|
stddev=1.0,
|
||||||
minval=-epsilon,
|
minval=-epsilon,
|
||||||
maxval=stddev_inside_bounds_before_using_randn - epsilon)
|
maxval=stddev_inside_bounds_before_using_randn - epsilon)
|
||||||
self.validateMoments(
|
self.validateMoments(
|
||||||
shape=[10**6],
|
shape=[int(1e6)],
|
||||||
mean=0.,
|
mean=0.,
|
||||||
stddev=1.0,
|
stddev=1.0,
|
||||||
minval=-epsilon,
|
minval=-epsilon,
|
||||||
maxval=stddev_inside_bounds_before_using_randn + epsilon)
|
maxval=stddev_inside_bounds_before_using_randn + epsilon)
|
||||||
|
|
||||||
|
self.validateMoments(
|
||||||
|
shape=[int(1e6)],
|
||||||
|
mean=0.,
|
||||||
|
stddev=1.0,
|
||||||
|
minval=-epsilon,
|
||||||
|
maxval=stddev_inside_bounds_before_using_randn - epsilon,
|
||||||
|
use_stateless=True)
|
||||||
|
self.validateMoments(
|
||||||
|
shape=[int(1e6)],
|
||||||
|
mean=0.,
|
||||||
|
stddev=1.0,
|
||||||
|
minval=-epsilon,
|
||||||
|
maxval=stddev_inside_bounds_before_using_randn + epsilon,
|
||||||
|
use_stateless=True)
|
||||||
|
|
||||||
|
|
||||||
# Benchmarking code
|
# Benchmarking code
|
||||||
def parameterized_vs_naive(shape, num_iters, use_gpu=False):
|
def parameterized_vs_naive(shape, num_iters, use_gpu=False):
|
@ -18,9 +18,14 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
|
from tensorflow.python.ops import gen_array_ops
|
||||||
from tensorflow.python.ops import gen_random_ops
|
from tensorflow.python.ops import gen_random_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
@ -114,3 +119,118 @@ def _StatelessRandomGammaV2Grad(op, grad): # pylint: disable=invalid-name
|
|||||||
return (None, None,
|
return (None, None,
|
||||||
math_ops.reduce_sum(
|
math_ops.reduce_sum(
|
||||||
grad * partial_a, axis=math_ops.range(num_sample_dimensions)))
|
grad * partial_a, axis=math_ops.range(num_sample_dimensions)))
|
||||||
|
|
||||||
|
|
||||||
|
def _Ndtr(x):
|
||||||
|
"""Normal distribution function."""
|
||||||
|
half_sqrt_2 = constant_op.constant(
|
||||||
|
0.5 * np.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
|
||||||
|
w = x * half_sqrt_2
|
||||||
|
z = math_ops.abs(w)
|
||||||
|
y = array_ops.where(
|
||||||
|
z < half_sqrt_2,
|
||||||
|
1. + math_ops.erf(w),
|
||||||
|
array_ops.where(
|
||||||
|
w > 0., 2. - math_ops.erfc(z), math_ops.erfc(z)))
|
||||||
|
return 0.5 * y
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("StatelessParameterizedTruncatedNormal")
|
||||||
|
def _StatelessParameterizedTruncatedNormalGrad(op, grad): # pylint: disable=invalid-name
|
||||||
|
"""Returns the gradient of a TruncatedNormal sample w.r.t. parameters.
|
||||||
|
|
||||||
|
The gradient is computed using implicit differentiation
|
||||||
|
(Figurnov et al., 2018).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: A `StatelessParameterizedTruncatedNormal` operation. We assume that the
|
||||||
|
inputs to the operation are `shape`, `seed`, `mean`, `stddev`, `minval`,
|
||||||
|
and `maxval` tensors, and the output is the `sample` tensor.
|
||||||
|
grad: The incoming gradient `dloss / dsample` of the same shape as
|
||||||
|
`op.outputs[0]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of `Tensor` with derivates with respect to each parameter.
|
||||||
|
|
||||||
|
References:
|
||||||
|
Implicit Reparameterization Gradients:
|
||||||
|
[Figurnov et al., 2018]
|
||||||
|
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
|
||||||
|
([pdf]
|
||||||
|
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
|
||||||
|
"""
|
||||||
|
shape = op.inputs[0]
|
||||||
|
mean = op.inputs[2]
|
||||||
|
stddev = op.inputs[3]
|
||||||
|
minval = op.inputs[4]
|
||||||
|
maxval = op.inputs[5]
|
||||||
|
sample = op.outputs[0]
|
||||||
|
|
||||||
|
with ops.control_dependencies([grad]):
|
||||||
|
minval_std = (minval - mean) / stddev
|
||||||
|
maxval_std = (maxval - mean) / stddev
|
||||||
|
sample_std = (sample - mean) / stddev
|
||||||
|
|
||||||
|
cdf_sample = (_Ndtr(sample_std) - _Ndtr(minval_std)) / (
|
||||||
|
_Ndtr(maxval_std) - _Ndtr(minval_std))
|
||||||
|
|
||||||
|
# Clip to avoid zero argument for log_cdf expression
|
||||||
|
tiny = np.finfo(mean.dtype.as_numpy_dtype).tiny
|
||||||
|
eps = np.finfo(mean.dtype.as_numpy_dtype).eps
|
||||||
|
cdf_sample = clip_ops.clip_by_value(cdf_sample, tiny, 1 - eps)
|
||||||
|
|
||||||
|
dmaxval = math_ops.exp(0.5 * (sample_std ** 2 - maxval_std ** 2) +
|
||||||
|
math_ops.log(cdf_sample))
|
||||||
|
dminval = math_ops.exp(0.5 * (sample_std ** 2 - minval_std ** 2) +
|
||||||
|
math_ops.log1p(-cdf_sample))
|
||||||
|
dmean = array_ops.ones_like(sample_std)
|
||||||
|
dstddev = sample_std
|
||||||
|
|
||||||
|
# Reduce over extra dimensions caused by `shape`. We need to get the
|
||||||
|
# difference in rank from shape vs. the broadcasted rank.
|
||||||
|
|
||||||
|
mean_shape = array_ops.shape(mean)
|
||||||
|
stddev_shape = array_ops.shape(stddev)
|
||||||
|
minval_shape = array_ops.shape(minval)
|
||||||
|
maxval_shape = array_ops.shape(maxval)
|
||||||
|
|
||||||
|
broadcast_shape = array_ops.broadcast_dynamic_shape(
|
||||||
|
mean_shape, stddev_shape)
|
||||||
|
broadcast_shape = array_ops.broadcast_dynamic_shape(
|
||||||
|
minval_shape, broadcast_shape)
|
||||||
|
broadcast_shape = array_ops.broadcast_dynamic_shape(
|
||||||
|
maxval_shape, broadcast_shape)
|
||||||
|
extra_dims = math_ops.range(
|
||||||
|
array_ops.size(shape) - array_ops.size(broadcast_shape))
|
||||||
|
|
||||||
|
grad_mean = math_ops.reduce_sum(grad * dmean, axis=extra_dims)
|
||||||
|
grad_stddev = math_ops.reduce_sum(grad * dstddev, axis=extra_dims)
|
||||||
|
grad_minval = math_ops.reduce_sum(grad * dminval, axis=extra_dims)
|
||||||
|
grad_maxval = math_ops.reduce_sum(grad * dmaxval, axis=extra_dims)
|
||||||
|
|
||||||
|
_, rmean = gen_array_ops.broadcast_gradient_args(
|
||||||
|
broadcast_shape, mean_shape)
|
||||||
|
_, rstddev = gen_array_ops.broadcast_gradient_args(
|
||||||
|
broadcast_shape, stddev_shape)
|
||||||
|
_, rminval = gen_array_ops.broadcast_gradient_args(
|
||||||
|
broadcast_shape, minval_shape)
|
||||||
|
_, rmaxval = gen_array_ops.broadcast_gradient_args(
|
||||||
|
broadcast_shape, maxval_shape)
|
||||||
|
|
||||||
|
grad_mean = array_ops.reshape(
|
||||||
|
math_ops.reduce_sum(grad_mean, axis=rmean, keepdims=True), mean_shape)
|
||||||
|
|
||||||
|
grad_stddev = array_ops.reshape(
|
||||||
|
math_ops.reduce_sum(grad_stddev, axis=rstddev, keepdims=True),
|
||||||
|
stddev_shape)
|
||||||
|
|
||||||
|
grad_minval = array_ops.reshape(
|
||||||
|
math_ops.reduce_sum(grad_minval, axis=rminval, keepdims=True),
|
||||||
|
minval_shape)
|
||||||
|
|
||||||
|
grad_maxval = array_ops.reshape(
|
||||||
|
math_ops.reduce_sum(grad_maxval, axis=rmaxval, keepdims=True),
|
||||||
|
maxval_shape)
|
||||||
|
|
||||||
|
# The first two inputs are shape.
|
||||||
|
return (None, None, grad_mean, grad_stddev, grad_minval, grad_maxval)
|
||||||
|
@ -618,3 +618,73 @@ def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed):
|
|||||||
logits = ops.convert_to_tensor(logits, name="logits")
|
logits = ops.convert_to_tensor(logits, name="logits")
|
||||||
return gen_stateless_random_ops.stateless_multinomial(
|
return gen_stateless_random_ops.stateless_multinomial(
|
||||||
logits, num_samples, seed, output_dtype=dtype)
|
logits, num_samples, seed, output_dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
|
@tf_export("random.stateless_parameterized_truncated_normal")
|
||||||
|
def stateless_parameterized_truncated_normal(shape,
|
||||||
|
seed,
|
||||||
|
means=0.0,
|
||||||
|
stddevs=1.0,
|
||||||
|
minvals=-2.0,
|
||||||
|
maxvals=2.0,
|
||||||
|
name=None):
|
||||||
|
"""Outputs random values from a truncated normal distribution.
|
||||||
|
|
||||||
|
The generated values follow a normal distribution with specified mean and
|
||||||
|
standard deviation, except that values whose magnitude is more than 2 standard
|
||||||
|
deviations from the mean are dropped and re-picked.
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Sample from a Truncated normal, with deferring shape parameters that
|
||||||
|
broadcast.
|
||||||
|
|
||||||
|
>>> means = 0.
|
||||||
|
>>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3]))
|
||||||
|
>>> minvals = [-1., -2., -1000.]
|
||||||
|
>>> maxvals = [[10000.], [1.]]
|
||||||
|
>>> y = tf.random.stateless_parameterized_truncated_normal(
|
||||||
|
... shape=[10, 2, 3], seed=[7, 17],
|
||||||
|
... means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals)
|
||||||
|
>>> y.shape
|
||||||
|
TensorShape([10, 2, 3])
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: A 1-D integer `Tensor` or Python array. The shape of the output
|
||||||
|
tensor.
|
||||||
|
seed: A shape [2] Tensor, the seed to the random number generator. Must have
|
||||||
|
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
|
||||||
|
means: A `Tensor` or Python value of type `dtype`. The mean of the truncated
|
||||||
|
normal distribution. This must broadcast with `stddevs`, `minvals` and
|
||||||
|
`maxvals`, and the broadcasted shape must be dominated by `shape`.
|
||||||
|
stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation
|
||||||
|
of the truncated normal distribution. This must broadcast with `means`,
|
||||||
|
`minvals` and `maxvals`, and the broadcasted shape must be dominated by
|
||||||
|
`shape`.
|
||||||
|
minvals: A `Tensor` or Python value of type `dtype`. The minimum value of
|
||||||
|
the truncated normal distribution. This must broadcast with `means`,
|
||||||
|
`stddevs` and `maxvals`, and the broadcasted shape must be dominated by
|
||||||
|
`shape`.
|
||||||
|
maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of
|
||||||
|
the truncated normal distribution. This must broadcast with `means`,
|
||||||
|
`stddevs` and `minvals`, and the broadcasted shape must be dominated by
|
||||||
|
`shape`.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of the specified shape filled with random truncated normal values.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, "stateless_parameterized_truncated_normal",
|
||||||
|
[shape, means, stddevs, minvals, maxvals]) as name:
|
||||||
|
shape_tensor = tensor_util.shape_tensor(shape)
|
||||||
|
means_tensor = ops.convert_to_tensor(means, name="means")
|
||||||
|
stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs")
|
||||||
|
minvals_tensor = ops.convert_to_tensor(minvals, name="minvals")
|
||||||
|
maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals")
|
||||||
|
rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal(
|
||||||
|
shape_tensor, seed, means_tensor, stddevs_tensor, minvals_tensor,
|
||||||
|
maxvals_tensor)
|
||||||
|
tensor_util.maybe_set_static_shape(rnd, shape)
|
||||||
|
return rnd
|
||||||
|
@ -92,6 +92,10 @@ tf_module {
|
|||||||
name: "stateless_normal"
|
name: "stateless_normal"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "stateless_parameterized_truncated_normal"
|
||||||
|
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'-2.0\', \'2.0\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "stateless_poisson"
|
name: "stateless_poisson"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||||
|
@ -4492,6 +4492,10 @@ tf_module {
|
|||||||
name: "StatelessMultinomial"
|
name: "StatelessMultinomial"
|
||||||
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessParameterizedTruncatedNormal"
|
||||||
|
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessRandomBinomial"
|
name: "StatelessRandomBinomial"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||||
|
@ -80,6 +80,10 @@ tf_module {
|
|||||||
name: "stateless_normal"
|
name: "stateless_normal"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "stateless_parameterized_truncated_normal"
|
||||||
|
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'-2.0\', \'2.0\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "stateless_poisson"
|
name: "stateless_poisson"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||||
|
@ -4492,6 +4492,10 @@ tf_module {
|
|||||||
name: "StatelessMultinomial"
|
name: "StatelessMultinomial"
|
||||||
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessParameterizedTruncatedNormal"
|
||||||
|
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessRandomBinomial"
|
name: "StatelessRandomBinomial"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user