Add GPU kernel for tf.random.stateless_gamma.

PiperOrigin-RevId: 351972885
Change-Id: Iabef759abd14aabc6e4c7f9783ba5899d824b40c
This commit is contained in:
Matej Rizman 2021-01-15 02:56:56 -08:00 committed by TensorFlower Gardener
parent 1059e04b92
commit 489133d42a
4 changed files with 295 additions and 72 deletions

View File

@ -47,53 +47,7 @@ namespace {
static constexpr int kReservedSamplesPerOutput = 256;
typedef Eigen::ThreadPoolDevice CPUDevice;
// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
// a single sample from this buffer. If the buffer is empty, it first generates
// new samples using the provided distribution.
//
// If the call to Distribution::operator() returns samples[0...N-1], then this
// class returns samples in the following order:
//
// samples[N-1], samples[N-2],..., samples[1], samples[0]
//
// For comparison, random::SingleSampleAdapter returns samples in
// the following order:
//
// samples[0], samples[1],...,samples[N-2], samples[N-1].
//
template <class Distribution>
class SampleBuffer {
public:
typedef typename Distribution::ResultElementType ResultElementType;
PHILOX_DEVICE_INLINE
explicit SampleBuffer(Distribution* distribution)
: distribution_(distribution), remaining_numbers_(0) {}
PHILOX_DEVICE_INLINE
ResultElementType operator()(random::PhiloxRandom* random) {
if (remaining_numbers_ == 0) {
results_ = (*distribution_)(random);
remaining_numbers_ = Distribution::kResultElementCount;
}
remaining_numbers_--;
return results_[remaining_numbers_];
}
// Mark this buffer as empty. The next call to operator() will fill it
// with new random numbers.
PHILOX_DEVICE_INLINE
void Clear() { remaining_numbers_ = 0; }
private:
typedef typename Distribution::ResultType ResultType;
Distribution* distribution_;
ResultType results_;
int remaining_numbers_;
};
typedef Eigen::GpuDevice GPUDevice;
}; // namespace
@ -102,7 +56,8 @@ namespace functor {
template <typename T>
struct StatelessRandomGammaFunctor<CPUDevice, T> {
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
int64 num_alphas, int64 samples_per_alpha,
int64 num_samples, int64 num_alphas,
int64 samples_per_alpha,
const random::PhiloxRandom& random, T* samples_flat) {
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
@ -124,8 +79,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
Normal normal;
Uniform uniform;
SampleBuffer<Normal> normal_buffer(&normal);
SampleBuffer<Uniform> uniform_buffer(&uniform);
RandomSampleBuffer<Normal> normal_buffer(&normal);
RandomSampleBuffer<Uniform> uniform_buffer(&uniform);
for (int64 output_idx = start_output; output_idx < limit_output;
/* output_idx incremented within inner loop below */) {
@ -138,14 +93,14 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
DISABLE_FLOAT_EQUALITY_WARNING
if (alpha == static_cast<double>(1.0)) {
if (alpha == 1.0) {
ENABLE_FLOAT_EQUALITY_WARNING
// Sample from an exponential distribution.
for (int64 sample_idx = output_idx % samples_per_alpha;
sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) {
// As we want data stable regardless of sharding
// (including eventually on GPU), we skip on a per-sample basis.
// As we want data stable regardless of sharding, we skip on a
// per-sample basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
double u = uniform(&gen)[Uniform::kResultElementCount - 1];
@ -162,7 +117,7 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
//
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
// result by uniform()^(1/alpha)
const bool alpha_less_than_one = alpha < 1;
const bool alpha_less_than_one = alpha < 1.0;
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
const double c = 1.0 / 3 / sqrt(d);
@ -171,8 +126,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) {
// Since each sample may use a variable number of normal/uniform
// samples, and we want data stable regardless of sharding
// (including eventually on GPU), we skip on a per-sample basis.
// samples, and we want data stable regardless of sharding, we skip
// on a per-sample basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
@ -226,8 +181,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
Uniform::kElementCost +
3 * random::PhiloxRandom::kElementCost;
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
num_alphas * samples_per_alpha, kElementCost, DoWork);
Shard(worker_threads.num_threads, worker_threads.workers, num_samples,
kElementCost, DoWork);
return Status::OK();
}
};
@ -280,19 +235,21 @@ class StatelessRandomGammaOp : public OpKernel {
"Input alpha should have non-zero element count, got: ",
num_alphas));
const int64 samples_per_alpha = samples_shape.num_elements() / num_alphas;
const int64 num_samples = samples_shape.num_elements();
const int64 samples_per_alpha = num_samples / num_alphas;
const auto alpha_flat = alpha_t.flat<T>().data();
auto samples_flat = output->flat<T>().data();
OP_REQUIRES_OK(ctx, functor::StatelessRandomGammaFunctor<Device, T>::Fill(
ctx, alpha_flat, num_alphas, samples_per_alpha,
random, samples_flat));
ctx, alpha_flat, num_samples, num_alphas,
samples_per_alpha, random, samples_flat));
}
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomGammaOp);
};
#define REGISTER_GAMMA(TYPE) \
// Register CPU kernels for stateless gamma op.
#define REGISTER_GAMMA_CPU(TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
.Device(DEVICE_CPU) \
.HostMemory("shape") \
@ -301,12 +258,32 @@ class StatelessRandomGammaOp : public OpKernel {
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomGammaOp<CPUDevice, TYPE>)
TF_CALL_half(REGISTER_GAMMA);
TF_CALL_bfloat16(REGISTER_GAMMA);
TF_CALL_float(REGISTER_GAMMA);
TF_CALL_double(REGISTER_GAMMA);
TF_CALL_half(REGISTER_GAMMA_CPU);
TF_CALL_bfloat16(REGISTER_GAMMA_CPU);
TF_CALL_float(REGISTER_GAMMA_CPU);
TF_CALL_double(REGISTER_GAMMA_CPU);
#undef REGISTER_GAMMA
#undef REGISTER_GAMMA_CPU
// Register GPU kernels for stateless gamma op.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GAMMA_GPU(TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
.Device(DEVICE_GPU) \
.HostMemory("shape") \
.HostMemory("seed") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomGammaOp<GPUDevice, TYPE>)
TF_CALL_half(REGISTER_GAMMA_GPU);
TF_CALL_bfloat16(REGISTER_GAMMA_GPU);
TF_CALL_float(REGISTER_GAMMA_GPU);
TF_CALL_double(REGISTER_GAMMA_GPU);
#undef REGISTER_GAMMA_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace
} // namespace tensorflow

