Add StatelessParameterizedTruncatedNormal sampler.
This sampler supports broadcasting of its input parameters as well as puts the # samples at the left of the output shape, rather than the right. PiperOrigin-RevId: 317129622 Change-Id: I4b62ad2e89a9637ae8b30b73af4b662ad0caa943
This commit is contained in:
parent
89b80c5fb9
commit
455750f362
@ -0,0 +1,54 @@
|
||||
op {
|
||||
graph_op_name: "StatelessParameterizedTruncatedNormal"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "seed"
|
||||
description: <<END
|
||||
2 seeds (shape [2]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "means"
|
||||
description: <<END
|
||||
The mean parameter of each batch.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "stddevs"
|
||||
description: <<END
|
||||
The standard deviation parameter of each batch. Must be greater than 0.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "minvals"
|
||||
description: <<END
|
||||
The minimum cutoff. May be -infinity.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "maxvals"
|
||||
description: <<END
|
||||
The maximum cutoff. May be +infinity, and must be more than the minval
|
||||
for each batch.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
The outputs are truncated normal samples and are a deterministic function of
|
||||
`shape`, `seed`, `minvals`, `maxvals`, `means` and `stddevs`.
|
||||
END
|
||||
}
|
||||
}
|
@ -6156,6 +6156,7 @@ tf_kernel_library(
|
||||
]),
|
||||
prefix = "parameterized_truncated_normal_op",
|
||||
deps = [
|
||||
":stateless_random_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/stateless_random_ops.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||
@ -67,7 +68,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3);
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
|
||||
auto DoWork = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
|
||||
auto do_work = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
|
||||
&minvals, &maxvals, &gen, &output,
|
||||
kStdDevsInsideBoundsToUseRandnSampler](int start_batch,
|
||||
int limit_batch) {
|
||||
@ -80,9 +81,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
// The sample from each iteration uses 2 random numbers.
|
||||
gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) /
|
||||
4);
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
|
||||
using Uniform = random::UniformDistribution<random::PhiloxRandom, T>;
|
||||
Uniform dist;
|
||||
typedef random::NormalDistribution<random::PhiloxRandom, T> Normal;
|
||||
using Normal = random::NormalDistribution<random::PhiloxRandom, T>;
|
||||
Normal normal_dist;
|
||||
|
||||
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||
@ -112,7 +113,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
Eigen::numext::isfinite(maxval)),
|
||||
errors::InvalidArgument("Invalid parameters"));
|
||||
|
||||
int numIterations = 0;
|
||||
int num_iterations = 0;
|
||||
|
||||
// If possible, make one-sided bound be the lower bound, or make both
|
||||
// bounds positive. Otherwise, the bounds are on either side of the
|
||||
@ -160,10 +161,10 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
if (sample >= limit_sample) {
|
||||
break;
|
||||
}
|
||||
numIterations = 0;
|
||||
num_iterations = 0;
|
||||
} else {
|
||||
numIterations++;
|
||||
if (numIterations > kMaxIterations) {
|
||||
num_iterations++;
|
||||
if (num_iterations > kMaxIterations) {
|
||||
// This should never occur because this sampler should
|
||||
// (by the selection criteria above) be used if at least 3
|
||||
// standard deviations of one side of the distribution
|
||||
@ -201,7 +202,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
const auto u = dist(&gen_copy);
|
||||
for (int i = 0; i < size; i++) {
|
||||
auto accept = u[i] <= Eigen::numext::exp(g[i]);
|
||||
if (accept || numIterations + 1 >= kMaxIterations) {
|
||||
if (accept || num_iterations + 1 >= kMaxIterations) {
|
||||
// Accept the sample z.
|
||||
// If we run out of iterations, just use the current uniform
|
||||
// sample, but emit a warning.
|
||||
@ -223,9 +224,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
if (sample >= limit_sample) {
|
||||
break;
|
||||
}
|
||||
numIterations = 0;
|
||||
num_iterations = 0;
|
||||
} else {
|
||||
numIterations++;
|
||||
num_iterations++;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -248,7 +249,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
const T u = rand[i];
|
||||
i++;
|
||||
auto accept = (u <= g && z < normMax);
|
||||
if (accept || numIterations + 1 >= kMaxIterations) {
|
||||
if (accept || num_iterations + 1 >= kMaxIterations) {
|
||||
if (!accept) {
|
||||
LOG(ERROR) << "TruncatedNormal exponential distribution "
|
||||
<< "rejection sampler exceeds max iterations. "
|
||||
@ -263,9 +264,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
if (sample >= limit_sample) {
|
||||
break;
|
||||
}
|
||||
numIterations = 0;
|
||||
num_iterations = 0;
|
||||
} else {
|
||||
numIterations++;
|
||||
num_iterations++;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -305,7 +306,297 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||
const int64 batchCost =
|
||||
batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch;
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, num_batches,
|
||||
batchCost, DoWork);
|
||||
batchCost, do_work);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TruncatedNormalFunctorV2<CPUDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
|
||||
int64 samples_per_batch, int64 num_elements,
|
||||
const BCastList<4>& bcast,
|
||||
typename TTypes<T>::ConstFlat means,
|
||||
typename TTypes<T>::ConstFlat stddevs,
|
||||
typename TTypes<T>::ConstFlat minvals,
|
||||
typename TTypes<T>::ConstFlat maxvals,
|
||||
const random::PhiloxRandom& gen,
|
||||
typename TTypes<T>::Flat output) {
|
||||
// The randn rejection sampling is used when the mean and at least this many
|
||||
// standard deviations are inside the bounds.
|
||||
// The uniform proposal samplers become less efficient as the bounds are
|
||||
// further from the mean, the reverse is true for the randn sampler.
|
||||
// This number was chosen by empirical benchmarking. If modified, the
|
||||
// benchmarks in parameterized_truncated_normal_op_test should also be
|
||||
// changed.
|
||||
const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3);
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
|
||||
auto do_work = [num_batches, samples_per_batch, &ctx, &bcast, &means,
|
||||
&stddevs, &minvals, &maxvals, &gen, &output,
|
||||
kStdDevsInsideBoundsToUseRandnSampler](int start_output,
|
||||
int limit_output) {
|
||||
// Capturing "gen" by-value would only make a copy for the _shared_
|
||||
// lambda. Since we want to let each worker have its own copy, we pass
|
||||
// "gen" by reference and explicitly do a copy assignment here.
|
||||
random::PhiloxRandom gen_copy = gen;
|
||||
using Uniform = random::UniformDistribution<random::PhiloxRandom, T>;
|
||||
Uniform dist;
|
||||
using Normal = random::NormalDistribution<random::PhiloxRandom, T>;
|
||||
Normal normal_dist;
|
||||
// Skip takes units of 128 bits. The Uniform::kResultElementCount - 1
|
||||
// is so rounding doesn't lead to
|
||||
// us using the same state in different workloads.
|
||||
// The sample from each iteration uses 2 random numbers.
|
||||
gen_copy.Skip((start_output * 2 * kMaxIterations +
|
||||
Uniform::kResultElementCount - 1) /
|
||||
Uniform::kResultElementCount);
|
||||
|
||||
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||
// We always generate at most 4 samples.
|
||||
Eigen::array<T, Uniform::kResultElementCount> z;
|
||||
Eigen::array<T, Uniform::kResultElementCount> g;
|
||||
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const auto& means_batch_indices = bcast.batch_indices(0);
|
||||
const auto& stddevs_batch_indices = bcast.batch_indices(1);
|
||||
const auto& minvals_batch_indices = bcast.batch_indices(2);
|
||||
const auto& maxvals_batch_indices = bcast.batch_indices(3);
|
||||
auto output_flat = output.data();
|
||||
|
||||
// We partition work across batches and then across samples
|
||||
// per batch member, to avoid extra work.
|
||||
for (int64 output_idx = start_output; output_idx < limit_output;
|
||||
// output_idx is incremented with the inner loops below.
|
||||
) {
|
||||
int64 batch_idx = output_idx / samples_per_batch;
|
||||
// The output layout is [samples_per_batch, num_batches]. Thus
|
||||
// the output address is sample_idx * num_batches + batch_idx.
|
||||
// Below, code will index at output_batch_offset[sample_idx *
|
||||
// num_batches] matching this.
|
||||
T* const output_batch_offset = output_flat + batch_idx;
|
||||
// Generate batch counts from BCast, as it has the right indices to loop
|
||||
// over.
|
||||
T mean, stddev, minval, maxval;
|
||||
if (should_bcast) {
|
||||
mean = means(means_batch_indices[batch_idx]);
|
||||
stddev = stddevs(stddevs_batch_indices[batch_idx]);
|
||||
minval = minvals(minvals_batch_indices[batch_idx]);
|
||||
maxval = maxvals(maxvals_batch_indices[batch_idx]);
|
||||
} else {
|
||||
mean = means(batch_idx);
|
||||
stddev = stddevs(batch_idx);
|
||||
minval = minvals(batch_idx);
|
||||
maxval = maxvals(batch_idx);
|
||||
}
|
||||
|
||||
// On GPU, this check will just fill samples with NAN if it fails.
|
||||
OP_REQUIRES(ctx,
|
||||
stddev > T(0) && minval < maxval &&
|
||||
(Eigen::numext::isfinite(minval) ||
|
||||
Eigen::numext::isfinite(maxval)),
|
||||
errors::InvalidArgument("Invalid parameters"));
|
||||
|
||||
int num_iterations = 0;
|
||||
|
||||
// If possible, make one-sided bound be the lower bound, or make both
|
||||
// bounds positive. Otherwise, the bounds are on either side of the
|
||||
// mean.
|
||||
if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) {
|
||||
// Reverse all calculations. normMin and normMax will be flipped.
|
||||
std::swap(minval, maxval);
|
||||
stddev = -stddev;
|
||||
}
|
||||
|
||||
// Calculate normalized samples, then convert them.
|
||||
const T normMin = (minval - mean) / stddev;
|
||||
const T normMax = (maxval - mean) / stddev;
|
||||
|
||||
// Determine the method to use.
|
||||
const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
|
||||
const T cutoff =
|
||||
T(2) *
|
||||
Eigen::numext::exp(T(0.5) +
|
||||
(normMin * (normMin - sqrtFactor)) / T(4)) /
|
||||
(normMin + sqrtFactor);
|
||||
const T diff = normMax - normMin;
|
||||
|
||||
if (((normMin < -kStdDevsInsideBoundsToUseRandnSampler) &&
|
||||
(normMax >= T(0.))) ||
|
||||
((normMax > kStdDevsInsideBoundsToUseRandnSampler) &&
|
||||
(normMin <= T(0.)))) {
|
||||
// If the bounds are a least 3 standard deviations from the mean
|
||||
// on at least one side then we rejection sample by sampling
|
||||
// from the normal distribution and rejecting samples outside
|
||||
// the bounds.
|
||||
// Under this condition the acceptance rate per iteration should
|
||||
// always be ~ 50%. This sampler is more efficient (and more
|
||||
// numerically stable when one or both bounds is far from the mean).
|
||||
for (int64 sample_idx = output_idx % samples_per_batch;
|
||||
sample_idx < samples_per_batch && output_idx < limit_output;) {
|
||||
const auto randn_sample = normal_dist(&gen_copy);
|
||||
const int size = randn_sample.size();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if ((randn_sample[i] >= normMin) &&
|
||||
(randn_sample[i] <= normMax)) {
|
||||
output_batch_offset[sample_idx * num_batches] =
|
||||
randn_sample[i] * stddev + mean;
|
||||
++sample_idx;
|
||||
++output_idx;
|
||||
if (sample_idx >= samples_per_batch ||
|
||||
output_idx >= limit_output) {
|
||||
break;
|
||||
}
|
||||
num_iterations = 0;
|
||||
} else {
|
||||
++num_iterations;
|
||||
if (num_iterations > kMaxIterations) {
|
||||
// This should never occur because this sampler should
|
||||
// (by the selection criteria above) be used if at least 3
|
||||
// standard deviations of one side of the distribution
|
||||
// is within the limits (so acceptance probability per
|
||||
// iterations >~ 1/2 per iteration).
|
||||
LOG(ERROR) << "TruncatedNormal randn rejection sampler "
|
||||
<< "exceeded maximum iterations for "
|
||||
<< "normMin=" << normMin << " normMax=" << normMax
|
||||
<< " kMaxIterations=" << kMaxIterations;
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"TruncatedNormal randn rejection sampler failed to accept"
|
||||
" a sample."));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (diff < cutoff) {
|
||||
// Sample from a uniform distribution on [normMin, normMax].
|
||||
|
||||
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
|
||||
|
||||
for (int64 sample_idx = output_idx % samples_per_batch;
|
||||
sample_idx < samples_per_batch && output_idx < limit_output;) {
|
||||
const auto rand = dist(&gen_copy);
|
||||
const int size = rand.size();
|
||||
// NOTE(ringwalt): These loops seem to only generate packed AVX
|
||||
// instructions for float32.
|
||||
for (int i = 0; i < size; i++) {
|
||||
z[i] = rand[i] * diff + normMin;
|
||||
g[i] = (plusFactor - z[i] * z[i]) / T(2.0);
|
||||
}
|
||||
|
||||
const auto u = dist(&gen_copy);
|
||||
for (int i = 0; i < size; i++) {
|
||||
auto accept = u[i] <= Eigen::numext::exp(g[i]);
|
||||
if (accept || num_iterations + 1 >= kMaxIterations) {
|
||||
// Accept the sample z.
|
||||
// If we run out of iterations, just use the current uniform
|
||||
// sample, but emit a warning.
|
||||
// TODO(jjhunt) For small entropies (relative to the bounds),
|
||||
// this sampler is poor and may take many iterations since
|
||||
// the proposal distribution is the uniform distribution
|
||||
// U(lower_bound, upper_bound).
|
||||
if (!accept) {
|
||||
LOG(ERROR) << "TruncatedNormal uniform rejection sampler "
|
||||
<< "exceeded max iterations. Sample may contain "
|
||||
<< "outliers.";
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"TruncatedNormal uniform rejection sampler failed to "
|
||||
" accept a sample."));
|
||||
return;
|
||||
}
|
||||
output_batch_offset[sample_idx * num_batches] =
|
||||
z[i] * stddev + mean;
|
||||
++sample_idx;
|
||||
++output_idx;
|
||||
if (sample_idx >= samples_per_batch ||
|
||||
output_idx >= limit_output) {
|
||||
break;
|
||||
}
|
||||
num_iterations = 0;
|
||||
} else {
|
||||
num_iterations++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Sample from an exponential distribution with alpha maximizing
|
||||
// acceptance probability, offset by normMin from the origin.
|
||||
// Accept only if less than normMax.
|
||||
const T alpha =
|
||||
(normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) /
|
||||
T(2);
|
||||
for (int64 sample_idx = output_idx % samples_per_batch;
|
||||
sample_idx < samples_per_batch && output_idx < limit_output;) {
|
||||
auto rand = dist(&gen_copy);
|
||||
const int size = rand.size();
|
||||
int i = 0;
|
||||
while (i < size) {
|
||||
const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
|
||||
i++;
|
||||
const T x = normMin < alpha ? alpha - z : normMin - alpha;
|
||||
const T g = Eigen::numext::exp(-x * x / T(2.0));
|
||||
const T u = rand[i];
|
||||
i++;
|
||||
auto accept = (u <= g && z < normMax);
|
||||
if (accept || num_iterations + 1 >= kMaxIterations) {
|
||||
if (!accept) {
|
||||
LOG(ERROR) << "TruncatedNormal exponential distribution "
|
||||
<< "rejection sampler exceeds max iterations. "
|
||||
<< "Sample may contain outliers.";
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"TruncatedNormal exponential distribution rejection"
|
||||
" sampler failed to accept a sample."));
|
||||
return;
|
||||
}
|
||||
output_batch_offset[sample_idx * num_batches] =
|
||||
z * stddev + mean;
|
||||
++sample_idx;
|
||||
++output_idx;
|
||||
if (sample_idx >= samples_per_batch ||
|
||||
output_idx >= limit_output) {
|
||||
break;
|
||||
}
|
||||
num_iterations = 0;
|
||||
} else {
|
||||
num_iterations++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
// The cost of the initial calculations for the batch.
|
||||
const int64 batchInitCost =
|
||||
// normMin, normMax
|
||||
(Eigen::TensorOpCost::AddCost<T>() +
|
||||
Eigen::TensorOpCost::MulCost<T>()) *
|
||||
2
|
||||
// sqrtFactor
|
||||
+ Eigen::TensorOpCost::AddCost<T>() +
|
||||
Eigen::TensorOpCost::MulCost<T>() +
|
||||
Eigen::internal::functor_traits<
|
||||
Eigen::internal::scalar_sqrt_op<T>>::Cost
|
||||
// cutoff
|
||||
+ Eigen::TensorOpCost::MulCost<T>() * 4 +
|
||||
Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost
|
||||
// diff
|
||||
+ Eigen::TensorOpCost::AddCost<T>();
|
||||
const int64 uniformSampleCost =
|
||||
random::PhiloxRandom::kElementCost +
|
||||
random::UniformDistribution<random::PhiloxRandom, T>::kElementCost;
|
||||
// The cost of a single uniform sampling round.
|
||||
const int64 uniformRejectionSamplingCost =
|
||||
uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() +
|
||||
Eigen::TensorOpCost::AddCost<T>() +
|
||||
Eigen::TensorOpCost::MulCost<T>() * 2 +
|
||||
Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost +
|
||||
Eigen::internal::functor_traits<
|
||||
Eigen::internal::scalar_exp_op<T>>::Cost +
|
||||
Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>();
|
||||
// Estimate the cost for an entire batch.
|
||||
// Assume we use uniform sampling, and accept the 2nd sample on average.
|
||||
const int64 batchCost = batchInitCost + uniformRejectionSamplingCost * 2;
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, num_elements,
|
||||
batchCost, do_work);
|
||||
}
|
||||
};
|
||||
|
||||
@ -436,13 +727,113 @@ class ParameterizedTruncatedNormalOp : public OpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp);
|
||||
};
|
||||
|
||||
// Samples from a truncated normal distribution, using the given parameters.
|
||||
template <typename Device, typename T>
|
||||
class StatelessParameterizedTruncatedNormal : public OpKernel {
|
||||
// Reshape batches so each batch is this size if possible.
|
||||
static const int32 kDesiredBatchSize = 100;
|
||||
|
||||
public:
|
||||
explicit StatelessParameterizedTruncatedNormal(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& shape_tensor = ctx->input(0);
|
||||
const Tensor& seed_tensor = ctx->input(1);
|
||||
const Tensor& means_tensor = ctx->input(2);
|
||||
const Tensor& stddevs_tensor = ctx->input(3);
|
||||
const Tensor& minvals_tensor = ctx->input(4);
|
||||
const Tensor& maxvals_tensor = ctx->input(5);
|
||||
|
||||
OP_REQUIRES(ctx, seed_tensor.dims() == 1 && seed_tensor.dim_size(0) == 2,
|
||||
errors::InvalidArgument("seed must have shape [2], not ",
|
||||
seed_tensor.shape().DebugString()));
|
||||
|
||||
tensorflow::BCastList<4> bcast(
|
||||
{means_tensor.shape().dim_sizes(), stddevs_tensor.shape().dim_sizes(),
|
||||
minvals_tensor.shape().dim_sizes(),
|
||||
maxvals_tensor.shape().dim_sizes()},
|
||||
/*fewer_dims_optimization=*/false,
|
||||
/*return_flattened_batch_indices=*/true);
|
||||
|
||||
OP_REQUIRES(ctx, bcast.IsValid(),
|
||||
errors::InvalidArgument(
|
||||
"means, stddevs, minvals, maxvals must have compatible "
|
||||
"batch dimensions: ",
|
||||
means_tensor.shape().DebugString(), " vs. ",
|
||||
stddevs_tensor.shape().DebugString(), " vs. ",
|
||||
minvals_tensor.shape().DebugString(), " vs. ",
|
||||
maxvals_tensor.shape().DebugString()));
|
||||
|
||||
// Let's check that the shape tensor dominates the broadcasted tensor.
|
||||
TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
|
||||
errors::InvalidArgument("Input shape should be a vector, got shape: ",
|
||||
shape_tensor.shape().DebugString()));
|
||||
TensorShape output_shape;
|
||||
if (shape_tensor.dtype() == DataType::DT_INT32) {
|
||||
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
|
||||
&output_shape));
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int64>(),
|
||||
&output_shape));
|
||||
}
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
|
||||
errors::InvalidArgument(
|
||||
"Shape passed in must end with broadcasted shape."));
|
||||
|
||||
int64 samples_per_batch = 1;
|
||||
const int64 num_sample_dims =
|
||||
(shape_tensor.dim_size(0) - bcast.output_shape().size());
|
||||
for (int64 i = 0; i < num_sample_dims; ++i) {
|
||||
samples_per_batch *= output_shape.dim_size(i);
|
||||
}
|
||||
int64 num_batches = 1;
|
||||
for (int64 i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
|
||||
num_batches *= output_shape.dim_size(i);
|
||||
}
|
||||
const int64 num_elements = num_batches * samples_per_batch;
|
||||
|
||||
Tensor* samples_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
|
||||
|
||||
auto truncFunctor = functor::TruncatedNormalFunctorV2<Device, T>();
|
||||
// Each worker has the same fudge factor, so use it here.
|
||||
random::PhiloxRandom::Key key;
|
||||
random::PhiloxRandom::ResultType counter;
|
||||
OP_REQUIRES_OK(ctx, GenerateKey(seed_tensor, &key, &counter));
|
||||
|
||||
auto philox = random::PhiloxRandom(counter, key);
|
||||
|
||||
truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches,
|
||||
samples_per_batch, num_elements, bcast, means_tensor.flat<T>(),
|
||||
stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(),
|
||||
maxvals_tensor.flat<T>(), philox, samples_tensor->flat<T>());
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessParameterizedTruncatedNormal);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#define REGISTER(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
ParameterizedTruncatedNormalOp<CPUDevice, TYPE>)
|
||||
ParameterizedTruncatedNormalOp<CPUDevice, TYPE>) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessParameterizedTruncatedNormal") \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("seed") \
|
||||
.HostMemory("means") \
|
||||
.HostMemory("stddevs") \
|
||||
.HostMemory("minvals") \
|
||||
.HostMemory("maxvals") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessParameterizedTruncatedNormal<CPUDevice, TYPE>)
|
||||
|
||||
TF_CALL_half(REGISTER);
|
||||
TF_CALL_float(REGISTER);
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -44,6 +45,21 @@ struct TruncatedNormalFunctor {
|
||||
typename TTypes<T>::Flat output);
|
||||
};
|
||||
|
||||
// This version supports broadcasting of the arguments, as well as puts
|
||||
// the sample dimension on the left.
|
||||
template <typename Device, typename T>
|
||||
struct TruncatedNormalFunctorV2 {
|
||||
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
|
||||
int64 samples_per_batch, int64 num_elements,
|
||||
const BCastList<4>& bcast,
|
||||
typename TTypes<T>::ConstFlat means,
|
||||
typename TTypes<T>::ConstFlat stddevs,
|
||||
typename TTypes<T>::ConstFlat minvals,
|
||||
typename TTypes<T>::ConstFlat maxvals,
|
||||
const random::PhiloxRandom& gen,
|
||||
typename TTypes<T>::Flat output);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -124,6 +124,41 @@ REGISTER_OP("StatelessRandomBinomial")
|
||||
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
|
||||
.SetShapeFn(StatelessShape);
|
||||
|
||||
REGISTER_OP("StatelessParameterizedTruncatedNormal")
|
||||
.Input("shape: S")
|
||||
.Input("seed: Tseed")
|
||||
.Input("means: dtype")
|
||||
.Input("stddevs: dtype")
|
||||
.Input("minvals: dtype")
|
||||
.Input("maxvals: dtype")
|
||||
.Output("output: dtype")
|
||||
.Attr("S: {int32, int64}")
|
||||
.Attr("Tseed: {int32, int64} = DT_INT64")
|
||||
.Attr("dtype: {float16, float32, float64}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Check seed shape
|
||||
ShapeHandle seed;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
|
||||
DimensionHandle unused_dim;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused_dim));
|
||||
|
||||
ShapeHandle bcast_means_stddevs;
|
||||
ShapeHandle bcast_except_maxvals;
|
||||
ShapeHandle bcast_all;
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||
c, c->input(2), c->input(3), true, &bcast_means_stddevs));
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||
c, c->input(4), bcast_means_stddevs, true, &bcast_except_maxvals));
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||
c, c->input(5), bcast_except_maxvals, true, &bcast_all));
|
||||
|
||||
// Set output shape
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("StatelessRandomPoisson")
|
||||
.Input("shape: T")
|
||||
.Input("seed: Tseed")
|
||||
|
@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 348> a = {{
|
||||
static std::array<OpIndexInfo, 349> a = {{
|
||||
{"Acosh"},
|
||||
{"AllToAll", 1, {0}},
|
||||
{"ApproximateEqual"},
|
||||
@ -326,6 +326,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
{"StackPop"},
|
||||
{"StackPush"},
|
||||
{"StatelessMultinomial"},
|
||||
{"StatelessParameterizedTruncatedNormal", 1, {1}},
|
||||
{"StatelessRandomBinomial"},
|
||||
{"StatelessRandomGammaV2", 1, {1}},
|
||||
{"StatelessRandomNormal"},
|
||||
|
@ -785,24 +785,6 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "parameterized_truncated_normal_op_test",
|
||||
size = "medium",
|
||||
srcs = ["parameterized_truncated_normal_op_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "parsing_ops_test",
|
||||
size = "medium",
|
||||
|
@ -20,6 +20,24 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "parameterized_truncated_normal_op_test",
|
||||
size = "medium",
|
||||
srcs = ["parameterized_truncated_normal_op_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "random_shuffle_queue_test",
|
||||
size = "small",
|
||||
|
@ -27,11 +27,15 @@ from six.moves import range # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import stateless_random_ops as stateless
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
@ -91,13 +95,8 @@ class TruncatedNormalMoments(object):
|
||||
|
||||
def calculate_moments(samples, max_moment):
|
||||
moments = [0.0] * (max_moment + 1)
|
||||
for sample in samples:
|
||||
value = 1.0
|
||||
for k in range(len(moments)):
|
||||
moments[k] += value
|
||||
value *= sample
|
||||
for i in range(len(moments)):
|
||||
moments[i] /= len(samples)
|
||||
moments[k] = np.mean(samples**k, axis=0)
|
||||
return moments
|
||||
|
||||
|
||||
@ -118,16 +117,31 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
|
||||
# Stop at moment 10 to avoid numerical errors in the theoretical moments.
|
||||
max_moment = 10
|
||||
|
||||
def validateMoments(self, shape, mean, stddev, minval, maxval, seed=1618):
|
||||
def validateMoments(self,
|
||||
shape,
|
||||
mean,
|
||||
stddev,
|
||||
minval,
|
||||
maxval,
|
||||
use_stateless=False,
|
||||
seed=1618):
|
||||
try:
|
||||
# TruncatedNormalMoments requires scipy.stats.
|
||||
# Give up early if we are unable to import it.
|
||||
import scipy.stats # pylint: disable=g-import-not-at-top,unused-variable
|
||||
random_seed.set_random_seed(seed)
|
||||
with self.cached_session(use_gpu=True):
|
||||
samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
|
||||
minval,
|
||||
maxval).eval()
|
||||
if use_stateless:
|
||||
# Generate a seed that stateless ops can use.
|
||||
new_seed = random_ops.random_uniform([2],
|
||||
seed=seed,
|
||||
minval=0,
|
||||
maxval=(2**31 - 1),
|
||||
dtype=np.int32)
|
||||
samples = stateless.stateless_parameterized_truncated_normal(
|
||||
shape, new_seed, mean, stddev, minval, maxval).eval()
|
||||
else:
|
||||
samples = random_ops.parameterized_truncated_normal(
|
||||
shape, mean, stddev, minval, maxval).eval()
|
||||
assert (~np.isnan(samples)).all()
|
||||
moments = calculate_moments(samples, self.max_moment)
|
||||
expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval)
|
||||
@ -144,14 +158,24 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
|
||||
stddev,
|
||||
minval,
|
||||
maxval,
|
||||
use_stateless=False,
|
||||
seed=1618):
|
||||
try:
|
||||
import scipy.stats # pylint: disable=g-import-not-at-top
|
||||
random_seed.set_random_seed(seed)
|
||||
with self.cached_session(use_gpu=True):
|
||||
samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
|
||||
minval,
|
||||
maxval).eval()
|
||||
if use_stateless:
|
||||
new_seed = random_ops.random_uniform([2],
|
||||
seed=seed,
|
||||
minval=0,
|
||||
maxval=(2**31 - 1),
|
||||
dtype=np.int32)
|
||||
samples = stateless.stateless_parameterized_truncated_normal(
|
||||
shape, new_seed, mean, stddev, minval, maxval).eval()
|
||||
else:
|
||||
samples = random_ops.parameterized_truncated_normal(
|
||||
shape, mean, stddev, minval, maxval).eval()
|
||||
|
||||
assert (~np.isnan(samples)).all()
|
||||
minval = max(mean - stddev * 10, minval)
|
||||
maxval = min(mean + stddev * 10, maxval)
|
||||
@ -169,61 +193,160 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testDefaults(self):
|
||||
self.validateMoments([10**5], 0.0, 1.0, -2.0, 2.0)
|
||||
self.validateMoments([int(1e5)], 0.0, 1.0, -2.0, 2.0)
|
||||
self.validateMoments([int(1e5)], 0.0, 1.0, -2.0, 2.0, use_stateless=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShifted(self):
|
||||
self.validateMoments([10**5], -1.0, 1.0, -2.0, 2.0)
|
||||
self.validateMoments([int(1e5)], -1.0, 1.0, -2.0, 2.0)
|
||||
self.validateMoments([int(1e5)], -1.0, 1.0, -2.0, 2.0, use_stateless=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRightTail(self):
|
||||
self.validateMoments([10**5], 0.0, 1.0, 4.0, np.infty)
|
||||
self.validateMoments([int(1e5)], 0.0, 1.0, 4.0, np.infty)
|
||||
self.validateMoments([int(1e5)],
|
||||
0.0,
|
||||
1.0,
|
||||
4.0,
|
||||
np.infty,
|
||||
use_stateless=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testLeftTail(self):
|
||||
self.validateMoments([10**5], 0.0, 1.0, -np.infty, -4.0)
|
||||
self.validateMoments([int(1e5)], 0.0, 1.0, -np.infty, -4.0)
|
||||
self.validateMoments([int(1e5)],
|
||||
0.0,
|
||||
1.0,
|
||||
-np.infty,
|
||||
-4.0,
|
||||
use_stateless=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testLeftTailTwoSidedBounds(self):
|
||||
self.validateMoments([10**5], 0.0, 1.0, -6.0, -3.0)
|
||||
self.validateMoments([int(1e5)], 0.0, 1.0, -6.0, -3.0)
|
||||
self.validateMoments([int(1e5)], 0.0, 1.0, -6.0, -3.0, use_stateless=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("Low probability region")
|
||||
def testTwoSidedLeftTailShifted(self):
|
||||
self.validateKolmogorovSmirnov([10**5], 6.0, 1.0, -1.0, 1.0)
|
||||
self.validateKolmogorovSmirnov([int(1e5)], 6.0, 1.0, -1.0, 1.0)
|
||||
self.validateKolmogorovSmirnov([int(1e5)],
|
||||
6.0,
|
||||
1.0,
|
||||
-1.0,
|
||||
1.0,
|
||||
use_stateless=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("Low probability region")
|
||||
def testRightTailShifted(self):
|
||||
self.validateMoments([10**5], -5.0, 1.0, 2.0, np.infty)
|
||||
self.validateMoments([int(1e5)], -5.0, 1.0, 2.0, np.infty)
|
||||
self.validateMoments([int(1e5)],
|
||||
-5.0,
|
||||
1.0,
|
||||
2.0,
|
||||
np.infty,
|
||||
use_stateless=True)
|
||||
|
||||
# Take the normal distribution around the mean, but truncating the left tail
|
||||
# far from the mean.
|
||||
@test_util.run_deprecated_v1
|
||||
def testTruncateOnLeft_entireTailOnRight(self):
|
||||
self.validateKolmogorovSmirnov([10**5], 10.0, 1.0, 4.0, np.infty)
|
||||
self.validateKolmogorovSmirnov([int(1e5)], 10.0, 1.0, 4.0, np.infty)
|
||||
self.validateKolmogorovSmirnov([int(1e5)],
|
||||
10.0,
|
||||
1.0,
|
||||
4.0,
|
||||
np.infty,
|
||||
use_stateless=True)
|
||||
|
||||
# Take the normal distribution around the mean, but truncating the right tail.
|
||||
@test_util.run_deprecated_v1
|
||||
def testTruncateOnRight_entireTailOnLeft(self):
|
||||
self.validateKolmogorovSmirnov([10**5], -8, 1.0, -np.infty, -4.0)
|
||||
self.validateKolmogorovSmirnov([int(1e5)], -8, 1.0, -np.infty, -4.0)
|
||||
self.validateKolmogorovSmirnov([int(1e5)],
|
||||
-8.,
|
||||
1.0,
|
||||
-np.infty,
|
||||
-4.0,
|
||||
use_stateless=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSmallStddev(self):
|
||||
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
|
||||
self.validateKolmogorovSmirnov([int(1e5)], 0.0, 0.1, 0.05, 0.10)
|
||||
self.validateKolmogorovSmirnov([int(1e5)],
|
||||
0.0,
|
||||
0.1,
|
||||
0.05,
|
||||
0.10,
|
||||
use_stateless=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSamplingWithSmallStdDevFarFromBound(self):
|
||||
sample_op = random_ops.parameterized_truncated_normal(
|
||||
shape=(int(1e5),), means=0.8, stddevs=0.05, minvals=-1., maxvals=1.)
|
||||
new_seed = random_ops.random_uniform([2],
|
||||
seed=1234,
|
||||
minval=0,
|
||||
maxval=(2**31 - 1),
|
||||
dtype=np.int32)
|
||||
sample_op_stateless = stateless.stateless_parameterized_truncated_normal(
|
||||
shape=(int(1e5),),
|
||||
seed=new_seed,
|
||||
means=0.8,
|
||||
stddevs=0.05,
|
||||
minvals=-1.,
|
||||
maxvals=1.)
|
||||
|
||||
with self.session(use_gpu=True) as sess:
|
||||
samples = sess.run(sample_op)
|
||||
samples, samples_stateless = sess.run([sample_op, sample_op_stateless])
|
||||
# 0. is more than 16 standard deviations from the mean, and
|
||||
# should have a likelihood < 1e-57.
|
||||
assert (~np.isnan(samples)).all()
|
||||
no_neg_samples = np.sum(samples < 0.)
|
||||
self.assertEqual(no_neg_samples, 0.)
|
||||
assert (~np.isnan(samples_stateless)).all()
|
||||
self.assertAllGreater(samples, 0.)
|
||||
self.assertAllGreater(samples_stateless, 0.)
|
||||
|
||||
def testStatelessParameterizedTruncatedNormalHasGrads(self):
|
||||
mean = variables.Variable(0.01)
|
||||
stddev = variables.Variable(1.)
|
||||
minval = variables.Variable(-1.)
|
||||
maxval = variables.Variable(1.)
|
||||
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
samples = stateless.stateless_parameterized_truncated_normal(
|
||||
[1], [1, 2], mean, stddev, minval, maxval)
|
||||
|
||||
sess.run(variables.variables_initializer([mean, stddev, minval, maxval]))
|
||||
[mean_grad, std_grad], mean_actual_grad, std_actual_grad = sess.run([
|
||||
tape.gradient(samples, [mean, stddev]),
|
||||
array_ops.ones_like(mean),
|
||||
(samples - mean) / stddev])
|
||||
self.assertAllClose(mean_grad, mean_actual_grad)
|
||||
self.assertAllClose(std_grad, std_actual_grad[0])
|
||||
|
||||
try:
|
||||
import scipy.stats # pylint:disable=g-import-not-at-top
|
||||
truncnorm = scipy.stats.truncnorm(a=-1., b=1., loc=0., scale=1.)
|
||||
samples_np, [minval_grad, maxval_grad] = sess.run([
|
||||
samples, tape.gradient(samples, [minval, maxval])])
|
||||
|
||||
sample_cdf = truncnorm.cdf(samples_np)
|
||||
# These come from the implicit reparameterization trick.
|
||||
scipy_maxval_grad = np.exp(
|
||||
0.5 * (samples_np ** 2 - ((1. - 0.01) / 1.) ** 2) +
|
||||
np.log(sample_cdf))
|
||||
|
||||
scipy_minval_grad = np.exp(
|
||||
0.5 * (samples_np ** 2 - ((-1. - 0.01) / 1.) ** 2) +
|
||||
np.log1p(-sample_cdf))
|
||||
|
||||
self.assertAllClose(minval_grad, scipy_minval_grad[0], rtol=1e-2)
|
||||
self.assertAllClose(maxval_grad, scipy_maxval_grad[0], rtol=1e-2)
|
||||
|
||||
except ImportError as e:
|
||||
tf_logging.warn("Cannot test truncated normal op: %s" % str(e))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSamplingAtRandnSwitchover(self):
|
||||
@ -239,18 +362,33 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
|
||||
|
||||
epsilon = 0.001
|
||||
self.validateMoments(
|
||||
shape=[10**6],
|
||||
shape=[int(1e6)],
|
||||
mean=0.,
|
||||
stddev=1.0,
|
||||
minval=-epsilon,
|
||||
maxval=stddev_inside_bounds_before_using_randn - epsilon)
|
||||
self.validateMoments(
|
||||
shape=[10**6],
|
||||
shape=[int(1e6)],
|
||||
mean=0.,
|
||||
stddev=1.0,
|
||||
minval=-epsilon,
|
||||
maxval=stddev_inside_bounds_before_using_randn + epsilon)
|
||||
|
||||
self.validateMoments(
|
||||
shape=[int(1e6)],
|
||||
mean=0.,
|
||||
stddev=1.0,
|
||||
minval=-epsilon,
|
||||
maxval=stddev_inside_bounds_before_using_randn - epsilon,
|
||||
use_stateless=True)
|
||||
self.validateMoments(
|
||||
shape=[int(1e6)],
|
||||
mean=0.,
|
||||
stddev=1.0,
|
||||
minval=-epsilon,
|
||||
maxval=stddev_inside_bounds_before_using_randn + epsilon,
|
||||
use_stateless=True)
|
||||
|
||||
|
||||
# Benchmarking code
|
||||
def parameterized_vs_naive(shape, num_iters, use_gpu=False):
|
@ -18,9 +18,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_random_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
@ -114,3 +119,118 @@ def _StatelessRandomGammaV2Grad(op, grad): # pylint: disable=invalid-name
|
||||
return (None, None,
|
||||
math_ops.reduce_sum(
|
||||
grad * partial_a, axis=math_ops.range(num_sample_dimensions)))
|
||||
|
||||
|
||||
def _Ndtr(x):
|
||||
"""Normal distribution function."""
|
||||
half_sqrt_2 = constant_op.constant(
|
||||
0.5 * np.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
|
||||
w = x * half_sqrt_2
|
||||
z = math_ops.abs(w)
|
||||
y = array_ops.where(
|
||||
z < half_sqrt_2,
|
||||
1. + math_ops.erf(w),
|
||||
array_ops.where(
|
||||
w > 0., 2. - math_ops.erfc(z), math_ops.erfc(z)))
|
||||
return 0.5 * y
|
||||
|
||||
|
||||
@ops.RegisterGradient("StatelessParameterizedTruncatedNormal")
|
||||
def _StatelessParameterizedTruncatedNormalGrad(op, grad): # pylint: disable=invalid-name
|
||||
"""Returns the gradient of a TruncatedNormal sample w.r.t. parameters.
|
||||
|
||||
The gradient is computed using implicit differentiation
|
||||
(Figurnov et al., 2018).
|
||||
|
||||
Args:
|
||||
op: A `StatelessParameterizedTruncatedNormal` operation. We assume that the
|
||||
inputs to the operation are `shape`, `seed`, `mean`, `stddev`, `minval`,
|
||||
and `maxval` tensors, and the output is the `sample` tensor.
|
||||
grad: The incoming gradient `dloss / dsample` of the same shape as
|
||||
`op.outputs[0]`.
|
||||
|
||||
Returns:
|
||||
A list of `Tensor` with derivates with respect to each parameter.
|
||||
|
||||
References:
|
||||
Implicit Reparameterization Gradients:
|
||||
[Figurnov et al., 2018]
|
||||
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
|
||||
([pdf]
|
||||
(http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
|
||||
"""
|
||||
shape = op.inputs[0]
|
||||
mean = op.inputs[2]
|
||||
stddev = op.inputs[3]
|
||||
minval = op.inputs[4]
|
||||
maxval = op.inputs[5]
|
||||
sample = op.outputs[0]
|
||||
|
||||
with ops.control_dependencies([grad]):
|
||||
minval_std = (minval - mean) / stddev
|
||||
maxval_std = (maxval - mean) / stddev
|
||||
sample_std = (sample - mean) / stddev
|
||||
|
||||
cdf_sample = (_Ndtr(sample_std) - _Ndtr(minval_std)) / (
|
||||
_Ndtr(maxval_std) - _Ndtr(minval_std))
|
||||
|
||||
# Clip to avoid zero argument for log_cdf expression
|
||||
tiny = np.finfo(mean.dtype.as_numpy_dtype).tiny
|
||||
eps = np.finfo(mean.dtype.as_numpy_dtype).eps
|
||||
cdf_sample = clip_ops.clip_by_value(cdf_sample, tiny, 1 - eps)
|
||||
|
||||
dmaxval = math_ops.exp(0.5 * (sample_std ** 2 - maxval_std ** 2) +
|
||||
math_ops.log(cdf_sample))
|
||||
dminval = math_ops.exp(0.5 * (sample_std ** 2 - minval_std ** 2) +
|
||||
math_ops.log1p(-cdf_sample))
|
||||
dmean = array_ops.ones_like(sample_std)
|
||||
dstddev = sample_std
|
||||
|
||||
# Reduce over extra dimensions caused by `shape`. We need to get the
|
||||
# difference in rank from shape vs. the broadcasted rank.
|
||||
|
||||
mean_shape = array_ops.shape(mean)
|
||||
stddev_shape = array_ops.shape(stddev)
|
||||
minval_shape = array_ops.shape(minval)
|
||||
maxval_shape = array_ops.shape(maxval)
|
||||
|
||||
broadcast_shape = array_ops.broadcast_dynamic_shape(
|
||||
mean_shape, stddev_shape)
|
||||
broadcast_shape = array_ops.broadcast_dynamic_shape(
|
||||
minval_shape, broadcast_shape)
|
||||
broadcast_shape = array_ops.broadcast_dynamic_shape(
|
||||
maxval_shape, broadcast_shape)
|
||||
extra_dims = math_ops.range(
|
||||
array_ops.size(shape) - array_ops.size(broadcast_shape))
|
||||
|
||||
grad_mean = math_ops.reduce_sum(grad * dmean, axis=extra_dims)
|
||||
grad_stddev = math_ops.reduce_sum(grad * dstddev, axis=extra_dims)
|
||||
grad_minval = math_ops.reduce_sum(grad * dminval, axis=extra_dims)
|
||||
grad_maxval = math_ops.reduce_sum(grad * dmaxval, axis=extra_dims)
|
||||
|
||||
_, rmean = gen_array_ops.broadcast_gradient_args(
|
||||
broadcast_shape, mean_shape)
|
||||
_, rstddev = gen_array_ops.broadcast_gradient_args(
|
||||
broadcast_shape, stddev_shape)
|
||||
_, rminval = gen_array_ops.broadcast_gradient_args(
|
||||
broadcast_shape, minval_shape)
|
||||
_, rmaxval = gen_array_ops.broadcast_gradient_args(
|
||||
broadcast_shape, maxval_shape)
|
||||
|
||||
grad_mean = array_ops.reshape(
|
||||
math_ops.reduce_sum(grad_mean, axis=rmean, keepdims=True), mean_shape)
|
||||
|
||||
grad_stddev = array_ops.reshape(
|
||||
math_ops.reduce_sum(grad_stddev, axis=rstddev, keepdims=True),
|
||||
stddev_shape)
|
||||
|
||||
grad_minval = array_ops.reshape(
|
||||
math_ops.reduce_sum(grad_minval, axis=rminval, keepdims=True),
|
||||
minval_shape)
|
||||
|
||||
grad_maxval = array_ops.reshape(
|
||||
math_ops.reduce_sum(grad_maxval, axis=rmaxval, keepdims=True),
|
||||
maxval_shape)
|
||||
|
||||
# The first two inputs are shape.
|
||||
return (None, None, grad_mean, grad_stddev, grad_minval, grad_maxval)
|
||||
|
@ -618,3 +618,73 @@ def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed):
|
||||
logits = ops.convert_to_tensor(logits, name="logits")
|
||||
return gen_stateless_random_ops.stateless_multinomial(
|
||||
logits, num_samples, seed, output_dtype=dtype)
|
||||
|
||||
|
||||
@dispatch.add_dispatch_support
|
||||
@tf_export("random.stateless_parameterized_truncated_normal")
|
||||
def stateless_parameterized_truncated_normal(shape,
|
||||
seed,
|
||||
means=0.0,
|
||||
stddevs=1.0,
|
||||
minvals=-2.0,
|
||||
maxvals=2.0,
|
||||
name=None):
|
||||
"""Outputs random values from a truncated normal distribution.
|
||||
|
||||
The generated values follow a normal distribution with specified mean and
|
||||
standard deviation, except that values whose magnitude is more than 2 standard
|
||||
deviations from the mean are dropped and re-picked.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
Sample from a Truncated normal, with deferring shape parameters that
|
||||
broadcast.
|
||||
|
||||
>>> means = 0.
|
||||
>>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3]))
|
||||
>>> minvals = [-1., -2., -1000.]
|
||||
>>> maxvals = [[10000.], [1.]]
|
||||
>>> y = tf.random.stateless_parameterized_truncated_normal(
|
||||
... shape=[10, 2, 3], seed=[7, 17],
|
||||
... means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals)
|
||||
>>> y.shape
|
||||
TensorShape([10, 2, 3])
|
||||
|
||||
Args:
|
||||
shape: A 1-D integer `Tensor` or Python array. The shape of the output
|
||||
tensor.
|
||||
seed: A shape [2] Tensor, the seed to the random number generator. Must have
|
||||
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
|
||||
means: A `Tensor` or Python value of type `dtype`. The mean of the truncated
|
||||
normal distribution. This must broadcast with `stddevs`, `minvals` and
|
||||
`maxvals`, and the broadcasted shape must be dominated by `shape`.
|
||||
stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation
|
||||
of the truncated normal distribution. This must broadcast with `means`,
|
||||
`minvals` and `maxvals`, and the broadcasted shape must be dominated by
|
||||
`shape`.
|
||||
minvals: A `Tensor` or Python value of type `dtype`. The minimum value of
|
||||
the truncated normal distribution. This must broadcast with `means`,
|
||||
`stddevs` and `maxvals`, and the broadcasted shape must be dominated by
|
||||
`shape`.
|
||||
maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of
|
||||
the truncated normal distribution. This must broadcast with `means`,
|
||||
`stddevs` and `minvals`, and the broadcasted shape must be dominated by
|
||||
`shape`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A tensor of the specified shape filled with random truncated normal values.
|
||||
"""
|
||||
with ops.name_scope(name, "stateless_parameterized_truncated_normal",
|
||||
[shape, means, stddevs, minvals, maxvals]) as name:
|
||||
shape_tensor = tensor_util.shape_tensor(shape)
|
||||
means_tensor = ops.convert_to_tensor(means, name="means")
|
||||
stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs")
|
||||
minvals_tensor = ops.convert_to_tensor(minvals, name="minvals")
|
||||
maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals")
|
||||
rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal(
|
||||
shape_tensor, seed, means_tensor, stddevs_tensor, minvals_tensor,
|
||||
maxvals_tensor)
|
||||
tensor_util.maybe_set_static_shape(rnd, shape)
|
||||
return rnd
|
||||
|
@ -92,6 +92,10 @@ tf_module {
|
||||
name: "stateless_normal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_parameterized_truncated_normal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'-2.0\', \'2.0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_poisson"
|
||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||
|
@ -4492,6 +4492,10 @@ tf_module {
|
||||
name: "StatelessMultinomial"
|
||||
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessParameterizedTruncatedNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomBinomial"
|
||||
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
|
@ -80,6 +80,10 @@ tf_module {
|
||||
name: "stateless_normal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_parameterized_truncated_normal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'-2.0\', \'2.0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_poisson"
|
||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||
|
@ -4492,6 +4492,10 @@ tf_module {
|
||||
name: "StatelessMultinomial"
|
||||
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessParameterizedTruncatedNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomBinomial"
|
||||
argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user