A continuing partial implementation of RFC "Random numbers in TensorFlow 2.0" (https://github.com/tensorflow/community/blob/master/rfcs/20181217-tf2-random-numbers.md):

In this change:
- CPU and GPU kernels for op 'StatefulRandomInt' and 'StatefulRandomFullInt'.

To be done:
- ops for other distributions;
- other RNG algorithms;
- batch seeds;
- initializers ('RandomUniform', etc.);

PiperOrigin-RevId: 239104292
This commit is contained in:
Peng Wang 2019-03-18 18:28:46 -07:00 committed by TensorFlower Gardener
parent e4df992d2e
commit 949474a448
17 changed files with 794 additions and 426 deletions

View File

@ -873,6 +873,7 @@ tf_xla_py_test(
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:standard_ops", "//tensorflow/python:standard_ops",
"//tensorflow/python:stateful_random_ops", "//tensorflow/python:stateful_random_ops",
"//tensorflow/python/kernel_tests/random:util",
], ],
) )
@ -887,6 +888,7 @@ tf_xla_py_test(
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:standard_ops", "//tensorflow/python:standard_ops",
"//tensorflow/python:stateless_random_ops", "//tensorflow/python:stateless_random_ops",
"//tensorflow/python/kernel_tests/random:util",
], ],
) )

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import numpy as np import numpy as np
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
@ -29,6 +27,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.kernel_tests.random import util as \
random_test_util
from tensorflow.python.ops import gen_stateful_random_ops from tensorflow.python.ops import gen_stateful_random_ops
from tensorflow.python.ops import stateful_random_ops as \ from tensorflow.python.ops import stateful_random_ops as \
random random
@ -181,14 +181,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
x = gen.normal(shape=[10000], dtype=dtype).numpy() x = gen.normal(shape=[10000], dtype=dtype).numpy()
self.assertTrue(np.all(np.isfinite(x))) self.assertTrue(np.all(np.isfinite(x)))
def _chi_squared(self, x, bins):
"""Pearson's Chi-squared test."""
x = np.ravel(x)
n = len(x)
histogram, _ = np.histogram(x, bins=bins, range=(0, 1))
expected = n / float(bins)
return np.sum(np.square(histogram - expected) / expected)
@test_util.run_v2_only @test_util.run_v2_only
def testDistributionOfUniform(self): def testDistributionOfUniform(self):
"""Use Pearson's Chi-squared test to test for uniformity.""" """Use Pearson's Chi-squared test to test for uniformity."""
@ -208,22 +200,9 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
# p=0.05. This test is probabilistic and would be flaky if the random # p=0.05. This test is probabilistic and would be flaky if the random
# seed were not fixed. # seed were not fixed.
val = self._chi_squared(x, 10) val = random_test_util.chi_squared(x, 10)
self.assertLess(val, 16.92) self.assertLess(val, 16.92)
def _normal_cdf(self, x):
"""Cumulative distribution function for a standard normal distribution."""
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
def _anderson_darling(self, x):
"""Anderson-Darling test for a standard normal distribution."""
x = np.sort(np.ravel(x))
n = len(x)
i = np.linspace(1, n, n)
z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) +
(2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x)))
return -n - z / n
@test_util.run_v2_only @test_util.run_v2_only
def testDistributionOfNormal(self): def testDistributionOfNormal(self):
"""Use Anderson-Darling test to test distribution appears normal.""" """Use Anderson-Darling test to test distribution appears normal."""
@ -235,7 +214,8 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
# The constant 2.492 is the 5% critical value for the Anderson-Darling # The constant 2.492 is the 5% critical value for the Anderson-Darling
# test where the mean and variance are known. This test is probabilistic # test where the mean and variance are known. This test is probabilistic
# so to avoid flakiness the seed is fixed. # so to avoid flakiness the seed is fixed.
self.assertLess(self._anderson_darling(x.astype(float)), 2.492) self.assertLess(
random_test_util.anderson_darling(x.astype(float)), 2.492)
@test_util.run_v2_only @test_util.run_v2_only
def testErrors(self): def testErrors(self):

View File

@ -24,6 +24,8 @@ import numpy as np
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.kernel_tests.random import util as \
random_test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import stateless_random_ops as stateless from tensorflow.python.ops import stateless_random_ops as stateless
from tensorflow.python.ops.distributions import special_math from tensorflow.python.ops.distributions import special_math
@ -43,7 +45,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
# Stateless values should be equal iff the seeds are equal (roughly) # Stateless values should be equal iff the seeds are equal (roughly)
with self.cached_session(), self.test_scope(): with self.cached_session(), self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
for stateless_op in [ for stateless_op in [
stateless.stateless_random_uniform, stateless.stateless_random_normal stateless.stateless_random_uniform, stateless.stateless_random_normal
]: ]:
@ -75,14 +77,6 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertTrue(np.all(y >= 0)) self.assertTrue(np.all(y >= 0))
self.assertTrue(np.all(y < maxval)) self.assertTrue(np.all(y < maxval))
def _chi_squared(self, x, bins):
"""Pearson's Chi-squared test."""
x = np.ravel(x)
n = len(x)
histogram, _ = np.histogram(x, bins=bins, range=(0, 1))
expected = n / float(bins)
return np.sum(np.square(histogram - expected) / expected)
def testDistributionOfStatelessRandomUniform(self): def testDistributionOfStatelessRandomUniform(self):
"""Use Pearson's Chi-squared test to test for uniformity.""" """Use Pearson's Chi-squared test to test for uniformity."""
with self.cached_session() as sess, self.test_scope(): with self.cached_session() as sess, self.test_scope():
@ -102,7 +96,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
# p=0.05. This test is probabilistic and would be flaky if the random # p=0.05. This test is probabilistic and would be flaky if the random
# seed were not fixed. # seed were not fixed.
self.assertTrue(self._chi_squared(y, 10) < 16.92) self.assertLess(random_test_util.chi_squared(y, 10), 16.92)
def testRandomNormalIsFinite(self): def testRandomNormalIsFinite(self):
with self.cached_session() as sess, self.test_scope(): with self.cached_session() as sess, self.test_scope():
@ -113,19 +107,6 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
self.assertTrue(np.all(np.isfinite(y))) self.assertTrue(np.all(np.isfinite(y)))
def _normal_cdf(self, x):
"""Cumulative distribution function for a standard normal distribution."""
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
def _anderson_darling(self, x):
"""Anderson-Darling test for a standard normal distribution."""
x = np.sort(np.ravel(x))
n = len(x)
i = np.linspace(1, n, n)
z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) +
(2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x)))
return -n - z / n
def testDistributionOfStatelessRandomNormal(self): def testDistributionOfStatelessRandomNormal(self):
"""Use Anderson-Darling test to test distribution appears normal.""" """Use Anderson-Darling test to test distribution appears normal."""
with self.cached_session() as sess, self.test_scope(): with self.cached_session() as sess, self.test_scope():
@ -138,7 +119,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
# The constant 2.492 is the 5% critical value for the Anderson-Darling # The constant 2.492 is the 5% critical value for the Anderson-Darling
# test where the mean and variance are known. This test is probabilistic # test where the mean and variance are known. This test is probabilistic
# so to avoid flakiness the seed is fixed. # so to avoid flakiness the seed is fixed.
self.assertTrue(self._anderson_darling(y.astype(float)) < 2.492) self.assertLess(
random_test_util.anderson_darling(y.astype(float)), 2.492)
def testTruncatedNormalIsInRange(self): def testTruncatedNormalIsInRange(self):
for dtype in self._random_types(): for dtype in self._random_types():
@ -155,7 +137,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def normal_pdf(x): def normal_pdf(x):
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x, sess=sess): def probit(x):
return self.evaluate(special_math.ndtri(x)) return self.evaluate(special_math.ndtri(x))
a = -2. a = -2.