View File

@ -27,12 +27,60 @@ namespace functor {
template <typename Device, typename T>
struct StatelessRandomGammaFunctor {
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
int64 num_alphas, int64 samples_per_alpha,
int64 num_samples, int64 num_alphas,
int64 samples_per_alpha,
const random::PhiloxRandom& random, T* samples_flat);
};
} // namespace functor
// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
// a single sample from this buffer. If the buffer is empty, it first generates
// new samples using the provided distribution.
//
// If the call to Distribution::operator() returns samples[0...N-1], then this
// class returns samples in the following order:
//
// samples[N-1], samples[N-2],..., samples[1], samples[0]
//
// For comparison, random::SingleSampleAdapter returns samples in
// the following order:
//
// samples[0], samples[1],...,samples[N-2], samples[N-1].
//
template <class Distribution>
class RandomSampleBuffer {
public:
typedef typename Distribution::ResultElementType ResultElementType;
PHILOX_DEVICE_INLINE
explicit RandomSampleBuffer(Distribution* distribution)
: distribution_(distribution), remaining_numbers_(0) {}
PHILOX_DEVICE_INLINE
ResultElementType operator()(random::PhiloxRandom* random) {
if (remaining_numbers_ == 0) {
results_ = (*distribution_)(random);
remaining_numbers_ = Distribution::kResultElementCount;
}
remaining_numbers_--;
return results_[remaining_numbers_];
}
// Mark this buffer as empty. The next call to operator() will fill it
// with new random numbers.
PHILOX_DEVICE_INLINE
void Clear() { remaining_numbers_ = 0; }
private:
typedef typename Distribution::ResultType ResultType;
Distribution* distribution_;
ResultType results_;
int remaining_numbers_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_

View File

@ -0,0 +1,179 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/stateless_random_gamma_op.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {
namespace {
typedef Eigen::GpuDevice GPUDevice;
// Each attempt to generate a new draw from the Gamma distribution is 95+%
// successful, and requires 1-2 normal + 1 uniform sample.
static constexpr int kReservedSamplesPerOutput = 256;
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
#define DISABLE_FLOAT_EQUALITY_WARNING \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
#else
#define DISABLE_FLOAT_EQUALITY_WARNING
#define ENABLE_FLOAT_EQUALITY_WARNING
#endif
template <typename T>
__global__ void __launch_bounds__(1024)
FillKernel(int64 num_samples, int64 num_alphas, int64 samples_per_alpha,
random::PhiloxRandom random, T* samples_flat,
const T* alpha_flat) {
using Eigen::numext::exp;
using Eigen::numext::log;
using Eigen::numext::log1p;
using Eigen::numext::pow;
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
Normal normal;
Uniform uniform;
RandomSampleBuffer<Normal> normal_buffer(&normal);
RandomSampleBuffer<Uniform> uniform_buffer(&uniform);
for (int64 output_idx : GpuGridRangeX(num_samples)) {
int64 alpha_idx = output_idx / samples_per_alpha;
int64 sample_idx = output_idx % samples_per_alpha;
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
DISABLE_FLOAT_EQUALITY_WARNING
if (alpha == 1.0) {
ENABLE_FLOAT_EQUALITY_WARNING
// Sample from an exponential distribution.
// As we want data stable regardless of sharding, we skip on a per-sample
// basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
double u = uniform(&gen)[Uniform::kResultElementCount - 1];
const double res = -log1p(-u);
// We use alpha_idx + sample_idx * num_alphas instead of output_idx
// to generate numbers in the right order (CPU and GPU kernels
// must generate numbers in the same order).
samples_flat[alpha_idx + sample_idx * num_alphas] = static_cast<T>(res);
} else { // if alpha != 1.0
// Transformation-rejection from pairs of uniform and normal random
// variables. http://dl.acm.org/citation.cfm?id=358414
//
// The algorithm has an acceptance rate of ~95% for small alpha (~1),
// and higher accept rates for higher alpha, so runtime is
// O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
//
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
// result by uniform()^(1/alpha)
const bool alpha_less_than_one = alpha < 1.0;
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
const double c = 1.0 / 3 / sqrt(d);
// Since each sample may use a variable number of normal/uniform
// samples, and we want data stable regardless of sharding, we skip on a
// per-sample basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
// To prevent overwriting SampleBuffer's underlying array with
// zeros (in tensorflow::random::Array constructor), we just mark
// the buffer as empty instead of initializing a new SampleBuffer
// object here. The next call to operator() will fill the buffer
// with new numbers.
normal_buffer.Clear();
uniform_buffer.Clear();
// Keep trying until we don't reject a sample. In practice, we will
// only reject ~5% at worst, for low alpha near 1.
while (true) {
const double x = normal_buffer(&gen);
double v = 1 + c * x;
if (v <= 0) {
continue;
}
v = v * v * v;
double u = uniform_buffer(&gen);
// The first option in the if is a "squeeze" short-circuit to
// dodge the two logs. Magic constant sourced from the paper
// linked above. Upward of .91 of the area covered by the log
// inequality is covered by the squeeze as well (larger coverage
// for smaller values of alpha).
if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
(log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
double res = d * v;
if (alpha_less_than_one) {
double b = uniform_buffer(&gen);
res *= pow(b, 1 / alpha);
}
// We use alpha_idx + sample_idx * num_alphas instead of output_idx
// to generate numbers in the right order (CPU and GPU kernels
// must generate numbers in the same order).
samples_flat[alpha_idx + sample_idx * num_alphas] =
static_cast<T>(res);
break;
}
} // while: true
} // if (alpha == 1.0)
} // for: output_idx
}
} // namespace
namespace functor {
template <typename T>
struct StatelessRandomGammaFunctor<GPUDevice, T> {
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
int64 num_samples, int64 num_alphas,
int64 samples_per_alpha,
const random::PhiloxRandom& random, T* samples_flat) {
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
GpuLaunchConfig cfg = GetGpuLaunchConfig(num_samples, d);
TF_CHECK_OK(GpuLaunchKernel(FillKernel<T>, cfg.block_count,
cfg.thread_per_block, 0, d.stream(),
num_samples, num_alphas, samples_per_alpha,
random, samples_flat, alpha_flat));
return Status::OK();
}
};
} // namespace functor
#define REGISTER_GPU_SPEC(type) \
template struct functor::StatelessRandomGammaFunctor<GPUDevice, type>;
TF_CALL_half(REGISTER_GPU_SPEC);
TF_CALL_bfloat16(REGISTER_GPU_SPEC);
TF_CALL_float(REGISTER_GPU_SPEC);
TF_CALL_double(REGISTER_GPU_SPEC);
#undef REGISTER_GPU_SPEC
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -216,6 +216,17 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
pure = stateless_op(seed=preseed)
self.assertAllEqual(stateful, pure)
def _test_match_stateless_cpu_gpu(self, case, seed):
# Stateless ops should produce the same result on CPUs and GPUs.
_, stateless_op, _ = case
with ops.device('CPU'):
result_cpu = stateless_op(seed=seed)
with ops.device(get_device().name):
result_gpu = stateless_op(seed=seed)
self.assertAllClose(result_cpu, result_gpu)
def _test_old_and_new_stateless_match(self, case, seed):
"""Tests that the new stateless ops match the old stateless ones."""
with ops.device(get_device().name):
@ -306,6 +317,18 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
self.skipTest('Lacking XLA kernel')
self._test_match(case, seed)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS)
for case_id, case in enumerate(gamma_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testStatelessGammaCpuGpuMatch(self, case, seed):
if get_device().device_type != 'GPU':
# This test compares the numbers produced by the CPU and GPU kernel for
# stateless_random_gamma.
self.skipTest('This test requires GPU')
self._test_match_stateless_cpu_gpu(case, seed)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS)
@ -387,10 +410,6 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
for case_id, case in enumerate(gamma_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testDeterminismGamma(self, case, seed_type):
if get_device().device_type == 'GPU':
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking GPU kernel')
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.