Add a GPU implementation of the ParameterizedTruncatedNormalOp kernel.

Benchmarks:

Benchmark                                   Time(ns)    CPU(ns) Iterations
--------------------------------------------------------------------------
BM_PTruncatedNormal_gpu_1000_1000            4632369    5175938        100  184.251M items/s
BM_PTruncatedNormal_2SD_gpu_10000_100        2849437    3368804        206  283.090M items/s
BM_PTruncatedNormal_OneTail_gpu_10000_100    3300317    3905713        179  244.174M items/s
Change: 138074670
This commit is contained in:
A. Unique TensorFlower 2016-11-03 07:42:29 -08:00 committed by TensorFlower Gardener
parent 81743e74f4
commit f19dc6c5e2
6 changed files with 299 additions and 49 deletions

View File

@ -46,25 +46,6 @@ namespace functor {
using random::PhiloxRandom;
using random::SingleSampleAdapter;
// Sample a truncated normal random variable, with mean, stddev, minval, and
// maxval parameters for each batch. Uses two rejection sampling algorithms
// described in http://rd.springer.com/article/10.1007/BF00143942.
//
// Either minval may be -infinity, or maxval may be +infinity. If the interval
// (minval, maxval) is empty, the result is NaN. Large intervals which include
// both tails may have reduced accuracy.
template <typename Device, typename T>
struct TruncatedNormalFunctor {
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
typename TTypes<T>::ConstFlat means,
typename TTypes<T>::ConstFlat stddevs,
typename TTypes<T>::ConstFlat minvals,
typename TTypes<T>::ConstFlat maxvals,
const random::PhiloxRandom& gen,
typename TTypes<T>::Flat output);
};
template <typename T>
struct TruncatedNormalFunctor<CPUDevice, T> {
static const int kMaxIterations = 100;
@ -96,8 +77,8 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
// Vectorized intermediate calculations for uniform rejection sampling.
// We always generate at most 4 samples.
tensorflow::random::Array<T, 4> z;
tensorflow::random::Array<T, 4> g;
Eigen::array<T, 4> z;
Eigen::array<T, 4> g;
for (int64 b = start_batch; b < limit_batch; ++b) {
// We are passed a flat array for each of the parameter tensors.
@ -145,13 +126,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
if (diff < cutoff) {
// Sample from a uniform distribution on [normMin, normMax].
T plusFactor;
if (normMin < T(0)) {
// normMax > 0 because it is flipped otherwise.
plusFactor = T(0);
} else {
plusFactor = normMin * normMin;
}
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
while (sample < limit_sample) {
const auto rand = dist(&gen_copy);
@ -395,4 +370,21 @@ TF_CALL_double(REGISTER);
#undef REGISTER
#if GOOGLE_CUDA
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
.Device(DEVICE_GPU) \
.HostMemory("shape") \
.TypeConstraint<TYPE>("dtype"), \
ParameterizedTruncatedNormalOp<GPUDevice, TYPE>)
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
#undef REGISTER
#endif // GOOGLE_CUDA
} // end namespace tensorflow

View File

@ -16,14 +16,35 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
#define TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/random/random_distributions.h"
namespace tensorflow {
class OpKernelContext;
namespace functor {
// Sample a truncated normal random variable, with mean, stddev, minval, and
// maxval parameters for each batch. Uses two rejection sampling algorithms
// described in http://rd.springer.com/article/10.1007/BF00143942.
//
// Either minval may be -infinity, or maxval may be +infinity. If the interval
// (minval, maxval) is empty, the result is NaN. Large intervals which include
// both tails may have reduced accuracy.
template <typename Device, typename T>
struct TruncatedNormalFunctor;
struct TruncatedNormalFunctor {
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
typename TTypes<T>::ConstFlat means,
typename TTypes<T>::ConstFlat stddevs,
typename TTypes<T>::ConstFlat minvals,
typename TTypes<T>::ConstFlat maxvals,
const random::PhiloxRandom& gen,
typename TTypes<T>::Flat output);
static const int kMaxIterations = 100;
};
} // namespace functor
} // namespace tensorflow

View File

@ -0,0 +1,214 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/parameterized_truncated_normal_op.h"
#include <assert.h>
#include <stdio.h>
#include <cmath>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#define UNROLL _Pragma("unroll")
namespace tensorflow {
class OpKernelContext;
namespace functor {
typedef Eigen::GpuDevice GPUDevice;
template <typename T>
__global__ void __launch_bounds__(1024)
TruncatedNormalKernel(random::PhiloxRandom gen, T* data, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
const T* means, bool single_mean, const T* stddevs,
bool single_stddev, const T* minvals,
bool single_minval, const T* maxvals,
bool single_maxval, int64 kMaxIterations) {
const int32 max_samples_per_item = 2 * kMaxIterations;
// Initial offset as given by CUDA_1D_KERNEL_LOOP.
const int32 initial_offset = blockIdx.x * blockDim.x + threadIdx.x;
gen.Skip(max_samples_per_item * initial_offset);
typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
Uniform dist;
const int kDistSize = Uniform::kResultElementCount;
const T quietNaN = Eigen::NumTraits<T>::quiet_NaN();
// We skip the total number of threads to get to the next element. To produce
// deterministic results between devices, each element in the output array
// skips max_samples_per_item in the generator. Then after generating this
// item, we need to skip the samples for one element for every thread to get
// to the next element that we actually process.
const int32 samples_between_processed_elements =
max_samples_per_item * (gridDim.x * blockDim.x);
CUDA_1D_KERNEL_LOOP(offset, num_elements) {
// Track how many more samples we need to skip before we process the next
// element.
int32 remaining_samples = samples_between_processed_elements;
const int64 batch_id = offset / samples_per_batch;
T mean = means[single_mean ? 0 : batch_id];
const T input_stddev = stddevs[single_stddev ? 0 : batch_id];
T minval = minvals[single_minval ? 0 : batch_id];
T maxval = maxvals[single_maxval ? 0 : batch_id];
// Flip the distribution if we can make the lower bound positive.
T stddev;
if (Eigen::numext::isinf(minval) || maxval < mean) {
// Reverse all calculations. normMin and normMax will be flipped.
// std::swap is a host function (not available in CUDA).
T temp = minval;
minval = maxval;
maxval = temp;
stddev = -input_stddev;
} else {
stddev = input_stddev;
}
// Calculate normalized samples, then scale them.
const T normMin = (minval - mean) / stddev;
const T normMax = (maxval - mean) / stddev;
// Determine the method to use.
const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
const T cutoff =
T(2) *
Eigen::numext::exp(T(0.5) + (normMin * (normMin - sqrtFactor)) / T(4)) /
(normMin + sqrtFactor);
const T diff = normMax - normMin;
// Validate the normalized min and max, because the originals may have been
// flipped already.
if (!(input_stddev > T(0) && normMin < normMax &&
(Eigen::numext::isfinite(normMin) ||
Eigen::numext::isfinite(normMax)))) {
data[offset] = quietNaN;
} else if (diff < cutoff) {
// Sample from a uniform distribution on [normMin, normMax].
// Vectorized intermediate calculations for uniform rejection sampling.
// We always generate at most 4 samples.
Eigen::array<T, 4> z;
Eigen::array<T, 4> g;
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
int numIterations = 0;
while (numIterations < kMaxIterations) {
const auto rand = dist(&gen);
remaining_samples -= gen.kResultElementCount;
UNROLL for (int i = 0; i < kDistSize; i++) {
z[i] = rand[i] * diff + normMin;
}
UNROLL for (int i = 0; i < kDistSize; i++) {
g[i] = (plusFactor - z[i] * z[i]) / 2.0;
}
const auto u = dist(&gen);
remaining_samples -= gen.kResultElementCount;
UNROLL for (int i = 0; i < kDistSize; i++) {
if (u[i] <= Eigen::numext::exp(g[i]) ||
numIterations + 1 >= kMaxIterations) {
// Accept the sample z.
// If we run out of iterations, just use the current uniform
// sample. Emperically, the probability of accepting each sample
// is at least 50% for typical inputs, so we will always accept
// by 100 iterations.
// This introduces a slight inaccuracy when at least one bound
// is large, minval is negative and maxval is positive.
data[offset] = z[i] * stddev + mean;
// Break out of the nested loop by updating numIterations.
numIterations = kMaxIterations;
break;
} else {
numIterations++;
}
}
}
} else {
// Sample from an exponential distribution with alpha maximizing
// acceptance probability, offset by normMin from the origin.
// Accept only if less than normMax.
const T alpha =
(normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / T(2);
int numIterations = 0;
while (numIterations < kMaxIterations) {
auto rand = dist(&gen);
remaining_samples -= gen.kResultElementCount;
UNROLL for (int i = 0; i < kDistSize; i += 2) {
const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
const T x = normMin < alpha ? alpha - z : normMin - alpha;
const T g = Eigen::numext::exp(-x * x / 2.0);
const T u = rand[i + 1];
if ((u <= g && z < normMax) || numIterations + 1 >= kMaxIterations) {
data[offset] = z * stddev + mean;
// Break out of the nested loop by updating numIterations.
numIterations = kMaxIterations;
break;
} else {
numIterations++;
}
}
}
}
gen.Skip(remaining_samples);
}
}
// Partial specialization for GPU
template <typename T>
struct TruncatedNormalFunctor<GPUDevice, T> {
static const int kMaxIterations = 100;
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
typename TTypes<T>::ConstFlat means,
typename TTypes<T>::ConstFlat stddevs,
typename TTypes<T>::ConstFlat minvals,
typename TTypes<T>::ConstFlat maxvals,
const random::PhiloxRandom& gen,
typename TTypes<T>::Flat output) {
const auto config = GetCudaLaunchConfig(num_elements, d);
TruncatedNormalKernel<
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
gen, output.data(), num_batches, samples_per_batch, num_elements,
means.data(), means.dimension(0) == 1, stddevs.data(),
stddevs.dimension(0) == 1, minvals.data(), minvals.dimension(0) == 1,
maxvals.data(), maxvals.dimension(0) == 1, kMaxIterations);
};
};
// Explicit instantiation of the GPU distributions functors
template struct TruncatedNormalFunctor<GPUDevice, Eigen::half>;
template struct TruncatedNormalFunctor<GPUDevice, float>;
template struct TruncatedNormalFunctor<GPUDevice, double>;
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -131,5 +131,8 @@ static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) {
BM_PTruncatedNormalDev(cpu, 1000, 1000);
BM_PTruncatedNormalDev_2SD(cpu, 10000, 100);
BM_PTruncatedNormalDev_OneTail(cpu, 10000, 100);
BM_PTruncatedNormalDev(gpu, 1000, 1000);
BM_PTruncatedNormalDev_2SD(gpu, 10000, 100);
BM_PTruncatedNormalDev_OneTail(gpu, 10000, 100);
} // namespace tensorflow

View File

@ -220,7 +220,7 @@ cuda_py_test(
additional_deps = ["//tensorflow:tensorflow_py"],
)
tf_py_test(
cuda_py_test(
name = "parameterized_truncated_normal_op_test",
size = "small",
srcs = ["parameterized_truncated_normal_op_test.py"],

View File

@ -97,10 +97,10 @@ def z_test(real, expected, i, num_samples):
class ParameterizedTruncatedNormalTest(tf.test.TestCase):
use_gpu = False
_use_gpu = False
z_limit = 6.0
# Stop at moment 20 to avoid numerical errors in the theoretical moments.
# Stop at moment 10 to avoid numerical errors in the theoretical moments.
max_moment = 10
def validateMoments(self, shape, mean, stddev, minval, maxval, seed=1618):
@ -109,9 +109,11 @@ class ParameterizedTruncatedNormalTest(tf.test.TestCase):
# Give up early if we are unable to import it.
import scipy.stats # pylint: disable=g-import-not-at-top,unused-variable
tf.set_random_seed(seed)
with self.test_session(use_gpu=self.use_gpu):
samples = random_ops.parameterized_truncated_normal(
shape, mean, stddev, minval, maxval).eval()
with self.test_session(use_gpu=self._use_gpu):
samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
minval,
maxval).eval()
assert (~np.isnan(samples)).all()
moments = calculate_moments(samples, self.max_moment)
expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval)
num_samples = functools.reduce(lambda x, y: x * y, shape, 1)
@ -131,9 +133,11 @@ class ParameterizedTruncatedNormalTest(tf.test.TestCase):
try:
import scipy.stats # pylint: disable=g-import-not-at-top
tf.set_random_seed(seed)
with self.test_session(use_gpu=self.use_gpu):
samples = random_ops.parameterized_truncated_normal(
shape, mean, stddev, minval, maxval).eval()
with self.test_session(use_gpu=self._use_gpu):
samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
minval,
maxval).eval()
assert (~np.isnan(samples)).all()
minval = max(mean - stddev * 10, minval)
maxval = min(mean + stddev * 10, maxval)
dist = scipy.stats.norm(loc=mean, scale=stddev)
@ -173,8 +177,12 @@ class ParameterizedTruncatedNormalTest(tf.test.TestCase):
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
class ParameterizedTruncatedNormalGpuTest(ParameterizedTruncatedNormalTest):
_use_gpu = True
# Benchmarking code
def parameterized_vs_naive(shape, num_iters):
def parameterized_vs_naive(shape, num_iters, use_gpu=False):
np.random.seed(1618) # Make it reproducible.
# No CSE/CF.
@ -183,17 +191,29 @@ def parameterized_vs_naive(shape, num_iters):
graph_options=tf.GraphOptions(optimizer_options=optimizer_options))
with tf.Session(config=config) as sess:
param_op = tf.group(random_ops.parameterized_truncated_normal(shape))
naive_op = tf.group(random_ops.truncated_normal(shape))
with tf.device("/cpu:0" if not use_gpu else None):
param_op = tf.group(random_ops.parameterized_truncated_normal(shape))
naive_op = tf.group(random_ops.truncated_normal(shape))
# Burn-in to avoid session setup costs in the timing.
sess.run(param_op)
sess.run(param_op)
param_dt = timeit.timeit(lambda: sess.run(param_op), number=num_iters)
sess.run(naive_op)
sess.run(naive_op)
naive_dt = timeit.timeit(lambda: sess.run(naive_op), number=num_iters)
return param_dt, naive_dt
class TruncatedNormalBenchmark(tf.test.Benchmark):
def benchmarkParameterizedOpVsNaiveOp(self):
def benchmarkParameterizedOpVsNaiveOpCpu(self):
self._benchmarkParameterizedOpVsNaiveOp(False)
def benchmarkParameterizedOpVsNaiveOpGpu(self):
self._benchmarkParameterizedOpVsNaiveOp(True)
def _benchmarkParameterizedOpVsNaiveOp(self, use_gpu):
num_iters = 50
print(("Composition of new ParameterizedTruncatedNormalOp vs. "
"naive TruncatedNormalOp [%d iters]") % num_iters)
@ -201,16 +221,16 @@ class TruncatedNormalBenchmark(tf.test.Benchmark):
for shape in [[10000, 100], [1000, 1000], [1000000], [100, 100, 100],
[20, 20, 20, 20]]:
p_dt, n_dt = parameterized_vs_naive(shape, num_iters)
p_dt, n_dt = parameterized_vs_naive(shape, num_iters, use_gpu)
print("%s\t%.3f\t%.3f\t%.2f" % (shape, p_dt, n_dt, p_dt / n_dt))
shape_str = "-".join(map(str, shape))
self.report_benchmark(name="parameterized_shape" + shape_str,
iters=num_iters,
wall_time=p_dt)
self.report_benchmark(name="naive_shape" + shape_str,
iters=num_iters,
wall_time=n_dt)
self.report_benchmark(
name="parameterized_shape" + shape_str,
iters=num_iters,
wall_time=p_dt)
self.report_benchmark(
name="naive_shape" + shape_str, iters=num_iters, wall_time=n_dt)
if __name__ == "__main__":