Add a GPU implementation of the ParameterizedTruncatedNormalOp kernel.
Benchmarks: Benchmark Time(ns) CPU(ns) Iterations -------------------------------------------------------------------------- BM_PTruncatedNormal_gpu_1000_1000 4632369 5175938 100 184.251M items/s BM_PTruncatedNormal_2SD_gpu_10000_100 2849437 3368804 206 283.090M items/s BM_PTruncatedNormal_OneTail_gpu_10000_100 3300317 3905713 179 244.174M items/s Change: 138074670
This commit is contained in:
parent
81743e74f4
commit
f19dc6c5e2
@ -46,25 +46,6 @@ namespace functor {
|
|||||||
using random::PhiloxRandom;
|
using random::PhiloxRandom;
|
||||||
using random::SingleSampleAdapter;
|
using random::SingleSampleAdapter;
|
||||||
|
|
||||||
// Sample a truncated normal random variable, with mean, stddev, minval, and
|
|
||||||
// maxval parameters for each batch. Uses two rejection sampling algorithms
|
|
||||||
// described in http://rd.springer.com/article/10.1007/BF00143942.
|
|
||||||
//
|
|
||||||
// Either minval may be -infinity, or maxval may be +infinity. If the interval
|
|
||||||
// (minval, maxval) is empty, the result is NaN. Large intervals which include
|
|
||||||
// both tails may have reduced accuracy.
|
|
||||||
template <typename Device, typename T>
|
|
||||||
struct TruncatedNormalFunctor {
|
|
||||||
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
|
|
||||||
int64 samples_per_batch, int64 num_elements,
|
|
||||||
typename TTypes<T>::ConstFlat means,
|
|
||||||
typename TTypes<T>::ConstFlat stddevs,
|
|
||||||
typename TTypes<T>::ConstFlat minvals,
|
|
||||||
typename TTypes<T>::ConstFlat maxvals,
|
|
||||||
const random::PhiloxRandom& gen,
|
|
||||||
typename TTypes<T>::Flat output);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct TruncatedNormalFunctor<CPUDevice, T> {
|
struct TruncatedNormalFunctor<CPUDevice, T> {
|
||||||
static const int kMaxIterations = 100;
|
static const int kMaxIterations = 100;
|
||||||
@ -96,8 +77,8 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
|
|
||||||
// Vectorized intermediate calculations for uniform rejection sampling.
|
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||||
// We always generate at most 4 samples.
|
// We always generate at most 4 samples.
|
||||||
tensorflow::random::Array<T, 4> z;
|
Eigen::array<T, 4> z;
|
||||||
tensorflow::random::Array<T, 4> g;
|
Eigen::array<T, 4> g;
|
||||||
|
|
||||||
for (int64 b = start_batch; b < limit_batch; ++b) {
|
for (int64 b = start_batch; b < limit_batch; ++b) {
|
||||||
// We are passed a flat array for each of the parameter tensors.
|
// We are passed a flat array for each of the parameter tensors.
|
||||||
@ -145,13 +126,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
|
|||||||
if (diff < cutoff) {
|
if (diff < cutoff) {
|
||||||
// Sample from a uniform distribution on [normMin, normMax].
|
// Sample from a uniform distribution on [normMin, normMax].
|
||||||
|
|
||||||
T plusFactor;
|
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
|
||||||
if (normMin < T(0)) {
|
|
||||||
// normMax > 0 because it is flipped otherwise.
|
|
||||||
plusFactor = T(0);
|
|
||||||
} else {
|
|
||||||
plusFactor = normMin * normMin;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (sample < limit_sample) {
|
while (sample < limit_sample) {
|
||||||
const auto rand = dist(&gen_copy);
|
const auto rand = dist(&gen_copy);
|
||||||
@ -395,4 +370,21 @@ TF_CALL_double(REGISTER);
|
|||||||
|
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define REGISTER(TYPE) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.HostMemory("shape") \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
ParameterizedTruncatedNormalOp<GPUDevice, TYPE>)
|
||||||
|
|
||||||
|
TF_CALL_half(REGISTER);
|
||||||
|
TF_CALL_float(REGISTER);
|
||||||
|
TF_CALL_double(REGISTER);
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -16,14 +16,35 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
|
#ifndef TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
|
||||||
#define TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
|
#define TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class OpKernelContext;
|
class OpKernelContext;
|
||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
|
// Sample a truncated normal random variable, with mean, stddev, minval, and
|
||||||
|
// maxval parameters for each batch. Uses two rejection sampling algorithms
|
||||||
|
// described in http://rd.springer.com/article/10.1007/BF00143942.
|
||||||
|
//
|
||||||
|
// Either minval may be -infinity, or maxval may be +infinity. If the interval
|
||||||
|
// (minval, maxval) is empty, the result is NaN. Large intervals which include
|
||||||
|
// both tails may have reduced accuracy.
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
struct TruncatedNormalFunctor;
|
struct TruncatedNormalFunctor {
|
||||||
|
void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches,
|
||||||
|
int64 samples_per_batch, int64 num_elements,
|
||||||
|
typename TTypes<T>::ConstFlat means,
|
||||||
|
typename TTypes<T>::ConstFlat stddevs,
|
||||||
|
typename TTypes<T>::ConstFlat minvals,
|
||||||
|
typename TTypes<T>::ConstFlat maxvals,
|
||||||
|
const random::PhiloxRandom& gen,
|
||||||
|
typename TTypes<T>::Flat output);
|
||||||
|
|
||||||
|
static const int kMaxIterations = 100;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -0,0 +1,214 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/parameterized_truncated_normal_op.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
|
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||||
|
|
||||||
|
#define UNROLL _Pragma("unroll")
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class OpKernelContext;
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void __launch_bounds__(1024)
|
||||||
|
TruncatedNormalKernel(random::PhiloxRandom gen, T* data, int64 num_batches,
|
||||||
|
int64 samples_per_batch, int64 num_elements,
|
||||||
|
const T* means, bool single_mean, const T* stddevs,
|
||||||
|
bool single_stddev, const T* minvals,
|
||||||
|
bool single_minval, const T* maxvals,
|
||||||
|
bool single_maxval, int64 kMaxIterations) {
|
||||||
|
const int32 max_samples_per_item = 2 * kMaxIterations;
|
||||||
|
// Initial offset as given by CUDA_1D_KERNEL_LOOP.
|
||||||
|
const int32 initial_offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
gen.Skip(max_samples_per_item * initial_offset);
|
||||||
|
typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
|
||||||
|
Uniform dist;
|
||||||
|
const int kDistSize = Uniform::kResultElementCount;
|
||||||
|
const T quietNaN = Eigen::NumTraits<T>::quiet_NaN();
|
||||||
|
|
||||||
|
// We skip the total number of threads to get to the next element. To produce
|
||||||
|
// deterministic results between devices, each element in the output array
|
||||||
|
// skips max_samples_per_item in the generator. Then after generating this
|
||||||
|
// item, we need to skip the samples for one element for every thread to get
|
||||||
|
// to the next element that we actually process.
|
||||||
|
const int32 samples_between_processed_elements =
|
||||||
|
max_samples_per_item * (gridDim.x * blockDim.x);
|
||||||
|
|
||||||
|
CUDA_1D_KERNEL_LOOP(offset, num_elements) {
|
||||||
|
// Track how many more samples we need to skip before we process the next
|
||||||
|
// element.
|
||||||
|
int32 remaining_samples = samples_between_processed_elements;
|
||||||
|
|
||||||
|
const int64 batch_id = offset / samples_per_batch;
|
||||||
|
T mean = means[single_mean ? 0 : batch_id];
|
||||||
|
const T input_stddev = stddevs[single_stddev ? 0 : batch_id];
|
||||||
|
T minval = minvals[single_minval ? 0 : batch_id];
|
||||||
|
T maxval = maxvals[single_maxval ? 0 : batch_id];
|
||||||
|
|
||||||
|
// Flip the distribution if we can make the lower bound positive.
|
||||||
|
T stddev;
|
||||||
|
if (Eigen::numext::isinf(minval) || maxval < mean) {
|
||||||
|
// Reverse all calculations. normMin and normMax will be flipped.
|
||||||
|
// std::swap is a host function (not available in CUDA).
|
||||||
|
T temp = minval;
|
||||||
|
minval = maxval;
|
||||||
|
maxval = temp;
|
||||||
|
stddev = -input_stddev;
|
||||||
|
} else {
|
||||||
|
stddev = input_stddev;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate normalized samples, then scale them.
|
||||||
|
const T normMin = (minval - mean) / stddev;
|
||||||
|
const T normMax = (maxval - mean) / stddev;
|
||||||
|
|
||||||
|
// Determine the method to use.
|
||||||
|
const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
|
||||||
|
const T cutoff =
|
||||||
|
T(2) *
|
||||||
|
Eigen::numext::exp(T(0.5) + (normMin * (normMin - sqrtFactor)) / T(4)) /
|
||||||
|
(normMin + sqrtFactor);
|
||||||
|
const T diff = normMax - normMin;
|
||||||
|
|
||||||
|
// Validate the normalized min and max, because the originals may have been
|
||||||
|
// flipped already.
|
||||||
|
if (!(input_stddev > T(0) && normMin < normMax &&
|
||||||
|
(Eigen::numext::isfinite(normMin) ||
|
||||||
|
Eigen::numext::isfinite(normMax)))) {
|
||||||
|
data[offset] = quietNaN;
|
||||||
|
} else if (diff < cutoff) {
|
||||||
|
// Sample from a uniform distribution on [normMin, normMax].
|
||||||
|
|
||||||
|
// Vectorized intermediate calculations for uniform rejection sampling.
|
||||||
|
// We always generate at most 4 samples.
|
||||||
|
Eigen::array<T, 4> z;
|
||||||
|
Eigen::array<T, 4> g;
|
||||||
|
|
||||||
|
const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
|
||||||
|
|
||||||
|
int numIterations = 0;
|
||||||
|
while (numIterations < kMaxIterations) {
|
||||||
|
const auto rand = dist(&gen);
|
||||||
|
remaining_samples -= gen.kResultElementCount;
|
||||||
|
UNROLL for (int i = 0; i < kDistSize; i++) {
|
||||||
|
z[i] = rand[i] * diff + normMin;
|
||||||
|
}
|
||||||
|
UNROLL for (int i = 0; i < kDistSize; i++) {
|
||||||
|
g[i] = (plusFactor - z[i] * z[i]) / 2.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto u = dist(&gen);
|
||||||
|
remaining_samples -= gen.kResultElementCount;
|
||||||
|
UNROLL for (int i = 0; i < kDistSize; i++) {
|
||||||
|
if (u[i] <= Eigen::numext::exp(g[i]) ||
|
||||||
|
numIterations + 1 >= kMaxIterations) {
|
||||||
|
// Accept the sample z.
|
||||||
|
// If we run out of iterations, just use the current uniform
|
||||||
|
// sample. Emperically, the probability of accepting each sample
|
||||||
|
// is at least 50% for typical inputs, so we will always accept
|
||||||
|
// by 100 iterations.
|
||||||
|
// This introduces a slight inaccuracy when at least one bound
|
||||||
|
// is large, minval is negative and maxval is positive.
|
||||||
|
data[offset] = z[i] * stddev + mean;
|
||||||
|
// Break out of the nested loop by updating numIterations.
|
||||||
|
numIterations = kMaxIterations;
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
numIterations++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Sample from an exponential distribution with alpha maximizing
|
||||||
|
// acceptance probability, offset by normMin from the origin.
|
||||||
|
// Accept only if less than normMax.
|
||||||
|
const T alpha =
|
||||||
|
(normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / T(2);
|
||||||
|
int numIterations = 0;
|
||||||
|
while (numIterations < kMaxIterations) {
|
||||||
|
auto rand = dist(&gen);
|
||||||
|
remaining_samples -= gen.kResultElementCount;
|
||||||
|
UNROLL for (int i = 0; i < kDistSize; i += 2) {
|
||||||
|
const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
|
||||||
|
const T x = normMin < alpha ? alpha - z : normMin - alpha;
|
||||||
|
const T g = Eigen::numext::exp(-x * x / 2.0);
|
||||||
|
const T u = rand[i + 1];
|
||||||
|
if ((u <= g && z < normMax) || numIterations + 1 >= kMaxIterations) {
|
||||||
|
data[offset] = z * stddev + mean;
|
||||||
|
// Break out of the nested loop by updating numIterations.
|
||||||
|
numIterations = kMaxIterations;
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
numIterations++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gen.Skip(remaining_samples);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Partial specialization for GPU
|
||||||
|
template <typename T>
|
||||||
|
struct TruncatedNormalFunctor<GPUDevice, T> {
|
||||||
|
static const int kMaxIterations = 100;
|
||||||
|
|
||||||
|
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
|
||||||
|
int64 samples_per_batch, int64 num_elements,
|
||||||
|
typename TTypes<T>::ConstFlat means,
|
||||||
|
typename TTypes<T>::ConstFlat stddevs,
|
||||||
|
typename TTypes<T>::ConstFlat minvals,
|
||||||
|
typename TTypes<T>::ConstFlat maxvals,
|
||||||
|
const random::PhiloxRandom& gen,
|
||||||
|
typename TTypes<T>::Flat output) {
|
||||||
|
const auto config = GetCudaLaunchConfig(num_elements, d);
|
||||||
|
|
||||||
|
TruncatedNormalKernel<
|
||||||
|
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
|
gen, output.data(), num_batches, samples_per_batch, num_elements,
|
||||||
|
means.data(), means.dimension(0) == 1, stddevs.data(),
|
||||||
|
stddevs.dimension(0) == 1, minvals.data(), minvals.dimension(0) == 1,
|
||||||
|
maxvals.data(), maxvals.dimension(0) == 1, kMaxIterations);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Explicit instantiation of the GPU distributions functors
|
||||||
|
template struct TruncatedNormalFunctor<GPUDevice, Eigen::half>;
|
||||||
|
template struct TruncatedNormalFunctor<GPUDevice, float>;
|
||||||
|
template struct TruncatedNormalFunctor<GPUDevice, double>;
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
@ -131,5 +131,8 @@ static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) {
|
|||||||
BM_PTruncatedNormalDev(cpu, 1000, 1000);
|
BM_PTruncatedNormalDev(cpu, 1000, 1000);
|
||||||
BM_PTruncatedNormalDev_2SD(cpu, 10000, 100);
|
BM_PTruncatedNormalDev_2SD(cpu, 10000, 100);
|
||||||
BM_PTruncatedNormalDev_OneTail(cpu, 10000, 100);
|
BM_PTruncatedNormalDev_OneTail(cpu, 10000, 100);
|
||||||
|
BM_PTruncatedNormalDev(gpu, 1000, 1000);
|
||||||
|
BM_PTruncatedNormalDev_2SD(gpu, 10000, 100);
|
||||||
|
BM_PTruncatedNormalDev_OneTail(gpu, 10000, 100);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -220,7 +220,7 @@ cuda_py_test(
|
|||||||
additional_deps = ["//tensorflow:tensorflow_py"],
|
additional_deps = ["//tensorflow:tensorflow_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
cuda_py_test(
|
||||||
name = "parameterized_truncated_normal_op_test",
|
name = "parameterized_truncated_normal_op_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["parameterized_truncated_normal_op_test.py"],
|
srcs = ["parameterized_truncated_normal_op_test.py"],
|
||||||
|
@ -97,10 +97,10 @@ def z_test(real, expected, i, num_samples):
|
|||||||
|
|
||||||
|
|
||||||
class ParameterizedTruncatedNormalTest(tf.test.TestCase):
|
class ParameterizedTruncatedNormalTest(tf.test.TestCase):
|
||||||
use_gpu = False
|
_use_gpu = False
|
||||||
z_limit = 6.0
|
z_limit = 6.0
|
||||||
|
|
||||||
# Stop at moment 20 to avoid numerical errors in the theoretical moments.
|
# Stop at moment 10 to avoid numerical errors in the theoretical moments.
|
||||||
max_moment = 10
|
max_moment = 10
|
||||||
|
|
||||||
def validateMoments(self, shape, mean, stddev, minval, maxval, seed=1618):
|
def validateMoments(self, shape, mean, stddev, minval, maxval, seed=1618):
|
||||||
@ -109,9 +109,11 @@ class ParameterizedTruncatedNormalTest(tf.test.TestCase):
|
|||||||
# Give up early if we are unable to import it.
|
# Give up early if we are unable to import it.
|
||||||
import scipy.stats # pylint: disable=g-import-not-at-top,unused-variable
|
import scipy.stats # pylint: disable=g-import-not-at-top,unused-variable
|
||||||
tf.set_random_seed(seed)
|
tf.set_random_seed(seed)
|
||||||
with self.test_session(use_gpu=self.use_gpu):
|
with self.test_session(use_gpu=self._use_gpu):
|
||||||
samples = random_ops.parameterized_truncated_normal(
|
samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
|
||||||
shape, mean, stddev, minval, maxval).eval()
|
minval,
|
||||||
|
maxval).eval()
|
||||||
|
assert (~np.isnan(samples)).all()
|
||||||
moments = calculate_moments(samples, self.max_moment)
|
moments = calculate_moments(samples, self.max_moment)
|
||||||
expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval)
|
expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval)
|
||||||
num_samples = functools.reduce(lambda x, y: x * y, shape, 1)
|
num_samples = functools.reduce(lambda x, y: x * y, shape, 1)
|
||||||
@ -131,9 +133,11 @@ class ParameterizedTruncatedNormalTest(tf.test.TestCase):
|
|||||||
try:
|
try:
|
||||||
import scipy.stats # pylint: disable=g-import-not-at-top
|
import scipy.stats # pylint: disable=g-import-not-at-top
|
||||||
tf.set_random_seed(seed)
|
tf.set_random_seed(seed)
|
||||||
with self.test_session(use_gpu=self.use_gpu):
|
with self.test_session(use_gpu=self._use_gpu):
|
||||||
samples = random_ops.parameterized_truncated_normal(
|
samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
|
||||||
shape, mean, stddev, minval, maxval).eval()
|
minval,
|
||||||
|
maxval).eval()
|
||||||
|
assert (~np.isnan(samples)).all()
|
||||||
minval = max(mean - stddev * 10, minval)
|
minval = max(mean - stddev * 10, minval)
|
||||||
maxval = min(mean + stddev * 10, maxval)
|
maxval = min(mean + stddev * 10, maxval)
|
||||||
dist = scipy.stats.norm(loc=mean, scale=stddev)
|
dist = scipy.stats.norm(loc=mean, scale=stddev)
|
||||||
@ -173,8 +177,12 @@ class ParameterizedTruncatedNormalTest(tf.test.TestCase):
|
|||||||
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
|
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterizedTruncatedNormalGpuTest(ParameterizedTruncatedNormalTest):
|
||||||
|
_use_gpu = True
|
||||||
|
|
||||||
|
|
||||||
# Benchmarking code
|
# Benchmarking code
|
||||||
def parameterized_vs_naive(shape, num_iters):
|
def parameterized_vs_naive(shape, num_iters, use_gpu=False):
|
||||||
np.random.seed(1618) # Make it reproducible.
|
np.random.seed(1618) # Make it reproducible.
|
||||||
|
|
||||||
# No CSE/CF.
|
# No CSE/CF.
|
||||||
@ -183,17 +191,29 @@ def parameterized_vs_naive(shape, num_iters):
|
|||||||
graph_options=tf.GraphOptions(optimizer_options=optimizer_options))
|
graph_options=tf.GraphOptions(optimizer_options=optimizer_options))
|
||||||
|
|
||||||
with tf.Session(config=config) as sess:
|
with tf.Session(config=config) as sess:
|
||||||
param_op = tf.group(random_ops.parameterized_truncated_normal(shape))
|
with tf.device("/cpu:0" if not use_gpu else None):
|
||||||
naive_op = tf.group(random_ops.truncated_normal(shape))
|
param_op = tf.group(random_ops.parameterized_truncated_normal(shape))
|
||||||
|
naive_op = tf.group(random_ops.truncated_normal(shape))
|
||||||
|
|
||||||
|
# Burn-in to avoid session setup costs in the timing.
|
||||||
|
sess.run(param_op)
|
||||||
|
sess.run(param_op)
|
||||||
param_dt = timeit.timeit(lambda: sess.run(param_op), number=num_iters)
|
param_dt = timeit.timeit(lambda: sess.run(param_op), number=num_iters)
|
||||||
|
sess.run(naive_op)
|
||||||
|
sess.run(naive_op)
|
||||||
naive_dt = timeit.timeit(lambda: sess.run(naive_op), number=num_iters)
|
naive_dt = timeit.timeit(lambda: sess.run(naive_op), number=num_iters)
|
||||||
return param_dt, naive_dt
|
return param_dt, naive_dt
|
||||||
|
|
||||||
|
|
||||||
class TruncatedNormalBenchmark(tf.test.Benchmark):
|
class TruncatedNormalBenchmark(tf.test.Benchmark):
|
||||||
|
|
||||||
def benchmarkParameterizedOpVsNaiveOp(self):
|
def benchmarkParameterizedOpVsNaiveOpCpu(self):
|
||||||
|
self._benchmarkParameterizedOpVsNaiveOp(False)
|
||||||
|
|
||||||
|
def benchmarkParameterizedOpVsNaiveOpGpu(self):
|
||||||
|
self._benchmarkParameterizedOpVsNaiveOp(True)
|
||||||
|
|
||||||
|
def _benchmarkParameterizedOpVsNaiveOp(self, use_gpu):
|
||||||
num_iters = 50
|
num_iters = 50
|
||||||
print(("Composition of new ParameterizedTruncatedNormalOp vs. "
|
print(("Composition of new ParameterizedTruncatedNormalOp vs. "
|
||||||
"naive TruncatedNormalOp [%d iters]") % num_iters)
|
"naive TruncatedNormalOp [%d iters]") % num_iters)
|
||||||
@ -201,16 +221,16 @@ class TruncatedNormalBenchmark(tf.test.Benchmark):
|
|||||||
|
|
||||||
for shape in [[10000, 100], [1000, 1000], [1000000], [100, 100, 100],
|
for shape in [[10000, 100], [1000, 1000], [1000000], [100, 100, 100],
|
||||||
[20, 20, 20, 20]]:
|
[20, 20, 20, 20]]:
|
||||||
p_dt, n_dt = parameterized_vs_naive(shape, num_iters)
|
p_dt, n_dt = parameterized_vs_naive(shape, num_iters, use_gpu)
|
||||||
print("%s\t%.3f\t%.3f\t%.2f" % (shape, p_dt, n_dt, p_dt / n_dt))
|
print("%s\t%.3f\t%.3f\t%.2f" % (shape, p_dt, n_dt, p_dt / n_dt))
|
||||||
|
|
||||||
shape_str = "-".join(map(str, shape))
|
shape_str = "-".join(map(str, shape))
|
||||||
self.report_benchmark(name="parameterized_shape" + shape_str,
|
self.report_benchmark(
|
||||||
iters=num_iters,
|
name="parameterized_shape" + shape_str,
|
||||||
wall_time=p_dt)
|
iters=num_iters,
|
||||||
self.report_benchmark(name="naive_shape" + shape_str,
|
wall_time=p_dt)
|
||||||
iters=num_iters,
|
self.report_benchmark(
|
||||||
wall_time=n_dt)
|
name="naive_shape" + shape_str, iters=num_iters, wall_time=n_dt)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user