View File

@ -257,6 +257,7 @@ tensorflow/core/kernels/split_op.cc
tensorflow/core/kernels/split_v_op.cc tensorflow/core/kernels/split_v_op.cc
tensorflow/core/kernels/stack.cc tensorflow/core/kernels/stack.cc
tensorflow/core/kernels/stack_ops.cc tensorflow/core/kernels/stack_ops.cc
tensorflow/core/kernels/stateful_random_ops.cc
tensorflow/core/kernels/stateless_random_ops.cc tensorflow/core/kernels/stateless_random_ops.cc
tensorflow/core/kernels/strided_slice_op.cc tensorflow/core/kernels/strided_slice_op.cc
tensorflow/core/kernels/strided_slice_op_inst_0.cc tensorflow/core/kernels/strided_slice_op_inst_0.cc

View File

@ -5784,6 +5784,7 @@ filegroup(
"queue_op.cc", "queue_op.cc",
"queue_ops.cc", "queue_ops.cc",
"random_op.cc", "random_op.cc",
"random_op_cpu.h",
"reduction_ops_all.cc", "reduction_ops_all.cc",
"reduction_ops_any.cc", "reduction_ops_any.cc",
"reduction_ops_common.cc", "reduction_ops_common.cc",

View File

@ -17,8 +17,6 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/random_op.h"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <memory> #include <memory>
@ -27,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/random_op_cpu.h"
#include "tensorflow/core/lib/hash/crc32c.h" #include "tensorflow/core/lib/hash/crc32c.h"
#include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/lib/random/simple_philox.h"
@ -52,131 +51,6 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice; typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
namespace functor {
using random::PhiloxRandom;
using random::SingleSampleAdapter;
// The default implementation of the functor, which should never be invoked
// But we still need to provide implementation for now for the linker to work,
// since we do not support all the distributions yet.
template <typename Device, class Distribution>
struct FillPhiloxRandom {
typedef typename Distribution::ResultElementType T;
void operator()(OpKernelContext*, const Device&, random::PhiloxRandom gen,
T* data, int64 size, Distribution dist) {
LOG(FATAL) << "Default FillPhiloxRandom should not be executed.";
}
};
// A class to fill a specified range of random groups
template <class Distribution, bool VariableSamplesPerOutput>
struct FillPhiloxRandomTask;
// Specialization for distribution that takes a fixed number of samples for
// each output.
template <class Distribution>
struct FillPhiloxRandomTask<Distribution, false> {
typedef typename Distribution::ResultElementType T;
static void Run(random::PhiloxRandom gen, T* data, int64 size,
int64 start_group, int64 limit_group, Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
gen.Skip(start_group);
int64 offset = start_group * kGroupSize;
// First fill all the full-size groups
int64 limit_group_full = std::min(limit_group, size / kGroupSize);
for (int64 index = start_group; index < limit_group_full; ++index) {
auto samples = dist(&gen);
std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
offset += kGroupSize;
}
// If there are any remaining elements that need to be filled, process them
if (limit_group_full < limit_group) {
int64 remaining_size = size - limit_group_full * kGroupSize;
auto samples = dist(&gen);
std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
}
}
};
// Specialization for distribution that takes a variable number of samples for
// each output. This will be slower due to the generality.
template <class Distribution>
struct FillPhiloxRandomTask<Distribution, true> {
typedef typename Distribution::ResultElementType T;
static const int64 kReservedSamplesPerOutput = 256;
static void Run(random::PhiloxRandom base_gen, T* data, int64 size,
int64 start_group, int64 limit_group, Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
static const int kGeneratorSkipPerOutputGroup =
kGroupSize * kReservedSamplesPerOutput /
PhiloxRandom::kResultElementCount;
int64 offset = start_group * kGroupSize;
// First fill all the full-size groups
int64 limit_group_full = std::min(limit_group, size / kGroupSize);
int64 group_index;
for (group_index = start_group; group_index < limit_group_full;
++group_index) {
// Reset the generator to the beginning of the output group region
// This is necessary if we want the results to be independent of order
// of work
PhiloxRandom gen = base_gen;
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
auto samples = dist(&single_samples);
std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
offset += kGroupSize;
}
// If there are any remaining elements that need to be filled, process them
if (limit_group_full < limit_group) {
PhiloxRandom gen = base_gen;
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
int64 remaining_size = size - limit_group_full * kGroupSize;
auto samples = dist(&single_samples);
std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
}
}
};
// Partial specialization for CPU to fill the entire region with randoms
// It splits the work into several tasks and run them in parallel
template <class Distribution>
void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen,
typename Distribution::ResultElementType* data, int64 size,
Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
const int kGroupCost =
random::PhiloxRandom::kResultElementCount *
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
kGroupCost,
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
FillPhiloxRandomTask<
Distribution,
Distribution::kVariableSamplesPerOutput>::Run(gen, data, size,
start_group,
limit_group, dist);
});
}
} // namespace functor
namespace { namespace {
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape, static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
@ -354,7 +228,7 @@ class RandomGammaOp : public OpKernel {
const double alpha = static_cast<double>(alpha_flat[alpha_idx]); const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
DISABLE_FLOAT_EQUALITY_WARNING DISABLE_FLOAT_EQUALITY_WARNING
if (alpha == double(1.0)) { if (alpha == static_cast<double>(1.0)) {
ENABLE_FLOAT_EQUALITY_WARNING ENABLE_FLOAT_EQUALITY_WARNING
// Sample from an exponential distribution. // Sample from an exponential distribution.
for (int64 sample_idx = output_idx % num_samples; for (int64 sample_idx = output_idx % num_samples;
@ -364,7 +238,7 @@ class RandomGammaOp : public OpKernel {
// (including eventually on GPU), we skip on a per-sample basis. // (including eventually on GPU), we skip on a per-sample basis.
PhiloxRandom gen = rng; PhiloxRandom gen = rng;
gen.Skip(kReservedSamplesPerOutput * output_idx); gen.Skip(kReservedSamplesPerOutput * output_idx);
short uniform_remaining = 0; int16 uniform_remaining = 0;
UNIFORM(u); UNIFORM(u);
const double res = -log(1.0 - u); const double res = -log(1.0 - u);
samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res); samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
@ -392,8 +266,8 @@ class RandomGammaOp : public OpKernel {
// (including eventually on GPU), we skip on a per-sample basis. // (including eventually on GPU), we skip on a per-sample basis.
PhiloxRandom gen = rng; PhiloxRandom gen = rng;
gen.Skip(kReservedSamplesPerOutput * output_idx); gen.Skip(kReservedSamplesPerOutput * output_idx);
short norm_remaining = 0; int16 norm_remaining = 0;
short uniform_remaining = 0; int16 uniform_remaining = 0;
// Keep trying until we don't reject a sample. In practice, we will // Keep trying until we don't reject a sample. In practice, we will
// only reject ~5% at worst, for low alpha near 1. // only reject ~5% at worst, for low alpha near 1.
@ -565,145 +439,6 @@ TF_CALL_int64(REGISTER_INT);
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
namespace functor {
using namespace cl;
template <class Distribution, bool VariableSamplesPerOutput>
struct FillPhiloxRandomKernel;
template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, false> {
typedef typename Distribution::ResultElementType T;
using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
sycl::access::target::global_buffer>;
FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
Distribution& dist)
: data_(data), gen_(gen), dist_(dist) {}
void operator()(sycl::nd_item<1> item) {
const size_t kGroupSize = Distribution::kResultElementCount;
const size_t item_id = item.get_global(0);
const size_t total_item_count = item.get_global_range();
size_t offset = item_id * kGroupSize;
gen_.Skip(item_id);
const size_t size = data_.get_size() / sizeof(T);
T* data = ConvertToActualTypeSycl(T, data_);
while (offset + kGroupSize <= size) {
const typename Distribution::ResultType samples = dist_(&gen_);
for (size_t i = 0; i < kGroupSize; ++i) {
data[offset + i] = samples[i];
}
offset += (total_item_count - 1) * kGroupSize;
gen_.Skip(total_item_count - 1);
}
const typename Distribution::ResultType samples = dist_(&gen_);
for (size_t i = 0; i < kGroupSize; ++i) {
if (offset >= size) {
return;
}
data[offset] = samples[i];
++offset;
}
}
private:
write_accessor data_;
random::PhiloxRandom gen_;
Distribution dist_;
};
template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, true> {
typedef typename Distribution::ResultElementType T;
using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
sycl::access::target::global_buffer>;
FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
Distribution& dist)
: data_(data), gen_(gen), dist_(dist) {}
void operator()(sycl::nd_item<1> item) {
using random::PhiloxRandom;
using random::SingleSampleAdapter;
const size_t kReservedSamplesPerOutput = 256;
const size_t kGroupSize = Distribution::kResultElementCount;
const size_t kGeneratorSkipPerOutputGroup =
kGroupSize * kReservedSamplesPerOutput /
PhiloxRandom::kResultElementCount;
const size_t item_id = item.get_global(0);
const size_t total_item_count = item.get_global_range();
size_t group_index = item_id;
size_t offset = group_index * kGroupSize;
T* data = ConvertToActualTypeSycl(T, data_);
const size_t size = data_.get_size() / sizeof(T);
while (offset < size) {
// Since each output takes a variable number of samples, we need to
// realign the generator to the beginning for the current output group
PhiloxRandom gen = gen_;
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
const typename Distribution::ResultType samples = dist_(&single_samples);
for (size_t i = 0; i < kGroupSize; ++i) {
if (offset >= size) {
return;
}
data[offset] = samples[i];
++offset;
}
offset += (total_item_count - 1) * kGroupSize;
group_index += total_item_count;
}
}
private:
write_accessor data_;
random::PhiloxRandom gen_;
Distribution dist_;
};
template <typename T>
class FillRandomKernel;
// Partial specialization for SYCL to fill the entire region with randoms
// It splits the work into several tasks and run them in parallel
template <class Distribution>
void FillPhiloxRandom<SYCLDevice, Distribution>::operator()(
OpKernelContext* context, const SYCLDevice& device,
random::PhiloxRandom gen, typename Distribution::ResultElementType* data,
int64 size, Distribution dist) {
const size_t group_size = device.maxSyclThreadsPerBlock();
const size_t group_count = (size + group_size - 1) / group_size;
auto buffer = device.get_sycl_buffer(data);
device.sycl_queue().submit([&](sycl::handler& cgh) {
auto access = buffer.template get_access<sycl::access::mode::write>(cgh);
FillPhiloxRandomKernel<Distribution,
Distribution::kVariableSamplesPerOutput>
task(access, gen, dist);
cgh.parallel_for<class FillRandomKernel<Distribution>>(
sycl::nd_range<1>(sycl::range<1>(group_count * group_size),
sycl::range<1>(group_size)),
task);
});
}
} // namespace functor
#define REGISTER(TYPE) \ #define REGISTER(TYPE) \
template struct functor::FillPhiloxRandom< \ template struct functor::FillPhiloxRandom< \
SYCLDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>; \ SYCLDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>; \

View File

@ -0,0 +1,325 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_
#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_
#define EIGEN_USE_THREADS
#include <algorithm>
#include <cmath>
#include <memory>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/random_op.h"
#include "tensorflow/core/lib/hash/crc32c.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/guarded_philox_random.h"
#include "tensorflow/core/util/work_sharder.h"
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
#define DISABLE_FLOAT_EQUALITY_WARNING \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
#else
#define DISABLE_FLOAT_EQUALITY_WARNING
#define ENABLE_FLOAT_EQUALITY_WARNING
#endif
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
namespace functor {
using random::PhiloxRandom;
using random::SingleSampleAdapter;
// The default implementation of the functor, which should never be invoked
// But we still need to provide implementation for now for the linker to work,
// since we do not support all the distributions yet.
template <typename Device, class Distribution>
struct FillPhiloxRandom {
typedef typename Distribution::ResultElementType T;
void operator()(OpKernelContext* ctx, const Device&, random::PhiloxRandom gen,
T* data, int64 size, Distribution dist) {
OP_REQUIRES(
ctx, false,
errors::Internal("Default FillPhiloxRandom should not be executed."));
}
};
// A class to fill a specified range of random groups
template <class Distribution, bool VariableSamplesPerOutput>
struct FillPhiloxRandomTask;
// Specialization for distribution that takes a fixed number of samples for
// each output.
template <class Distribution>
struct FillPhiloxRandomTask<Distribution, false> {
typedef typename Distribution::ResultElementType T;
static void Run(random::PhiloxRandom gen, T* data, int64 size,
int64 start_group, int64 limit_group, Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
gen.Skip(start_group);
int64 offset = start_group * kGroupSize;
// First fill all the full-size groups
int64 limit_group_full = std::min(limit_group, size / kGroupSize);
for (int64 index = start_group; index < limit_group_full; ++index) {
auto samples = dist(&gen);
std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
offset += kGroupSize;
}
// If there are any remaining elements that need to be filled, process them
if (limit_group_full < limit_group) {
int64 remaining_size = size - limit_group_full * kGroupSize;
auto samples = dist(&gen);
std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
}
}
};
// Specialization for distribution that takes a variable number of samples for
// each output. This will be slower due to the generality.
template <class Distribution>
struct FillPhiloxRandomTask<Distribution, true> {
typedef typename Distribution::ResultElementType T;
static const int64 kReservedSamplesPerOutput = 256;
static void Run(random::PhiloxRandom base_gen, T* data, int64 size,
int64 start_group, int64 limit_group, Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
static const int kGeneratorSkipPerOutputGroup =
kGroupSize * kReservedSamplesPerOutput /
PhiloxRandom::kResultElementCount;
int64 offset = start_group * kGroupSize;
// First fill all the full-size groups
int64 limit_group_full = std::min(limit_group, size / kGroupSize);
int64 group_index;
for (group_index = start_group; group_index < limit_group_full;
++group_index) {
// Reset the generator to the beginning of the output group region
// This is necessary if we want the results to be independent of order
// of work
PhiloxRandom gen = base_gen;
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
auto samples = dist(&single_samples);
std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
offset += kGroupSize;
}
// If there are any remaining elements that need to be filled, process them
if (limit_group_full < limit_group) {
PhiloxRandom gen = base_gen;
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
int64 remaining_size = size - limit_group_full * kGroupSize;
auto samples = dist(&single_samples);
std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
}
}
};
// Partial specialization for CPU to fill the entire region with randoms
// It splits the work into several tasks and run them in parallel
template <class Distribution>
void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen,
typename Distribution::ResultElementType* data, int64 size,
Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
const int kGroupCost =
random::PhiloxRandom::kResultElementCount *
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
kGroupCost,
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
FillPhiloxRandomTask<
Distribution,
Distribution::kVariableSamplesPerOutput>::Run(gen, data, size,
start_group,
limit_group, dist);
});
}
} // namespace functor
#ifdef TENSORFLOW_USE_SYCL
namespace functor {
template <class Distribution, bool VariableSamplesPerOutput>
struct FillPhiloxRandomKernel;
template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, false> {
typedef typename Distribution::ResultElementType T;
using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
sycl::access::target::global_buffer>;
FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
Distribution& dist)
: data_(data), gen_(gen), dist_(dist) {}
void operator()(sycl::nd_item<1> item) {
const size_t kGroupSize = Distribution::kResultElementCount;
const size_t item_id = item.get_global(0);
const size_t total_item_count = item.get_global_range();
size_t offset = item_id * kGroupSize;
gen_.Skip(item_id);
const size_t size = data_.get_size() / sizeof(T);
T* data = ConvertToActualTypeSycl(T, data_);
while (offset + kGroupSize <= size) {
const typename Distribution::ResultType samples = dist_(&gen_);
for (size_t i = 0; i < kGroupSize; ++i) {
data[offset + i] = samples[i];
}
offset += (total_item_count - 1) * kGroupSize;
gen_.Skip(total_item_count - 1);
}
const typename Distribution::ResultType samples = dist_(&gen_);
for (size_t i = 0; i < kGroupSize; ++i) {
if (offset >= size) {
return;
}
data[offset] = samples[i];
++offset;
}
}
private:
write_accessor data_;
random::PhiloxRandom gen_;
Distribution dist_;
};
template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, true> {
typedef typename Distribution::ResultElementType T;
using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
sycl::access::target::global_buffer>;
FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
Distribution& dist)
: data_(data), gen_(gen), dist_(dist) {}
void operator()(sycl::nd_item<1> item) {
using random::PhiloxRandom;
using random::SingleSampleAdapter;
const size_t kReservedSamplesPerOutput = 256;
const size_t kGroupSize = Distribution::kResultElementCount;
const size_t kGeneratorSkipPerOutputGroup =
kGroupSize * kReservedSamplesPerOutput /
PhiloxRandom::kResultElementCount;
const size_t item_id = item.get_global(0);
const size_t total_item_count = item.get_global_range();
size_t group_index = item_id;
size_t offset = group_index * kGroupSize;
T* data = ConvertToActualTypeSycl(T, data_);
const size_t size = data_.get_size() / sizeof(T);
while (offset < size) {
// Since each output takes a variable number of samples, we need to
// realign the generator to the beginning for the current output group
PhiloxRandom gen = gen_;
gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
const typename Distribution::ResultType samples = dist_(&single_samples);
for (size_t i = 0; i < kGroupSize; ++i) {
if (offset >= size) {
return;
}
data[offset] = samples[i];
++offset;
}
offset += (total_item_count - 1) * kGroupSize;
group_index += total_item_count;
}
}
private:
write_accessor data_;
random::PhiloxRandom gen_;
Distribution dist_;
};
template <typename T>
class FillRandomKernel;
// Partial specialization for SYCL to fill the entire region with randoms
// It splits the work into several tasks and run them in parallel
template <class Distribution>
void FillPhiloxRandom<SYCLDevice, Distribution>::operator()(
OpKernelContext* context, const SYCLDevice& device,
random::PhiloxRandom gen, typename Distribution::ResultElementType* data,
int64 size, Distribution dist) {
const size_t group_size = device.maxSyclThreadsPerBlock();
const size_t group_count = (size + group_size - 1) / group_size;
auto buffer = device.get_sycl_buffer(data);
device.sycl_queue().submit([&](sycl::handler& cgh) {
auto access = buffer.template get_access<sycl::access::mode::write>(cgh);
FillPhiloxRandomKernel<Distribution,
Distribution::kVariableSamplesPerOutput>
task(access, gen, dist);
cgh.parallel_for<class FillRandomKernel<Distribution>>(
sycl::nd_range<1>(sycl::range<1>(group_count * group_size),
sycl::range<1>(group_size)),
task);
});
}
} // namespace functor
#endif // TENSORFLOW_USE_SYCL
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_

View File

@ -17,17 +17,15 @@ limitations under the License.
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "tensorflow/core/kernels/random_op.h"
#include "tensorflow/core/kernels/random_op_gpu.h"
#include <assert.h> #include <assert.h>
#include <stdio.h> #include <stdio.h>
#include "tensorflow/core/kernels/random_op_gpu.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow { namespace tensorflow {
@ -37,33 +35,6 @@ namespace functor {
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
// A simple launch pad to call the correct function templates to fill the data
template <class Distribution>
__global__ void __launch_bounds__(1024)
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
typename Distribution::ResultElementType* data,
int64 size, Distribution dist) {
FillPhiloxRandomKernel<Distribution,
Distribution::kVariableSamplesPerOutput>()
.Run(base_gen, data, size, dist);
}
// Partial specialization for GPU
template <class Distribution>
void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
typename Distribution::ResultElementType* data, int64 size,
Distribution dist) {
const int32 block_size = d.maxGpuThreadsPerBlock();
const int32 num_blocks =
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
block_size;
TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
num_blocks, block_size, 0, d.stream(), gen, data,
size, dist));
}
// Explicit instantiation of the GPU distributions functors // Explicit instantiation of the GPU distributions functors
// clang-format off // clang-format off
// NVCC cannot handle ">>" properly // NVCC cannot handle ">>" properly

