A continuing partial implementation of RFC "Random numbers in TensorFlow 2.0" (https://github.com/tensorflow/community/blob/master/rfcs/20181217-tf2-random-numbers.md):
In this change: - CPU and GPU kernels for op 'StatefulRandomInt' and 'StatefulRandomFullInt'. To be done: - ops for other distributions; - other RNG algorithms; - batch seeds; - initializers ('RandomUniform', etc.); PiperOrigin-RevId: 239104292
This commit is contained in:
parent
e4df992d2e
commit
949474a448
@ -873,6 +873,7 @@ tf_xla_py_test(
|
|||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:standard_ops",
|
"//tensorflow/python:standard_ops",
|
||||||
"//tensorflow/python:stateful_random_ops",
|
"//tensorflow/python:stateful_random_ops",
|
||||||
|
"//tensorflow/python/kernel_tests/random:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -887,6 +888,7 @@ tf_xla_py_test(
|
|||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:standard_ops",
|
"//tensorflow/python:standard_ops",
|
||||||
"//tensorflow/python:stateless_random_ops",
|
"//tensorflow/python:stateless_random_ops",
|
||||||
|
"//tensorflow/python/kernel_tests/random:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
@ -29,6 +27,8 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.kernel_tests.random import util as \
|
||||||
|
random_test_util
|
||||||
from tensorflow.python.ops import gen_stateful_random_ops
|
from tensorflow.python.ops import gen_stateful_random_ops
|
||||||
from tensorflow.python.ops import stateful_random_ops as \
|
from tensorflow.python.ops import stateful_random_ops as \
|
||||||
random
|
random
|
||||||
@ -181,14 +181,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
|
|||||||
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
||||||
self.assertTrue(np.all(np.isfinite(x)))
|
self.assertTrue(np.all(np.isfinite(x)))
|
||||||
|
|
||||||
def _chi_squared(self, x, bins):
|
|
||||||
"""Pearson's Chi-squared test."""
|
|
||||||
x = np.ravel(x)
|
|
||||||
n = len(x)
|
|
||||||
histogram, _ = np.histogram(x, bins=bins, range=(0, 1))
|
|
||||||
expected = n / float(bins)
|
|
||||||
return np.sum(np.square(histogram - expected) / expected)
|
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testDistributionOfUniform(self):
|
def testDistributionOfUniform(self):
|
||||||
"""Use Pearson's Chi-squared test to test for uniformity."""
|
"""Use Pearson's Chi-squared test to test for uniformity."""
|
||||||
@ -208,22 +200,9 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
|
|||||||
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
|
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
|
||||||
# p=0.05. This test is probabilistic and would be flaky if the random
|
# p=0.05. This test is probabilistic and would be flaky if the random
|
||||||
# seed were not fixed.
|
# seed were not fixed.
|
||||||
val = self._chi_squared(x, 10)
|
val = random_test_util.chi_squared(x, 10)
|
||||||
self.assertLess(val, 16.92)
|
self.assertLess(val, 16.92)
|
||||||
|
|
||||||
def _normal_cdf(self, x):
|
|
||||||
"""Cumulative distribution function for a standard normal distribution."""
|
|
||||||
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
|
|
||||||
|
|
||||||
def _anderson_darling(self, x):
|
|
||||||
"""Anderson-Darling test for a standard normal distribution."""
|
|
||||||
x = np.sort(np.ravel(x))
|
|
||||||
n = len(x)
|
|
||||||
i = np.linspace(1, n, n)
|
|
||||||
z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) +
|
|
||||||
(2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x)))
|
|
||||||
return -n - z / n
|
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testDistributionOfNormal(self):
|
def testDistributionOfNormal(self):
|
||||||
"""Use Anderson-Darling test to test distribution appears normal."""
|
"""Use Anderson-Darling test to test distribution appears normal."""
|
||||||
@ -235,7 +214,8 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
|
|||||||
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
||||||
# test where the mean and variance are known. This test is probabilistic
|
# test where the mean and variance are known. This test is probabilistic
|
||||||
# so to avoid flakiness the seed is fixed.
|
# so to avoid flakiness the seed is fixed.
|
||||||
self.assertLess(self._anderson_darling(x.astype(float)), 2.492)
|
self.assertLess(
|
||||||
|
random_test_util.anderson_darling(x.astype(float)), 2.492)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testErrors(self):
|
def testErrors(self):
|
||||||
|
@ -24,6 +24,8 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.kernel_tests.random import util as \
|
||||||
|
random_test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import stateless_random_ops as stateless
|
from tensorflow.python.ops import stateless_random_ops as stateless
|
||||||
from tensorflow.python.ops.distributions import special_math
|
from tensorflow.python.ops.distributions import special_math
|
||||||
@ -43,7 +45,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||||
with self.cached_session(), self.test_scope():
|
with self.cached_session(), self.test_scope():
|
||||||
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
|
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
|
||||||
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
|
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
|
||||||
for stateless_op in [
|
for stateless_op in [
|
||||||
stateless.stateless_random_uniform, stateless.stateless_random_normal
|
stateless.stateless_random_uniform, stateless.stateless_random_normal
|
||||||
]:
|
]:
|
||||||
@ -75,14 +77,6 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
self.assertTrue(np.all(y >= 0))
|
self.assertTrue(np.all(y >= 0))
|
||||||
self.assertTrue(np.all(y < maxval))
|
self.assertTrue(np.all(y < maxval))
|
||||||
|
|
||||||
def _chi_squared(self, x, bins):
|
|
||||||
"""Pearson's Chi-squared test."""
|
|
||||||
x = np.ravel(x)
|
|
||||||
n = len(x)
|
|
||||||
histogram, _ = np.histogram(x, bins=bins, range=(0, 1))
|
|
||||||
expected = n / float(bins)
|
|
||||||
return np.sum(np.square(histogram - expected) / expected)
|
|
||||||
|
|
||||||
def testDistributionOfStatelessRandomUniform(self):
|
def testDistributionOfStatelessRandomUniform(self):
|
||||||
"""Use Pearson's Chi-squared test to test for uniformity."""
|
"""Use Pearson's Chi-squared test to test for uniformity."""
|
||||||
with self.cached_session() as sess, self.test_scope():
|
with self.cached_session() as sess, self.test_scope():
|
||||||
@ -102,7 +96,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
|
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
|
||||||
# p=0.05. This test is probabilistic and would be flaky if the random
|
# p=0.05. This test is probabilistic and would be flaky if the random
|
||||||
# seed were not fixed.
|
# seed were not fixed.
|
||||||
self.assertTrue(self._chi_squared(y, 10) < 16.92)
|
self.assertLess(random_test_util.chi_squared(y, 10), 16.92)
|
||||||
|
|
||||||
def testRandomNormalIsFinite(self):
|
def testRandomNormalIsFinite(self):
|
||||||
with self.cached_session() as sess, self.test_scope():
|
with self.cached_session() as sess, self.test_scope():
|
||||||
@ -113,19 +107,6 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
|
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
|
||||||
self.assertTrue(np.all(np.isfinite(y)))
|
self.assertTrue(np.all(np.isfinite(y)))
|
||||||
|
|
||||||
def _normal_cdf(self, x):
|
|
||||||
"""Cumulative distribution function for a standard normal distribution."""
|
|
||||||
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
|
|
||||||
|
|
||||||
def _anderson_darling(self, x):
|
|
||||||
"""Anderson-Darling test for a standard normal distribution."""
|
|
||||||
x = np.sort(np.ravel(x))
|
|
||||||
n = len(x)
|
|
||||||
i = np.linspace(1, n, n)
|
|
||||||
z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) +
|
|
||||||
(2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x)))
|
|
||||||
return -n - z / n
|
|
||||||
|
|
||||||
def testDistributionOfStatelessRandomNormal(self):
|
def testDistributionOfStatelessRandomNormal(self):
|
||||||
"""Use Anderson-Darling test to test distribution appears normal."""
|
"""Use Anderson-Darling test to test distribution appears normal."""
|
||||||
with self.cached_session() as sess, self.test_scope():
|
with self.cached_session() as sess, self.test_scope():
|
||||||
@ -138,7 +119,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
||||||
# test where the mean and variance are known. This test is probabilistic
|
# test where the mean and variance are known. This test is probabilistic
|
||||||
# so to avoid flakiness the seed is fixed.
|
# so to avoid flakiness the seed is fixed.
|
||||||
self.assertTrue(self._anderson_darling(y.astype(float)) < 2.492)
|
self.assertLess(
|
||||||
|
random_test_util.anderson_darling(y.astype(float)), 2.492)
|
||||||
|
|
||||||
def testTruncatedNormalIsInRange(self):
|
def testTruncatedNormalIsInRange(self):
|
||||||
for dtype in self._random_types():
|
for dtype in self._random_types():
|
||||||
@ -155,7 +137,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
def normal_pdf(x):
|
def normal_pdf(x):
|
||||||
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
|
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
|
||||||
|
|
||||||
def probit(x, sess=sess):
|
def probit(x):
|
||||||
return self.evaluate(special_math.ndtri(x))
|
return self.evaluate(special_math.ndtri(x))
|
||||||
|
|
||||||
a = -2.
|
a = -2.
|
||||||
|
@ -257,6 +257,7 @@ tensorflow/core/kernels/split_op.cc
|
|||||||
tensorflow/core/kernels/split_v_op.cc
|
tensorflow/core/kernels/split_v_op.cc
|
||||||
tensorflow/core/kernels/stack.cc
|
tensorflow/core/kernels/stack.cc
|
||||||
tensorflow/core/kernels/stack_ops.cc
|
tensorflow/core/kernels/stack_ops.cc
|
||||||
|
tensorflow/core/kernels/stateful_random_ops.cc
|
||||||
tensorflow/core/kernels/stateless_random_ops.cc
|
tensorflow/core/kernels/stateless_random_ops.cc
|
||||||
tensorflow/core/kernels/strided_slice_op.cc
|
tensorflow/core/kernels/strided_slice_op.cc
|
||||||
tensorflow/core/kernels/strided_slice_op_inst_0.cc
|
tensorflow/core/kernels/strided_slice_op_inst_0.cc
|
||||||
|
@ -5784,6 +5784,7 @@ filegroup(
|
|||||||
"queue_op.cc",
|
"queue_op.cc",
|
||||||
"queue_ops.cc",
|
"queue_ops.cc",
|
||||||
"random_op.cc",
|
"random_op.cc",
|
||||||
|
"random_op_cpu.h",
|
||||||
"reduction_ops_all.cc",
|
"reduction_ops_all.cc",
|
||||||
"reduction_ops_any.cc",
|
"reduction_ops_any.cc",
|
||||||
"reduction_ops_common.cc",
|
"reduction_ops_common.cc",
|
||||||
|
@ -17,8 +17,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/random_op.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -27,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/random_op_cpu.h"
|
||||||
#include "tensorflow/core/lib/hash/crc32c.h"
|
#include "tensorflow/core/lib/hash/crc32c.h"
|
||||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
@ -52,131 +51,6 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||||||
typedef Eigen::SyclDevice SYCLDevice;
|
typedef Eigen::SyclDevice SYCLDevice;
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
namespace functor {
|
|
||||||
using random::PhiloxRandom;
|
|
||||||
using random::SingleSampleAdapter;
|
|
||||||
|
|
||||||
// The default implementation of the functor, which should never be invoked
|
|
||||||
// But we still need to provide implementation for now for the linker to work,
|
|
||||||
// since we do not support all the distributions yet.
|
|
||||||
template <typename Device, class Distribution>
|
|
||||||
struct FillPhiloxRandom {
|
|
||||||
typedef typename Distribution::ResultElementType T;
|
|
||||||
void operator()(OpKernelContext*, const Device&, random::PhiloxRandom gen,
|
|
||||||
T* data, int64 size, Distribution dist) {
|
|
||||||
LOG(FATAL) << "Default FillPhiloxRandom should not be executed.";
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// A class to fill a specified range of random groups
|
|
||||||
template <class Distribution, bool VariableSamplesPerOutput>
|
|
||||||
struct FillPhiloxRandomTask;
|
|
||||||
|
|
||||||
// Specialization for distribution that takes a fixed number of samples for
|
|
||||||
// each output.
|
|
||||||
template <class Distribution>
|
|
||||||
struct FillPhiloxRandomTask<Distribution, false> {
|
|
||||||
typedef typename Distribution::ResultElementType T;
|
|
||||||
static void Run(random::PhiloxRandom gen, T* data, int64 size,
|
|
||||||
int64 start_group, int64 limit_group, Distribution dist) {
|
|
||||||
const int kGroupSize = Distribution::kResultElementCount;
|
|
||||||
|
|
||||||
gen.Skip(start_group);
|
|
||||||
int64 offset = start_group * kGroupSize;
|
|
||||||
|
|
||||||
// First fill all the full-size groups
|
|
||||||
int64 limit_group_full = std::min(limit_group, size / kGroupSize);
|
|
||||||
for (int64 index = start_group; index < limit_group_full; ++index) {
|
|
||||||
auto samples = dist(&gen);
|
|
||||||
std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
|
|
||||||
offset += kGroupSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there are any remaining elements that need to be filled, process them
|
|
||||||
if (limit_group_full < limit_group) {
|
|
||||||
int64 remaining_size = size - limit_group_full * kGroupSize;
|
|
||||||
auto samples = dist(&gen);
|
|
||||||
std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Specialization for distribution that takes a variable number of samples for
|
|
||||||
// each output. This will be slower due to the generality.
|
|
||||||
template <class Distribution>
|
|
||||||
struct FillPhiloxRandomTask<Distribution, true> {
|
|
||||||
typedef typename Distribution::ResultElementType T;
|
|
||||||
static const int64 kReservedSamplesPerOutput = 256;
|
|
||||||
|
|
||||||
static void Run(random::PhiloxRandom base_gen, T* data, int64 size,
|
|
||||||
int64 start_group, int64 limit_group, Distribution dist) {
|
|
||||||
const int kGroupSize = Distribution::kResultElementCount;
|
|
||||||
|
|
||||||
static const int kGeneratorSkipPerOutputGroup =
|
|
||||||
kGroupSize * kReservedSamplesPerOutput /
|
|
||||||
PhiloxRandom::kResultElementCount;
|
|
||||||
|
|
||||||
int64 offset = start_group * kGroupSize;
|
|
||||||
|
|
||||||
// First fill all the full-size groups
|
|
||||||
int64 limit_group_full = std::min(limit_group, size / kGroupSize);
|
|
||||||
int64 group_index;
|
|
||||||
for (group_index = start_group; group_index < limit_group_full;
|
|
||||||
++group_index) {
|
|
||||||
// Reset the generator to the beginning of the output group region
|
|
||||||
// This is necessary if we want the results to be independent of order
|
|
||||||
// of work
|
|
||||||
PhiloxRandom gen = base_gen;
|
|
||||||
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
|
|
||||||
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
|
|
||||||
|
|
||||||
auto samples = dist(&single_samples);
|
|
||||||
std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
|
|
||||||
offset += kGroupSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there are any remaining elements that need to be filled, process them
|
|
||||||
if (limit_group_full < limit_group) {
|
|
||||||
PhiloxRandom gen = base_gen;
|
|
||||||
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
|
|
||||||
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
|
|
||||||
|
|
||||||
int64 remaining_size = size - limit_group_full * kGroupSize;
|
|
||||||
auto samples = dist(&single_samples);
|
|
||||||
std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Partial specialization for CPU to fill the entire region with randoms
|
|
||||||
// It splits the work into several tasks and run them in parallel
|
|
||||||
template <class Distribution>
|
|
||||||
void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
|
|
||||||
OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen,
|
|
||||||
typename Distribution::ResultElementType* data, int64 size,
|
|
||||||
Distribution dist) {
|
|
||||||
const int kGroupSize = Distribution::kResultElementCount;
|
|
||||||
|
|
||||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
|
||||||
|
|
||||||
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
|
|
||||||
|
|
||||||
const int kGroupCost =
|
|
||||||
random::PhiloxRandom::kResultElementCount *
|
|
||||||
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
|
|
||||||
Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
|
|
||||||
kGroupCost,
|
|
||||||
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
|
|
||||||
FillPhiloxRandomTask<
|
|
||||||
Distribution,
|
|
||||||
Distribution::kVariableSamplesPerOutput>::Run(gen, data, size,
|
|
||||||
start_group,
|
|
||||||
limit_group, dist);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace functor
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
|
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
|
||||||
@ -354,7 +228,7 @@ class RandomGammaOp : public OpKernel {
|
|||||||
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
|
const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
|
||||||
|
|
||||||
DISABLE_FLOAT_EQUALITY_WARNING
|
DISABLE_FLOAT_EQUALITY_WARNING
|
||||||
if (alpha == double(1.0)) {
|
if (alpha == static_cast<double>(1.0)) {
|
||||||
ENABLE_FLOAT_EQUALITY_WARNING
|
ENABLE_FLOAT_EQUALITY_WARNING
|
||||||
// Sample from an exponential distribution.
|
// Sample from an exponential distribution.
|
||||||
for (int64 sample_idx = output_idx % num_samples;
|
for (int64 sample_idx = output_idx % num_samples;
|
||||||
@ -364,7 +238,7 @@ class RandomGammaOp : public OpKernel {
|
|||||||
// (including eventually on GPU), we skip on a per-sample basis.
|
// (including eventually on GPU), we skip on a per-sample basis.
|
||||||
PhiloxRandom gen = rng;
|
PhiloxRandom gen = rng;
|
||||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||||
short uniform_remaining = 0;
|
int16 uniform_remaining = 0;
|
||||||
UNIFORM(u);
|
UNIFORM(u);
|
||||||
const double res = -log(1.0 - u);
|
const double res = -log(1.0 - u);
|
||||||
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
|
||||||
@ -392,8 +266,8 @@ class RandomGammaOp : public OpKernel {
|
|||||||
// (including eventually on GPU), we skip on a per-sample basis.
|
// (including eventually on GPU), we skip on a per-sample basis.
|
||||||
PhiloxRandom gen = rng;
|
PhiloxRandom gen = rng;
|
||||||
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
gen.Skip(kReservedSamplesPerOutput * output_idx);
|
||||||
short norm_remaining = 0;
|
int16 norm_remaining = 0;
|
||||||
short uniform_remaining = 0;
|
int16 uniform_remaining = 0;
|
||||||
|
|
||||||
// Keep trying until we don't reject a sample. In practice, we will
|
// Keep trying until we don't reject a sample. In practice, we will
|
||||||
// only reject ~5% at worst, for low alpha near 1.
|
// only reject ~5% at worst, for low alpha near 1.
|
||||||
@ -565,145 +439,6 @@ TF_CALL_int64(REGISTER_INT);
|
|||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
namespace functor {
|
|
||||||
|
|
||||||
using namespace cl;
|
|
||||||
|
|
||||||
template <class Distribution, bool VariableSamplesPerOutput>
|
|
||||||
struct FillPhiloxRandomKernel;
|
|
||||||
|
|
||||||
template <class Distribution>
|
|
||||||
struct FillPhiloxRandomKernel<Distribution, false> {
|
|
||||||
typedef typename Distribution::ResultElementType T;
|
|
||||||
using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
|
|
||||||
sycl::access::target::global_buffer>;
|
|
||||||
|
|
||||||
FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
|
|
||||||
Distribution& dist)
|
|
||||||
: data_(data), gen_(gen), dist_(dist) {}
|
|
||||||
|
|
||||||
void operator()(sycl::nd_item<1> item) {
|
|
||||||
const size_t kGroupSize = Distribution::kResultElementCount;
|
|
||||||
|
|
||||||
const size_t item_id = item.get_global(0);
|
|
||||||
const size_t total_item_count = item.get_global_range();
|
|
||||||
size_t offset = item_id * kGroupSize;
|
|
||||||
gen_.Skip(item_id);
|
|
||||||
|
|
||||||
const size_t size = data_.get_size() / sizeof(T);
|
|
||||||
T* data = ConvertToActualTypeSycl(T, data_);
|
|
||||||
|
|
||||||
while (offset + kGroupSize <= size) {
|
|
||||||
const typename Distribution::ResultType samples = dist_(&gen_);
|
|
||||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
|
||||||
data[offset + i] = samples[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
offset += (total_item_count - 1) * kGroupSize;
|
|
||||||
gen_.Skip(total_item_count - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
const typename Distribution::ResultType samples = dist_(&gen_);
|
|
||||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
|
||||||
if (offset >= size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
data[offset] = samples[i];
|
|
||||||
++offset;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
write_accessor data_;
|
|
||||||
random::PhiloxRandom gen_;
|
|
||||||
Distribution dist_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class Distribution>
|
|
||||||
struct FillPhiloxRandomKernel<Distribution, true> {
|
|
||||||
typedef typename Distribution::ResultElementType T;
|
|
||||||
using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
|
|
||||||
sycl::access::target::global_buffer>;
|
|
||||||
|
|
||||||
FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
|
|
||||||
Distribution& dist)
|
|
||||||
: data_(data), gen_(gen), dist_(dist) {}
|
|
||||||
|
|
||||||
void operator()(sycl::nd_item<1> item) {
|
|
||||||
using random::PhiloxRandom;
|
|
||||||
using random::SingleSampleAdapter;
|
|
||||||
|
|
||||||
const size_t kReservedSamplesPerOutput = 256;
|
|
||||||
const size_t kGroupSize = Distribution::kResultElementCount;
|
|
||||||
const size_t kGeneratorSkipPerOutputGroup =
|
|
||||||
kGroupSize * kReservedSamplesPerOutput /
|
|
||||||
PhiloxRandom::kResultElementCount;
|
|
||||||
|
|
||||||
const size_t item_id = item.get_global(0);
|
|
||||||
const size_t total_item_count = item.get_global_range();
|
|
||||||
size_t group_index = item_id;
|
|
||||||
size_t offset = group_index * kGroupSize;
|
|
||||||
|
|
||||||
T* data = ConvertToActualTypeSycl(T, data_);
|
|
||||||
const size_t size = data_.get_size() / sizeof(T);
|
|
||||||
|
|
||||||
while (offset < size) {
|
|
||||||
// Since each output takes a variable number of samples, we need to
|
|
||||||
// realign the generator to the beginning for the current output group
|
|
||||||
PhiloxRandom gen = gen_;
|
|
||||||
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
|
|
||||||
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
|
|
||||||
|
|
||||||
const typename Distribution::ResultType samples = dist_(&single_samples);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
|
||||||
if (offset >= size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
data[offset] = samples[i];
|
|
||||||
++offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
offset += (total_item_count - 1) * kGroupSize;
|
|
||||||
group_index += total_item_count;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
write_accessor data_;
|
|
||||||
random::PhiloxRandom gen_;
|
|
||||||
Distribution dist_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class FillRandomKernel;
|
|
||||||
// Partial specialization for SYCL to fill the entire region with randoms
|
|
||||||
// It splits the work into several tasks and run them in parallel
|
|
||||||
template <class Distribution>
|
|
||||||
void FillPhiloxRandom<SYCLDevice, Distribution>::operator()(
|
|
||||||
OpKernelContext* context, const SYCLDevice& device,
|
|
||||||
random::PhiloxRandom gen, typename Distribution::ResultElementType* data,
|
|
||||||
int64 size, Distribution dist) {
|
|
||||||
const size_t group_size = device.maxSyclThreadsPerBlock();
|
|
||||||
const size_t group_count = (size + group_size - 1) / group_size;
|
|
||||||
|
|
||||||
auto buffer = device.get_sycl_buffer(data);
|
|
||||||
|
|
||||||
device.sycl_queue().submit([&](sycl::handler& cgh) {
|
|
||||||
auto access = buffer.template get_access<sycl::access::mode::write>(cgh);
|
|
||||||
|
|
||||||
FillPhiloxRandomKernel<Distribution,
|
|
||||||
Distribution::kVariableSamplesPerOutput>
|
|
||||||
task(access, gen, dist);
|
|
||||||
cgh.parallel_for<class FillRandomKernel<Distribution>>(
|
|
||||||
sycl::nd_range<1>(sycl::range<1>(group_count * group_size),
|
|
||||||
sycl::range<1>(group_size)),
|
|
||||||
task);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace functor
|
|
||||||
|
|
||||||
#define REGISTER(TYPE) \
|
#define REGISTER(TYPE) \
|
||||||
template struct functor::FillPhiloxRandom< \
|
template struct functor::FillPhiloxRandom< \
|
||||||
SYCLDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>; \
|
SYCLDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>; \
|
||||||
|
325
tensorflow/core/kernels/random_op_cpu.h
Normal file
325
tensorflow/core/kernels/random_op_cpu.h
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/random_op.h"
|
||||||
|
#include "tensorflow/core/lib/hash/crc32c.h"
|
||||||
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/util/guarded_philox_random.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||||
|
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||||
|
_Pragma("GCC diagnostic push") \
|
||||||
|
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||||
|
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||||
|
#else
|
||||||
|
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||||
|
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
typedef Eigen::SyclDevice SYCLDevice;
|
||||||
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
using random::PhiloxRandom;
|
||||||
|
using random::SingleSampleAdapter;
|
||||||
|
|
||||||
|
// The default implementation of the functor, which should never be invoked
|
||||||
|
// But we still need to provide implementation for now for the linker to work,
|
||||||
|
// since we do not support all the distributions yet.
|
||||||
|
template <typename Device, class Distribution>
|
||||||
|
struct FillPhiloxRandom {
|
||||||
|
typedef typename Distribution::ResultElementType T;
|
||||||
|
void operator()(OpKernelContext* ctx, const Device&, random::PhiloxRandom gen,
|
||||||
|
T* data, int64 size, Distribution dist) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, false,
|
||||||
|
errors::Internal("Default FillPhiloxRandom should not be executed."));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// A class to fill a specified range of random groups
|
||||||
|
template <class Distribution, bool VariableSamplesPerOutput>
|
||||||
|
struct FillPhiloxRandomTask;
|
||||||
|
|
||||||
|
// Specialization for distribution that takes a fixed number of samples for
|
||||||
|
// each output.
|
||||||
|
template <class Distribution>
|
||||||
|
struct FillPhiloxRandomTask<Distribution, false> {
|
||||||
|
typedef typename Distribution::ResultElementType T;
|
||||||
|
static void Run(random::PhiloxRandom gen, T* data, int64 size,
|
||||||
|
int64 start_group, int64 limit_group, Distribution dist) {
|
||||||
|
const int kGroupSize = Distribution::kResultElementCount;
|
||||||
|
|
||||||
|
gen.Skip(start_group);
|
||||||
|
int64 offset = start_group * kGroupSize;
|
||||||
|
|
||||||
|
// First fill all the full-size groups
|
||||||
|
int64 limit_group_full = std::min(limit_group, size / kGroupSize);
|
||||||
|
for (int64 index = start_group; index < limit_group_full; ++index) {
|
||||||
|
auto samples = dist(&gen);
|
||||||
|
std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
|
||||||
|
offset += kGroupSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are any remaining elements that need to be filled, process them
|
||||||
|
if (limit_group_full < limit_group) {
|
||||||
|
int64 remaining_size = size - limit_group_full * kGroupSize;
|
||||||
|
auto samples = dist(&gen);
|
||||||
|
std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Specialization for distribution that takes a variable number of samples for
|
||||||
|
// each output. This will be slower due to the generality.
|
||||||
|
template <class Distribution>
|
||||||
|
struct FillPhiloxRandomTask<Distribution, true> {
|
||||||
|
typedef typename Distribution::ResultElementType T;
|
||||||
|
static const int64 kReservedSamplesPerOutput = 256;
|
||||||
|
|
||||||
|
static void Run(random::PhiloxRandom base_gen, T* data, int64 size,
|
||||||
|
int64 start_group, int64 limit_group, Distribution dist) {
|
||||||
|
const int kGroupSize = Distribution::kResultElementCount;
|
||||||
|
|
||||||
|
static const int kGeneratorSkipPerOutputGroup =
|
||||||
|
kGroupSize * kReservedSamplesPerOutput /
|
||||||
|
PhiloxRandom::kResultElementCount;
|
||||||
|
|
||||||
|
int64 offset = start_group * kGroupSize;
|
||||||
|
|
||||||
|
// First fill all the full-size groups
|
||||||
|
int64 limit_group_full = std::min(limit_group, size / kGroupSize);
|
||||||
|
int64 group_index;
|
||||||
|
for (group_index = start_group; group_index < limit_group_full;
|
||||||
|
++group_index) {
|
||||||
|
// Reset the generator to the beginning of the output group region
|
||||||
|
// This is necessary if we want the results to be independent of order
|
||||||
|
// of work
|
||||||
|
PhiloxRandom gen = base_gen;
|
||||||
|
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
|
||||||
|
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
|
||||||
|
|
||||||
|
auto samples = dist(&single_samples);
|
||||||
|
std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
|
||||||
|
offset += kGroupSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are any remaining elements that need to be filled, process them
|
||||||
|
if (limit_group_full < limit_group) {
|
||||||
|
PhiloxRandom gen = base_gen;
|
||||||
|
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
|
||||||
|
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
|
||||||
|
|
||||||
|
int64 remaining_size = size - limit_group_full * kGroupSize;
|
||||||
|
auto samples = dist(&single_samples);
|
||||||
|
std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Partial specialization for CPU to fill the entire region with randoms
|
||||||
|
// It splits the work into several tasks and run them in parallel
|
||||||
|
template <class Distribution>
|
||||||
|
void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
|
||||||
|
OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen,
|
||||||
|
typename Distribution::ResultElementType* data, int64 size,
|
||||||
|
Distribution dist) {
|
||||||
|
const int kGroupSize = Distribution::kResultElementCount;
|
||||||
|
|
||||||
|
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||||
|
|
||||||
|
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
|
||||||
|
|
||||||
|
const int kGroupCost =
|
||||||
|
random::PhiloxRandom::kResultElementCount *
|
||||||
|
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
|
||||||
|
kGroupCost,
|
||||||
|
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
|
||||||
|
FillPhiloxRandomTask<
|
||||||
|
Distribution,
|
||||||
|
Distribution::kVariableSamplesPerOutput>::Run(gen, data, size,
|
||||||
|
start_group,
|
||||||
|
limit_group, dist);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <class Distribution, bool VariableSamplesPerOutput>
|
||||||
|
struct FillPhiloxRandomKernel;
|
||||||
|
|
||||||
|
template <class Distribution>
|
||||||
|
struct FillPhiloxRandomKernel<Distribution, false> {
|
||||||
|
typedef typename Distribution::ResultElementType T;
|
||||||
|
using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
|
||||||
|
sycl::access::target::global_buffer>;
|
||||||
|
|
||||||
|
FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
|
||||||
|
Distribution& dist)
|
||||||
|
: data_(data), gen_(gen), dist_(dist) {}
|
||||||
|
|
||||||
|
void operator()(sycl::nd_item<1> item) {
|
||||||
|
const size_t kGroupSize = Distribution::kResultElementCount;
|
||||||
|
|
||||||
|
const size_t item_id = item.get_global(0);
|
||||||
|
const size_t total_item_count = item.get_global_range();
|
||||||
|
size_t offset = item_id * kGroupSize;
|
||||||
|
gen_.Skip(item_id);
|
||||||
|
|
||||||
|
const size_t size = data_.get_size() / sizeof(T);
|
||||||
|
T* data = ConvertToActualTypeSycl(T, data_);
|
||||||
|
|
||||||
|
while (offset + kGroupSize <= size) {
|
||||||
|
const typename Distribution::ResultType samples = dist_(&gen_);
|
||||||
|
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||||
|
data[offset + i] = samples[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += (total_item_count - 1) * kGroupSize;
|
||||||
|
gen_.Skip(total_item_count - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const typename Distribution::ResultType samples = dist_(&gen_);
|
||||||
|
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||||
|
if (offset >= size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data[offset] = samples[i];
|
||||||
|
++offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
write_accessor data_;
|
||||||
|
random::PhiloxRandom gen_;
|
||||||
|
Distribution dist_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Distribution>
|
||||||
|
struct FillPhiloxRandomKernel<Distribution, true> {
|
||||||
|
typedef typename Distribution::ResultElementType T;
|
||||||
|
using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
|
||||||
|
sycl::access::target::global_buffer>;
|
||||||
|
|
||||||
|
FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
|
||||||
|
Distribution& dist)
|
||||||
|
: data_(data), gen_(gen), dist_(dist) {}
|
||||||
|
|
||||||
|
void operator()(sycl::nd_item<1> item) {
|
||||||
|
using random::PhiloxRandom;
|
||||||
|
using random::SingleSampleAdapter;
|
||||||
|
|
||||||
|
const size_t kReservedSamplesPerOutput = 256;
|
||||||
|
const size_t kGroupSize = Distribution::kResultElementCount;
|
||||||
|
const size_t kGeneratorSkipPerOutputGroup =
|
||||||
|
kGroupSize * kReservedSamplesPerOutput /
|
||||||
|
PhiloxRandom::kResultElementCount;
|
||||||
|
|
||||||
|
const size_t item_id = item.get_global(0);
|
||||||
|
const size_t total_item_count = item.get_global_range();
|
||||||
|
size_t group_index = item_id;
|
||||||
|
size_t offset = group_index * kGroupSize;
|
||||||
|
|
||||||
|
T* data = ConvertToActualTypeSycl(T, data_);
|
||||||
|
const size_t size = data_.get_size() / sizeof(T);
|
||||||
|
|
||||||
|
while (offset < size) {
|
||||||
|
// Since each output takes a variable number of samples, we need to
|
||||||
|
// realign the generator to the beginning for the current output group
|
||||||
|
PhiloxRandom gen = gen_;
|
||||||
|
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
|
||||||
|
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
|
||||||
|
|
||||||
|
const typename Distribution::ResultType samples = dist_(&single_samples);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||||
|
if (offset >= size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data[offset] = samples[i];
|
||||||
|
++offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += (total_item_count - 1) * kGroupSize;
|
||||||
|
group_index += total_item_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
write_accessor data_;
|
||||||
|
random::PhiloxRandom gen_;
|
||||||
|
Distribution dist_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FillRandomKernel;
|
||||||
|
// Partial specialization for SYCL to fill the entire region with randoms
|
||||||
|
// It splits the work into several tasks and run them in parallel
|
||||||
|
template <class Distribution>
|
||||||
|
void FillPhiloxRandom<SYCLDevice, Distribution>::operator()(
|
||||||
|
OpKernelContext* context, const SYCLDevice& device,
|
||||||
|
random::PhiloxRandom gen, typename Distribution::ResultElementType* data,
|
||||||
|
int64 size, Distribution dist) {
|
||||||
|
const size_t group_size = device.maxSyclThreadsPerBlock();
|
||||||
|
const size_t group_count = (size + group_size - 1) / group_size;
|
||||||
|
|
||||||
|
auto buffer = device.get_sycl_buffer(data);
|
||||||
|
|
||||||
|
device.sycl_queue().submit([&](sycl::handler& cgh) {
|
||||||
|
auto access = buffer.template get_access<sycl::access::mode::write>(cgh);
|
||||||
|
|
||||||
|
FillPhiloxRandomKernel<Distribution,
|
||||||
|
Distribution::kVariableSamplesPerOutput>
|
||||||
|
task(access, gen, dist);
|
||||||
|
cgh.parallel_for<class FillRandomKernel<Distribution>>(
|
||||||
|
sycl::nd_range<1>(sycl::range<1>(group_count * group_size),
|
||||||
|
sycl::range<1>(group_size)),
|
||||||
|
task);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_
|
@ -17,17 +17,15 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/random_op.h"
|
|
||||||
#include "tensorflow/core/kernels/random_op_gpu.h"
|
|
||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/random_op_gpu.h"
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/lib/random/philox_random.h"
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -37,33 +35,6 @@ namespace functor {
|
|||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
// A simple launch pad to call the correct function templates to fill the data
|
|
||||||
template <class Distribution>
|
|
||||||
__global__ void __launch_bounds__(1024)
|
|
||||||
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
|
|
||||||
typename Distribution::ResultElementType* data,
|
|
||||||
int64 size, Distribution dist) {
|
|
||||||
FillPhiloxRandomKernel<Distribution,
|
|
||||||
Distribution::kVariableSamplesPerOutput>()
|
|
||||||
.Run(base_gen, data, size, dist);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Partial specialization for GPU
|
|
||||||
template <class Distribution>
|
|
||||||
void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
|
||||||
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
|
|
||||||
typename Distribution::ResultElementType* data, int64 size,
|
|
||||||
Distribution dist) {
|
|
||||||
const int32 block_size = d.maxGpuThreadsPerBlock();
|
|
||||||
const int32 num_blocks =
|
|
||||||
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
|
|
||||||
block_size;
|
|
||||||
|
|
||||||
TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
|
|
||||||
num_blocks, block_size, 0, d.stream(), gen, data,
|
|
||||||
size, dist));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Explicit instantiation of the GPU distributions functors
|
// Explicit instantiation of the GPU distributions functors
|
||||||
// clang-format off
|
// clang-format off
|
||||||
// NVCC cannot handle ">>" properly
|
// NVCC cannot handle ">>" properly
|
||||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/random_op.h"
|
||||||
#include "tensorflow/core/lib/random/philox_random.h"
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
|
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -31,14 +33,14 @@ struct FillPhiloxRandomKernel;
|
|||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
struct FillPhiloxRandomKernel<Distribution, false> {
|
struct FillPhiloxRandomKernel<Distribution, false> {
|
||||||
typedef typename Distribution::ResultElementType T;
|
typedef typename Distribution::ResultElementType T;
|
||||||
PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size,
|
PHILOX_DEVICE_INLINE void Run(random::PhiloxRandom gen, T* data, int64 size,
|
||||||
Distribution dist);
|
Distribution dist);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
struct FillPhiloxRandomKernel<Distribution, true> {
|
struct FillPhiloxRandomKernel<Distribution, true> {
|
||||||
typedef typename Distribution::ResultElementType T;
|
typedef typename Distribution::ResultElementType T;
|
||||||
PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data,
|
PHILOX_DEVICE_INLINE void Run(const random::PhiloxRandom& base_gen, T* data,
|
||||||
int64 size, Distribution dist);
|
int64 size, Distribution dist);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -128,7 +130,7 @@ class SampleCopier<int64, 2> {
|
|||||||
// A cuda kernel to fill the data with random numbers from the specified
|
// A cuda kernel to fill the data with random numbers from the specified
|
||||||
// distribution. Each output takes a fixed number of samples.
|
// distribution. Each output takes a fixed number of samples.
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel<Distribution, false>::Run(
|
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
|
||||||
random::PhiloxRandom gen, T* data, int64 size, Distribution dist) {
|
random::PhiloxRandom gen, T* data, int64 size, Distribution dist) {
|
||||||
const int kGroupSize = Distribution::kResultElementCount;
|
const int kGroupSize = Distribution::kResultElementCount;
|
||||||
|
|
||||||
@ -159,7 +161,7 @@ PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel<Distribution, false>::Run(
|
|||||||
// A cuda kernel to fill the data with random numbers from the specified
|
// A cuda kernel to fill the data with random numbers from the specified
|
||||||
// distribution. Each output takes a variable number of samples.
|
// distribution. Each output takes a variable number of samples.
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel<Distribution, true>::Run(
|
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
||||||
const random::PhiloxRandom& base_gen, T* data, int64 size,
|
const random::PhiloxRandom& base_gen, T* data, int64 size,
|
||||||
Distribution dist) {
|
Distribution dist) {
|
||||||
using random::PhiloxRandom;
|
using random::PhiloxRandom;
|
||||||
@ -198,6 +200,33 @@ PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel<Distribution, true>::Run(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A simple launch pad to call the correct function templates to fill the data
|
||||||
|
template <class Distribution>
|
||||||
|
__global__ void __launch_bounds__(1024)
|
||||||
|
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
|
||||||
|
typename Distribution::ResultElementType* data,
|
||||||
|
int64 size, Distribution dist) {
|
||||||
|
FillPhiloxRandomKernel<Distribution,
|
||||||
|
Distribution::kVariableSamplesPerOutput>()
|
||||||
|
.Run(base_gen, data, size, dist);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Partial specialization for GPU
|
||||||
|
template <class Distribution>
|
||||||
|
void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
||||||
|
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
|
||||||
|
typename Distribution::ResultElementType* data, int64 size,
|
||||||
|
Distribution dist) {
|
||||||
|
const int32 block_size = d.maxGpuThreadsPerBlock();
|
||||||
|
const int32 num_blocks =
|
||||||
|
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
|
||||||
|
block_size;
|
||||||
|
|
||||||
|
TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
|
||||||
|
num_blocks, block_size, 0, d.stream(), gen, data,
|
||||||
|
size, dist));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/random_op.h"
|
#include "tensorflow/core/kernels/random_op_cpu.h"
|
||||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
@ -25,7 +25,7 @@ namespace tensorflow {
|
|||||||
template <typename Distribution>
|
template <typename Distribution>
|
||||||
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
||||||
void operator()(OpKernelContext* ctx, const CPUDevice& device,
|
void operator()(OpKernelContext* ctx, const CPUDevice& device,
|
||||||
int64 output_size, int64 alg_tag_skip,
|
Distribution dist, int64 output_size, int64 alg_tag_skip,
|
||||||
ScopedUnlockUnrefVar* state_var_guard, Tensor* state_tensor,
|
ScopedUnlockUnrefVar* state_var_guard, Tensor* state_tensor,
|
||||||
typename Distribution::ResultElementType* output_data) {
|
typename Distribution::ResultElementType* output_data) {
|
||||||
auto state_tensor_flat = state_tensor->flat<StateElementType>();
|
auto state_tensor_flat = state_tensor->flat<StateElementType>();
|
||||||
@ -36,14 +36,14 @@ struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
|||||||
// No longer needs the lock.
|
// No longer needs the lock.
|
||||||
state_var_guard->Release();
|
state_var_guard->Release();
|
||||||
functor::FillPhiloxRandom<CPUDevice, Distribution>()(
|
functor::FillPhiloxRandom<CPUDevice, Distribution>()(
|
||||||
ctx, device, philox, output_data, output_size, Distribution());
|
ctx, device, philox, output_data, output_size, dist);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename Distribution>
|
template <typename Device, typename Distribution>
|
||||||
Status UpdateVariableAndFill(
|
Status UpdateVariableAndFill(
|
||||||
OpKernelContext* ctx, int state_input_idx, bool read_alg_from_state,
|
OpKernelContext* ctx, Distribution dist, int state_input_idx,
|
||||||
Algorithm alg, int64 output_size,
|
bool read_alg_from_state, Algorithm alg, int64 output_size,
|
||||||
typename Distribution::ResultElementType* output_data) {
|
typename Distribution::ResultElementType* output_data) {
|
||||||
Var* var = nullptr;
|
Var* var = nullptr;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -89,7 +89,7 @@ Status UpdateVariableAndFill(
|
|||||||
TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, StateElementType>(
|
TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, StateElementType>(
|
||||||
ctx, var_tensor, var->copy_on_read_mode.load()));
|
ctx, var_tensor, var->copy_on_read_mode.load()));
|
||||||
UpdateVariableAndFill_Philox<Device, Distribution>()(
|
UpdateVariableAndFill_Philox<Device, Distribution>()(
|
||||||
ctx, ctx->eigen_device<Device>(), output_size, alg_tag_skip,
|
ctx, ctx->eigen_device<Device>(), dist, output_size, alg_tag_skip,
|
||||||
&state_var_guard, var_tensor, output_data);
|
&state_var_guard, var_tensor, output_data);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
@ -99,7 +99,8 @@ Status UpdateVariableAndFill(
|
|||||||
|
|
||||||
// Preconditon: input(0) is an existing resource.
|
// Preconditon: input(0) is an existing resource.
|
||||||
template <typename Device, class Distribution>
|
template <typename Device, class Distribution>
|
||||||
void ComputeImpl(OpKernelContext* ctx, int state_input_idx, int shape_input_idx,
|
void StatefulRandomCompute(OpKernelContext* ctx, Distribution dist,
|
||||||
|
int state_input_idx, int shape_input_idx,
|
||||||
bool read_alg_from_state, Algorithm alg) {
|
bool read_alg_from_state, Algorithm alg) {
|
||||||
using T = typename Distribution::ResultElementType;
|
using T = typename Distribution::ResultElementType;
|
||||||
const Tensor& shape_t = ctx->input(shape_input_idx);
|
const Tensor& shape_t = ctx->input(shape_input_idx);
|
||||||
@ -108,8 +109,8 @@ void ComputeImpl(OpKernelContext* ctx, int state_input_idx, int shape_input_idx,
|
|||||||
Tensor* output;
|
Tensor* output;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
|
||||||
auto output_flat = output->flat<T>();
|
auto output_flat = output->flat<T>();
|
||||||
OP_REQUIRES_OK(ctx, UpdateVariableAndFill<Device, Distribution>(
|
OP_REQUIRES_OK(ctx, UpdateVariableAndFill<Device>(
|
||||||
ctx, state_input_idx, read_alg_from_state, alg,
|
ctx, dist, state_input_idx, read_alg_from_state, alg,
|
||||||
output_flat.size(), output_flat.data()));
|
output_flat.size(), output_flat.data()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,27 +120,89 @@ class StatefulRandomOp : public OpKernel {
|
|||||||
explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
ComputeImpl<Device, Distribution>(ctx, 0, 1, true, 0);
|
StatefulRandomCompute<Device>(ctx, Distribution(), 0, 1, true, 0);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status GetAlgorithm(OpKernelContext* ctx, int alg_input_idx, Algorithm* alg) {
|
||||||
|
const Tensor& alg_tensor = ctx->input(alg_input_idx);
|
||||||
|
if (alg_tensor.dims() != 0) {
|
||||||
|
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||||
|
alg_tensor.shape().DebugString());
|
||||||
|
}
|
||||||
|
if (alg_tensor.dtype() != ALGORITHM_DTYPE) {
|
||||||
|
return errors::InvalidArgument("algorithm's dtype must be ",
|
||||||
|
DataTypeString(ALGORITHM_DTYPE), ", not ",
|
||||||
|
DataTypeString(alg_tensor.dtype()));
|
||||||
|
}
|
||||||
|
*alg = alg_tensor.flat<Algorithm>()(0);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Device, class Distribution>
|
template <typename Device, class Distribution>
|
||||||
class StatefulRandomOpV2 : public OpKernel {
|
class StatefulRandomOpV2 : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit StatefulRandomOpV2(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit StatefulRandomOpV2(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
const Tensor& alg_tensor = ctx->input(1);
|
Algorithm alg;
|
||||||
OP_REQUIRES(ctx, alg_tensor.dims() == 0,
|
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, /*alg_input_idx=*/1, &alg));
|
||||||
errors::InvalidArgument("algorithm must be of shape [], not ",
|
StatefulRandomCompute<Device>(ctx, Distribution(), /*state_input_idx=*/0,
|
||||||
alg_tensor.shape().DebugString()));
|
/*shape_input_idx=*/2,
|
||||||
|
/*read_alg_from_state=*/false, alg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, class IntType>
|
||||||
|
class StatefulUniformIntOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
Algorithm alg;
|
||||||
|
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, /*alg_input_idx=*/1, &alg));
|
||||||
|
const Tensor& minval = ctx->input(3);
|
||||||
|
const Tensor& maxval = ctx->input(4);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
|
||||||
|
errors::InvalidArgument("minval must be 0-D, got shape ",
|
||||||
|
minval.shape().DebugString()));
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
|
||||||
|
errors::InvalidArgument("maxval must be 0-D, got shape ",
|
||||||
|
maxval.shape().DebugString()));
|
||||||
|
|
||||||
|
// Verify that minval < maxval. This check intentionally happens after the
|
||||||
|
// early exit for empty output. Zero impossible things are fine.
|
||||||
|
IntType lo = minval.scalar<IntType>()();
|
||||||
|
IntType hi = maxval.scalar<IntType>()();
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, alg_tensor.dtype() == ALGORITHM_DTYPE,
|
ctx, lo < hi,
|
||||||
errors::InvalidArgument("algorithm's dtype must be ",
|
errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
|
||||||
DataTypeString(ALGORITHM_DTYPE), ", not ",
|
|
||||||
DataTypeString(alg_tensor.dtype())));
|
// Build distribution
|
||||||
auto alg = alg_tensor.flat<Algorithm>()(0);
|
typedef random::UniformDistribution<random::PhiloxRandom, IntType>
|
||||||
ComputeImpl<Device, Distribution>(ctx, 0, 2, false, alg);
|
Distribution;
|
||||||
|
Distribution dist(lo, hi);
|
||||||
|
|
||||||
|
StatefulRandomCompute<Device>(ctx, dist, /*state_input_idx=*/0,
|
||||||
|
/*shape_input_idx=*/2,
|
||||||
|
/*read_alg_from_state=*/false, alg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, class IntType>
|
||||||
|
class StatefulUniformFullIntOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx)
|
||||||
|
: OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
Algorithm alg;
|
||||||
|
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, /*alg_input_idx=*/1, &alg));
|
||||||
|
StatefulRandomCompute<Device>(
|
||||||
|
ctx,
|
||||||
|
random::UniformFullIntDistribution<random::PhiloxRandom, IntType>(),
|
||||||
|
/*state_input_idx=*/0, /*shape_input_idx=*/2,
|
||||||
|
/*read_alg_from_state=*/false, alg);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -213,14 +276,65 @@ TF_CALL_bfloat16(REGISTER_CPU);
|
|||||||
TF_CALL_float(REGISTER_CPU);
|
TF_CALL_float(REGISTER_CPU);
|
||||||
TF_CALL_double(REGISTER_CPU);
|
TF_CALL_double(REGISTER_CPU);
|
||||||
|
|
||||||
|
#define REGISTER_StatefulUniformInt(DEVICE, TYPE) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("StatefulUniformInt") \
|
||||||
|
.Device(DEVICE_##DEVICE) \
|
||||||
|
.HostMemory("resource") \
|
||||||
|
.HostMemory("algorithm") \
|
||||||
|
.HostMemory("shape") \
|
||||||
|
.HostMemory("minval") \
|
||||||
|
.HostMemory("maxval") \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
StatefulUniformIntOp<DEVICE##Device, TYPE>);
|
||||||
|
|
||||||
|
#define REGISTER_StatefulUniformInt_CPU(TYPE) \
|
||||||
|
REGISTER_StatefulUniformInt(CPU, TYPE)
|
||||||
|
#define REGISTER_StatefulUniformInt_GPU(TYPE) \
|
||||||
|
REGISTER_StatefulUniformInt(GPU, TYPE)
|
||||||
|
|
||||||
|
TF_CALL_int32(REGISTER_StatefulUniformInt_CPU);
|
||||||
|
TF_CALL_int64(REGISTER_StatefulUniformInt_CPU);
|
||||||
|
|
||||||
|
#define REGISTER_StatefulUniformFullInt(DEVICE, TYPE) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("StatefulUniformFullInt") \
|
||||||
|
.Device(DEVICE_##DEVICE) \
|
||||||
|
.HostMemory("resource") \
|
||||||
|
.HostMemory("algorithm") \
|
||||||
|
.HostMemory("shape") \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
StatefulUniformFullIntOp<DEVICE##Device, TYPE>);
|
||||||
|
|
||||||
|
#define REGISTER_StatefulUniformFullInt_CPU(TYPE) \
|
||||||
|
REGISTER_StatefulUniformFullInt(CPU, TYPE)
|
||||||
|
#define REGISTER_StatefulUniformFullInt_GPU(TYPE) \
|
||||||
|
REGISTER_StatefulUniformFullInt(GPU, TYPE)
|
||||||
|
|
||||||
|
TF_CALL_int32(REGISTER_StatefulUniformFullInt_CPU);
|
||||||
|
TF_CALL_int64(REGISTER_StatefulUniformFullInt_CPU);
|
||||||
|
TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU);
|
||||||
|
TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
TF_CALL_half(REGISTER_GPU);
|
TF_CALL_half(REGISTER_GPU);
|
||||||
|
TF_CALL_bfloat16(REGISTER_GPU);
|
||||||
TF_CALL_float(REGISTER_GPU);
|
TF_CALL_float(REGISTER_GPU);
|
||||||
TF_CALL_double(REGISTER_GPU);
|
TF_CALL_double(REGISTER_GPU);
|
||||||
|
TF_CALL_int32(REGISTER_StatefulUniformInt_GPU);
|
||||||
|
TF_CALL_int64(REGISTER_StatefulUniformInt_GPU);
|
||||||
|
TF_CALL_int32(REGISTER_StatefulUniformFullInt_GPU);
|
||||||
|
TF_CALL_int64(REGISTER_StatefulUniformFullInt_GPU);
|
||||||
|
TF_CALL_uint32(REGISTER_StatefulUniformFullInt_GPU);
|
||||||
|
TF_CALL_uint64(REGISTER_StatefulUniformFullInt_GPU);
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER_StatefulUniformFullInt_GPU
|
||||||
|
#undef REGISTER_StatefulUniformFullInt_CPU
|
||||||
|
#undef REGISTER_StatefulUniformFullInt
|
||||||
|
#undef REGISTER_StatefulUniformInt_GPU
|
||||||
|
#undef REGISTER_StatefulUniformInt_CPU
|
||||||
|
#undef REGISTER_StatefulUniformInt
|
||||||
#undef REGISTER_GPU
|
#undef REGISTER_GPU
|
||||||
#undef REGISTER_CPU
|
#undef REGISTER_CPU
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
@ -87,7 +87,7 @@ using GPUDevice = Eigen::GpuDevice;
|
|||||||
template <typename Distribution>
|
template <typename Distribution>
|
||||||
struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
|
struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
|
||||||
void operator()(OpKernelContext* ctx, const GPUDevice& device,
|
void operator()(OpKernelContext* ctx, const GPUDevice& device,
|
||||||
int64 output_size, int64 alg_tag_skip,
|
Distribution dist, int64 output_size, int64 alg_tag_skip,
|
||||||
ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
|
ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
|
||||||
typename Distribution::ResultElementType* output_data);
|
typename Distribution::ResultElementType* output_data);
|
||||||
};
|
};
|
||||||
|
@ -53,8 +53,9 @@ __global__ void FillKernel(
|
|||||||
|
|
||||||
template <typename Distribution>
|
template <typename Distribution>
|
||||||
void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
||||||
OpKernelContext* ctx, const GPUDevice& d, int64 output_size,
|
OpKernelContext* ctx, const GPUDevice& d, Distribution dist,
|
||||||
int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
|
int64 output_size, int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used,
|
||||||
|
Tensor* state_tensor,
|
||||||
typename Distribution::ResultElementType* output_data) {
|
typename Distribution::ResultElementType* output_data) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, alg_tag_skip == 0,
|
ctx, alg_tag_skip == 0,
|
||||||
@ -74,10 +75,9 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
|||||||
|
|
||||||
int zero = 0;
|
int zero = 0;
|
||||||
cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int));
|
cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int));
|
||||||
TF_CHECK_OK(CudaLaunchKernel(FillKernel<Distribution>, cfg.block_count,
|
TF_CHECK_OK(CudaLaunchKernel(
|
||||||
cfg.thread_per_block, 0, d.stream(),
|
FillKernel<Distribution>, cfg.block_count, cfg.thread_per_block, 0,
|
||||||
Distribution(), state_size, output_size,
|
d.stream(), dist, state_size, output_size, state_data, output_data));
|
||||||
state_data, output_data));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Explicit instantiation of the GPU distributions functors.
|
// Explicit instantiation of the GPU distributions functors.
|
||||||
@ -86,10 +86,28 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
|||||||
// NVCC cannot handle ">>" properly
|
// NVCC cannot handle ">>" properly
|
||||||
template struct UpdateVariableAndFill_Philox<
|
template struct UpdateVariableAndFill_Philox<
|
||||||
GPUDevice, random::NormalDistribution<random::PhiloxRandom, Eigen::half> >;
|
GPUDevice, random::NormalDistribution<random::PhiloxRandom, Eigen::half> >;
|
||||||
|
template struct UpdateVariableAndFill_Philox<
|
||||||
|
GPUDevice, random::NormalDistribution<random::PhiloxRandom, bfloat16> >;
|
||||||
template struct UpdateVariableAndFill_Philox<
|
template struct UpdateVariableAndFill_Philox<
|
||||||
GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >;
|
GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >;
|
||||||
template struct UpdateVariableAndFill_Philox<
|
template struct UpdateVariableAndFill_Philox<
|
||||||
GPUDevice, random::NormalDistribution<random::PhiloxRandom, double> >;
|
GPUDevice, random::NormalDistribution<random::PhiloxRandom, double> >;
|
||||||
|
template struct UpdateVariableAndFill_Philox<
|
||||||
|
GPUDevice, random::UniformDistribution<random::PhiloxRandom, int32> >;
|
||||||
|
template struct UpdateVariableAndFill_Philox<
|
||||||
|
GPUDevice, random::UniformDistribution<random::PhiloxRandom, int64> >;
|
||||||
|
template struct UpdateVariableAndFill_Philox<
|
||||||
|
GPUDevice, random::UniformFullIntDistribution<
|
||||||
|
random::PhiloxRandom, int32> >;
|
||||||
|
template struct UpdateVariableAndFill_Philox<
|
||||||
|
GPUDevice, random::UniformFullIntDistribution<
|
||||||
|
random::PhiloxRandom, int64> >;
|
||||||
|
template struct UpdateVariableAndFill_Philox<
|
||||||
|
GPUDevice, random::UniformFullIntDistribution<
|
||||||
|
random::PhiloxRandom, uint32> >;
|
||||||
|
template struct UpdateVariableAndFill_Philox<
|
||||||
|
GPUDevice, random::UniformFullIntDistribution<
|
||||||
|
random::PhiloxRandom, uint64> >;
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -16,12 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
|
#ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
|
||||||
#define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
|
#define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
|
||||||
|
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
#define _USE_MATH_DEFINES
|
#define _USE_MATH_DEFINES
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#undef _USE_MATH_DEFINES
|
#undef _USE_MATH_DEFINES
|
||||||
|
|
||||||
#include <string.h>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
@ -236,6 +237,73 @@ class UniformDistribution<Generator, int64> {
|
|||||||
uint64 range_;
|
uint64 range_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Similar to `UniformDistribution`, except that instead of generating numbers
|
||||||
|
// in the range [low, high), it generates numbers covering the whole range of
|
||||||
|
// the integer type.
|
||||||
|
template <typename Generator, typename IntType>
|
||||||
|
class UniformFullIntDistribution;
|
||||||
|
|
||||||
|
template <typename Generator, typename IntType>
|
||||||
|
class UniformFullIntDistribution32 {
|
||||||
|
public:
|
||||||
|
// The number of elements that will be returned.
|
||||||
|
static const int kResultElementCount = Generator::kResultElementCount;
|
||||||
|
// Cost of generation of a single element (in cycles).
|
||||||
|
static const int kElementCost = 3;
|
||||||
|
// Indicate that this distribution may take variable number of samples
|
||||||
|
// during the runtime.
|
||||||
|
static const bool kVariableSamplesPerOutput = false;
|
||||||
|
typedef Array<IntType, kResultElementCount> ResultType;
|
||||||
|
typedef IntType ResultElementType;
|
||||||
|
|
||||||
|
PHILOX_DEVICE_INLINE
|
||||||
|
ResultType operator()(Generator* gen) {
|
||||||
|
typename Generator::ResultType sample = (*gen)();
|
||||||
|
ResultType result;
|
||||||
|
for (int i = 0; i < kResultElementCount; ++i) {
|
||||||
|
result[i] = sample[i];
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Generator, typename IntType>
|
||||||
|
class UniformFullIntDistribution64 {
|
||||||
|
public:
|
||||||
|
// The number of elements that will be returned.
|
||||||
|
static const int kResultElementCount = Generator::kResultElementCount / 2;
|
||||||
|
// Cost of generation of a single element (in cycles).
|
||||||
|
static const int kElementCost = 3;
|
||||||
|
// Indicate that this distribution may take variable number of samples
|
||||||
|
// during the runtime.
|
||||||
|
static const bool kVariableSamplesPerOutput = false;
|
||||||
|
typedef Array<IntType, kResultElementCount> ResultType;
|
||||||
|
typedef IntType ResultElementType;
|
||||||
|
|
||||||
|
PHILOX_DEVICE_INLINE
|
||||||
|
ResultType operator()(Generator* gen) {
|
||||||
|
typename Generator::ResultType sample = (*gen)();
|
||||||
|
ResultType result;
|
||||||
|
for (int i = 0; i < kResultElementCount; ++i) {
|
||||||
|
result[i] = sample[2 * i] | static_cast<uint64>(sample[2 * i + 1]) << 32;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Generator>
|
||||||
|
class UniformFullIntDistribution<Generator, int32>
|
||||||
|
: public UniformFullIntDistribution32<Generator, int32> {};
|
||||||
|
template <typename Generator>
|
||||||
|
class UniformFullIntDistribution<Generator, uint32>
|
||||||
|
: public UniformFullIntDistribution32<Generator, uint32> {};
|
||||||
|
template <typename Generator>
|
||||||
|
class UniformFullIntDistribution<Generator, int64>
|
||||||
|
: public UniformFullIntDistribution64<Generator, int64> {};
|
||||||
|
template <typename Generator>
|
||||||
|
class UniformFullIntDistribution<Generator, uint64>
|
||||||
|
: public UniformFullIntDistribution64<Generator, uint64> {};
|
||||||
|
|
||||||
// A class that adapts the underlying native multiple samples to return a single
|
// A class that adapts the underlying native multiple samples to return a single
|
||||||
// sample at a time.
|
// sample at a time.
|
||||||
template <class Generator>
|
template <class Generator>
|
||||||
|
@ -2992,6 +2992,7 @@ cuda_py_test(
|
|||||||
":client_testlib",
|
":client_testlib",
|
||||||
":logging_ops",
|
":logging_ops",
|
||||||
":random_ops_gen",
|
":random_ops_gen",
|
||||||
|
"//tensorflow/python/kernel_tests/random:util",
|
||||||
],
|
],
|
||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@ -70,3 +72,26 @@ def test_moment_matching(
|
|||||||
total_variance)))
|
total_variance)))
|
||||||
return z_test_scores
|
return z_test_scores
|
||||||
|
|
||||||
|
|
||||||
|
def chi_squared(x, bins):
|
||||||
|
"""Pearson's Chi-squared test."""
|
||||||
|
x = np.ravel(x)
|
||||||
|
n = len(x)
|
||||||
|
histogram, _ = np.histogram(x, bins=bins, range=(0, 1))
|
||||||
|
expected = n / float(bins)
|
||||||
|
return np.sum(np.square(histogram - expected) / expected)
|
||||||
|
|
||||||
|
|
||||||
|
def normal_cdf(x):
|
||||||
|
"""Cumulative distribution function for a standard normal distribution."""
|
||||||
|
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
|
||||||
|
|
||||||
|
|
||||||
|
def anderson_darling(x):
|
||||||
|
"""Anderson-Darling test for a standard normal distribution."""
|
||||||
|
x = np.sort(np.ravel(x))
|
||||||
|
n = len(x)
|
||||||
|
i = np.linspace(1, n, n)
|
||||||
|
z = np.sum((2 * i - 1) * np.log(normal_cdf(x)) +
|
||||||
|
(2 * (n - i) + 1) * np.log(1 - normal_cdf(x)))
|
||||||
|
return -n - z / n
|
||||||
|
@ -160,7 +160,19 @@ def _shape_tensor(shape):
|
|||||||
class Generator(tracking.AutoTrackable):
|
class Generator(tracking.AutoTrackable):
|
||||||
"""Random-number generator.
|
"""Random-number generator.
|
||||||
|
|
||||||
It uses Variable to manage its internal state.
|
It uses Variable to manage its internal state, and allows choosing an
|
||||||
|
Random-Number-Generation (RNG) algorithm.
|
||||||
|
|
||||||
|
CPU and GPU with the same algorithm and seed will generate the same integer
|
||||||
|
random numbers. Float-point results (such as the output of `normal`) may have
|
||||||
|
small numerical discrepancies between CPU and GPU.
|
||||||
|
|
||||||
|
Because of different counter-reservation schemes, TPU's integer random numbers
|
||||||
|
will be different from CPU/GPU even with the same algorithm and seed.
|
||||||
|
Also, TPU uses different sampling algorithms for some distributions
|
||||||
|
(e.g. using reverse CDF for sampling normal distribution instead of
|
||||||
|
Box-Muller used by CPU/GPU). Harmonizing TPU's RNG behavior with CPU/GPU is
|
||||||
|
work in progress.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, copy_from=None, seed=None, algorithm=None):
|
def __init__(self, copy_from=None, seed=None, algorithm=None):
|
||||||
|
@ -23,9 +23,11 @@ import numpy as np
|
|||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.kernel_tests.random import util as \
|
||||||
|
random_test_util
|
||||||
from tensorflow.python.ops import gen_random_ops
|
from tensorflow.python.ops import gen_random_ops
|
||||||
from tensorflow.python.ops import gen_stateful_random_ops
|
from tensorflow.python.ops import gen_stateful_random_ops
|
||||||
from tensorflow.python.ops import logging_ops
|
from tensorflow.python.ops import logging_ops
|
||||||
@ -189,33 +191,60 @@ class StatefulRandomOpsTest(test.TestCase):
|
|||||||
compare(True, True)
|
compare(True, True)
|
||||||
compare(True, False)
|
compare(True, False)
|
||||||
|
|
||||||
|
def _sameAsOldRandomOps(self, device):
|
||||||
|
def compare(dtype, old, new):
|
||||||
|
seed1, seed2 = 79, 25
|
||||||
|
# note how the two seeds for the old op correspond to the seed for the new
|
||||||
|
# op
|
||||||
|
with ops.device(device):
|
||||||
|
gen = random.Generator(seed=[0, seed2, seed1])
|
||||||
|
|
||||||
|
# create a graph for the old op in order to call it many times
|
||||||
|
@def_function.function
|
||||||
|
def run_old():
|
||||||
|
with ops.device(device):
|
||||||
|
return old(dtype, seed1, seed2)
|
||||||
|
|
||||||
|
def run_new():
|
||||||
|
with ops.device(device):
|
||||||
|
return new(dtype, gen)
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
self.assertAllEqual(run_old(), run_new())
|
||||||
|
|
||||||
|
shape = constant_op.constant([4, 7])
|
||||||
|
minval = 128
|
||||||
|
maxval = 256
|
||||||
|
|
||||||
|
# passing `dtype` around to compress go/gpylint-faq#cell-var-from-loop and
|
||||||
|
# go/gpylint-faq#undefined-loop-variable
|
||||||
|
def old_normal(dtype, seed1, seed2):
|
||||||
|
return gen_random_ops.random_standard_normal(
|
||||||
|
shape, dtype=dtype, seed=seed1, seed2=seed2)
|
||||||
|
def new_normal(dtype, gen):
|
||||||
|
return gen._standard_normal(shape, dtype=dtype)
|
||||||
|
def old_uniform(dtype, seed1, seed2):
|
||||||
|
minval2 = constant_op.constant(minval, dtype=dtype)
|
||||||
|
maxval2 = constant_op.constant(maxval, dtype=dtype)
|
||||||
|
return gen_random_ops.random_uniform_int(
|
||||||
|
shape, minval=minval2, maxval=maxval2, seed=seed1, seed2=seed2)
|
||||||
|
def new_uniform(dtype, gen):
|
||||||
|
return gen.uniform(
|
||||||
|
shape, minval=minval, maxval=maxval, dtype=dtype)
|
||||||
|
|
||||||
|
for dtype in (dtypes.float16, dtypes.bfloat16, dtypes.float32,
|
||||||
|
dtypes.float64):
|
||||||
|
compare(dtype, old_normal, new_normal)
|
||||||
|
for dtype in [dtypes.int32, dtypes.int64]:
|
||||||
|
compare(dtype, old_uniform, new_uniform)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testCPUSameAsOldRandomOps(self):
|
def testCPUSameAsOldRandomOps(self):
|
||||||
"""Tests that the generated numbers are the same as the old random_ops.py.
|
"""Tests that the generated numbers are the same as the old random_ops.py.
|
||||||
|
|
||||||
The CPU version.
|
The CPU version.
|
||||||
"""
|
"""
|
||||||
seed1, seed2 = 79, 25
|
self._sameAsOldRandomOps("/device:CPU:0")
|
||||||
# note how the two seeds for the old op correspond to the seed for the new
|
|
||||||
# op
|
|
||||||
with ops.device("/device:CPU:0"):
|
|
||||||
random.reset_global_generator([0, seed2, seed1])
|
|
||||||
shape = constant_op.constant([4, 7])
|
|
||||||
dtype = dtypes.float64
|
|
||||||
|
|
||||||
# create a graph for the old op in order to call it many times
|
|
||||||
@def_function.function
|
|
||||||
def old():
|
|
||||||
with ops.device("/device:CPU:0"):
|
|
||||||
return gen_random_ops.random_standard_normal(
|
|
||||||
shape, dtype=dtype, seed=seed1, seed2=seed2)
|
|
||||||
|
|
||||||
def new():
|
|
||||||
with ops.device("/device:CPU:0"):
|
|
||||||
return random.get_global_generator().normal(shape, dtype=dtype)
|
|
||||||
|
|
||||||
for _ in range(100):
|
|
||||||
self.assertAllEqual(old(), new())
|
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
@test_util.run_cuda_only
|
@test_util.run_cuda_only
|
||||||
@ -224,28 +253,103 @@ class StatefulRandomOpsTest(test.TestCase):
|
|||||||
|
|
||||||
The GPU version.
|
The GPU version.
|
||||||
"""
|
"""
|
||||||
seed1, seed2 = 79, 25
|
self._sameAsOldRandomOps(test_util.gpu_device_name())
|
||||||
with ops.device(test_util.gpu_device_name()):
|
|
||||||
random.reset_global_generator([0, seed2, seed1])
|
|
||||||
shape = constant_op.constant([4, 7])
|
|
||||||
dtype = dtypes.float64
|
|
||||||
|
|
||||||
@def_function.function
|
@test_util.run_v2_only
|
||||||
def old():
|
def testUniformIntIsInRange(self):
|
||||||
with ops.device(test_util.gpu_device_name()):
|
minval = 2
|
||||||
return gen_random_ops.random_standard_normal(
|
maxval = 33
|
||||||
shape, dtype=dtype, seed=seed1, seed2=seed2)
|
size = 1000
|
||||||
|
gen = random.Generator(seed=1234)
|
||||||
|
for dtype in [dtypes.int32, dtypes.int64]:
|
||||||
|
x = gen.uniform(
|
||||||
|
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
|
||||||
|
self.assertTrue(np.all(x >= minval))
|
||||||
|
self.assertTrue(np.all(x < maxval))
|
||||||
|
|
||||||
def new():
|
@test_util.run_v2_only
|
||||||
with ops.device(test_util.gpu_device_name()):
|
def testNormalIsFinite(self):
|
||||||
return random.get_global_generator().normal(shape, dtype=dtype)
|
gen = random.Generator(seed=1234)
|
||||||
|
for dtype in [dtypes.float32]:
|
||||||
|
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
||||||
|
self.assertTrue(np.all(np.isfinite(x)))
|
||||||
|
|
||||||
for _ in range(100):
|
@test_util.run_v2_only
|
||||||
self.assertAllEqual(old(), new())
|
def testDistributionOfUniform(self):
|
||||||
|
"""Use Pearson's Chi-squared test to test for uniformity."""
|
||||||
|
n = 1000
|
||||||
|
seed = 12
|
||||||
|
for dtype in [dtypes.int32, dtypes.int64]:
|
||||||
|
gen = random.Generator(seed=seed)
|
||||||
|
maxval = 1
|
||||||
|
if dtype.is_integer:
|
||||||
|
maxval = 100
|
||||||
|
x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy()
|
||||||
|
if maxval > 1:
|
||||||
|
# Normalize y to range [0, 1).
|
||||||
|
x = x.astype(float) / maxval
|
||||||
|
# Tests that the values are distributed amongst 10 bins with equal
|
||||||
|
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
|
||||||
|
# p=0.05. This test is probabilistic and would be flaky if the random
|
||||||
|
# seed were not fixed.
|
||||||
|
val = random_test_util.chi_squared(x, 10)
|
||||||
|
self.assertLess(val, 16.92)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testDistributionOfNormal(self):
|
||||||
|
"""Use Anderson-Darling test to test distribution appears normal."""
|
||||||
|
n = 1000
|
||||||
|
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
|
||||||
|
gen = random.Generator(seed=1234)
|
||||||
|
x = gen.normal(shape=[n], dtype=dtype).numpy()
|
||||||
|
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
||||||
|
# test where the mean and variance are known. This test is probabilistic
|
||||||
|
# so to avoid flakiness the seed is fixed.
|
||||||
|
self.assertLess(
|
||||||
|
random_test_util.anderson_darling(x.astype(float)), 2.492)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testErrors(self):
|
||||||
|
"""Tests that proper errors are raised.
|
||||||
|
"""
|
||||||
|
shape = [2, 3]
|
||||||
|
gen = random.Generator(seed=1234)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
r"algorithm must be of shape \[\], not"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
gen.state.handle, [0, 0], shape)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
TypeError, "Requested dtype: int64"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
gen.state.handle, 1.1, shape)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"Unsupported algorithm id"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
gen.state.handle, 123, shape)
|
||||||
|
var = variables.Variable([0, 0], dtype=dtypes.int32)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"dtype of RNG state variable must be int64, not"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
var.handle, random.RNG_ALG_PHILOX, shape)
|
||||||
|
var = variables.Variable([[0]], dtype=dtypes.int64)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"RNG state must have one and only one dimension, not"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
var.handle, random.RNG_ALG_PHILOX, shape)
|
||||||
|
var = variables.Variable([0], dtype=dtypes.int64)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"For the Philox algorithm, the size of state must be at least"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
var.handle, random.RNG_ALG_PHILOX, shape)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testStatefulStandardNormal(self):
|
def testStatefulStandardNormal(self):
|
||||||
"""Tests that op 'StatefulStandardNormal' still works.
|
"""Tests that the deprecated op 'StatefulStandardNormal' still works.
|
||||||
"""
|
"""
|
||||||
shape = constant_op.constant([4, 7])
|
shape = constant_op.constant([4, 7])
|
||||||
dtype = dtypes.float64
|
dtype = dtypes.float64
|
||||||
@ -277,7 +381,7 @@ class StatefulRandomOpsTest(test.TestCase):
|
|||||||
|
|
||||||
random.reset_global_generator(50)
|
random.reset_global_generator(50)
|
||||||
with self.assertRaisesWithPredicateMatch(
|
with self.assertRaisesWithPredicateMatch(
|
||||||
errors_impl.NotFoundError, "Resource .+ does not exist"):
|
errors.NotFoundError, "Resource .+ does not exist"):
|
||||||
a = f()
|
a = f()
|
||||||
random.reset_global_generator(50)
|
random.reset_global_generator(50)
|
||||||
b = f()
|
b = f()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user