From 455750f3623b15a3b5d46c11d4c5102e9388dbda Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Thu, 18 Jun 2020 10:23:23 -0700 Subject: [PATCH] Add StatelessParameterizedTruncatedNormal sampler. This sampler supports broadcasting of its input parameters as well as puts the # samples at the left of the output shape, rather than the right. PiperOrigin-RevId: 317129622 Change-Id: I4b62ad2e89a9637ae8b30b73af4b662ad0caa943 --- ...tatelessParameterizedTruncatedNormal.pbtxt | 54 +++ tensorflow/core/kernels/BUILD | 1 + .../parameterized_truncated_normal_op.cc | 435 +++++++++++++++++- .../parameterized_truncated_normal_op.h | 16 + tensorflow/core/ops/stateless_random_ops.cc | 35 ++ .../eager/pywrap_gradient_exclusions.cc | 3 +- tensorflow/python/kernel_tests/BUILD | 18 - tensorflow/python/kernel_tests/random/BUILD | 18 + .../parameterized_truncated_normal_op_test.py | 198 ++++++-- tensorflow/python/ops/random_grad.py | 120 +++++ tensorflow/python/ops/stateless_random_ops.py | 70 +++ .../api/golden/v1/tensorflow.random.pbtxt | 4 + .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.random.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + 15 files changed, 913 insertions(+), 71 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_StatelessParameterizedTruncatedNormal.pbtxt rename tensorflow/python/kernel_tests/{ => random}/parameterized_truncated_normal_op_test.py (63%) diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessParameterizedTruncatedNormal.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessParameterizedTruncatedNormal.pbtxt new file mode 100644 index 00000000000..15bd4670cef --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatelessParameterizedTruncatedNormal.pbtxt @@ -0,0 +1,54 @@ +op { + graph_op_name: "StatelessParameterizedTruncatedNormal" + visibility: HIDDEN + in_arg { + name: "shape" + description: < { const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3); auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); - auto DoWork = [samples_per_batch, num_elements, &ctx, &means, &stddevs, - &minvals, &maxvals, &gen, &output, - kStdDevsInsideBoundsToUseRandnSampler](int start_batch, - int limit_batch) { + auto do_work = [samples_per_batch, num_elements, &ctx, &means, &stddevs, + &minvals, &maxvals, &gen, &output, + kStdDevsInsideBoundsToUseRandnSampler](int start_batch, + int limit_batch) { // Capturing "gen" by-value would only make a copy for the _shared_ // lambda. Since we want to let each worker have its own copy, we pass // "gen" by reference and explicitly do a copy assignment here. @@ -80,9 +81,9 @@ struct TruncatedNormalFunctor { // The sample from each iteration uses 2 random numbers. gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) / 4); - typedef random::UniformDistribution Uniform; + using Uniform = random::UniformDistribution; Uniform dist; - typedef random::NormalDistribution Normal; + using Normal = random::NormalDistribution; Normal normal_dist; // Vectorized intermediate calculations for uniform rejection sampling. @@ -112,7 +113,7 @@ struct TruncatedNormalFunctor { Eigen::numext::isfinite(maxval)), errors::InvalidArgument("Invalid parameters")); - int numIterations = 0; + int num_iterations = 0; // If possible, make one-sided bound be the lower bound, or make both // bounds positive. Otherwise, the bounds are on either side of the @@ -160,10 +161,10 @@ struct TruncatedNormalFunctor { if (sample >= limit_sample) { break; } - numIterations = 0; + num_iterations = 0; } else { - numIterations++; - if (numIterations > kMaxIterations) { + num_iterations++; + if (num_iterations > kMaxIterations) { // This should never occur because this sampler should // (by the selection criteria above) be used if at least 3 // standard deviations of one side of the distribution @@ -201,7 +202,7 @@ struct TruncatedNormalFunctor { const auto u = dist(&gen_copy); for (int i = 0; i < size; i++) { auto accept = u[i] <= Eigen::numext::exp(g[i]); - if (accept || numIterations + 1 >= kMaxIterations) { + if (accept || num_iterations + 1 >= kMaxIterations) { // Accept the sample z. // If we run out of iterations, just use the current uniform // sample, but emit a warning. @@ -223,9 +224,9 @@ struct TruncatedNormalFunctor { if (sample >= limit_sample) { break; } - numIterations = 0; + num_iterations = 0; } else { - numIterations++; + num_iterations++; } } } @@ -248,7 +249,7 @@ struct TruncatedNormalFunctor { const T u = rand[i]; i++; auto accept = (u <= g && z < normMax); - if (accept || numIterations + 1 >= kMaxIterations) { + if (accept || num_iterations + 1 >= kMaxIterations) { if (!accept) { LOG(ERROR) << "TruncatedNormal exponential distribution " << "rejection sampler exceeds max iterations. " @@ -263,9 +264,9 @@ struct TruncatedNormalFunctor { if (sample >= limit_sample) { break; } - numIterations = 0; + num_iterations = 0; } else { - numIterations++; + num_iterations++; } } } @@ -305,7 +306,297 @@ struct TruncatedNormalFunctor { const int64 batchCost = batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch; Shard(worker_threads.num_threads, worker_threads.workers, num_batches, - batchCost, DoWork); + batchCost, do_work); + } +}; + +template +struct TruncatedNormalFunctorV2 { + void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches, + int64 samples_per_batch, int64 num_elements, + const BCastList<4>& bcast, + typename TTypes::ConstFlat means, + typename TTypes::ConstFlat stddevs, + typename TTypes::ConstFlat minvals, + typename TTypes::ConstFlat maxvals, + const random::PhiloxRandom& gen, + typename TTypes::Flat output) { + // The randn rejection sampling is used when the mean and at least this many + // standard deviations are inside the bounds. + // The uniform proposal samplers become less efficient as the bounds are + // further from the mean, the reverse is true for the randn sampler. + // This number was chosen by empirical benchmarking. If modified, the + // benchmarks in parameterized_truncated_normal_op_test should also be + // changed. + const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3); + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + + auto do_work = [num_batches, samples_per_batch, &ctx, &bcast, &means, + &stddevs, &minvals, &maxvals, &gen, &output, + kStdDevsInsideBoundsToUseRandnSampler](int start_output, + int limit_output) { + // Capturing "gen" by-value would only make a copy for the _shared_ + // lambda. Since we want to let each worker have its own copy, we pass + // "gen" by reference and explicitly do a copy assignment here. + random::PhiloxRandom gen_copy = gen; + using Uniform = random::UniformDistribution; + Uniform dist; + using Normal = random::NormalDistribution; + Normal normal_dist; + // Skip takes units of 128 bits. The Uniform::kResultElementCount - 1 + // is so rounding doesn't lead to + // us using the same state in different workloads. + // The sample from each iteration uses 2 random numbers. + gen_copy.Skip((start_output * 2 * kMaxIterations + + Uniform::kResultElementCount - 1) / + Uniform::kResultElementCount); + + // Vectorized intermediate calculations for uniform rejection sampling. + // We always generate at most 4 samples. + Eigen::array z; + Eigen::array g; + + const bool should_bcast = bcast.IsBroadcastingRequired(); + const auto& means_batch_indices = bcast.batch_indices(0); + const auto& stddevs_batch_indices = bcast.batch_indices(1); + const auto& minvals_batch_indices = bcast.batch_indices(2); + const auto& maxvals_batch_indices = bcast.batch_indices(3); + auto output_flat = output.data(); + + // We partition work across batches and then across samples + // per batch member, to avoid extra work. + for (int64 output_idx = start_output; output_idx < limit_output; + // output_idx is incremented with the inner loops below. + ) { + int64 batch_idx = output_idx / samples_per_batch; + // The output layout is [samples_per_batch, num_batches]. Thus + // the output address is sample_idx * num_batches + batch_idx. + // Below, code will index at output_batch_offset[sample_idx * + // num_batches] matching this. + T* const output_batch_offset = output_flat + batch_idx; + // Generate batch counts from BCast, as it has the right indices to loop + // over. + T mean, stddev, minval, maxval; + if (should_bcast) { + mean = means(means_batch_indices[batch_idx]); + stddev = stddevs(stddevs_batch_indices[batch_idx]); + minval = minvals(minvals_batch_indices[batch_idx]); + maxval = maxvals(maxvals_batch_indices[batch_idx]); + } else { + mean = means(batch_idx); + stddev = stddevs(batch_idx); + minval = minvals(batch_idx); + maxval = maxvals(batch_idx); + } + + // On GPU, this check will just fill samples with NAN if it fails. + OP_REQUIRES(ctx, + stddev > T(0) && minval < maxval && + (Eigen::numext::isfinite(minval) || + Eigen::numext::isfinite(maxval)), + errors::InvalidArgument("Invalid parameters")); + + int num_iterations = 0; + + // If possible, make one-sided bound be the lower bound, or make both + // bounds positive. Otherwise, the bounds are on either side of the + // mean. + if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) { + // Reverse all calculations. normMin and normMax will be flipped. + std::swap(minval, maxval); + stddev = -stddev; + } + + // Calculate normalized samples, then convert them. + const T normMin = (minval - mean) / stddev; + const T normMax = (maxval - mean) / stddev; + + // Determine the method to use. + const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4)); + const T cutoff = + T(2) * + Eigen::numext::exp(T(0.5) + + (normMin * (normMin - sqrtFactor)) / T(4)) / + (normMin + sqrtFactor); + const T diff = normMax - normMin; + + if (((normMin < -kStdDevsInsideBoundsToUseRandnSampler) && + (normMax >= T(0.))) || + ((normMax > kStdDevsInsideBoundsToUseRandnSampler) && + (normMin <= T(0.)))) { + // If the bounds are a least 3 standard deviations from the mean + // on at least one side then we rejection sample by sampling + // from the normal distribution and rejecting samples outside + // the bounds. + // Under this condition the acceptance rate per iteration should + // always be ~ 50%. This sampler is more efficient (and more + // numerically stable when one or both bounds is far from the mean). + for (int64 sample_idx = output_idx % samples_per_batch; + sample_idx < samples_per_batch && output_idx < limit_output;) { + const auto randn_sample = normal_dist(&gen_copy); + const int size = randn_sample.size(); + for (int i = 0; i < size; ++i) { + if ((randn_sample[i] >= normMin) && + (randn_sample[i] <= normMax)) { + output_batch_offset[sample_idx * num_batches] = + randn_sample[i] * stddev + mean; + ++sample_idx; + ++output_idx; + if (sample_idx >= samples_per_batch || + output_idx >= limit_output) { + break; + } + num_iterations = 0; + } else { + ++num_iterations; + if (num_iterations > kMaxIterations) { + // This should never occur because this sampler should + // (by the selection criteria above) be used if at least 3 + // standard deviations of one side of the distribution + // is within the limits (so acceptance probability per + // iterations >~ 1/2 per iteration). + LOG(ERROR) << "TruncatedNormal randn rejection sampler " + << "exceeded maximum iterations for " + << "normMin=" << normMin << " normMax=" << normMax + << " kMaxIterations=" << kMaxIterations; + ctx->SetStatus(errors::Internal( + "TruncatedNormal randn rejection sampler failed to accept" + " a sample.")); + return; + } + } + } + } + } else if (diff < cutoff) { + // Sample from a uniform distribution on [normMin, normMax]. + + const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin; + + for (int64 sample_idx = output_idx % samples_per_batch; + sample_idx < samples_per_batch && output_idx < limit_output;) { + const auto rand = dist(&gen_copy); + const int size = rand.size(); + // NOTE(ringwalt): These loops seem to only generate packed AVX + // instructions for float32. + for (int i = 0; i < size; i++) { + z[i] = rand[i] * diff + normMin; + g[i] = (plusFactor - z[i] * z[i]) / T(2.0); + } + + const auto u = dist(&gen_copy); + for (int i = 0; i < size; i++) { + auto accept = u[i] <= Eigen::numext::exp(g[i]); + if (accept || num_iterations + 1 >= kMaxIterations) { + // Accept the sample z. + // If we run out of iterations, just use the current uniform + // sample, but emit a warning. + // TODO(jjhunt) For small entropies (relative to the bounds), + // this sampler is poor and may take many iterations since + // the proposal distribution is the uniform distribution + // U(lower_bound, upper_bound). + if (!accept) { + LOG(ERROR) << "TruncatedNormal uniform rejection sampler " + << "exceeded max iterations. Sample may contain " + << "outliers."; + ctx->SetStatus(errors::Internal( + "TruncatedNormal uniform rejection sampler failed to " + " accept a sample.")); + return; + } + output_batch_offset[sample_idx * num_batches] = + z[i] * stddev + mean; + ++sample_idx; + ++output_idx; + if (sample_idx >= samples_per_batch || + output_idx >= limit_output) { + break; + } + num_iterations = 0; + } else { + num_iterations++; + } + } + } + } else { + // Sample from an exponential distribution with alpha maximizing + // acceptance probability, offset by normMin from the origin. + // Accept only if less than normMax. + const T alpha = + (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / + T(2); + for (int64 sample_idx = output_idx % samples_per_batch; + sample_idx < samples_per_batch && output_idx < limit_output;) { + auto rand = dist(&gen_copy); + const int size = rand.size(); + int i = 0; + while (i < size) { + const T z = -Eigen::numext::log(rand[i]) / alpha + normMin; + i++; + const T x = normMin < alpha ? alpha - z : normMin - alpha; + const T g = Eigen::numext::exp(-x * x / T(2.0)); + const T u = rand[i]; + i++; + auto accept = (u <= g && z < normMax); + if (accept || num_iterations + 1 >= kMaxIterations) { + if (!accept) { + LOG(ERROR) << "TruncatedNormal exponential distribution " + << "rejection sampler exceeds max iterations. " + << "Sample may contain outliers."; + ctx->SetStatus(errors::Internal( + "TruncatedNormal exponential distribution rejection" + " sampler failed to accept a sample.")); + return; + } + output_batch_offset[sample_idx * num_batches] = + z * stddev + mean; + ++sample_idx; + ++output_idx; + if (sample_idx >= samples_per_batch || + output_idx >= limit_output) { + break; + } + num_iterations = 0; + } else { + num_iterations++; + } + } + } + } + } + }; + // The cost of the initial calculations for the batch. + const int64 batchInitCost = + // normMin, normMax + (Eigen::TensorOpCost::AddCost() + + Eigen::TensorOpCost::MulCost()) * + 2 + // sqrtFactor + + Eigen::TensorOpCost::AddCost() + + Eigen::TensorOpCost::MulCost() + + Eigen::internal::functor_traits< + Eigen::internal::scalar_sqrt_op>::Cost + // cutoff + + Eigen::TensorOpCost::MulCost() * 4 + + Eigen::internal::functor_traits>::Cost + // diff + + Eigen::TensorOpCost::AddCost(); + const int64 uniformSampleCost = + random::PhiloxRandom::kElementCost + + random::UniformDistribution::kElementCost; + // The cost of a single uniform sampling round. + const int64 uniformRejectionSamplingCost = + uniformSampleCost + Eigen::TensorOpCost::MulCost() + + Eigen::TensorOpCost::AddCost() + + Eigen::TensorOpCost::MulCost() * 2 + + Eigen::TensorOpCost::AddCost() + uniformSampleCost + + Eigen::internal::functor_traits< + Eigen::internal::scalar_exp_op>::Cost + + Eigen::TensorOpCost::MulCost() + Eigen::TensorOpCost::AddCost(); + // Estimate the cost for an entire batch. + // Assume we use uniform sampling, and accept the 2nd sample on average. + const int64 batchCost = batchInitCost + uniformRejectionSamplingCost * 2; + Shard(worker_threads.num_threads, worker_threads.workers, num_elements, + batchCost, do_work); } }; @@ -436,13 +727,113 @@ class ParameterizedTruncatedNormalOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp); }; +// Samples from a truncated normal distribution, using the given parameters. +template +class StatelessParameterizedTruncatedNormal : public OpKernel { + // Reshape batches so each batch is this size if possible. + static const int32 kDesiredBatchSize = 100; + + public: + explicit StatelessParameterizedTruncatedNormal(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& shape_tensor = ctx->input(0); + const Tensor& seed_tensor = ctx->input(1); + const Tensor& means_tensor = ctx->input(2); + const Tensor& stddevs_tensor = ctx->input(3); + const Tensor& minvals_tensor = ctx->input(4); + const Tensor& maxvals_tensor = ctx->input(5); + + OP_REQUIRES(ctx, seed_tensor.dims() == 1 && seed_tensor.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_tensor.shape().DebugString())); + + tensorflow::BCastList<4> bcast( + {means_tensor.shape().dim_sizes(), stddevs_tensor.shape().dim_sizes(), + minvals_tensor.shape().dim_sizes(), + maxvals_tensor.shape().dim_sizes()}, + /*fewer_dims_optimization=*/false, + /*return_flattened_batch_indices=*/true); + + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument( + "means, stddevs, minvals, maxvals must have compatible " + "batch dimensions: ", + means_tensor.shape().DebugString(), " vs. ", + stddevs_tensor.shape().DebugString(), " vs. ", + minvals_tensor.shape().DebugString(), " vs. ", + maxvals_tensor.shape().DebugString())); + + // Let's check that the shape tensor dominates the broadcasted tensor. + TensorShape bcast_shape = BCast::ToShape(bcast.output_shape()); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), + errors::InvalidArgument("Input shape should be a vector, got shape: ", + shape_tensor.shape().DebugString())); + TensorShape output_shape; + if (shape_tensor.dtype() == DataType::DT_INT32) { + OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec(), + &output_shape)); + } else { + OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec(), + &output_shape)); + } + OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape), + errors::InvalidArgument( + "Shape passed in must end with broadcasted shape.")); + + int64 samples_per_batch = 1; + const int64 num_sample_dims = + (shape_tensor.dim_size(0) - bcast.output_shape().size()); + for (int64 i = 0; i < num_sample_dims; ++i) { + samples_per_batch *= output_shape.dim_size(i); + } + int64 num_batches = 1; + for (int64 i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) { + num_batches *= output_shape.dim_size(i); + } + const int64 num_elements = num_batches * samples_per_batch; + + Tensor* samples_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor)); + + auto truncFunctor = functor::TruncatedNormalFunctorV2(); + // Each worker has the same fudge factor, so use it here. + random::PhiloxRandom::Key key; + random::PhiloxRandom::ResultType counter; + OP_REQUIRES_OK(ctx, GenerateKey(seed_tensor, &key, &counter)); + + auto philox = random::PhiloxRandom(counter, key); + + truncFunctor(ctx, ctx->eigen_device(), num_batches, + samples_per_batch, num_elements, bcast, means_tensor.flat(), + stddevs_tensor.flat(), minvals_tensor.flat(), + maxvals_tensor.flat(), philox, samples_tensor->flat()); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StatelessParameterizedTruncatedNormal); +}; + } // namespace -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("dtype"), \ - ParameterizedTruncatedNormalOp) +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + ParameterizedTruncatedNormalOp) \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessParameterizedTruncatedNormal") \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .HostMemory("means") \ + .HostMemory("stddevs") \ + .HostMemory("minvals") \ + .HostMemory("maxvals") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + StatelessParameterizedTruncatedNormal) TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.h b/tensorflow/core/kernels/parameterized_truncated_normal_op.h index c919a22c7b0..ee7fb7bf605 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op.h +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/util/bcast.h" namespace tensorflow { @@ -44,6 +45,21 @@ struct TruncatedNormalFunctor { typename TTypes::Flat output); }; +// This version supports broadcasting of the arguments, as well as puts +// the sample dimension on the left. +template +struct TruncatedNormalFunctorV2 { + void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches, + int64 samples_per_batch, int64 num_elements, + const BCastList<4>& bcast, + typename TTypes::ConstFlat means, + typename TTypes::ConstFlat stddevs, + typename TTypes::ConstFlat minvals, + typename TTypes::ConstFlat maxvals, + const random::PhiloxRandom& gen, + typename TTypes::Flat output); +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc index d540b9a04d9..e1820ea4feb 100644 --- a/tensorflow/core/ops/stateless_random_ops.cc +++ b/tensorflow/core/ops/stateless_random_ops.cc @@ -124,6 +124,41 @@ REGISTER_OP("StatelessRandomBinomial") .Attr("dtype: {half, float, double, int32, int64} = DT_INT64") .SetShapeFn(StatelessShape); +REGISTER_OP("StatelessParameterizedTruncatedNormal") + .Input("shape: S") + .Input("seed: Tseed") + .Input("means: dtype") + .Input("stddevs: dtype") + .Input("minvals: dtype") + .Input("maxvals: dtype") + .Output("output: dtype") + .Attr("S: {int32, int64}") + .Attr("Tseed: {int32, int64} = DT_INT64") + .Attr("dtype: {float16, float32, float64}") + .SetShapeFn([](InferenceContext* c) { + // Check seed shape + ShapeHandle seed; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused_dim)); + + ShapeHandle bcast_means_stddevs; + ShapeHandle bcast_except_maxvals; + ShapeHandle bcast_all; + TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( + c, c->input(2), c->input(3), true, &bcast_means_stddevs)); + TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( + c, c->input(4), bcast_means_stddevs, true, &bcast_except_maxvals)); + TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( + c, c->input(5), bcast_except_maxvals, true, &bcast_all)); + + // Set output shape + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); + return Status::OK(); + }); + REGISTER_OP("StatelessRandomPoisson") .Input("shape: T") .Input("seed: Tseed") diff --git a/tensorflow/python/eager/pywrap_gradient_exclusions.cc b/tensorflow/python/eager/pywrap_gradient_exclusions.cc index a7c7ab7abc7..7da45e36118 100644 --- a/tensorflow/python/eager/pywrap_gradient_exclusions.cc +++ b/tensorflow/python/eager/pywrap_gradient_exclusions.cc @@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) { absl::optional> OpGradientUnusedInputIndices( const tensorflow::string &op_name) { - static std::array a = {{ + static std::array a = {{ {"Acosh"}, {"AllToAll", 1, {0}}, {"ApproximateEqual"}, @@ -326,6 +326,7 @@ absl::optional> OpGradientUnusedInputIndices( {"StackPop"}, {"StackPush"}, {"StatelessMultinomial"}, + {"StatelessParameterizedTruncatedNormal", 1, {1}}, {"StatelessRandomBinomial"}, {"StatelessRandomGammaV2", 1, {1}}, {"StatelessRandomNormal"}, diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index f2c614974f5..f93bf5cd1ae 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -785,24 +785,6 @@ cuda_py_test( ], ) -cuda_py_test( - name = "parameterized_truncated_normal_op_test", - size = "medium", - srcs = ["parameterized_truncated_normal_op_test.py"], - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", - "//tensorflow/python:random_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - tf_py_test( name = "parsing_ops_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD index b5d291d2973..6e404b4cd5f 100644 --- a/tensorflow/python/kernel_tests/random/BUILD +++ b/tensorflow/python/kernel_tests/random/BUILD @@ -20,6 +20,24 @@ py_library( ], ) +cuda_py_test( + name = "parameterized_truncated_normal_op_test", + size = "medium", + srcs = ["parameterized_truncated_normal_op_test.py"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + tf_py_test( name = "random_shuffle_queue_test", size = "small", diff --git a/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py b/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py similarity index 63% rename from tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py rename to tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py index ac8ad7a2bd4..309c3e404db 100644 --- a/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py +++ b/tensorflow/python/kernel_tests/random/parameterized_truncated_normal_op_test.py @@ -27,11 +27,15 @@ from six.moves import range # pylint: disable=redefined-builtin from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops as stateless +from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -91,13 +95,8 @@ class TruncatedNormalMoments(object): def calculate_moments(samples, max_moment): moments = [0.0] * (max_moment + 1) - for sample in samples: - value = 1.0 - for k in range(len(moments)): - moments[k] += value - value *= sample - for i in range(len(moments)): - moments[i] /= len(samples) + for k in range(len(moments)): + moments[k] = np.mean(samples**k, axis=0) return moments @@ -118,16 +117,31 @@ class ParameterizedTruncatedNormalTest(test.TestCase): # Stop at moment 10 to avoid numerical errors in the theoretical moments. max_moment = 10 - def validateMoments(self, shape, mean, stddev, minval, maxval, seed=1618): + def validateMoments(self, + shape, + mean, + stddev, + minval, + maxval, + use_stateless=False, + seed=1618): try: # TruncatedNormalMoments requires scipy.stats. # Give up early if we are unable to import it. - import scipy.stats # pylint: disable=g-import-not-at-top,unused-variable random_seed.set_random_seed(seed) with self.cached_session(use_gpu=True): - samples = random_ops.parameterized_truncated_normal(shape, mean, stddev, - minval, - maxval).eval() + if use_stateless: + # Generate a seed that stateless ops can use. + new_seed = random_ops.random_uniform([2], + seed=seed, + minval=0, + maxval=(2**31 - 1), + dtype=np.int32) + samples = stateless.stateless_parameterized_truncated_normal( + shape, new_seed, mean, stddev, minval, maxval).eval() + else: + samples = random_ops.parameterized_truncated_normal( + shape, mean, stddev, minval, maxval).eval() assert (~np.isnan(samples)).all() moments = calculate_moments(samples, self.max_moment) expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval) @@ -144,14 +158,24 @@ class ParameterizedTruncatedNormalTest(test.TestCase): stddev, minval, maxval, + use_stateless=False, seed=1618): try: import scipy.stats # pylint: disable=g-import-not-at-top random_seed.set_random_seed(seed) with self.cached_session(use_gpu=True): - samples = random_ops.parameterized_truncated_normal(shape, mean, stddev, - minval, - maxval).eval() + if use_stateless: + new_seed = random_ops.random_uniform([2], + seed=seed, + minval=0, + maxval=(2**31 - 1), + dtype=np.int32) + samples = stateless.stateless_parameterized_truncated_normal( + shape, new_seed, mean, stddev, minval, maxval).eval() + else: + samples = random_ops.parameterized_truncated_normal( + shape, mean, stddev, minval, maxval).eval() + assert (~np.isnan(samples)).all() minval = max(mean - stddev * 10, minval) maxval = min(mean + stddev * 10, maxval) @@ -169,61 +193,160 @@ class ParameterizedTruncatedNormalTest(test.TestCase): @test_util.run_deprecated_v1 def testDefaults(self): - self.validateMoments([10**5], 0.0, 1.0, -2.0, 2.0) + self.validateMoments([int(1e5)], 0.0, 1.0, -2.0, 2.0) + self.validateMoments([int(1e5)], 0.0, 1.0, -2.0, 2.0, use_stateless=True) @test_util.run_deprecated_v1 def testShifted(self): - self.validateMoments([10**5], -1.0, 1.0, -2.0, 2.0) + self.validateMoments([int(1e5)], -1.0, 1.0, -2.0, 2.0) + self.validateMoments([int(1e5)], -1.0, 1.0, -2.0, 2.0, use_stateless=True) @test_util.run_deprecated_v1 def testRightTail(self): - self.validateMoments([10**5], 0.0, 1.0, 4.0, np.infty) + self.validateMoments([int(1e5)], 0.0, 1.0, 4.0, np.infty) + self.validateMoments([int(1e5)], + 0.0, + 1.0, + 4.0, + np.infty, + use_stateless=True) @test_util.run_deprecated_v1 def testLeftTail(self): - self.validateMoments([10**5], 0.0, 1.0, -np.infty, -4.0) + self.validateMoments([int(1e5)], 0.0, 1.0, -np.infty, -4.0) + self.validateMoments([int(1e5)], + 0.0, + 1.0, + -np.infty, + -4.0, + use_stateless=True) @test_util.run_deprecated_v1 def testLeftTailTwoSidedBounds(self): - self.validateMoments([10**5], 0.0, 1.0, -6.0, -3.0) + self.validateMoments([int(1e5)], 0.0, 1.0, -6.0, -3.0) + self.validateMoments([int(1e5)], 0.0, 1.0, -6.0, -3.0, use_stateless=True) @test_util.run_deprecated_v1 @test_util.disable_xla("Low probability region") def testTwoSidedLeftTailShifted(self): - self.validateKolmogorovSmirnov([10**5], 6.0, 1.0, -1.0, 1.0) + self.validateKolmogorovSmirnov([int(1e5)], 6.0, 1.0, -1.0, 1.0) + self.validateKolmogorovSmirnov([int(1e5)], + 6.0, + 1.0, + -1.0, + 1.0, + use_stateless=True) @test_util.run_deprecated_v1 @test_util.disable_xla("Low probability region") def testRightTailShifted(self): - self.validateMoments([10**5], -5.0, 1.0, 2.0, np.infty) + self.validateMoments([int(1e5)], -5.0, 1.0, 2.0, np.infty) + self.validateMoments([int(1e5)], + -5.0, + 1.0, + 2.0, + np.infty, + use_stateless=True) # Take the normal distribution around the mean, but truncating the left tail # far from the mean. @test_util.run_deprecated_v1 def testTruncateOnLeft_entireTailOnRight(self): - self.validateKolmogorovSmirnov([10**5], 10.0, 1.0, 4.0, np.infty) + self.validateKolmogorovSmirnov([int(1e5)], 10.0, 1.0, 4.0, np.infty) + self.validateKolmogorovSmirnov([int(1e5)], + 10.0, + 1.0, + 4.0, + np.infty, + use_stateless=True) # Take the normal distribution around the mean, but truncating the right tail. @test_util.run_deprecated_v1 def testTruncateOnRight_entireTailOnLeft(self): - self.validateKolmogorovSmirnov([10**5], -8, 1.0, -np.infty, -4.0) + self.validateKolmogorovSmirnov([int(1e5)], -8, 1.0, -np.infty, -4.0) + self.validateKolmogorovSmirnov([int(1e5)], + -8., + 1.0, + -np.infty, + -4.0, + use_stateless=True) @test_util.run_deprecated_v1 def testSmallStddev(self): - self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10) + self.validateKolmogorovSmirnov([int(1e5)], 0.0, 0.1, 0.05, 0.10) + self.validateKolmogorovSmirnov([int(1e5)], + 0.0, + 0.1, + 0.05, + 0.10, + use_stateless=True) @test_util.run_deprecated_v1 def testSamplingWithSmallStdDevFarFromBound(self): sample_op = random_ops.parameterized_truncated_normal( shape=(int(1e5),), means=0.8, stddevs=0.05, minvals=-1., maxvals=1.) + new_seed = random_ops.random_uniform([2], + seed=1234, + minval=0, + maxval=(2**31 - 1), + dtype=np.int32) + sample_op_stateless = stateless.stateless_parameterized_truncated_normal( + shape=(int(1e5),), + seed=new_seed, + means=0.8, + stddevs=0.05, + minvals=-1., + maxvals=1.) with self.session(use_gpu=True) as sess: - samples = sess.run(sample_op) + samples, samples_stateless = sess.run([sample_op, sample_op_stateless]) # 0. is more than 16 standard deviations from the mean, and # should have a likelihood < 1e-57. assert (~np.isnan(samples)).all() - no_neg_samples = np.sum(samples < 0.) - self.assertEqual(no_neg_samples, 0.) + assert (~np.isnan(samples_stateless)).all() + self.assertAllGreater(samples, 0.) + self.assertAllGreater(samples_stateless, 0.) + + def testStatelessParameterizedTruncatedNormalHasGrads(self): + mean = variables.Variable(0.01) + stddev = variables.Variable(1.) + minval = variables.Variable(-1.) + maxval = variables.Variable(1.) + + with self.cached_session(use_gpu=True) as sess: + with backprop.GradientTape(persistent=True) as tape: + samples = stateless.stateless_parameterized_truncated_normal( + [1], [1, 2], mean, stddev, minval, maxval) + + sess.run(variables.variables_initializer([mean, stddev, minval, maxval])) + [mean_grad, std_grad], mean_actual_grad, std_actual_grad = sess.run([ + tape.gradient(samples, [mean, stddev]), + array_ops.ones_like(mean), + (samples - mean) / stddev]) + self.assertAllClose(mean_grad, mean_actual_grad) + self.assertAllClose(std_grad, std_actual_grad[0]) + + try: + import scipy.stats # pylint:disable=g-import-not-at-top + truncnorm = scipy.stats.truncnorm(a=-1., b=1., loc=0., scale=1.) + samples_np, [minval_grad, maxval_grad] = sess.run([ + samples, tape.gradient(samples, [minval, maxval])]) + + sample_cdf = truncnorm.cdf(samples_np) + # These come from the implicit reparameterization trick. + scipy_maxval_grad = np.exp( + 0.5 * (samples_np ** 2 - ((1. - 0.01) / 1.) ** 2) + + np.log(sample_cdf)) + + scipy_minval_grad = np.exp( + 0.5 * (samples_np ** 2 - ((-1. - 0.01) / 1.) ** 2) + + np.log1p(-sample_cdf)) + + self.assertAllClose(minval_grad, scipy_minval_grad[0], rtol=1e-2) + self.assertAllClose(maxval_grad, scipy_maxval_grad[0], rtol=1e-2) + + except ImportError as e: + tf_logging.warn("Cannot test truncated normal op: %s" % str(e)) @test_util.run_deprecated_v1 def testSamplingAtRandnSwitchover(self): @@ -239,18 +362,33 @@ class ParameterizedTruncatedNormalTest(test.TestCase): epsilon = 0.001 self.validateMoments( - shape=[10**6], + shape=[int(1e6)], mean=0., stddev=1.0, minval=-epsilon, maxval=stddev_inside_bounds_before_using_randn - epsilon) self.validateMoments( - shape=[10**6], + shape=[int(1e6)], mean=0., stddev=1.0, minval=-epsilon, maxval=stddev_inside_bounds_before_using_randn + epsilon) + self.validateMoments( + shape=[int(1e6)], + mean=0., + stddev=1.0, + minval=-epsilon, + maxval=stddev_inside_bounds_before_using_randn - epsilon, + use_stateless=True) + self.validateMoments( + shape=[int(1e6)], + mean=0., + stddev=1.0, + minval=-epsilon, + maxval=stddev_inside_bounds_before_using_randn + epsilon, + use_stateless=True) + # Benchmarking code def parameterized_vs_naive(shape, num_iters, use_gpu=False): diff --git a/tensorflow/python/ops/random_grad.py b/tensorflow/python/ops/random_grad.py index 771980932cb..3caa08d96f9 100644 --- a/tensorflow/python/ops/random_grad.py +++ b/tensorflow/python/ops/random_grad.py @@ -18,9 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import math_ops @@ -114,3 +119,118 @@ def _StatelessRandomGammaV2Grad(op, grad): # pylint: disable=invalid-name return (None, None, math_ops.reduce_sum( grad * partial_a, axis=math_ops.range(num_sample_dimensions))) + + +def _Ndtr(x): + """Normal distribution function.""" + half_sqrt_2 = constant_op.constant( + 0.5 * np.sqrt(2.), dtype=x.dtype, name="half_sqrt_2") + w = x * half_sqrt_2 + z = math_ops.abs(w) + y = array_ops.where( + z < half_sqrt_2, + 1. + math_ops.erf(w), + array_ops.where( + w > 0., 2. - math_ops.erfc(z), math_ops.erfc(z))) + return 0.5 * y + + +@ops.RegisterGradient("StatelessParameterizedTruncatedNormal") +def _StatelessParameterizedTruncatedNormalGrad(op, grad): # pylint: disable=invalid-name + """Returns the gradient of a TruncatedNormal sample w.r.t. parameters. + + The gradient is computed using implicit differentiation + (Figurnov et al., 2018). + + Args: + op: A `StatelessParameterizedTruncatedNormal` operation. We assume that the + inputs to the operation are `shape`, `seed`, `mean`, `stddev`, `minval`, + and `maxval` tensors, and the output is the `sample` tensor. + grad: The incoming gradient `dloss / dsample` of the same shape as + `op.outputs[0]`. + + Returns: + A list of `Tensor` with derivates with respect to each parameter. + + References: + Implicit Reparameterization Gradients: + [Figurnov et al., 2018] + (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients) + ([pdf] + (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf)) + """ + shape = op.inputs[0] + mean = op.inputs[2] + stddev = op.inputs[3] + minval = op.inputs[4] + maxval = op.inputs[5] + sample = op.outputs[0] + + with ops.control_dependencies([grad]): + minval_std = (minval - mean) / stddev + maxval_std = (maxval - mean) / stddev + sample_std = (sample - mean) / stddev + + cdf_sample = (_Ndtr(sample_std) - _Ndtr(minval_std)) / ( + _Ndtr(maxval_std) - _Ndtr(minval_std)) + + # Clip to avoid zero argument for log_cdf expression + tiny = np.finfo(mean.dtype.as_numpy_dtype).tiny + eps = np.finfo(mean.dtype.as_numpy_dtype).eps + cdf_sample = clip_ops.clip_by_value(cdf_sample, tiny, 1 - eps) + + dmaxval = math_ops.exp(0.5 * (sample_std ** 2 - maxval_std ** 2) + + math_ops.log(cdf_sample)) + dminval = math_ops.exp(0.5 * (sample_std ** 2 - minval_std ** 2) + + math_ops.log1p(-cdf_sample)) + dmean = array_ops.ones_like(sample_std) + dstddev = sample_std + + # Reduce over extra dimensions caused by `shape`. We need to get the + # difference in rank from shape vs. the broadcasted rank. + + mean_shape = array_ops.shape(mean) + stddev_shape = array_ops.shape(stddev) + minval_shape = array_ops.shape(minval) + maxval_shape = array_ops.shape(maxval) + + broadcast_shape = array_ops.broadcast_dynamic_shape( + mean_shape, stddev_shape) + broadcast_shape = array_ops.broadcast_dynamic_shape( + minval_shape, broadcast_shape) + broadcast_shape = array_ops.broadcast_dynamic_shape( + maxval_shape, broadcast_shape) + extra_dims = math_ops.range( + array_ops.size(shape) - array_ops.size(broadcast_shape)) + + grad_mean = math_ops.reduce_sum(grad * dmean, axis=extra_dims) + grad_stddev = math_ops.reduce_sum(grad * dstddev, axis=extra_dims) + grad_minval = math_ops.reduce_sum(grad * dminval, axis=extra_dims) + grad_maxval = math_ops.reduce_sum(grad * dmaxval, axis=extra_dims) + + _, rmean = gen_array_ops.broadcast_gradient_args( + broadcast_shape, mean_shape) + _, rstddev = gen_array_ops.broadcast_gradient_args( + broadcast_shape, stddev_shape) + _, rminval = gen_array_ops.broadcast_gradient_args( + broadcast_shape, minval_shape) + _, rmaxval = gen_array_ops.broadcast_gradient_args( + broadcast_shape, maxval_shape) + + grad_mean = array_ops.reshape( + math_ops.reduce_sum(grad_mean, axis=rmean, keepdims=True), mean_shape) + + grad_stddev = array_ops.reshape( + math_ops.reduce_sum(grad_stddev, axis=rstddev, keepdims=True), + stddev_shape) + + grad_minval = array_ops.reshape( + math_ops.reduce_sum(grad_minval, axis=rminval, keepdims=True), + minval_shape) + + grad_maxval = array_ops.reshape( + math_ops.reduce_sum(grad_maxval, axis=rmaxval, keepdims=True), + maxval_shape) + + # The first two inputs are shape. + return (None, None, grad_mean, grad_stddev, grad_minval, grad_maxval) diff --git a/tensorflow/python/ops/stateless_random_ops.py b/tensorflow/python/ops/stateless_random_ops.py index 25fefcc514c..3e825cc4775 100644 --- a/tensorflow/python/ops/stateless_random_ops.py +++ b/tensorflow/python/ops/stateless_random_ops.py @@ -618,3 +618,73 @@ def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed): logits = ops.convert_to_tensor(logits, name="logits") return gen_stateless_random_ops.stateless_multinomial( logits, num_samples, seed, output_dtype=dtype) + + +@dispatch.add_dispatch_support +@tf_export("random.stateless_parameterized_truncated_normal") +def stateless_parameterized_truncated_normal(shape, + seed, + means=0.0, + stddevs=1.0, + minvals=-2.0, + maxvals=2.0, + name=None): + """Outputs random values from a truncated normal distribution. + + The generated values follow a normal distribution with specified mean and + standard deviation, except that values whose magnitude is more than 2 standard + deviations from the mean are dropped and re-picked. + + + Examples: + + Sample from a Truncated normal, with deferring shape parameters that + broadcast. + + >>> means = 0. + >>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3])) + >>> minvals = [-1., -2., -1000.] + >>> maxvals = [[10000.], [1.]] + >>> y = tf.random.stateless_parameterized_truncated_normal( + ... shape=[10, 2, 3], seed=[7, 17], + ... means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals) + >>> y.shape + TensorShape([10, 2, 3]) + + Args: + shape: A 1-D integer `Tensor` or Python array. The shape of the output + tensor. + seed: A shape [2] Tensor, the seed to the random number generator. Must have + dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) + means: A `Tensor` or Python value of type `dtype`. The mean of the truncated + normal distribution. This must broadcast with `stddevs`, `minvals` and + `maxvals`, and the broadcasted shape must be dominated by `shape`. + stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation + of the truncated normal distribution. This must broadcast with `means`, + `minvals` and `maxvals`, and the broadcasted shape must be dominated by + `shape`. + minvals: A `Tensor` or Python value of type `dtype`. The minimum value of + the truncated normal distribution. This must broadcast with `means`, + `stddevs` and `maxvals`, and the broadcasted shape must be dominated by + `shape`. + maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of + the truncated normal distribution. This must broadcast with `means`, + `stddevs` and `minvals`, and the broadcasted shape must be dominated by + `shape`. + name: A name for the operation (optional). + + Returns: + A tensor of the specified shape filled with random truncated normal values. + """ + with ops.name_scope(name, "stateless_parameterized_truncated_normal", + [shape, means, stddevs, minvals, maxvals]) as name: + shape_tensor = tensor_util.shape_tensor(shape) + means_tensor = ops.convert_to_tensor(means, name="means") + stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs") + minvals_tensor = ops.convert_to_tensor(minvals, name="minvals") + maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals") + rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal( + shape_tensor, seed, means_tensor, stddevs_tensor, minvals_tensor, + maxvals_tensor) + tensor_util.maybe_set_static_shape(rnd, shape) + return rnd diff --git a/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt index 9c6fa7154a3..f5963f1324c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt @@ -92,6 +92,10 @@ tf_module { name: "stateless_normal" argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"\", \'None\'], " } + member_method { + name: "stateless_parameterized_truncated_normal" + argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'-2.0\', \'2.0\', \'None\'], " + } member_method { name: "stateless_poisson" argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 62969b5a0dd..8e5303cbea4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -4492,6 +4492,10 @@ tf_module { name: "StatelessMultinomial" argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } + member_method { + name: "StatelessParameterizedTruncatedNormal" + argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "StatelessRandomBinomial" argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt index e3a11ee4610..d1b8c90bfae 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt @@ -80,6 +80,10 @@ tf_module { name: "stateless_normal" argspec: "args=[\'shape\', \'seed\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"\", \'None\'], " } + member_method { + name: "stateless_parameterized_truncated_normal" + argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'-2.0\', \'2.0\', \'None\'], " + } member_method { name: "stateless_poisson" argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 62969b5a0dd..8e5303cbea4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -4492,6 +4492,10 @@ tf_module { name: "StatelessMultinomial" argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } + member_method { + name: "StatelessParameterizedTruncatedNormal" + argspec: "args=[\'shape\', \'seed\', \'means\', \'stddevs\', \'minvals\', \'maxvals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "StatelessRandomBinomial" argspec: "args=[\'shape\', \'seed\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], "