Add GPU kernel for tf.random.stateless_gamma.

PiperOrigin-RevId: 351972885
Change-Id: Iabef759abd14aabc6e4c7f9783ba5899d824b40c
This commit is contained in:
Matej Rizman 2021-01-15 02:56:56 -08:00 committed by TensorFlower Gardener
parent 1059e04b92
commit 489133d42a
4 changed files with 295 additions and 72 deletions

View File

@ -47,53 +47,7 @@ namespace {
static constexpr int kReservedSamplesPerOutput = 256; static constexpr int kReservedSamplesPerOutput = 256;
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
// a single sample from this buffer. If the buffer is empty, it first generates
// new samples using the provided distribution.
//
// If the call to Distribution::operator() returns samples[0...N-1], then this
// class returns samples in the following order:
//
// samples[N-1], samples[N-2],..., samples[1], samples[0]
//
// For comparison, random::SingleSampleAdapter returns samples in
// the following order:
//
// samples[0], samples[1],...,samples[N-2], samples[N-1].
//
template <class Distribution>
class SampleBuffer {
public:
typedef typename Distribution::ResultElementType ResultElementType;
PHILOX_DEVICE_INLINE
explicit SampleBuffer(Distribution* distribution)
: distribution_(distribution), remaining_numbers_(0) {}
PHILOX_DEVICE_INLINE
ResultElementType operator()(random::PhiloxRandom* random) {
if (remaining_numbers_ == 0) {
results_ = (*distribution_)(random);
remaining_numbers_ = Distribution::kResultElementCount;
}
remaining_numbers_--;
return results_[remaining_numbers_];
}
// Mark this buffer as empty. The next call to operator() will fill it
// with new random numbers.
PHILOX_DEVICE_INLINE
void Clear() { remaining_numbers_ = 0; }
private:
typedef typename Distribution::ResultType ResultType;
Distribution* distribution_;
ResultType results_;
int remaining_numbers_;
};
}; // namespace }; // namespace
@ -102,7 +56,8 @@ namespace functor {
template <typename T> template <typename T>
struct StatelessRandomGammaFunctor<CPUDevice, T> { struct StatelessRandomGammaFunctor<CPUDevice, T> {
static Status Fill(OpKernelContext* ctx, const T* alpha_flat, static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
int64 num_alphas, int64 samples_per_alpha, int64 num_samples, int64 num_alphas,
int64 samples_per_alpha,
const random::PhiloxRandom& random, T* samples_flat) { const random::PhiloxRandom& random, T* samples_flat) {
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal; typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform; typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
@ -124,8 +79,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
Normal normal; Normal normal;
Uniform uniform; Uniform uniform;
SampleBuffer<Normal> normal_buffer(&normal); RandomSampleBuffer<Normal> normal_buffer(&normal);
SampleBuffer<Uniform> uniform_buffer(&uniform); RandomSampleBuffer<Uniform> uniform_buffer(&uniform);
for (int64 output_idx = start_output; output_idx < limit_output; for (int64 output_idx = start_output; output_idx < limit_output;
/* output_idx incremented within inner loop below */) { /* output_idx incremented within inner loop below */) {
@ -138,14 +93,14 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
const double alpha = static_cast<double>(alpha_flat[alpha_idx]); const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
DISABLE_FLOAT_EQUALITY_WARNING DISABLE_FLOAT_EQUALITY_WARNING
if (alpha == static_cast<double>(1.0)) { if (alpha == 1.0) {
ENABLE_FLOAT_EQUALITY_WARNING ENABLE_FLOAT_EQUALITY_WARNING
// Sample from an exponential distribution. // Sample from an exponential distribution.
for (int64 sample_idx = output_idx % samples_per_alpha; for (int64 sample_idx = output_idx % samples_per_alpha;
sample_idx < samples_per_alpha && output_idx < limit_output; sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) { sample_idx++, output_idx++) {
// As we want data stable regardless of sharding // As we want data stable regardless of sharding, we skip on a
// (including eventually on GPU), we skip on a per-sample basis. // per-sample basis.
random::PhiloxRandom gen = random; random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx); gen.Skip(kReservedSamplesPerOutput * output_idx);
double u = uniform(&gen)[Uniform::kResultElementCount - 1]; double u = uniform(&gen)[Uniform::kResultElementCount - 1];
@ -162,7 +117,7 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
// //
// For alpha<1, we add one to d=alpha-1/3, and multiply the final // For alpha<1, we add one to d=alpha-1/3, and multiply the final
// result by uniform()^(1/alpha) // result by uniform()^(1/alpha)
const bool alpha_less_than_one = alpha < 1; const bool alpha_less_than_one = alpha < 1.0;
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3); const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
const double c = 1.0 / 3 / sqrt(d); const double c = 1.0 / 3 / sqrt(d);
@ -171,8 +126,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
sample_idx < samples_per_alpha && output_idx < limit_output; sample_idx < samples_per_alpha && output_idx < limit_output;
sample_idx++, output_idx++) { sample_idx++, output_idx++) {
// Since each sample may use a variable number of normal/uniform // Since each sample may use a variable number of normal/uniform
// samples, and we want data stable regardless of sharding // samples, and we want data stable regardless of sharding, we skip
// (including eventually on GPU), we skip on a per-sample basis. // on a per-sample basis.
random::PhiloxRandom gen = random; random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx); gen.Skip(kReservedSamplesPerOutput * output_idx);
@ -226,8 +181,8 @@ struct StatelessRandomGammaFunctor<CPUDevice, T> {
Uniform::kElementCost + Uniform::kElementCost +
3 * random::PhiloxRandom::kElementCost; 3 * random::PhiloxRandom::kElementCost;
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers, Shard(worker_threads.num_threads, worker_threads.workers, num_samples,
num_alphas * samples_per_alpha, kElementCost, DoWork); kElementCost, DoWork);
return Status::OK(); return Status::OK();
} }
}; };
@ -280,19 +235,21 @@ class StatelessRandomGammaOp : public OpKernel {
"Input alpha should have non-zero element count, got: ", "Input alpha should have non-zero element count, got: ",
num_alphas)); num_alphas));
const int64 samples_per_alpha = samples_shape.num_elements() / num_alphas; const int64 num_samples = samples_shape.num_elements();
const int64 samples_per_alpha = num_samples / num_alphas;
const auto alpha_flat = alpha_t.flat<T>().data(); const auto alpha_flat = alpha_t.flat<T>().data();
auto samples_flat = output->flat<T>().data(); auto samples_flat = output->flat<T>().data();
OP_REQUIRES_OK(ctx, functor::StatelessRandomGammaFunctor<Device, T>::Fill( OP_REQUIRES_OK(ctx, functor::StatelessRandomGammaFunctor<Device, T>::Fill(
ctx, alpha_flat, num_alphas, samples_per_alpha, ctx, alpha_flat, num_samples, num_alphas,
random, samples_flat)); samples_per_alpha, random, samples_flat));
} }
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomGammaOp); TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomGammaOp);
}; };
#define REGISTER_GAMMA(TYPE) \ // Register CPU kernels for stateless gamma op.
#define REGISTER_GAMMA_CPU(TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \ REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \
.HostMemory("shape") \ .HostMemory("shape") \
@ -301,12 +258,32 @@ class StatelessRandomGammaOp : public OpKernel {
.TypeConstraint<TYPE>("dtype"), \ .TypeConstraint<TYPE>("dtype"), \
StatelessRandomGammaOp<CPUDevice, TYPE>) StatelessRandomGammaOp<CPUDevice, TYPE>)
TF_CALL_half(REGISTER_GAMMA); TF_CALL_half(REGISTER_GAMMA_CPU);
TF_CALL_bfloat16(REGISTER_GAMMA); TF_CALL_bfloat16(REGISTER_GAMMA_CPU);
TF_CALL_float(REGISTER_GAMMA); TF_CALL_float(REGISTER_GAMMA_CPU);
TF_CALL_double(REGISTER_GAMMA); TF_CALL_double(REGISTER_GAMMA_CPU);
#undef REGISTER_GAMMA #undef REGISTER_GAMMA_CPU
// Register GPU kernels for stateless gamma op.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GAMMA_GPU(TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
.Device(DEVICE_GPU) \
.HostMemory("shape") \
.HostMemory("seed") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomGammaOp<GPUDevice, TYPE>)
TF_CALL_half(REGISTER_GAMMA_GPU);
TF_CALL_bfloat16(REGISTER_GAMMA_GPU);
TF_CALL_float(REGISTER_GAMMA_GPU);
TF_CALL_double(REGISTER_GAMMA_GPU);
#undef REGISTER_GAMMA_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -27,12 +27,60 @@ namespace functor {
template <typename Device, typename T> template <typename Device, typename T>
struct StatelessRandomGammaFunctor { struct StatelessRandomGammaFunctor {
static Status Fill(OpKernelContext* ctx, const T* alpha_flat, static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
int64 num_alphas, int64 samples_per_alpha, int64 num_samples, int64 num_alphas,
int64 samples_per_alpha,
const random::PhiloxRandom& random, T* samples_flat); const random::PhiloxRandom& random, T* samples_flat);
}; };
} // namespace functor } // namespace functor
// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
// a single sample from this buffer. If the buffer is empty, it first generates
// new samples using the provided distribution.
//
// If the call to Distribution::operator() returns samples[0...N-1], then this
// class returns samples in the following order:
//
// samples[N-1], samples[N-2],..., samples[1], samples[0]
//
// For comparison, random::SingleSampleAdapter returns samples in
// the following order:
//
// samples[0], samples[1],...,samples[N-2], samples[N-1].
//
template <class Distribution>
class RandomSampleBuffer {
public:
typedef typename Distribution::ResultElementType ResultElementType;
PHILOX_DEVICE_INLINE
explicit RandomSampleBuffer(Distribution* distribution)
: distribution_(distribution), remaining_numbers_(0) {}
PHILOX_DEVICE_INLINE
ResultElementType operator()(random::PhiloxRandom* random) {
if (remaining_numbers_ == 0) {
results_ = (*distribution_)(random);
remaining_numbers_ = Distribution::kResultElementCount;
}
remaining_numbers_--;
return results_[remaining_numbers_];
}
// Mark this buffer as empty. The next call to operator() will fill it
// with new random numbers.
PHILOX_DEVICE_INLINE
void Clear() { remaining_numbers_ = 0; }
private:
typedef typename Distribution::ResultType ResultType;
Distribution* distribution_;
ResultType results_;
int remaining_numbers_;
};
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_ #endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_