View File

@ -18,8 +18,10 @@ limitations under the License.
#if defined(__CUDACC__) #if defined(__CUDACC__)
#include "tensorflow/core/kernels/random_op.h"
#include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow { namespace tensorflow {
@ -31,14 +33,14 @@ struct FillPhiloxRandomKernel;
template <class Distribution> template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, false> { struct FillPhiloxRandomKernel<Distribution, false> {
typedef typename Distribution::ResultElementType T; typedef typename Distribution::ResultElementType T;
PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size, PHILOX_DEVICE_INLINE void Run(random::PhiloxRandom gen, T* data, int64 size,
Distribution dist); Distribution dist);
}; };
template <class Distribution> template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, true> { struct FillPhiloxRandomKernel<Distribution, true> {
typedef typename Distribution::ResultElementType T; typedef typename Distribution::ResultElementType T;
PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data, PHILOX_DEVICE_INLINE void Run(const random::PhiloxRandom& base_gen, T* data,
int64 size, Distribution dist); int64 size, Distribution dist);
}; };
@ -128,7 +130,7 @@ class SampleCopier<int64, 2> {
// A cuda kernel to fill the data with random numbers from the specified // A cuda kernel to fill the data with random numbers from the specified
// distribution. Each output takes a fixed number of samples. // distribution. Each output takes a fixed number of samples.
template <class Distribution> template <class Distribution>
PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel<Distribution, false>::Run( PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
random::PhiloxRandom gen, T* data, int64 size, Distribution dist) { random::PhiloxRandom gen, T* data, int64 size, Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount; const int kGroupSize = Distribution::kResultElementCount;
@ -159,7 +161,7 @@ PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel<Distribution, false>::Run(
// A cuda kernel to fill the data with random numbers from the specified // A cuda kernel to fill the data with random numbers from the specified
// distribution. Each output takes a variable number of samples. // distribution. Each output takes a variable number of samples.
template <class Distribution> template <class Distribution>
PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel<Distribution, true>::Run( PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
const random::PhiloxRandom& base_gen, T* data, int64 size, const random::PhiloxRandom& base_gen, T* data, int64 size,
Distribution dist) { Distribution dist) {
using random::PhiloxRandom; using random::PhiloxRandom;
@ -198,6 +200,33 @@ PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel<Distribution, true>::Run(
} }
} }
// A simple launch pad to call the correct function templates to fill the data
template <class Distribution>
__global__ void __launch_bounds__(1024)
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
typename Distribution::ResultElementType* data,
int64 size, Distribution dist) {
FillPhiloxRandomKernel<Distribution,
Distribution::kVariableSamplesPerOutput>()
.Run(base_gen, data, size, dist);
}
// Partial specialization for GPU
template <class Distribution>
void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
typename Distribution::ResultElementType* data, int64 size,
Distribution dist) {
const int32 block_size = d.maxGpuThreadsPerBlock();
const int32 num_blocks =
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
block_size;
TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
num_blocks, block_size, 0, d.stream(), gen, data,
size, dist));
}
} // namespace functor } // namespace functor
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/random_op.h" #include "tensorflow/core/kernels/random_op_cpu.h"
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
#include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
@ -25,7 +25,7 @@ namespace tensorflow {
template <typename Distribution> template <typename Distribution>
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> { struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
void operator()(OpKernelContext* ctx, const CPUDevice& device, void operator()(OpKernelContext* ctx, const CPUDevice& device,
int64 output_size, int64 alg_tag_skip, Distribution dist, int64 output_size, int64 alg_tag_skip,
ScopedUnlockUnrefVar* state_var_guard, Tensor* state_tensor, ScopedUnlockUnrefVar* state_var_guard, Tensor* state_tensor,
typename Distribution::ResultElementType* output_data) { typename Distribution::ResultElementType* output_data) {
auto state_tensor_flat = state_tensor->flat<StateElementType>(); auto state_tensor_flat = state_tensor->flat<StateElementType>();
@ -36,14 +36,14 @@ struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
// No longer needs the lock. // No longer needs the lock.
state_var_guard->Release(); state_var_guard->Release();
functor::FillPhiloxRandom<CPUDevice, Distribution>()( functor::FillPhiloxRandom<CPUDevice, Distribution>()(
ctx, device, philox, output_data, output_size, Distribution()); ctx, device, philox, output_data, output_size, dist);
} }
}; };
template <typename Device, typename Distribution> template <typename Device, typename Distribution>
Status UpdateVariableAndFill( Status UpdateVariableAndFill(
OpKernelContext* ctx, int state_input_idx, bool read_alg_from_state, OpKernelContext* ctx, Distribution dist, int state_input_idx,
Algorithm alg, int64 output_size, bool read_alg_from_state, Algorithm alg, int64 output_size,
typename Distribution::ResultElementType* output_data) { typename Distribution::ResultElementType* output_data) {
Var* var = nullptr; Var* var = nullptr;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -89,7 +89,7 @@ Status UpdateVariableAndFill(
TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, StateElementType>( TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, StateElementType>(
ctx, var_tensor, var->copy_on_read_mode.load())); ctx, var_tensor, var->copy_on_read_mode.load()));
UpdateVariableAndFill_Philox<Device, Distribution>()( UpdateVariableAndFill_Philox<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(), output_size, alg_tag_skip, ctx, ctx->eigen_device<Device>(), dist, output_size, alg_tag_skip,
&state_var_guard, var_tensor, output_data); &state_var_guard, var_tensor, output_data);
return Status::OK(); return Status::OK();
} else { } else {
@ -99,7 +99,8 @@ Status UpdateVariableAndFill(
// Preconditon: input(0) is an existing resource. // Preconditon: input(0) is an existing resource.
template <typename Device, class Distribution> template <typename Device, class Distribution>
void ComputeImpl(OpKernelContext* ctx, int state_input_idx, int shape_input_idx, void StatefulRandomCompute(OpKernelContext* ctx, Distribution dist,
int state_input_idx, int shape_input_idx,
bool read_alg_from_state, Algorithm alg) { bool read_alg_from_state, Algorithm alg) {
using T = typename Distribution::ResultElementType; using T = typename Distribution::ResultElementType;
const Tensor& shape_t = ctx->input(shape_input_idx); const Tensor& shape_t = ctx->input(shape_input_idx);
@ -108,8 +109,8 @@ void ComputeImpl(OpKernelContext* ctx, int state_input_idx, int shape_input_idx,
Tensor* output; Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
auto output_flat = output->flat<T>(); auto output_flat = output->flat<T>();
OP_REQUIRES_OK(ctx, UpdateVariableAndFill<Device, Distribution>( OP_REQUIRES_OK(ctx, UpdateVariableAndFill<Device>(
ctx, state_input_idx, read_alg_from_state, alg, ctx, dist, state_input_idx, read_alg_from_state, alg,
output_flat.size(), output_flat.data())); output_flat.size(), output_flat.data()));
} }
@ -119,27 +120,89 @@ class StatefulRandomOp : public OpKernel {
explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
ComputeImpl<Device, Distribution>(ctx, 0, 1, true, 0); StatefulRandomCompute<Device>(ctx, Distribution(), 0, 1, true, 0);
} }
}; };
Status GetAlgorithm(OpKernelContext* ctx, int alg_input_idx, Algorithm* alg) {
const Tensor& alg_tensor = ctx->input(alg_input_idx);
if (alg_tensor.dims() != 0) {
return errors::InvalidArgument("algorithm must be of shape [], not ",
alg_tensor.shape().DebugString());
}
if (alg_tensor.dtype() != ALGORITHM_DTYPE) {
return errors::InvalidArgument("algorithm's dtype must be ",
DataTypeString(ALGORITHM_DTYPE), ", not ",
DataTypeString(alg_tensor.dtype()));
}
*alg = alg_tensor.flat<Algorithm>()(0);
return Status::OK();
}
template <typename Device, class Distribution> template <typename Device, class Distribution>
class StatefulRandomOpV2 : public OpKernel { class StatefulRandomOpV2 : public OpKernel {
public: public:
explicit StatefulRandomOpV2(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit StatefulRandomOpV2(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
const Tensor& alg_tensor = ctx->input(1); Algorithm alg;
OP_REQUIRES(ctx, alg_tensor.dims() == 0, OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, /*alg_input_idx=*/1, &alg));
errors::InvalidArgument("algorithm must be of shape [], not ", StatefulRandomCompute<Device>(ctx, Distribution(), /*state_input_idx=*/0,
alg_tensor.shape().DebugString())); /*shape_input_idx=*/2,
/*read_alg_from_state=*/false, alg);
}
};
template <typename Device, class IntType>
class StatefulUniformIntOp : public OpKernel {
public:
explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
Algorithm alg;
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, /*alg_input_idx=*/1, &alg));
const Tensor& minval = ctx->input(3);
const Tensor& maxval = ctx->input(4);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
errors::InvalidArgument("minval must be 0-D, got shape ",
minval.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval.shape().DebugString()));
// Verify that minval < maxval. This check intentionally happens after the
// early exit for empty output. Zero impossible things are fine.
IntType lo = minval.scalar<IntType>()();
IntType hi = maxval.scalar<IntType>()();
OP_REQUIRES( OP_REQUIRES(
ctx, alg_tensor.dtype() == ALGORITHM_DTYPE, ctx, lo < hi,
errors::InvalidArgument("algorithm's dtype must be ", errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
DataTypeString(ALGORITHM_DTYPE), ", not ",
DataTypeString(alg_tensor.dtype()))); // Build distribution
auto alg = alg_tensor.flat<Algorithm>()(0); typedef random::UniformDistribution<random::PhiloxRandom, IntType>
ComputeImpl<Device, Distribution>(ctx, 0, 2, false, alg); Distribution;
Distribution dist(lo, hi);
StatefulRandomCompute<Device>(ctx, dist, /*state_input_idx=*/0,
/*shape_input_idx=*/2,
/*read_alg_from_state=*/false, alg);
}
};
template <typename Device, class IntType>
class StatefulUniformFullIntOp : public OpKernel {
public:
explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
Algorithm alg;
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, /*alg_input_idx=*/1, &alg));
StatefulRandomCompute<Device>(
ctx,
random::UniformFullIntDistribution<random::PhiloxRandom, IntType>(),
/*state_input_idx=*/0, /*shape_input_idx=*/2,
/*read_alg_from_state=*/false, alg);
} }
}; };
@ -213,14 +276,65 @@ TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_float(REGISTER_CPU); TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU); TF_CALL_double(REGISTER_CPU);
#define REGISTER_StatefulUniformInt(DEVICE, TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatefulUniformInt") \
.Device(DEVICE_##DEVICE) \
.HostMemory("resource") \
.HostMemory("algorithm") \
.HostMemory("shape") \
.HostMemory("minval") \
.HostMemory("maxval") \
.TypeConstraint<TYPE>("dtype"), \
StatefulUniformIntOp<DEVICE##Device, TYPE>);
#define REGISTER_StatefulUniformInt_CPU(TYPE) \
REGISTER_StatefulUniformInt(CPU, TYPE)
#define REGISTER_StatefulUniformInt_GPU(TYPE) \
REGISTER_StatefulUniformInt(GPU, TYPE)
TF_CALL_int32(REGISTER_StatefulUniformInt_CPU);
TF_CALL_int64(REGISTER_StatefulUniformInt_CPU);
#define REGISTER_StatefulUniformFullInt(DEVICE, TYPE) \
REGISTER_KERNEL_BUILDER(Name("StatefulUniformFullInt") \
.Device(DEVICE_##DEVICE) \
.HostMemory("resource") \
.HostMemory("algorithm") \
.HostMemory("shape") \
.TypeConstraint<TYPE>("dtype"), \
StatefulUniformFullIntOp<DEVICE##Device, TYPE>);
#define REGISTER_StatefulUniformFullInt_CPU(TYPE) \
REGISTER_StatefulUniformFullInt(CPU, TYPE)
#define REGISTER_StatefulUniformFullInt_GPU(TYPE) \
REGISTER_StatefulUniformFullInt(GPU, TYPE)
TF_CALL_int32(REGISTER_StatefulUniformFullInt_CPU);
TF_CALL_int64(REGISTER_StatefulUniformFullInt_CPU);
TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU);
TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU);
#if GOOGLE_CUDA #if GOOGLE_CUDA
TF_CALL_half(REGISTER_GPU); TF_CALL_half(REGISTER_GPU);
TF_CALL_bfloat16(REGISTER_GPU);
TF_CALL_float(REGISTER_GPU); TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU); TF_CALL_double(REGISTER_GPU);
TF_CALL_int32(REGISTER_StatefulUniformInt_GPU);
TF_CALL_int64(REGISTER_StatefulUniformInt_GPU);
TF_CALL_int32(REGISTER_StatefulUniformFullInt_GPU);
TF_CALL_int64(REGISTER_StatefulUniformFullInt_GPU);
TF_CALL_uint32(REGISTER_StatefulUniformFullInt_GPU);
TF_CALL_uint64(REGISTER_StatefulUniformFullInt_GPU);
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
#undef REGISTER_StatefulUniformFullInt_GPU
#undef REGISTER_StatefulUniformFullInt_CPU
#undef REGISTER_StatefulUniformFullInt
#undef REGISTER_StatefulUniformInt_GPU
#undef REGISTER_StatefulUniformInt_CPU
#undef REGISTER_StatefulUniformInt
#undef REGISTER_GPU #undef REGISTER_GPU
#undef REGISTER_CPU #undef REGISTER_CPU
#undef REGISTER #undef REGISTER

