Move tf.random.stateless_gamma CPU kernel to a separate file.
Minor cleanup: replace UNIFORM(X) macro with a class PiperOrigin-RevId: 348811389 Change-Id: Ia996963b0db4b6cf66bb0f6b1bd900a008b0df46
This commit is contained in:
parent
2d52be6a52
commit
aaf94e8166
@ -646,6 +646,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:sparse",
|
||||
"//tensorflow/core/kernels:state",
|
||||
"//tensorflow/core/kernels:stateless_random_ops",
|
||||
"//tensorflow/core/kernels:stateless_random_gamma_op",
|
||||
"//tensorflow/core/kernels:string",
|
||||
"//tensorflow/core/kernels:summary_kernels",
|
||||
"//tensorflow/core/kernels:training_ops",
|
||||
|
@ -4445,6 +4445,16 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "stateless_random_gamma_op",
|
||||
prefix = "stateless_random_gamma_op",
|
||||
deps = [
|
||||
":stateless_random_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "stateless_random_ops",
|
||||
prefix = "stateless_random_ops",
|
||||
@ -5929,6 +5939,7 @@ filegroup(
|
||||
"spacetobatch_functor.h",
|
||||
"spacetodepth_op.h",
|
||||
"spectrogram.h",
|
||||
"stateless_random_gamma_op.h",
|
||||
"stateless_random_ops.h",
|
||||
"stateless_random_ops_v2.h",
|
||||
"sparse_fill_empty_rows_op.h",
|
||||
@ -6196,6 +6207,7 @@ filegroup(
|
||||
"stack.cc",
|
||||
"stack.h",
|
||||
"stack_ops.cc",
|
||||
"stateless_random_gamma_op.cc",
|
||||
"stateless_random_ops.cc",
|
||||
"stateless_random_ops_v2.cc",
|
||||
"string_join_op.cc",
|
||||
|
312
tensorflow/core/kernels/stateless_random_gamma_op.cc
Normal file
312
tensorflow/core/kernels/stateless_random_gamma_op.cc
Normal file
@ -0,0 +1,312 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/math_ops.cc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/stateless_random_gamma_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/kernels/stateless_random_ops.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||
#else
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Each attempt to generate a new draw from the Gamma distribution is 95+%
|
||||
// successful, and requires 1-2 normal + 1 uniform sample.
|
||||
static constexpr int kReservedSamplesPerOutput = 256;
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
|
||||
// a single sample from this buffer. If the buffer is empty, it first generates
|
||||
// new samples using the provided distribution.
|
||||
//
|
||||
// If the call to Distribution::operator() returns samples[0...N-1], then this
|
||||
// class returns samples in the following order:
|
||||
//
|
||||
// samples[N-1], samples[N-2],..., samples[1], samples[0]
|
||||
//
|
||||
// For comparison, random::SingleSampleAdapter returns samples in
|
||||
// the following order:
|
||||
//
|
||||
// samples[0], samples[1],...,samples[N-2], samples[N-1].
|
||||
//
|
||||
template <class Distribution>
|
||||
class SampleBuffer {
|
||||
public:
|
||||
typedef typename Distribution::ResultElementType ResultElementType;
|
||||
|
||||
PHILOX_DEVICE_INLINE
|
||||
explicit SampleBuffer(Distribution* distribution)
|
||||
: distribution_(distribution), remaining_numbers_(0) {}
|
||||
|
||||
PHILOX_DEVICE_INLINE
|
||||
ResultElementType operator()(random::PhiloxRandom* random) {
|
||||
if (remaining_numbers_ == 0) {
|
||||
results_ = (*distribution_)(random);
|
||||
remaining_numbers_ = Distribution::kResultElementCount;
|
||||
}
|
||||
|
||||
remaining_numbers_--;
|
||||
return results_[remaining_numbers_];
|
||||
}
|
||||
|
||||
// Mark this buffer as empty. The next call to operator() will fill it
|
||||
// with new random numbers.
|
||||
PHILOX_DEVICE_INLINE
|
||||
void Clear() { remaining_numbers_ = 0; }
|
||||
|
||||
private:
|
||||
typedef typename Distribution::ResultType ResultType;
|
||||
|
||||
Distribution* distribution_;
|
||||
ResultType results_;
|
||||
int remaining_numbers_;
|
||||
};
|
||||
|
||||
}; // namespace
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T>
|
||||
struct StatelessRandomGammaFunctor<CPUDevice, T> {
|
||||
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
|
||||
int64 num_alphas, int64 samples_per_alpha,
|
||||
const random::PhiloxRandom& random, T* samples_flat) {
|
||||
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
|
||||
|
||||
// We partition work first across alphas then across samples-per-alpha to
|
||||
// avoid a couple flops which can be done on a per-alpha basis.
|
||||
|
||||
auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat,
|
||||
alpha_flat](int64 start_output, int64 limit_output) {
|
||||
// Capturing "random" by-value would only make a copy for the _shared_
|
||||
// lambda. Since we want to let each worker have its own copy, we pass
|
||||
// "random" by reference and explicitly do a copy assignment.
|
||||
|
||||
using Eigen::numext::exp;
|
||||
using Eigen::numext::log;
|
||||
using Eigen::numext::log1p;
|
||||
using Eigen::numext::pow;
|
||||
|
||||
Normal normal;
|
||||
Uniform uniform;
|
||||
|
||||
SampleBuffer<Normal> normal_buffer(&normal);
|
||||
SampleBuffer<Uniform> uniform_buffer(&uniform);
|
||||
|
||||
for (int64 output_idx = start_output; output_idx < limit_output;
|
||||
/* output_idx incremented within inner loop below */) {
|
||||
int64 alpha_idx = output_idx / samples_per_alpha;
|
||||
|
||||
// Instead of +alpha_idx for each sample, we offset the pointer once.
|
||||
T* const samples_alpha_offset = samples_flat + alpha_idx;
|
||||
|
||||
// Several calculations can be done on a per-alpha basis.
|
||||
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
|
||||
|
||||
DISABLE_FLOAT_EQUALITY_WARNING
|
||||
if (alpha == static_cast<double>(1.0)) {
|
||||
ENABLE_FLOAT_EQUALITY_WARNING
|
||||
// Sample from an exponential distribution.
|
||||
for (int64 sample_idx = output_idx % samples_per_alpha;
|
||||
sample_idx < samples_per_alpha && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// As we want data stable regardless of sharding
|
||||
// (including eventually on GPU), we skip on a per-sample basis.
|
||||
random::PhiloxRandom gen = random;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
double u = uniform(&gen)[Uniform::kResultElementCount - 1];
|
||||
const double res = -log1p(-u);
|
||||
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
||||
} // for (sample_idx)
|
||||
} else { // if alpha != 1.0
|
||||
// Transformation-rejection from pairs of uniform and normal random
|
||||
// variables. http://dl.acm.org/citation.cfm?id=358414
|
||||
//
|
||||
// The algorithm has an acceptance rate of ~95% for small alpha (~1),
|
||||
// and higher accept rates for higher alpha, so runtime is
|
||||
// O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
|
||||
//
|
||||
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
|
||||
// result by uniform()^(1/alpha)
|
||||
const bool alpha_less_than_one = alpha < 1;
|
||||
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
|
||||
const double c = 1.0 / 3 / sqrt(d);
|
||||
|
||||
// Compute the rest of the samples for the current alpha value.
|
||||
for (int64 sample_idx = output_idx % samples_per_alpha;
|
||||
sample_idx < samples_per_alpha && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// Since each sample may use a variable number of normal/uniform
|
||||
// samples, and we want data stable regardless of sharding
|
||||
// (including eventually on GPU), we skip on a per-sample basis.
|
||||
random::PhiloxRandom gen = random;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
|
||||
// To prevent overwriting SampleBuffer's underlying array with
|
||||
// zeros (in tensorflow::random::Array constructor), we just mark
|
||||
// the buffer as empty instead of initializing a new SampleBuffer
|
||||
// object here. The next call to operator() will fill the buffer
|
||||
// with new numbers.
|
||||
normal_buffer.Clear();
|
||||
uniform_buffer.Clear();
|
||||
|
||||
// Keep trying until we don't reject a sample. In practice, we will
|
||||
// only reject ~5% at worst, for low alpha near 1.
|
||||
while (true) {
|
||||
const double x = normal_buffer(&gen);
|
||||
double v = 1 + c * x;
|
||||
if (v <= 0) {
|
||||
continue;
|
||||
}
|
||||
v = v * v * v;
|
||||
double u = uniform_buffer(&gen);
|
||||
// The first option in the if is a "squeeze" short-circuit to
|
||||
// dodge the two logs. Magic constant sourced from the paper
|
||||
// linked above. Upward of .91 of the area covered by the log
|
||||
// inequality is covered by the squeeze as well (larger coverage
|
||||
// for smaller values of alpha).
|
||||
if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
|
||||
(log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
|
||||
double res = d * v;
|
||||
if (alpha_less_than_one) {
|
||||
double b = uniform_buffer(&gen);
|
||||
res *= pow(b, 1 / alpha);
|
||||
}
|
||||
samples_alpha_offset[sample_idx * num_alphas] =
|
||||
static_cast<T>(res);
|
||||
break;
|
||||
}
|
||||
} // while: true
|
||||
} // for: sample_idx
|
||||
} // if (alpha == 1.0)
|
||||
} // for: output_idx
|
||||
}; // DoWork
|
||||
|
||||
// Two calls to log only occur for ~10% of samples reaching the log line.
|
||||
// 2 x 100 (64-bit cycles per log) x 0.10 = ~20.
|
||||
// Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
|
||||
// each = ~60.
|
||||
// All of this /0.95 (expected value of geometric distribution is 1/p) due
|
||||
// to the rejection possibility = ~85.
|
||||
static const int kElementCost = 85 + 2 * Normal::kElementCost +
|
||||
Uniform::kElementCost +
|
||||
3 * random::PhiloxRandom::kElementCost;
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||
num_alphas * samples_per_alpha, kElementCost, DoWork);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Device, typename T>
|
||||
class StatelessRandomGammaOp : public OpKernel {
|
||||
public:
|
||||
explicit StatelessRandomGammaOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Sanitize input
|
||||
const Tensor& shape_t = context->input(0);
|
||||
const Tensor& seed_t = context->input(1);
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(context, tensor::MakeShape(shape_t, &shape));
|
||||
OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
|
||||
errors::InvalidArgument("seed must have shape [2], not ",
|
||||
seed_t.shape().DebugString()));
|
||||
|
||||
// Allocate output
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
|
||||
if (shape.num_elements() == 0) return;
|
||||
|
||||
random::PhiloxRandom::Key key;
|
||||
random::PhiloxRandom::ResultType counter;
|
||||
OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
|
||||
|
||||
// Fill in the random numbers
|
||||
Fill(context, random::PhiloxRandom(counter, key), output);
|
||||
}
|
||||
|
||||
private:
|
||||
void Fill(OpKernelContext* ctx, random::PhiloxRandom random, Tensor* output) {
|
||||
const Tensor& alpha_t = ctx->input(2);
|
||||
|
||||
TensorShape samples_shape = output->shape();
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, alpha_t.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Shape passed in must end with broadcasted shape."));
|
||||
|
||||
const int64 num_alphas = alpha_t.NumElements();
|
||||
OP_REQUIRES(ctx, num_alphas > 0,
|
||||
errors::InvalidArgument(
|
||||
"Input alpha should have non-zero element count, got: ",
|
||||
num_alphas));
|
||||
|
||||
const int64 samples_per_alpha = samples_shape.num_elements() / num_alphas;
|
||||
const auto alpha_flat = alpha_t.flat<T>().data();
|
||||
auto samples_flat = output->flat<T>().data();
|
||||
|
||||
OP_REQUIRES_OK(ctx, functor::StatelessRandomGammaFunctor<Device, T>::Fill(
|
||||
ctx, alpha_flat, num_alphas, samples_per_alpha,
|
||||
random, samples_flat));
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomGammaOp);
|
||||
};
|
||||
|
||||
#define REGISTER_GAMMA(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("seed") \
|
||||
.HostMemory("alpha") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomGammaOp<CPUDevice, TYPE>)
|
||||
|
||||
TF_CALL_half(REGISTER_GAMMA);
|
||||
TF_CALL_bfloat16(REGISTER_GAMMA);
|
||||
TF_CALL_float(REGISTER_GAMMA);
|
||||
TF_CALL_double(REGISTER_GAMMA);
|
||||
|
||||
#undef REGISTER_GAMMA
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
38
tensorflow/core/kernels/stateless_random_gamma_op.h
Normal file
38
tensorflow/core/kernels/stateless_random_gamma_op.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct StatelessRandomGammaFunctor {
|
||||
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
|
||||
int64 num_alphas, int64 samples_per_alpha,
|
||||
const random::PhiloxRandom& random, T* samples_flat);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
|
@ -23,17 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/random_poisson_op.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||
#else
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -212,163 +201,6 @@ class StatelessRandomPoissonOp : public StatelessRandomOpBase {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomPoissonOp);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class StatelessRandomGammaOp : public StatelessRandomOpBase {
|
||||
public:
|
||||
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||
|
||||
void Fill(OpKernelContext* ctx, random::PhiloxRandom random,
|
||||
Tensor* output) override {
|
||||
const Tensor& alpha_t = ctx->input(2);
|
||||
|
||||
TensorShape samples_shape = output->shape();
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, alpha_t.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Shape passed in must end with broadcasted shape."));
|
||||
|
||||
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
|
||||
#define UNIFORM(X) \
|
||||
if (uniform_remaining == 0) { \
|
||||
uniform_remaining = Uniform::kResultElementCount; \
|
||||
uniform_result = uniform(&gen); \
|
||||
} \
|
||||
uniform_remaining--; \
|
||||
double X = uniform_result[uniform_remaining]
|
||||
|
||||
// Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform
|
||||
static constexpr int kReservedSamplesPerOutput = 256;
|
||||
|
||||
const int64 num_alphas = alpha_t.NumElements();
|
||||
OP_REQUIRES(ctx, num_alphas > 0,
|
||||
errors::InvalidArgument(
|
||||
"Input alpha should have non-zero element count, got: ",
|
||||
num_alphas));
|
||||
const int64 samples_per_alpha = samples_shape.num_elements() / num_alphas;
|
||||
const auto alpha_flat = alpha_t.flat<T>().data();
|
||||
auto samples_flat = output->flat<T>().data();
|
||||
|
||||
// We partition work first across alphas then across samples-per-alpha to
|
||||
// avoid a couple flops which can be done on a per-alpha basis.
|
||||
|
||||
auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat,
|
||||
alpha_flat](int64 start_output, int64 limit_output) {
|
||||
// Capturing "random" by-value would only make a copy for the _shared_
|
||||
// lambda. Since we want to let each worker have its own copy, we pass
|
||||
// "random" by reference and explicitly do a copy assignment.
|
||||
|
||||
using Eigen::numext::exp;
|
||||
using Eigen::numext::log;
|
||||
using Eigen::numext::log1p;
|
||||
using Eigen::numext::pow;
|
||||
|
||||
Normal normal;
|
||||
Uniform uniform;
|
||||
typename Normal::ResultType norm_result;
|
||||
typename Uniform::ResultType uniform_result;
|
||||
for (int64 output_idx = start_output; output_idx < limit_output;
|
||||
/* output_idx incremented within inner loop below */) {
|
||||
int64 alpha_idx = output_idx / samples_per_alpha;
|
||||
|
||||
// Instead of +alpha_idx for each sample, we offset the pointer once.
|
||||
T* const samples_alpha_offset = samples_flat + alpha_idx;
|
||||
|
||||
// Several calculations can be done on a per-alpha basis.
|
||||
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
|
||||
|
||||
DISABLE_FLOAT_EQUALITY_WARNING
|
||||
if (alpha == static_cast<double>(1.0)) {
|
||||
ENABLE_FLOAT_EQUALITY_WARNING
|
||||
// Sample from an exponential distribution.
|
||||
for (int64 sample_idx = output_idx % samples_per_alpha;
|
||||
sample_idx < samples_per_alpha && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// As we want data stable regardless of sharding
|
||||
// (including eventually on GPU), we skip on a per-sample basis.
|
||||
random::PhiloxRandom gen = random;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
int16 uniform_remaining = 0;
|
||||
UNIFORM(u);
|
||||
const double res = -log1p(-u);
|
||||
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
||||
} // for (sample_idx)
|
||||
} else { // if alpha != 1.0
|
||||
// Transformation-rejection from pairs of uniform and normal random
|
||||
// variables. http://dl.acm.org/citation.cfm?id=358414
|
||||
//
|
||||
// The algorithm has an acceptance rate of ~95% for small alpha (~1),
|
||||
// and higher accept rates for higher alpha, so runtime is
|
||||
// O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
|
||||
//
|
||||
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
|
||||
// result by uniform()^(1/alpha)
|
||||
const bool alpha_less_than_one = alpha < 1;
|
||||
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
|
||||
const double c = 1.0 / 3 / sqrt(d);
|
||||
|
||||
// Compute the rest of the samples for the current alpha value.
|
||||
for (int64 sample_idx = output_idx % samples_per_alpha;
|
||||
sample_idx < samples_per_alpha && output_idx < limit_output;
|
||||
sample_idx++, output_idx++) {
|
||||
// Since each sample may use a variable number of normal/uniform
|
||||
// samples, and we want data stable regardless of sharding
|
||||
// (including eventually on GPU), we skip on a per-sample basis.
|
||||
random::PhiloxRandom gen = random;
|
||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||
int16 norm_remaining = 0;
|
||||
int16 uniform_remaining = 0;
|
||||
|
||||
// Keep trying until we don't reject a sample. In practice, we will
|
||||
// only reject ~5% at worst, for low alpha near 1.
|
||||
while (true) {
|
||||
if (norm_remaining == 0) {
|
||||
norm_remaining = Normal::kResultElementCount;
|
||||
norm_result = normal(&gen);
|
||||
}
|
||||
norm_remaining--;
|
||||
const double x = norm_result[norm_remaining];
|
||||
double v = 1 + c * x;
|
||||
if (v <= 0) {
|
||||
continue;
|
||||
}
|
||||
v = v * v * v;
|
||||
UNIFORM(u);
|
||||
// The first option in the if is a "squeeze" short-circuit to
|
||||
// dodge the two logs. Magic constant sourced from the paper
|
||||
// linked above. Upward of .91 of the area covered by the log
|
||||
// inequality is covered by the squeeze as well (larger coverage
|
||||
// for smaller values of alpha).
|
||||
if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
|
||||
(log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
|
||||
double res = d * v;
|
||||
if (alpha_less_than_one) {
|
||||
UNIFORM(b);
|
||||
res *= pow(b, 1 / alpha);
|
||||
}
|
||||
samples_alpha_offset[sample_idx * num_alphas] =
|
||||
static_cast<T>(res);
|
||||
break;
|
||||
}
|
||||
} // while: true
|
||||
} // for: sample_idx
|
||||
} // if (alpha == 1.0)
|
||||
} // for: output_idx
|
||||
}; // DoWork
|
||||
#undef UNIFORM
|
||||
// Two calls to log only occur for ~10% of samples reaching the log line.
|
||||
// 2 x 100 (64-bit cycles per log) x 0.10 = ~20.
|
||||
// Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
|
||||
// each = ~60.
|
||||
// All of this /0.95 due to the rejection possibility = ~85.
|
||||
static const int kElementCost = 85 + 2 * Normal::kElementCost +
|
||||
Uniform::kElementCost +
|
||||
3 * random::PhiloxRandom::kElementCost;
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||
num_alphas * samples_per_alpha, kElementCost, DoWork);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER(DEVICE, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessRandomUniform") \
|
||||
@ -459,22 +291,6 @@ TF_CALL_int64(REGISTER_ALL_POISSON);
|
||||
#undef REGISTER_ALL_POISSON
|
||||
#undef REGISTER_POISSON
|
||||
|
||||
#define REGISTER_GAMMA(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("seed") \
|
||||
.HostMemory("alpha") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomGammaOp<CPUDevice, TYPE>)
|
||||
|
||||
TF_CALL_half(REGISTER_GAMMA);
|
||||
TF_CALL_bfloat16(REGISTER_GAMMA);
|
||||
TF_CALL_float(REGISTER_GAMMA);
|
||||
TF_CALL_double(REGISTER_GAMMA);
|
||||
|
||||
#undef REGISTER_GAMMA
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
TF_CALL_half(REGISTER_GPU);
|
||||
|
Loading…
Reference in New Issue
Block a user