View File

@ -0,0 +1,179 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/stateless_random_gamma_op.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {
namespace {
typedef Eigen::GpuDevice GPUDevice;
// Each attempt to generate a new draw from the Gamma distribution is 95+%
// successful, and requires 1-2 normal + 1 uniform sample.
static constexpr int kReservedSamplesPerOutput = 256;
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
#define DISABLE_FLOAT_EQUALITY_WARNING \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
#else
#define DISABLE_FLOAT_EQUALITY_WARNING
#define ENABLE_FLOAT_EQUALITY_WARNING
#endif
template <typename T>
__global__ void __launch_bounds__(1024)
FillKernel(int64 num_samples, int64 num_alphas, int64 samples_per_alpha,
random::PhiloxRandom random, T* samples_flat,
const T* alpha_flat) {
using Eigen::numext::exp;
using Eigen::numext::log;
using Eigen::numext::log1p;
using Eigen::numext::pow;
typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
Normal normal;
Uniform uniform;
RandomSampleBuffer<Normal> normal_buffer(&normal);
RandomSampleBuffer<Uniform> uniform_buffer(&uniform);
for (int64 output_idx : GpuGridRangeX(num_samples)) {
int64 alpha_idx = output_idx / samples_per_alpha;
int64 sample_idx = output_idx % samples_per_alpha;
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
DISABLE_FLOAT_EQUALITY_WARNING
if (alpha == 1.0) {
ENABLE_FLOAT_EQUALITY_WARNING
// Sample from an exponential distribution.
// As we want data stable regardless of sharding, we skip on a per-sample
// basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
double u = uniform(&gen)[Uniform::kResultElementCount - 1];
const double res = -log1p(-u);
// We use alpha_idx + sample_idx * num_alphas instead of output_idx
// to generate numbers in the right order (CPU and GPU kernels
// must generate numbers in the same order).
samples_flat[alpha_idx + sample_idx * num_alphas] = static_cast<T>(res);
} else { // if alpha != 1.0
// Transformation-rejection from pairs of uniform and normal random
// variables. http://dl.acm.org/citation.cfm?id=358414
//
// The algorithm has an acceptance rate of ~95% for small alpha (~1),
// and higher accept rates for higher alpha, so runtime is
// O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
//
// For alpha<1, we add one to d=alpha-1/3, and multiply the final
// result by uniform()^(1/alpha)
const bool alpha_less_than_one = alpha < 1.0;
const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
const double c = 1.0 / 3 / sqrt(d);
// Since each sample may use a variable number of normal/uniform
// samples, and we want data stable regardless of sharding, we skip on a
// per-sample basis.
random::PhiloxRandom gen = random;
gen.Skip(kReservedSamplesPerOutput * output_idx);
// To prevent overwriting SampleBuffer's underlying array with
// zeros (in tensorflow::random::Array constructor), we just mark
// the buffer as empty instead of initializing a new SampleBuffer
// object here. The next call to operator() will fill the buffer
// with new numbers.
normal_buffer.Clear();
uniform_buffer.Clear();
// Keep trying until we don't reject a sample. In practice, we will
// only reject ~5% at worst, for low alpha near 1.
while (true) {
const double x = normal_buffer(&gen);
double v = 1 + c * x;
if (v <= 0) {
continue;
}
v = v * v * v;
double u = uniform_buffer(&gen);
// The first option in the if is a "squeeze" short-circuit to
// dodge the two logs. Magic constant sourced from the paper
// linked above. Upward of .91 of the area covered by the log
// inequality is covered by the squeeze as well (larger coverage
// for smaller values of alpha).
if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
(log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
double res = d * v;
if (alpha_less_than_one) {
double b = uniform_buffer(&gen);
res *= pow(b, 1 / alpha);
}
// We use alpha_idx + sample_idx * num_alphas instead of output_idx
// to generate numbers in the right order (CPU and GPU kernels
// must generate numbers in the same order).
samples_flat[alpha_idx + sample_idx * num_alphas] =
static_cast<T>(res);
break;
}
} // while: true
} // if (alpha == 1.0)
} // for: output_idx
}
} // namespace
namespace functor {
template <typename T>
struct StatelessRandomGammaFunctor<GPUDevice, T> {
static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
int64 num_samples, int64 num_alphas,
int64 samples_per_alpha,
const random::PhiloxRandom& random, T* samples_flat) {
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
GpuLaunchConfig cfg = GetGpuLaunchConfig(num_samples, d);
TF_CHECK_OK(GpuLaunchKernel(FillKernel<T>, cfg.block_count,
cfg.thread_per_block, 0, d.stream(),
num_samples, num_alphas, samples_per_alpha,
random, samples_flat, alpha_flat));
return Status::OK();
}
};
} // namespace functor
#define REGISTER_GPU_SPEC(type) \
template struct functor::StatelessRandomGammaFunctor<GPUDevice, type>;
TF_CALL_half(REGISTER_GPU_SPEC);
TF_CALL_bfloat16(REGISTER_GPU_SPEC);
TF_CALL_float(REGISTER_GPU_SPEC);
TF_CALL_double(REGISTER_GPU_SPEC);
#undef REGISTER_GPU_SPEC
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -216,6 +216,17 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
pure = stateless_op(seed=preseed) pure = stateless_op(seed=preseed)
self.assertAllEqual(stateful, pure) self.assertAllEqual(stateful, pure)
def _test_match_stateless_cpu_gpu(self, case, seed):
# Stateless ops should produce the same result on CPUs and GPUs.
_, stateless_op, _ = case
with ops.device('CPU'):
result_cpu = stateless_op(seed=seed)
with ops.device(get_device().name):
result_gpu = stateless_op(seed=seed)
self.assertAllClose(result_cpu, result_gpu)
def _test_old_and_new_stateless_match(self, case, seed): def _test_old_and_new_stateless_match(self, case, seed):
"""Tests that the new stateless ops match the old stateless ones.""" """Tests that the new stateless ops match the old stateless ones."""
with ops.device(get_device().name): with ops.device(get_device().name):
@ -306,6 +317,18 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
self.skipTest('Lacking XLA kernel') self.skipTest('Lacking XLA kernel')
self._test_match(case, seed) self._test_match(case, seed)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS)
for case_id, case in enumerate(gamma_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testStatelessGammaCpuGpuMatch(self, case, seed):
if get_device().device_type != 'GPU':
# This test compares the numbers produced by the CPU and GPU kernel for
# stateless_random_gamma.
self.skipTest('This test requires GPU')
self._test_match_stateless_cpu_gpu(case, seed)
@parameterized.named_parameters( @parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS) for seed_id, seed in enumerate(SEEDS)
@ -387,10 +410,6 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
for case_id, case in enumerate(gamma_cases())) for case_id, case in enumerate(gamma_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testDeterminismGamma(self, case, seed_type): def testDeterminismGamma(self, case, seed_type):
if get_device().device_type == 'GPU':
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking GPU kernel')
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'): if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
# This test was passing before because soft placement silently picked the # This test was passing before because soft placement silently picked the
# CPU kernels. # CPU kernels.