View File

@ -87,7 +87,7 @@ using GPUDevice = Eigen::GpuDevice;
template <typename Distribution> template <typename Distribution>
struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> { struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
void operator()(OpKernelContext* ctx, const GPUDevice& device, void operator()(OpKernelContext* ctx, const GPUDevice& device,
int64 output_size, int64 alg_tag_skip, Distribution dist, int64 output_size, int64 alg_tag_skip,
ScopedUnlockUnrefVar* not_used, Tensor* state_tensor, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
typename Distribution::ResultElementType* output_data); typename Distribution::ResultElementType* output_data);
}; };

View File

@ -53,8 +53,9 @@ __global__ void FillKernel(
template <typename Distribution> template <typename Distribution>
void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()( void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
OpKernelContext* ctx, const GPUDevice& d, int64 output_size, OpKernelContext* ctx, const GPUDevice& d, Distribution dist,
int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor, int64 output_size, int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used,
Tensor* state_tensor,
typename Distribution::ResultElementType* output_data) { typename Distribution::ResultElementType* output_data) {
OP_REQUIRES( OP_REQUIRES(
ctx, alg_tag_skip == 0, ctx, alg_tag_skip == 0,
@ -74,10 +75,9 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
int zero = 0; int zero = 0;
cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int)); cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int));
TF_CHECK_OK(CudaLaunchKernel(FillKernel<Distribution>, cfg.block_count, TF_CHECK_OK(CudaLaunchKernel(
cfg.thread_per_block, 0, d.stream(), FillKernel<Distribution>, cfg.block_count, cfg.thread_per_block, 0,
Distribution(), state_size, output_size, d.stream(), dist, state_size, output_size, state_data, output_data));
state_data, output_data));
} }
// Explicit instantiation of the GPU distributions functors. // Explicit instantiation of the GPU distributions functors.
@ -86,10 +86,28 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
// NVCC cannot handle ">>" properly // NVCC cannot handle ">>" properly
template struct UpdateVariableAndFill_Philox< template struct UpdateVariableAndFill_Philox<
GPUDevice, random::NormalDistribution<random::PhiloxRandom, Eigen::half> >; GPUDevice, random::NormalDistribution<random::PhiloxRandom, Eigen::half> >;
template struct UpdateVariableAndFill_Philox<
GPUDevice, random::NormalDistribution<random::PhiloxRandom, bfloat16> >;
template struct UpdateVariableAndFill_Philox< template struct UpdateVariableAndFill_Philox<
GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >; GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >;
template struct UpdateVariableAndFill_Philox< template struct UpdateVariableAndFill_Philox<
GPUDevice, random::NormalDistribution<random::PhiloxRandom, double> >; GPUDevice, random::NormalDistribution<random::PhiloxRandom, double> >;
template struct UpdateVariableAndFill_Philox<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, int32> >;
template struct UpdateVariableAndFill_Philox<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, int64> >;
template struct UpdateVariableAndFill_Philox<
GPUDevice, random::UniformFullIntDistribution<
random::PhiloxRandom, int32> >;
template struct UpdateVariableAndFill_Philox<
GPUDevice, random::UniformFullIntDistribution<
random::PhiloxRandom, int64> >;
template struct UpdateVariableAndFill_Philox<
GPUDevice, random::UniformFullIntDistribution<
random::PhiloxRandom, uint32> >;
template struct UpdateVariableAndFill_Philox<
GPUDevice, random::UniformFullIntDistribution<
random::PhiloxRandom, uint64> >;
// clang-format on // clang-format on
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -16,12 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
#define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
#include <string.h>
#define _USE_MATH_DEFINES #define _USE_MATH_DEFINES
#include <math.h> #include <math.h>
#include <cmath> #include <cmath>
#undef _USE_MATH_DEFINES #undef _USE_MATH_DEFINES
#include <string.h>
#include <algorithm> #include <algorithm>
#include <type_traits> #include <type_traits>
@ -236,6 +237,73 @@ class UniformDistribution<Generator, int64> {
uint64 range_; uint64 range_;
}; };
// Similar to `UniformDistribution`, except that instead of generating numbers
// in the range [low, high), it generates numbers covering the whole range of
// the integer type.
template <typename Generator, typename IntType>
class UniformFullIntDistribution;
template <typename Generator, typename IntType>
class UniformFullIntDistribution32 {
public:
// The number of elements that will be returned.
static const int kResultElementCount = Generator::kResultElementCount;
// Cost of generation of a single element (in cycles).
static const int kElementCost = 3;
// Indicate that this distribution may take variable number of samples
// during the runtime.
static const bool kVariableSamplesPerOutput = false;
typedef Array<IntType, kResultElementCount> ResultType;
typedef IntType ResultElementType;
PHILOX_DEVICE_INLINE
ResultType operator()(Generator* gen) {
typename Generator::ResultType sample = (*gen)();
ResultType result;
for (int i = 0; i < kResultElementCount; ++i) {
result[i] = sample[i];
}
return result;
}
};
template <typename Generator, typename IntType>
class UniformFullIntDistribution64 {
public:
// The number of elements that will be returned.
static const int kResultElementCount = Generator::kResultElementCount / 2;
// Cost of generation of a single element (in cycles).
static const int kElementCost = 3;
// Indicate that this distribution may take variable number of samples
// during the runtime.
static const bool kVariableSamplesPerOutput = false;
typedef Array<IntType, kResultElementCount> ResultType;
typedef IntType ResultElementType;
PHILOX_DEVICE_INLINE
ResultType operator()(Generator* gen) {
typename Generator::ResultType sample = (*gen)();
ResultType result;
for (int i = 0; i < kResultElementCount; ++i) {
result[i] = sample[2 * i] | static_cast<uint64>(sample[2 * i + 1]) << 32;
}
return result;
}
};
template <typename Generator>
class UniformFullIntDistribution<Generator, int32>
: public UniformFullIntDistribution32<Generator, int32> {};
template <typename Generator>
class UniformFullIntDistribution<Generator, uint32>
: public UniformFullIntDistribution32<Generator, uint32> {};
template <typename Generator>
class UniformFullIntDistribution<Generator, int64>
: public UniformFullIntDistribution64<Generator, int64> {};
template <typename Generator>
class UniformFullIntDistribution<Generator, uint64>
: public UniformFullIntDistribution64<Generator, uint64> {};
// A class that adapts the underlying native multiple samples to return a single // A class that adapts the underlying native multiple samples to return a single
// sample at a time. // sample at a time.
template <class Generator> template <class Generator>

