Move tf.random.stateless_gamma CPU kernel to a separate file.

Minor cleanup: replace UNIFORM(X) macro with a class

PiperOrigin-RevId: 348811389
Change-Id: Ia996963b0db4b6cf66bb0f6b1bd900a008b0df46
This commit is contained in:
Matej Rizman 2020-12-23 10:04:17 -08:00 committed by TensorFlower Gardener
parent 2d52be6a52
commit aaf94e8166
5 changed files with 363 additions and 184 deletions

View File

@ -646,6 +646,7 @@ cc_library(
"//tensorflow/core/kernels:sparse",
"//tensorflow/core/kernels:state",
"//tensorflow/core/kernels:stateless_random_ops",
"//tensorflow/core/kernels:stateless_random_gamma_op",
"//tensorflow/core/kernels:string",
"//tensorflow/core/kernels:summary_kernels",
"//tensorflow/core/kernels:training_ops",

View File

@ -4445,6 +4445,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "stateless_random_gamma_op",
prefix = "stateless_random_gamma_op",
deps = [
":stateless_random_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
tf_kernel_library(
name = "stateless_random_ops",
prefix = "stateless_random_ops",
@ -5929,6 +5939,7 @@ filegroup(
"spacetobatch_functor.h",
"spacetodepth_op.h",
"spectrogram.h",
"stateless_random_gamma_op.h",
"stateless_random_ops.h",
"stateless_random_ops_v2.h",
"sparse_fill_empty_rows_op.h",
@ -6196,6 +6207,7 @@ filegroup(
"stack.cc",
"stack.h",
"stack_ops.cc",
"stateless_random_gamma_op.cc",
"stateless_random_ops.cc",
"stateless_random_ops_v2.cc",
"string_join_op.cc",

View File

@ -0,0 +1,312 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/math_ops.cc.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/stateless_random_gamma_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/stateless_random_ops.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/work_sharder.h"
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
#define DISABLE_FLOAT_EQUALITY_WARNING \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
#else
#define DISABLE_FLOAT_EQUALITY_WARNING
#define ENABLE_FLOAT_EQUALITY_WARNING
#endif
namespace tensorflow {
namespace {
// Each attempt to generate a new draw from the Gamma distribution is 95+%
// successful, and requires 1-2 normal + 1 uniform sample.
static constexpr int kReservedSamplesPerOutput = 256;
typedef Eigen::ThreadPoolDevice CPUDevice;
// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
// a single sample from this buffer. If the buffer is empty, it first generates
// new samples using the provided distribution.
//
// If the call to Distribution::operator() returns samples[0...N-1], then this
// class returns samples in the following order:
//
// samples[N-1], samples[N-2],..., samples[1], samples[0]
//
// For comparison, random::SingleSampleAdapter returns samples in
// the following order:
//
// samples[0], samples[1],...,samples[N-2], samples[N-1].
//
template <class Distribution>
class SampleBuffer {
public:
typedef typename Distribution::ResultElementType ResultElementType;
PHILOX_DEVICE_INLINE
explicit SampleBuffer(Distribution* distribution)
: distribution_(distribution), remaining_numbers_(0) {}
PHILOX_DEVICE_INLINE
ResultElementType operator()(random::PhiloxRandom* random) {
if (remaining_numbers_ == 0) {
results_ = (*distribution_)(random);
remaining_numbers_ = Distribution::kResultElementCount;
}
remaining_numbers_--;
return results_[remaining_numbers_];
}
// Mark this buffer as empty. The next call to operator() will fill it
// with new random numbers.
PHILOX_DEVICE_INLINE
void Clear() { remaining_numbers_ = 0; }
private:
typedef typename Distribution::ResultType ResultType;
Distribution* distribution_;
ResultType results_;
int remaining_numbers_;
};
}; // namespace
namespace functor {
template <typename T>
struct StatelessRandomGammaFunctor<CPUDevice, T> {
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
int64 num_alphas, int64 samples_per_alpha,
const random::PhiloxRandom& random, T* samples_flat) {
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
// We partition work first across alphas then across samples-per-alpha to
// avoid a couple flops which can be done on a per-alpha basis.
auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat,
alpha_flat](int64 start_output, int64 limit_output) {
// Capturing "random" by-value would only make a copy for the _shared_
// lambda. Since we want to let each worker have its own copy, we pass
// "random" by reference and explicitly do a copy assignment.
using Eigen::numext::exp;
using Eigen::numext::log;
using Eigen::numext::log1p;
using Eigen::numext::pow;
Normal normal;
Uniform uniform;
SampleBuffer<Normal> normal_buffer(&normal);
SampleBuffer<Uniform> uniform_buffer(&uniform);
for (int64 output_idx = start_output; output_idx < limit_output;
/* output_idx incremented within inner loop below */) {
int64 alpha_idx = output_idx / samples_per_alpha;
// Instead of +alpha_idx for each sample, we offset the pointer once.
T* const samples_alpha_offset = samples_flat + alpha_idx;
// Several calculations can be done on a per-alpha basis.
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
DISABLE_FLOAT_EQUALITY_WARNING
if (alpha == static_cast<double>(1.0)) {
ENABLE_FLOAT_EQUALITY_WARNING
// Sample from an exponential distribution.
for (int64 sample_idx = output_idx % samples_per_alpha;
sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) {
// As we want data stable regardless of sharding
// (including eventually on GPU), we skip on a per-sample basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
double u = uniform(&gen)[Uniform::kResultElementCount - 1];
const double res = -log1p(-u);
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
} // for (sample_idx)
} else { // if alpha != 1.0
// Transformation-rejection from pairs of uniform and normal random
// variables. http://dl.acm.org/citation.cfm?id=358414
//
// The algorithm has an acceptance rate of ~95% for small alpha (~1),
// and higher accept rates for higher alpha, so runtime is
// O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
//
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
// result by uniform()^(1/alpha)
const bool alpha_less_than_one = alpha < 1;
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
const double c = 1.0 / 3 / sqrt(d);
// Compute the rest of the samples for the current alpha value.
for (int64 sample_idx = output_idx % samples_per_alpha;
sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) {
// Since each sample may use a variable number of normal/uniform
// samples, and we want data stable regardless of sharding
// (including eventually on GPU), we skip on a per-sample basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
// To prevent overwriting SampleBuffer's underlying array with
// zeros (in tensorflow::random::Array constructor), we just mark
// the buffer as empty instead of initializing a new SampleBuffer
// object here. The next call to operator() will fill the buffer
// with new numbers.
normal_buffer.Clear();
uniform_buffer.Clear();
// Keep trying until we don't reject a sample. In practice, we will
// only reject ~5% at worst, for low alpha near 1.
while (true) {
const double x = normal_buffer(&gen);
double v = 1 + c * x;
if (v <= 0) {
continue;
}
v = v * v * v;
double u = uniform_buffer(&gen);
// The first option in the if is a "squeeze" short-circuit to
// dodge the two logs. Magic constant sourced from the paper
// linked above. Upward of .91 of the area covered by the log
// inequality is covered by the squeeze as well (larger coverage
// for smaller values of alpha).
if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
(log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
double res = d * v;
if (alpha_less_than_one) {
double b = uniform_buffer(&gen);
res *= pow(b, 1 / alpha);
}
samples_alpha_offset[sample_idx * num_alphas] =
static_cast<T>(res);
break;
}
} // while: true
} // for: sample_idx
} // if (alpha == 1.0)
} // for: output_idx
}; // DoWork
// Two calls to log only occur for ~10% of samples reaching the log line.
// 2 x 100 (64-bit cycles per log) x 0.10 = ~20.
// Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
// each = ~60.
// All of this /0.95 (expected value of geometric distribution is 1/p) due
// to the rejection possibility = ~85.
static const int kElementCost = 85 + 2 * Normal::kElementCost +
Uniform::kElementCost +
3 * random::PhiloxRandom::kElementCost;
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
num_alphas * samples_per_alpha, kElementCost, DoWork);
return Status::OK();
}
};
} // namespace functor
namespace {
template <typename Device, typename T>
class StatelessRandomGammaOp : public OpKernel {
public:
explicit StatelessRandomGammaOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Sanitize input
const Tensor& shape_t = context->input(0);
const Tensor& seed_t = context->input(1);
TensorShape shape;
OP_REQUIRES_OK(context, tensor::MakeShape(shape_t, &shape));
OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
errors::InvalidArgument("seed must have shape [2], not ",
seed_t.shape().DebugString()));
// Allocate output
Tensor* output;
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
if (shape.num_elements() == 0) return;
random::PhiloxRandom::Key key;
random::PhiloxRandom::ResultType counter;
OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
// Fill in the random numbers
Fill(context, random::PhiloxRandom(counter, key), output);
}
private:
void Fill(OpKernelContext* ctx, random::PhiloxRandom random, Tensor* output) {
const Tensor& alpha_t = ctx->input(2);
TensorShape samples_shape = output->shape();
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, alpha_t.shape()),
errors::InvalidArgument(
"Shape passed in must end with broadcasted shape."));
const int64 num_alphas = alpha_t.NumElements();
OP_REQUIRES(ctx, num_alphas > 0,
errors::InvalidArgument(
"Input alpha should have non-zero element count, got: ",
num_alphas));
const int64 samples_per_alpha = samples_shape.num_elements() / num_alphas;
const auto alpha_flat = alpha_t.flat<T>().data();
auto samples_flat = output->flat<T>().data();
OP_REQUIRES_OK(ctx, functor::StatelessRandomGammaFunctor<Device, T>::Fill(
ctx, alpha_flat, num_alphas, samples_per_alpha,
random, samples_flat));
}
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomGammaOp);
};
#define REGISTER_GAMMA(TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
.Device(DEVICE_CPU) \
.HostMemory("shape") \
.HostMemory("seed") \
.HostMemory("alpha") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomGammaOp<CPUDevice, TYPE>)
TF_CALL_half(REGISTER_GAMMA);
TF_CALL_bfloat16(REGISTER_GAMMA);
TF_CALL_float(REGISTER_GAMMA);
TF_CALL_double(REGISTER_GAMMA);
#undef REGISTER_GAMMA
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/random/philox_random.h"
namespace tensorflow {
namespace functor {
template <typename Device, typename T>
struct StatelessRandomGammaFunctor {
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
int64 num_alphas, int64 samples_per_alpha,
const random::PhiloxRandom& random, T* samples_flat);
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_

View File

@ -23,17 +23,6 @@ limitations under the License.
#include "tensorflow/core/kernels/random_poisson_op.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/work_sharder.h"
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
#define DISABLE_FLOAT_EQUALITY_WARNING \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
#else
#define DISABLE_FLOAT_EQUALITY_WARNING
#define ENABLE_FLOAT_EQUALITY_WARNING
#endif
namespace tensorflow {
@ -212,163 +201,6 @@ class StatelessRandomPoissonOp : public StatelessRandomOpBase {
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomPoissonOp);
};
template <typename Device, typename T>
class StatelessRandomGammaOp : public StatelessRandomOpBase {
public:
using StatelessRandomOpBase::StatelessRandomOpBase;
void Fill(OpKernelContext* ctx, random::PhiloxRandom random,
Tensor* output) override {
const Tensor& alpha_t = ctx->input(2);
TensorShape samples_shape = output->shape();
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, alpha_t.shape()),
errors::InvalidArgument(
"Shape passed in must end with broadcasted shape."));
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
#define UNIFORM(X) \
if (uniform_remaining == 0) { \
uniform_remaining = Uniform::kResultElementCount; \
uniform_result = uniform(&gen); \
} \
uniform_remaining--; \
double X = uniform_result[uniform_remaining]
// Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform
static constexpr int kReservedSamplesPerOutput = 256;
const int64 num_alphas = alpha_t.NumElements();
OP_REQUIRES(ctx, num_alphas > 0,
errors::InvalidArgument(
"Input alpha should have non-zero element count, got: ",
num_alphas));
const int64 samples_per_alpha = samples_shape.num_elements() / num_alphas;
const auto alpha_flat = alpha_t.flat<T>().data();
auto samples_flat = output->flat<T>().data();
// We partition work first across alphas then across samples-per-alpha to
// avoid a couple flops which can be done on a per-alpha basis.
auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat,
alpha_flat](int64 start_output, int64 limit_output) {
// Capturing "random" by-value would only make a copy for the _shared_
// lambda. Since we want to let each worker have its own copy, we pass
// "random" by reference and explicitly do a copy assignment.
using Eigen::numext::exp;
using Eigen::numext::log;
using Eigen::numext::log1p;
using Eigen::numext::pow;
Normal normal;
Uniform uniform;
typename Normal::ResultType norm_result;
typename Uniform::ResultType uniform_result;
for (int64 output_idx = start_output; output_idx < limit_output;
/* output_idx incremented within inner loop below */) {
int64 alpha_idx = output_idx / samples_per_alpha;
// Instead of +alpha_idx for each sample, we offset the pointer once.
T* const samples_alpha_offset = samples_flat + alpha_idx;
// Several calculations can be done on a per-alpha basis.
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
DISABLE_FLOAT_EQUALITY_WARNING
if (alpha == static_cast<double>(1.0)) {
ENABLE_FLOAT_EQUALITY_WARNING
// Sample from an exponential distribution.
for (int64 sample_idx = output_idx % samples_per_alpha;
sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) {
// As we want data stable regardless of sharding
// (including eventually on GPU), we skip on a per-sample basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
int16 uniform_remaining = 0;
UNIFORM(u);
const double res = -log1p(-u);
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
} // for (sample_idx)
} else { // if alpha != 1.0
// Transformation-rejection from pairs of uniform and normal random
// variables. http://dl.acm.org/citation.cfm?id=358414
//
// The algorithm has an acceptance rate of ~95% for small alpha (~1),
// and higher accept rates for higher alpha, so runtime is
// O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
//
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
// result by uniform()^(1/alpha)
const bool alpha_less_than_one = alpha < 1;
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
const double c = 1.0 / 3 / sqrt(d);
// Compute the rest of the samples for the current alpha value.
for (int64 sample_idx = output_idx % samples_per_alpha;
sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) {
// Since each sample may use a variable number of normal/uniform
// samples, and we want data stable regardless of sharding
// (including eventually on GPU), we skip on a per-sample basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
int16 norm_remaining = 0;
int16 uniform_remaining = 0;
// Keep trying until we don't reject a sample. In practice, we will
// only reject ~5% at worst, for low alpha near 1.
while (true) {
if (norm_remaining == 0) {
norm_remaining = Normal::kResultElementCount;
norm_result = normal(&gen);
}
norm_remaining--;
const double x = norm_result[norm_remaining];
double v = 1 + c * x;
if (v <= 0) {
continue;
}
v = v * v * v;
UNIFORM(u);
// The first option in the if is a "squeeze" short-circuit to
// dodge the two logs. Magic constant sourced from the paper
// linked above. Upward of .91 of the area covered by the log
// inequality is covered by the squeeze as well (larger coverage
// for smaller values of alpha).
if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
(log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
double res = d * v;
if (alpha_less_than_one) {
UNIFORM(b);
res *= pow(b, 1 / alpha);
}
samples_alpha_offset[sample_idx * num_alphas] =
static_cast<T>(res);
break;
}
} // while: true
} // for: sample_idx
} // if (alpha == 1.0)
} // for: output_idx
}; // DoWork
#undef UNIFORM
// Two calls to log only occur for ~10% of samples reaching the log line.
// 2 x 100 (64-bit cycles per log) x 0.10 = ~20.
// Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
// each = ~60.
// All of this /0.95 due to the rejection possibility = ~85.
static const int kElementCost = 85 + 2 * Normal::kElementCost +
Uniform::kElementCost +
3 * random::PhiloxRandom::kElementCost;
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
num_alphas * samples_per_alpha, kElementCost, DoWork);
}
};
#define REGISTER(DEVICE, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomUniform") \
@ -459,22 +291,6 @@ TF_CALL_int64(REGISTER_ALL_POISSON);
#undef REGISTER_ALL_POISSON
#undef REGISTER_POISSON
#define REGISTER_GAMMA(TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
.Device(DEVICE_CPU) \
.HostMemory("shape") \
.HostMemory("seed") \
.HostMemory("alpha") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomGammaOp<CPUDevice, TYPE>)
TF_CALL_half(REGISTER_GAMMA);
TF_CALL_bfloat16(REGISTER_GAMMA);
TF_CALL_float(REGISTER_GAMMA);
TF_CALL_double(REGISTER_GAMMA);
#undef REGISTER_GAMMA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_half(REGISTER_GPU);