View File

@ -2992,6 +2992,7 @@ cuda_py_test(
":client_testlib", ":client_testlib",
":logging_ops", ":logging_ops",
":random_ops_gen", ":random_ops_gen",
"//tensorflow/python/kernel_tests/random:util",
], ],
xla_enable_strict_auto_jit = True, xla_enable_strict_auto_jit = True,
) )

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import numpy as np import numpy as np
@ -70,3 +72,26 @@ def test_moment_matching(
total_variance))) total_variance)))
return z_test_scores return z_test_scores
def chi_squared(x, bins):
"""Pearson's Chi-squared test."""
x = np.ravel(x)
n = len(x)
histogram, _ = np.histogram(x, bins=bins, range=(0, 1))
expected = n / float(bins)
return np.sum(np.square(histogram - expected) / expected)
def normal_cdf(x):
"""Cumulative distribution function for a standard normal distribution."""
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
def anderson_darling(x):
"""Anderson-Darling test for a standard normal distribution."""
x = np.sort(np.ravel(x))
n = len(x)
i = np.linspace(1, n, n)
z = np.sum((2 * i - 1) * np.log(normal_cdf(x)) +
(2 * (n - i) + 1) * np.log(1 - normal_cdf(x)))
return -n - z / n

View File

@ -160,7 +160,19 @@ def _shape_tensor(shape):
class Generator(tracking.AutoTrackable): class Generator(tracking.AutoTrackable):
"""Random-number generator. """Random-number generator.
It uses Variable to manage its internal state. It uses Variable to manage its internal state, and allows choosing an
Random-Number-Generation (RNG) algorithm.
CPU and GPU with the same algorithm and seed will generate the same integer
random numbers. Float-point results (such as the output of `normal`) may have
small numerical discrepancies between CPU and GPU.
Because of different counter-reservation schemes, TPU's integer random numbers
will be different from CPU/GPU even with the same algorithm and seed.
Also, TPU uses different sampling algorithms for some distributions
(e.g. using reverse CDF for sampling normal distribution instead of
Box-Muller used by CPU/GPU). Harmonizing TPU's RNG behavior with CPU/GPU is
work in progress.
""" """
def __init__(self, copy_from=None, seed=None, algorithm=None): def __init__(self, copy_from=None, seed=None, algorithm=None):

View File

@ -23,9 +23,11 @@ import numpy as np
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.kernel_tests.random import util as \
random_test_util
from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import gen_stateful_random_ops from tensorflow.python.ops import gen_stateful_random_ops
from tensorflow.python.ops import logging_ops from tensorflow.python.ops import logging_ops
@ -189,33 +191,60 @@ class StatefulRandomOpsTest(test.TestCase):
compare(True, True) compare(True, True)
compare(True, False) compare(True, False)
def _sameAsOldRandomOps(self, device):
def compare(dtype, old, new):
seed1, seed2 = 79, 25
# note how the two seeds for the old op correspond to the seed for the new
# op
with ops.device(device):
gen = random.Generator(seed=[0, seed2, seed1])
# create a graph for the old op in order to call it many times
@def_function.function
def run_old():
with ops.device(device):
return old(dtype, seed1, seed2)
def run_new():
with ops.device(device):
return new(dtype, gen)
for _ in range(100):
self.assertAllEqual(run_old(), run_new())
shape = constant_op.constant([4, 7])
minval = 128
maxval = 256
# passing `dtype` around to compress go/gpylint-faq#cell-var-from-loop and
# go/gpylint-faq#undefined-loop-variable
def old_normal(dtype, seed1, seed2):
return gen_random_ops.random_standard_normal(
shape, dtype=dtype, seed=seed1, seed2=seed2)
def new_normal(dtype, gen):
return gen._standard_normal(shape, dtype=dtype)
def old_uniform(dtype, seed1, seed2):
minval2 = constant_op.constant(minval, dtype=dtype)
maxval2 = constant_op.constant(maxval, dtype=dtype)
return gen_random_ops.random_uniform_int(
shape, minval=minval2, maxval=maxval2, seed=seed1, seed2=seed2)
def new_uniform(dtype, gen):
return gen.uniform(
shape, minval=minval, maxval=maxval, dtype=dtype)
for dtype in (dtypes.float16, dtypes.bfloat16, dtypes.float32,
dtypes.float64):
compare(dtype, old_normal, new_normal)
for dtype in [dtypes.int32, dtypes.int64]:
compare(dtype, old_uniform, new_uniform)
@test_util.run_v2_only @test_util.run_v2_only
def testCPUSameAsOldRandomOps(self): def testCPUSameAsOldRandomOps(self):
"""Tests that the generated numbers are the same as the old random_ops.py. """Tests that the generated numbers are the same as the old random_ops.py.
The CPU version. The CPU version.
""" """
seed1, seed2 = 79, 25 self._sameAsOldRandomOps("/device:CPU:0")
# note how the two seeds for the old op correspond to the seed for the new
# op
with ops.device("/device:CPU:0"):
random.reset_global_generator([0, seed2, seed1])
shape = constant_op.constant([4, 7])
dtype = dtypes.float64
# create a graph for the old op in order to call it many times
@def_function.function
def old():
with ops.device("/device:CPU:0"):
return gen_random_ops.random_standard_normal(
shape, dtype=dtype, seed=seed1, seed2=seed2)
def new():
with ops.device("/device:CPU:0"):
return random.get_global_generator().normal(shape, dtype=dtype)
for _ in range(100):
self.assertAllEqual(old(), new())
@test_util.run_v2_only @test_util.run_v2_only
@test_util.run_cuda_only @test_util.run_cuda_only
@ -224,28 +253,103 @@ class StatefulRandomOpsTest(test.TestCase):
The GPU version. The GPU version.
""" """
seed1, seed2 = 79, 25 self._sameAsOldRandomOps(test_util.gpu_device_name())
with ops.device(test_util.gpu_device_name()):
random.reset_global_generator([0, seed2, seed1])
shape = constant_op.constant([4, 7])
dtype = dtypes.float64
@def_function.function @test_util.run_v2_only
def old(): def testUniformIntIsInRange(self):
with ops.device(test_util.gpu_device_name()): minval = 2
return gen_random_ops.random_standard_normal( maxval = 33
shape, dtype=dtype, seed=seed1, seed2=seed2) size = 1000
gen = random.Generator(seed=1234)
for dtype in [dtypes.int32, dtypes.int64]:
x = gen.uniform(
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
self.assertTrue(np.all(x >= minval))
self.assertTrue(np.all(x < maxval))
def new(): @test_util.run_v2_only
with ops.device(test_util.gpu_device_name()): def testNormalIsFinite(self):
return random.get_global_generator().normal(shape, dtype=dtype) gen = random.Generator(seed=1234)
for dtype in [dtypes.float32]:
x = gen.normal(shape=[10000], dtype=dtype).numpy()
self.assertTrue(np.all(np.isfinite(x)))
for _ in range(100): @test_util.run_v2_only
self.assertAllEqual(old(), new()) def testDistributionOfUniform(self):
"""Use Pearson's Chi-squared test to test for uniformity."""
n = 1000
seed = 12
for dtype in [dtypes.int32, dtypes.int64]:
gen = random.Generator(seed=seed)
maxval = 1
if dtype.is_integer:
maxval = 100
x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy()
if maxval > 1:
# Normalize y to range [0, 1).
x = x.astype(float) / maxval
# Tests that the values are distributed amongst 10 bins with equal
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
# p=0.05. This test is probabilistic and would be flaky if the random
# seed were not fixed.
val = random_test_util.chi_squared(x, 10)
self.assertLess(val, 16.92)
@test_util.run_v2_only
def testDistributionOfNormal(self):
"""Use Anderson-Darling test to test distribution appears normal."""
n = 1000
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
gen = random.Generator(seed=1234)
x = gen.normal(shape=[n], dtype=dtype).numpy()
# The constant 2.492 is the 5% critical value for the Anderson-Darling
# test where the mean and variance are known. This test is probabilistic
# so to avoid flakiness the seed is fixed.
self.assertLess(
random_test_util.anderson_darling(x.astype(float)), 2.492)
@test_util.run_v2_only
def testErrors(self):
"""Tests that proper errors are raised.
"""
shape = [2, 3]
gen = random.Generator(seed=1234)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
r"algorithm must be of shape \[\], not"):
gen_stateful_random_ops.stateful_standard_normal_v2(
gen.state.handle, [0, 0], shape)
with self.assertRaisesWithPredicateMatch(
TypeError, "Requested dtype: int64"):
gen_stateful_random_ops.stateful_standard_normal_v2(
gen.state.handle, 1.1, shape)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"Unsupported algorithm id"):
gen_stateful_random_ops.stateful_standard_normal_v2(
gen.state.handle, 123, shape)
var = variables.Variable([0, 0], dtype=dtypes.int32)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"dtype of RNG state variable must be int64, not"):
gen_stateful_random_ops.stateful_standard_normal_v2(
var.handle, random.RNG_ALG_PHILOX, shape)
var = variables.Variable([[0]], dtype=dtypes.int64)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"RNG state must have one and only one dimension, not"):
gen_stateful_random_ops.stateful_standard_normal_v2(
var.handle, random.RNG_ALG_PHILOX, shape)
var = variables.Variable([0], dtype=dtypes.int64)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"For the Philox algorithm, the size of state must be at least"):
gen_stateful_random_ops.stateful_standard_normal_v2(
var.handle, random.RNG_ALG_PHILOX, shape)
@test_util.run_v2_only @test_util.run_v2_only
def testStatefulStandardNormal(self): def testStatefulStandardNormal(self):
"""Tests that op 'StatefulStandardNormal' still works. """Tests that the deprecated op 'StatefulStandardNormal' still works.
""" """
shape = constant_op.constant([4, 7]) shape = constant_op.constant([4, 7])
dtype = dtypes.float64 dtype = dtypes.float64
@ -277,7 +381,7 @@ class StatefulRandomOpsTest(test.TestCase):
random.reset_global_generator(50) random.reset_global_generator(50)
with self.assertRaisesWithPredicateMatch( with self.assertRaisesWithPredicateMatch(
errors_impl.NotFoundError, "Resource .+ does not exist"): errors.NotFoundError, "Resource .+ does not exist"):
a = f() a = f()
random.reset_global_generator(50) random.reset_global_generator(50)
b = f() b = f()