From 949474a44831bc28bac8a97bdafd1f7dc6589a97 Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Mon, 18 Mar 2019 18:28:46 -0700 Subject: [PATCH] A continuing partial implementation of RFC "Random numbers in TensorFlow 2.0" (https://github.com/tensorflow/community/blob/master/rfcs/20181217-tf2-random-numbers.md): In this change: - CPU and GPU kernels for op 'StatefulRandomInt' and 'StatefulRandomFullInt'. To be done: - ops for other distributions; - other RNG algorithms; - batch seeds; - initializers ('RandomUniform', etc.); PiperOrigin-RevId: 239104292 --- tensorflow/compiler/tests/BUILD | 2 + .../tests/stateful_random_ops_test.py | 30 +- .../tests/stateless_random_ops_test.py | 32 +- tensorflow/contrib/makefile/tf_op_files.txt | 1 + tensorflow/core/kernels/BUILD | 1 + tensorflow/core/kernels/random_op.cc | 275 +-------------- tensorflow/core/kernels/random_op_cpu.h | 325 ++++++++++++++++++ tensorflow/core/kernels/random_op_gpu.cu.cc | 33 +- tensorflow/core/kernels/random_op_gpu.h | 41 ++- .../core/kernels/stateful_random_ops.cc | 156 +++++++-- .../kernels/stateful_random_ops_cpu_gpu.h | 2 +- .../kernels/stateful_random_ops_gpu.cu.cc | 30 +- .../core/lib/random/random_distributions.h | 70 +++- tensorflow/python/BUILD | 1 + tensorflow/python/kernel_tests/random/util.py | 25 ++ tensorflow/python/ops/stateful_random_ops.py | 14 +- .../python/ops/stateful_random_ops_test.py | 182 +++++++--- 17 files changed, 794 insertions(+), 426 deletions(-) create mode 100644 tensorflow/core/kernels/random_op_cpu.h diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 0c3adb0bcf9..9f63f323af9 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -873,6 +873,7 @@ tf_xla_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:standard_ops", "//tensorflow/python:stateful_random_ops", + "//tensorflow/python/kernel_tests/random:util", ], ) @@ -887,6 +888,7 @@ tf_xla_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:standard_ops", "//tensorflow/python:stateless_random_ops", + "//tensorflow/python/kernel_tests/random:util", ], ) diff --git a/tensorflow/compiler/tests/stateful_random_ops_test.py b/tensorflow/compiler/tests/stateful_random_ops_test.py index f0535579bf2..fd1f69789ae 100644 --- a/tensorflow/compiler/tests/stateful_random_ops_test.py +++ b/tensorflow/compiler/tests/stateful_random_ops_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as np from tensorflow.compiler.tests import xla_test @@ -29,6 +27,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.kernel_tests.random import util as \ +random_test_util from tensorflow.python.ops import gen_stateful_random_ops from tensorflow.python.ops import stateful_random_ops as \ random @@ -181,14 +181,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): x = gen.normal(shape=[10000], dtype=dtype).numpy() self.assertTrue(np.all(np.isfinite(x))) - def _chi_squared(self, x, bins): - """Pearson's Chi-squared test.""" - x = np.ravel(x) - n = len(x) - histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) - expected = n / float(bins) - return np.sum(np.square(histogram - expected) / expected) - @test_util.run_v2_only def testDistributionOfUniform(self): """Use Pearson's Chi-squared test to test for uniformity.""" @@ -208,22 +200,9 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with # p=0.05. This test is probabilistic and would be flaky if the random # seed were not fixed. - val = self._chi_squared(x, 10) + val = random_test_util.chi_squared(x, 10) self.assertLess(val, 16.92) - def _normal_cdf(self, x): - """Cumulative distribution function for a standard normal distribution.""" - return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) - - def _anderson_darling(self, x): - """Anderson-Darling test for a standard normal distribution.""" - x = np.sort(np.ravel(x)) - n = len(x) - i = np.linspace(1, n, n) - z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) + - (2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x))) - return -n - z / n - @test_util.run_v2_only def testDistributionOfNormal(self): """Use Anderson-Darling test to test distribution appears normal.""" @@ -235,7 +214,8 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): # The constant 2.492 is the 5% critical value for the Anderson-Darling # test where the mean and variance are known. This test is probabilistic # so to avoid flakiness the seed is fixed. - self.assertLess(self._anderson_darling(x.astype(float)), 2.492) + self.assertLess( + random_test_util.anderson_darling(x.astype(float)), 2.492) @test_util.run_v2_only def testErrors(self): diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index df5914a518e..3fb3176ee00 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -24,6 +24,8 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes +from tensorflow.python.kernel_tests.random import util as \ +random_test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import stateless_random_ops as stateless from tensorflow.python.ops.distributions import special_math @@ -43,7 +45,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # Stateless values should be equal iff the seeds are equal (roughly) with self.cached_session(), self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension for stateless_op in [ stateless.stateless_random_uniform, stateless.stateless_random_normal ]: @@ -75,14 +77,6 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertTrue(np.all(y >= 0)) self.assertTrue(np.all(y < maxval)) - def _chi_squared(self, x, bins): - """Pearson's Chi-squared test.""" - x = np.ravel(x) - n = len(x) - histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) - expected = n / float(bins) - return np.sum(np.square(histogram - expected) / expected) - def testDistributionOfStatelessRandomUniform(self): """Use Pearson's Chi-squared test to test for uniformity.""" with self.cached_session() as sess, self.test_scope(): @@ -102,7 +96,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with # p=0.05. This test is probabilistic and would be flaky if the random # seed were not fixed. - self.assertTrue(self._chi_squared(y, 10) < 16.92) + self.assertLess(random_test_util.chi_squared(y, 10), 16.92) def testRandomNormalIsFinite(self): with self.cached_session() as sess, self.test_scope(): @@ -113,19 +107,6 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) self.assertTrue(np.all(np.isfinite(y))) - def _normal_cdf(self, x): - """Cumulative distribution function for a standard normal distribution.""" - return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) - - def _anderson_darling(self, x): - """Anderson-Darling test for a standard normal distribution.""" - x = np.sort(np.ravel(x)) - n = len(x) - i = np.linspace(1, n, n) - z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) + - (2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x))) - return -n - z / n - def testDistributionOfStatelessRandomNormal(self): """Use Anderson-Darling test to test distribution appears normal.""" with self.cached_session() as sess, self.test_scope(): @@ -138,7 +119,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # The constant 2.492 is the 5% critical value for the Anderson-Darling # test where the mean and variance are known. This test is probabilistic # so to avoid flakiness the seed is fixed. - self.assertTrue(self._anderson_darling(y.astype(float)) < 2.492) + self.assertLess( + random_test_util.anderson_darling(y.astype(float)), 2.492) def testTruncatedNormalIsInRange(self): for dtype in self._random_types(): @@ -155,7 +137,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def normal_pdf(x): return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) - def probit(x, sess=sess): + def probit(x): return self.evaluate(special_math.ndtri(x)) a = -2. diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index ea5f5913c66..8afd3c1a0a6 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -257,6 +257,7 @@ tensorflow/core/kernels/split_op.cc tensorflow/core/kernels/split_v_op.cc tensorflow/core/kernels/stack.cc tensorflow/core/kernels/stack_ops.cc +tensorflow/core/kernels/stateful_random_ops.cc tensorflow/core/kernels/stateless_random_ops.cc tensorflow/core/kernels/strided_slice_op.cc tensorflow/core/kernels/strided_slice_op_inst_0.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ab9d3e11607..624cd7236a8 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5784,6 +5784,7 @@ filegroup( "queue_op.cc", "queue_ops.cc", "random_op.cc", + "random_op_cpu.h", "reduction_ops_all.cc", "reduction_ops_any.cc", "reduction_ops_common.cc", diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 3810d817ca9..996950b65f3 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -17,8 +17,6 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow/core/kernels/random_op.h" - #include #include #include @@ -27,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/random_op_cpu.h" #include "tensorflow/core/lib/hash/crc32c.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/simple_philox.h" @@ -52,131 +51,6 @@ typedef Eigen::GpuDevice GPUDevice; typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL -namespace functor { -using random::PhiloxRandom; -using random::SingleSampleAdapter; - -// The default implementation of the functor, which should never be invoked -// But we still need to provide implementation for now for the linker to work, -// since we do not support all the distributions yet. -template -struct FillPhiloxRandom { - typedef typename Distribution::ResultElementType T; - void operator()(OpKernelContext*, const Device&, random::PhiloxRandom gen, - T* data, int64 size, Distribution dist) { - LOG(FATAL) << "Default FillPhiloxRandom should not be executed."; - } -}; - -// A class to fill a specified range of random groups -template -struct FillPhiloxRandomTask; - -// Specialization for distribution that takes a fixed number of samples for -// each output. -template -struct FillPhiloxRandomTask { - typedef typename Distribution::ResultElementType T; - static void Run(random::PhiloxRandom gen, T* data, int64 size, - int64 start_group, int64 limit_group, Distribution dist) { - const int kGroupSize = Distribution::kResultElementCount; - - gen.Skip(start_group); - int64 offset = start_group * kGroupSize; - - // First fill all the full-size groups - int64 limit_group_full = std::min(limit_group, size / kGroupSize); - for (int64 index = start_group; index < limit_group_full; ++index) { - auto samples = dist(&gen); - std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); - offset += kGroupSize; - } - - // If there are any remaining elements that need to be filled, process them - if (limit_group_full < limit_group) { - int64 remaining_size = size - limit_group_full * kGroupSize; - auto samples = dist(&gen); - std::copy(&samples[0], &samples[0] + remaining_size, data + offset); - } - } -}; - -// Specialization for distribution that takes a variable number of samples for -// each output. This will be slower due to the generality. -template -struct FillPhiloxRandomTask { - typedef typename Distribution::ResultElementType T; - static const int64 kReservedSamplesPerOutput = 256; - - static void Run(random::PhiloxRandom base_gen, T* data, int64 size, - int64 start_group, int64 limit_group, Distribution dist) { - const int kGroupSize = Distribution::kResultElementCount; - - static const int kGeneratorSkipPerOutputGroup = - kGroupSize * kReservedSamplesPerOutput / - PhiloxRandom::kResultElementCount; - - int64 offset = start_group * kGroupSize; - - // First fill all the full-size groups - int64 limit_group_full = std::min(limit_group, size / kGroupSize); - int64 group_index; - for (group_index = start_group; group_index < limit_group_full; - ++group_index) { - // Reset the generator to the beginning of the output group region - // This is necessary if we want the results to be independent of order - // of work - PhiloxRandom gen = base_gen; - gen.Skip(group_index * kGeneratorSkipPerOutputGroup); - SingleSampleAdapter single_samples(&gen); - - auto samples = dist(&single_samples); - std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); - offset += kGroupSize; - } - - // If there are any remaining elements that need to be filled, process them - if (limit_group_full < limit_group) { - PhiloxRandom gen = base_gen; - gen.Skip(group_index * kGeneratorSkipPerOutputGroup); - SingleSampleAdapter single_samples(&gen); - - int64 remaining_size = size - limit_group_full * kGroupSize; - auto samples = dist(&single_samples); - std::copy(&samples[0], &samples[0] + remaining_size, data + offset); - } - } -}; - -// Partial specialization for CPU to fill the entire region with randoms -// It splits the work into several tasks and run them in parallel -template -void FillPhiloxRandom::operator()( - OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen, - typename Distribution::ResultElementType* data, int64 size, - Distribution dist) { - const int kGroupSize = Distribution::kResultElementCount; - - auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); - - int64 total_group_count = (size + kGroupSize - 1) / kGroupSize; - - const int kGroupCost = - random::PhiloxRandom::kResultElementCount * - (random::PhiloxRandom::kElementCost + Distribution::kElementCost); - Shard(worker_threads.num_threads, worker_threads.workers, total_group_count, - kGroupCost, - [&gen, data, size, dist](int64 start_group, int64 limit_group) { - FillPhiloxRandomTask< - Distribution, - Distribution::kVariableSamplesPerOutput>::Run(gen, data, size, - start_group, - limit_group, dist); - }); -} - -} // namespace functor - namespace { static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape, @@ -354,7 +228,7 @@ class RandomGammaOp : public OpKernel { const double alpha = static_cast(alpha_flat[alpha_idx]); DISABLE_FLOAT_EQUALITY_WARNING - if (alpha == double(1.0)) { + if (alpha == static_cast(1.0)) { ENABLE_FLOAT_EQUALITY_WARNING // Sample from an exponential distribution. for (int64 sample_idx = output_idx % num_samples; @@ -364,7 +238,7 @@ class RandomGammaOp : public OpKernel { // (including eventually on GPU), we skip on a per-sample basis. PhiloxRandom gen = rng; gen.Skip(kReservedSamplesPerOutput * output_idx); - short uniform_remaining = 0; + int16 uniform_remaining = 0; UNIFORM(u); const double res = -log(1.0 - u); samples_alpha_offset[sample_idx * num_alphas] = static_cast(res); @@ -392,8 +266,8 @@ class RandomGammaOp : public OpKernel { // (including eventually on GPU), we skip on a per-sample basis. PhiloxRandom gen = rng; gen.Skip(kReservedSamplesPerOutput * output_idx); - short norm_remaining = 0; - short uniform_remaining = 0; + int16 norm_remaining = 0; + int16 uniform_remaining = 0; // Keep trying until we don't reject a sample. In practice, we will // only reject ~5% at worst, for low alpha near 1. @@ -565,145 +439,6 @@ TF_CALL_int64(REGISTER_INT); #ifdef TENSORFLOW_USE_SYCL -namespace functor { - -using namespace cl; - -template -struct FillPhiloxRandomKernel; - -template -struct FillPhiloxRandomKernel { - typedef typename Distribution::ResultElementType T; - using write_accessor = sycl::accessor; - - FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, - Distribution& dist) - : data_(data), gen_(gen), dist_(dist) {} - - void operator()(sycl::nd_item<1> item) { - const size_t kGroupSize = Distribution::kResultElementCount; - - const size_t item_id = item.get_global(0); - const size_t total_item_count = item.get_global_range(); - size_t offset = item_id * kGroupSize; - gen_.Skip(item_id); - - const size_t size = data_.get_size() / sizeof(T); - T* data = ConvertToActualTypeSycl(T, data_); - - while (offset + kGroupSize <= size) { - const typename Distribution::ResultType samples = dist_(&gen_); - for (size_t i = 0; i < kGroupSize; ++i) { - data[offset + i] = samples[i]; - } - - offset += (total_item_count - 1) * kGroupSize; - gen_.Skip(total_item_count - 1); - } - - const typename Distribution::ResultType samples = dist_(&gen_); - for (size_t i = 0; i < kGroupSize; ++i) { - if (offset >= size) { - return; - } - data[offset] = samples[i]; - ++offset; - } - } - - private: - write_accessor data_; - random::PhiloxRandom gen_; - Distribution dist_; -}; - -template -struct FillPhiloxRandomKernel { - typedef typename Distribution::ResultElementType T; - using write_accessor = sycl::accessor; - - FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, - Distribution& dist) - : data_(data), gen_(gen), dist_(dist) {} - - void operator()(sycl::nd_item<1> item) { - using random::PhiloxRandom; - using random::SingleSampleAdapter; - - const size_t kReservedSamplesPerOutput = 256; - const size_t kGroupSize = Distribution::kResultElementCount; - const size_t kGeneratorSkipPerOutputGroup = - kGroupSize * kReservedSamplesPerOutput / - PhiloxRandom::kResultElementCount; - - const size_t item_id = item.get_global(0); - const size_t total_item_count = item.get_global_range(); - size_t group_index = item_id; - size_t offset = group_index * kGroupSize; - - T* data = ConvertToActualTypeSycl(T, data_); - const size_t size = data_.get_size() / sizeof(T); - - while (offset < size) { - // Since each output takes a variable number of samples, we need to - // realign the generator to the beginning for the current output group - PhiloxRandom gen = gen_; - gen.Skip(group_index * kGeneratorSkipPerOutputGroup); - SingleSampleAdapter single_samples(&gen); - - const typename Distribution::ResultType samples = dist_(&single_samples); - - for (size_t i = 0; i < kGroupSize; ++i) { - if (offset >= size) { - return; - } - data[offset] = samples[i]; - ++offset; - } - - offset += (total_item_count - 1) * kGroupSize; - group_index += total_item_count; - } - } - - private: - write_accessor data_; - random::PhiloxRandom gen_; - Distribution dist_; -}; - -template -class FillRandomKernel; -// Partial specialization for SYCL to fill the entire region with randoms -// It splits the work into several tasks and run them in parallel -template -void FillPhiloxRandom::operator()( - OpKernelContext* context, const SYCLDevice& device, - random::PhiloxRandom gen, typename Distribution::ResultElementType* data, - int64 size, Distribution dist) { - const size_t group_size = device.maxSyclThreadsPerBlock(); - const size_t group_count = (size + group_size - 1) / group_size; - - auto buffer = device.get_sycl_buffer(data); - - device.sycl_queue().submit([&](sycl::handler& cgh) { - auto access = buffer.template get_access(cgh); - - FillPhiloxRandomKernel - task(access, gen, dist); - cgh.parallel_for>( - sycl::nd_range<1>(sycl::range<1>(group_count * group_size), - sycl::range<1>(group_size)), - task); - }); -} - -} // namespace functor - #define REGISTER(TYPE) \ template struct functor::FillPhiloxRandom< \ SYCLDevice, random::UniformDistribution>; \ diff --git a/tensorflow/core/kernels/random_op_cpu.h b/tensorflow/core/kernels/random_op_cpu.h new file mode 100644 index 00000000000..45561369957 --- /dev/null +++ b/tensorflow/core/kernels/random_op_cpu.h @@ -0,0 +1,325 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/random_op.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow/core/util/work_sharder.h" + +#if EIGEN_COMP_GNUC && __cplusplus > 199711L +#define DISABLE_FLOAT_EQUALITY_WARNING \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") +#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") +#else +#define DISABLE_FLOAT_EQUALITY_WARNING +#define ENABLE_FLOAT_EQUALITY_WARNING +#endif + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + +namespace functor { +using random::PhiloxRandom; +using random::SingleSampleAdapter; + +// The default implementation of the functor, which should never be invoked +// But we still need to provide implementation for now for the linker to work, +// since we do not support all the distributions yet. +template +struct FillPhiloxRandom { + typedef typename Distribution::ResultElementType T; + void operator()(OpKernelContext* ctx, const Device&, random::PhiloxRandom gen, + T* data, int64 size, Distribution dist) { + OP_REQUIRES( + ctx, false, + errors::Internal("Default FillPhiloxRandom should not be executed.")); + } +}; + +// A class to fill a specified range of random groups +template +struct FillPhiloxRandomTask; + +// Specialization for distribution that takes a fixed number of samples for +// each output. +template +struct FillPhiloxRandomTask { + typedef typename Distribution::ResultElementType T; + static void Run(random::PhiloxRandom gen, T* data, int64 size, + int64 start_group, int64 limit_group, Distribution dist) { + const int kGroupSize = Distribution::kResultElementCount; + + gen.Skip(start_group); + int64 offset = start_group * kGroupSize; + + // First fill all the full-size groups + int64 limit_group_full = std::min(limit_group, size / kGroupSize); + for (int64 index = start_group; index < limit_group_full; ++index) { + auto samples = dist(&gen); + std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); + offset += kGroupSize; + } + + // If there are any remaining elements that need to be filled, process them + if (limit_group_full < limit_group) { + int64 remaining_size = size - limit_group_full * kGroupSize; + auto samples = dist(&gen); + std::copy(&samples[0], &samples[0] + remaining_size, data + offset); + } + } +}; + +// Specialization for distribution that takes a variable number of samples for +// each output. This will be slower due to the generality. +template +struct FillPhiloxRandomTask { + typedef typename Distribution::ResultElementType T; + static const int64 kReservedSamplesPerOutput = 256; + + static void Run(random::PhiloxRandom base_gen, T* data, int64 size, + int64 start_group, int64 limit_group, Distribution dist) { + const int kGroupSize = Distribution::kResultElementCount; + + static const int kGeneratorSkipPerOutputGroup = + kGroupSize * kReservedSamplesPerOutput / + PhiloxRandom::kResultElementCount; + + int64 offset = start_group * kGroupSize; + + // First fill all the full-size groups + int64 limit_group_full = std::min(limit_group, size / kGroupSize); + int64 group_index; + for (group_index = start_group; group_index < limit_group_full; + ++group_index) { + // Reset the generator to the beginning of the output group region + // This is necessary if we want the results to be independent of order + // of work + PhiloxRandom gen = base_gen; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + auto samples = dist(&single_samples); + std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); + offset += kGroupSize; + } + + // If there are any remaining elements that need to be filled, process them + if (limit_group_full < limit_group) { + PhiloxRandom gen = base_gen; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + int64 remaining_size = size - limit_group_full * kGroupSize; + auto samples = dist(&single_samples); + std::copy(&samples[0], &samples[0] + remaining_size, data + offset); + } + } +}; + +// Partial specialization for CPU to fill the entire region with randoms +// It splits the work into several tasks and run them in parallel +template +void FillPhiloxRandom::operator()( + OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen, + typename Distribution::ResultElementType* data, int64 size, + Distribution dist) { + const int kGroupSize = Distribution::kResultElementCount; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + int64 total_group_count = (size + kGroupSize - 1) / kGroupSize; + + const int kGroupCost = + random::PhiloxRandom::kResultElementCount * + (random::PhiloxRandom::kElementCost + Distribution::kElementCost); + Shard(worker_threads.num_threads, worker_threads.workers, total_group_count, + kGroupCost, + [&gen, data, size, dist](int64 start_group, int64 limit_group) { + FillPhiloxRandomTask< + Distribution, + Distribution::kVariableSamplesPerOutput>::Run(gen, data, size, + start_group, + limit_group, dist); + }); +} + +} // namespace functor + +#ifdef TENSORFLOW_USE_SYCL + +namespace functor { + +template +struct FillPhiloxRandomKernel; + +template +struct FillPhiloxRandomKernel { + typedef typename Distribution::ResultElementType T; + using write_accessor = sycl::accessor; + + FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, + Distribution& dist) + : data_(data), gen_(gen), dist_(dist) {} + + void operator()(sycl::nd_item<1> item) { + const size_t kGroupSize = Distribution::kResultElementCount; + + const size_t item_id = item.get_global(0); + const size_t total_item_count = item.get_global_range(); + size_t offset = item_id * kGroupSize; + gen_.Skip(item_id); + + const size_t size = data_.get_size() / sizeof(T); + T* data = ConvertToActualTypeSycl(T, data_); + + while (offset + kGroupSize <= size) { + const typename Distribution::ResultType samples = dist_(&gen_); + for (size_t i = 0; i < kGroupSize; ++i) { + data[offset + i] = samples[i]; + } + + offset += (total_item_count - 1) * kGroupSize; + gen_.Skip(total_item_count - 1); + } + + const typename Distribution::ResultType samples = dist_(&gen_); + for (size_t i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } + } + + private: + write_accessor data_; + random::PhiloxRandom gen_; + Distribution dist_; +}; + +template +struct FillPhiloxRandomKernel { + typedef typename Distribution::ResultElementType T; + using write_accessor = sycl::accessor; + + FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, + Distribution& dist) + : data_(data), gen_(gen), dist_(dist) {} + + void operator()(sycl::nd_item<1> item) { + using random::PhiloxRandom; + using random::SingleSampleAdapter; + + const size_t kReservedSamplesPerOutput = 256; + const size_t kGroupSize = Distribution::kResultElementCount; + const size_t kGeneratorSkipPerOutputGroup = + kGroupSize * kReservedSamplesPerOutput / + PhiloxRandom::kResultElementCount; + + const size_t item_id = item.get_global(0); + const size_t total_item_count = item.get_global_range(); + size_t group_index = item_id; + size_t offset = group_index * kGroupSize; + + T* data = ConvertToActualTypeSycl(T, data_); + const size_t size = data_.get_size() / sizeof(T); + + while (offset < size) { + // Since each output takes a variable number of samples, we need to + // realign the generator to the beginning for the current output group + PhiloxRandom gen = gen_; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + const typename Distribution::ResultType samples = dist_(&single_samples); + + for (size_t i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } + + offset += (total_item_count - 1) * kGroupSize; + group_index += total_item_count; + } + } + + private: + write_accessor data_; + random::PhiloxRandom gen_; + Distribution dist_; +}; + +template +class FillRandomKernel; +// Partial specialization for SYCL to fill the entire region with randoms +// It splits the work into several tasks and run them in parallel +template +void FillPhiloxRandom::operator()( + OpKernelContext* context, const SYCLDevice& device, + random::PhiloxRandom gen, typename Distribution::ResultElementType* data, + int64 size, Distribution dist) { + const size_t group_size = device.maxSyclThreadsPerBlock(); + const size_t group_count = (size + group_size - 1) / group_size; + + auto buffer = device.get_sycl_buffer(data); + + device.sycl_queue().submit([&](sycl::handler& cgh) { + auto access = buffer.template get_access(cgh); + + FillPhiloxRandomKernel + task(access, gen, dist); + cgh.parallel_for>( + sycl::nd_range<1>(sycl::range<1>(group_count * group_size), + sycl::range<1>(group_size)), + task); + }); +} + +} // namespace functor + +#endif // TENSORFLOW_USE_SYCL + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ diff --git a/tensorflow/core/kernels/random_op_gpu.cu.cc b/tensorflow/core/kernels/random_op_gpu.cu.cc index 55278d0480e..9c3db8742ba 100644 --- a/tensorflow/core/kernels/random_op_gpu.cu.cc +++ b/tensorflow/core/kernels/random_op_gpu.cu.cc @@ -17,17 +17,15 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/random_op.h" -#include "tensorflow/core/kernels/random_op_gpu.h" - #include #include +#include "tensorflow/core/kernels/random_op_gpu.h" + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { @@ -37,33 +35,6 @@ namespace functor { typedef Eigen::GpuDevice GPUDevice; -// A simple launch pad to call the correct function templates to fill the data -template -__global__ void __launch_bounds__(1024) - FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen, - typename Distribution::ResultElementType* data, - int64 size, Distribution dist) { - FillPhiloxRandomKernel() - .Run(base_gen, data, size, dist); -} - -// Partial specialization for GPU -template -void FillPhiloxRandom::operator()( - OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen, - typename Distribution::ResultElementType* data, int64 size, - Distribution dist) { - const int32 block_size = d.maxGpuThreadsPerBlock(); - const int32 num_blocks = - (d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) / - block_size; - - TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch, - num_blocks, block_size, 0, d.stream(), gen, data, - size, dist)); -} - // Explicit instantiation of the GPU distributions functors // clang-format off // NVCC cannot handle ">>" properly diff --git a/tensorflow/core/kernels/random_op_gpu.h b/tensorflow/core/kernels/random_op_gpu.h index e32c755d782..cd20d95b634 100644 --- a/tensorflow/core/kernels/random_op_gpu.h +++ b/tensorflow/core/kernels/random_op_gpu.h @@ -18,8 +18,10 @@ limitations under the License. #if defined(__CUDACC__) +#include "tensorflow/core/kernels/random_op.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { @@ -31,15 +33,15 @@ struct FillPhiloxRandomKernel; template struct FillPhiloxRandomKernel { typedef typename Distribution::ResultElementType T; - PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size, - Distribution dist); + PHILOX_DEVICE_INLINE void Run(random::PhiloxRandom gen, T* data, int64 size, + Distribution dist); }; template struct FillPhiloxRandomKernel { typedef typename Distribution::ResultElementType T; - PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data, - int64 size, Distribution dist); + PHILOX_DEVICE_INLINE void Run(const random::PhiloxRandom& base_gen, T* data, + int64 size, Distribution dist); }; template @@ -128,7 +130,7 @@ class SampleCopier { // A cuda kernel to fill the data with random numbers from the specified // distribution. Each output takes a fixed number of samples. template -PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel::Run( +PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel::Run( random::PhiloxRandom gen, T* data, int64 size, Distribution dist) { const int kGroupSize = Distribution::kResultElementCount; @@ -159,7 +161,7 @@ PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel::Run( // A cuda kernel to fill the data with random numbers from the specified // distribution. Each output takes a variable number of samples. template -PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel::Run( +PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel::Run( const random::PhiloxRandom& base_gen, T* data, int64 size, Distribution dist) { using random::PhiloxRandom; @@ -198,6 +200,33 @@ PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel::Run( } } +// A simple launch pad to call the correct function templates to fill the data +template +__global__ void __launch_bounds__(1024) + FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen, + typename Distribution::ResultElementType* data, + int64 size, Distribution dist) { + FillPhiloxRandomKernel() + .Run(base_gen, data, size, dist); +} + +// Partial specialization for GPU +template +void FillPhiloxRandom::operator()( + OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen, + typename Distribution::ResultElementType* data, int64 size, + Distribution dist) { + const int32 block_size = d.maxGpuThreadsPerBlock(); + const int32 num_blocks = + (d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) / + block_size; + + TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch, + num_blocks, block_size, 0, d.stream(), gen, data, + size, dist)); +} + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/stateful_random_ops.cc b/tensorflow/core/kernels/stateful_random_ops.cc index 1312593d2a5..e2e2b0fcac2 100644 --- a/tensorflow/core/kernels/stateful_random_ops.cc +++ b/tensorflow/core/kernels/stateful_random_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow/core/kernels/random_op.h" +#include "tensorflow/core/kernels/random_op_cpu.h" #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" #include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/lib/random/random.h" @@ -25,7 +25,7 @@ namespace tensorflow { template struct UpdateVariableAndFill_Philox { void operator()(OpKernelContext* ctx, const CPUDevice& device, - int64 output_size, int64 alg_tag_skip, + Distribution dist, int64 output_size, int64 alg_tag_skip, ScopedUnlockUnrefVar* state_var_guard, Tensor* state_tensor, typename Distribution::ResultElementType* output_data) { auto state_tensor_flat = state_tensor->flat(); @@ -36,14 +36,14 @@ struct UpdateVariableAndFill_Philox { // No longer needs the lock. state_var_guard->Release(); functor::FillPhiloxRandom()( - ctx, device, philox, output_data, output_size, Distribution()); + ctx, device, philox, output_data, output_size, dist); } }; template Status UpdateVariableAndFill( - OpKernelContext* ctx, int state_input_idx, bool read_alg_from_state, - Algorithm alg, int64 output_size, + OpKernelContext* ctx, Distribution dist, int state_input_idx, + bool read_alg_from_state, Algorithm alg, int64 output_size, typename Distribution::ResultElementType* output_data) { Var* var = nullptr; TF_RETURN_IF_ERROR( @@ -89,7 +89,7 @@ Status UpdateVariableAndFill( TF_RETURN_IF_ERROR(PrepareToUpdateVariable( ctx, var_tensor, var->copy_on_read_mode.load())); UpdateVariableAndFill_Philox()( - ctx, ctx->eigen_device(), output_size, alg_tag_skip, + ctx, ctx->eigen_device(), dist, output_size, alg_tag_skip, &state_var_guard, var_tensor, output_data); return Status::OK(); } else { @@ -99,8 +99,9 @@ Status UpdateVariableAndFill( // Preconditon: input(0) is an existing resource. template -void ComputeImpl(OpKernelContext* ctx, int state_input_idx, int shape_input_idx, - bool read_alg_from_state, Algorithm alg) { +void StatefulRandomCompute(OpKernelContext* ctx, Distribution dist, + int state_input_idx, int shape_input_idx, + bool read_alg_from_state, Algorithm alg) { using T = typename Distribution::ResultElementType; const Tensor& shape_t = ctx->input(shape_input_idx); TensorShape shape; @@ -108,8 +109,8 @@ void ComputeImpl(OpKernelContext* ctx, int state_input_idx, int shape_input_idx, Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); auto output_flat = output->flat(); - OP_REQUIRES_OK(ctx, UpdateVariableAndFill( - ctx, state_input_idx, read_alg_from_state, alg, + OP_REQUIRES_OK(ctx, UpdateVariableAndFill( + ctx, dist, state_input_idx, read_alg_from_state, alg, output_flat.size(), output_flat.data())); } @@ -119,27 +120,89 @@ class StatefulRandomOp : public OpKernel { explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - ComputeImpl(ctx, 0, 1, true, 0); + StatefulRandomCompute(ctx, Distribution(), 0, 1, true, 0); } }; +Status GetAlgorithm(OpKernelContext* ctx, int alg_input_idx, Algorithm* alg) { + const Tensor& alg_tensor = ctx->input(alg_input_idx); + if (alg_tensor.dims() != 0) { + return errors::InvalidArgument("algorithm must be of shape [], not ", + alg_tensor.shape().DebugString()); + } + if (alg_tensor.dtype() != ALGORITHM_DTYPE) { + return errors::InvalidArgument("algorithm's dtype must be ", + DataTypeString(ALGORITHM_DTYPE), ", not ", + DataTypeString(alg_tensor.dtype())); + } + *alg = alg_tensor.flat()(0); + return Status::OK(); +} + template class StatefulRandomOpV2 : public OpKernel { public: explicit StatefulRandomOpV2(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - const Tensor& alg_tensor = ctx->input(1); - OP_REQUIRES(ctx, alg_tensor.dims() == 0, - errors::InvalidArgument("algorithm must be of shape [], not ", - alg_tensor.shape().DebugString())); + Algorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, /*alg_input_idx=*/1, &alg)); + StatefulRandomCompute(ctx, Distribution(), /*state_input_idx=*/0, + /*shape_input_idx=*/2, + /*read_alg_from_state=*/false, alg); + } +}; + +template +class StatefulUniformIntOp : public OpKernel { + public: + explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Algorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, /*alg_input_idx=*/1, &alg)); + const Tensor& minval = ctx->input(3); + const Tensor& maxval = ctx->input(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()), + errors::InvalidArgument("minval must be 0-D, got shape ", + minval.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()), + errors::InvalidArgument("maxval must be 0-D, got shape ", + maxval.shape().DebugString())); + + // Verify that minval < maxval. This check intentionally happens after the + // early exit for empty output. Zero impossible things are fine. + IntType lo = minval.scalar()(); + IntType hi = maxval.scalar()(); OP_REQUIRES( - ctx, alg_tensor.dtype() == ALGORITHM_DTYPE, - errors::InvalidArgument("algorithm's dtype must be ", - DataTypeString(ALGORITHM_DTYPE), ", not ", - DataTypeString(alg_tensor.dtype()))); - auto alg = alg_tensor.flat()(0); - ComputeImpl(ctx, 0, 2, false, alg); + ctx, lo < hi, + errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi)); + + // Build distribution + typedef random::UniformDistribution + Distribution; + Distribution dist(lo, hi); + + StatefulRandomCompute(ctx, dist, /*state_input_idx=*/0, + /*shape_input_idx=*/2, + /*read_alg_from_state=*/false, alg); + } +}; + +template +class StatefulUniformFullIntOp : public OpKernel { + public: + explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Algorithm alg; + OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, /*alg_input_idx=*/1, &alg)); + StatefulRandomCompute( + ctx, + random::UniformFullIntDistribution(), + /*state_input_idx=*/0, /*shape_input_idx=*/2, + /*read_alg_from_state=*/false, alg); } }; @@ -213,14 +276,65 @@ TF_CALL_bfloat16(REGISTER_CPU); TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); +#define REGISTER_StatefulUniformInt(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatefulUniformInt") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("resource") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint("dtype"), \ + StatefulUniformIntOp); + +#define REGISTER_StatefulUniformInt_CPU(TYPE) \ + REGISTER_StatefulUniformInt(CPU, TYPE) +#define REGISTER_StatefulUniformInt_GPU(TYPE) \ + REGISTER_StatefulUniformInt(GPU, TYPE) + +TF_CALL_int32(REGISTER_StatefulUniformInt_CPU); +TF_CALL_int64(REGISTER_StatefulUniformInt_CPU); + +#define REGISTER_StatefulUniformFullInt(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatefulUniformFullInt") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("resource") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + StatefulUniformFullIntOp); + +#define REGISTER_StatefulUniformFullInt_CPU(TYPE) \ + REGISTER_StatefulUniformFullInt(CPU, TYPE) +#define REGISTER_StatefulUniformFullInt_GPU(TYPE) \ + REGISTER_StatefulUniformFullInt(GPU, TYPE) + +TF_CALL_int32(REGISTER_StatefulUniformFullInt_CPU); +TF_CALL_int64(REGISTER_StatefulUniformFullInt_CPU); +TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU); +TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU); + #if GOOGLE_CUDA TF_CALL_half(REGISTER_GPU); +TF_CALL_bfloat16(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); +TF_CALL_int32(REGISTER_StatefulUniformInt_GPU); +TF_CALL_int64(REGISTER_StatefulUniformInt_GPU); +TF_CALL_int32(REGISTER_StatefulUniformFullInt_GPU); +TF_CALL_int64(REGISTER_StatefulUniformFullInt_GPU); +TF_CALL_uint32(REGISTER_StatefulUniformFullInt_GPU); +TF_CALL_uint64(REGISTER_StatefulUniformFullInt_GPU); #endif // GOOGLE_CUDA +#undef REGISTER_StatefulUniformFullInt_GPU +#undef REGISTER_StatefulUniformFullInt_CPU +#undef REGISTER_StatefulUniformFullInt +#undef REGISTER_StatefulUniformInt_GPU +#undef REGISTER_StatefulUniformInt_CPU +#undef REGISTER_StatefulUniformInt #undef REGISTER_GPU #undef REGISTER_CPU #undef REGISTER diff --git a/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h b/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h index 8dc72de3658..63d746fcdb7 100644 --- a/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h +++ b/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h @@ -87,7 +87,7 @@ using GPUDevice = Eigen::GpuDevice; template struct UpdateVariableAndFill_Philox { void operator()(OpKernelContext* ctx, const GPUDevice& device, - int64 output_size, int64 alg_tag_skip, + Distribution dist, int64 output_size, int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor, typename Distribution::ResultElementType* output_data); }; diff --git a/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc b/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc index 99ce3e677d8..8d6243e3211 100644 --- a/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc @@ -53,8 +53,9 @@ __global__ void FillKernel( template void UpdateVariableAndFill_Philox::operator()( - OpKernelContext* ctx, const GPUDevice& d, int64 output_size, - int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor, + OpKernelContext* ctx, const GPUDevice& d, Distribution dist, + int64 output_size, int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, + Tensor* state_tensor, typename Distribution::ResultElementType* output_data) { OP_REQUIRES( ctx, alg_tag_skip == 0, @@ -74,10 +75,9 @@ void UpdateVariableAndFill_Philox::operator()( int zero = 0; cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int)); - TF_CHECK_OK(CudaLaunchKernel(FillKernel, cfg.block_count, - cfg.thread_per_block, 0, d.stream(), - Distribution(), state_size, output_size, - state_data, output_data)); + TF_CHECK_OK(CudaLaunchKernel( + FillKernel, cfg.block_count, cfg.thread_per_block, 0, + d.stream(), dist, state_size, output_size, state_data, output_data)); } // Explicit instantiation of the GPU distributions functors. @@ -86,10 +86,28 @@ void UpdateVariableAndFill_Philox::operator()( // NVCC cannot handle ">>" properly template struct UpdateVariableAndFill_Philox< GPUDevice, random::NormalDistribution >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::NormalDistribution >; template struct UpdateVariableAndFill_Philox< GPUDevice, random::NormalDistribution >; template struct UpdateVariableAndFill_Philox< GPUDevice, random::NormalDistribution >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformDistribution >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformDistribution >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformFullIntDistribution< + random::PhiloxRandom, int32> >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformFullIntDistribution< + random::PhiloxRandom, int64> >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformFullIntDistribution< + random::PhiloxRandom, uint32> >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformFullIntDistribution< + random::PhiloxRandom, uint64> >; // clang-format on } // end namespace tensorflow diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h index c3801a04128..102f9ba7ea8 100644 --- a/tensorflow/core/lib/random/random_distributions.h +++ b/tensorflow/core/lib/random/random_distributions.h @@ -16,12 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#include + #define _USE_MATH_DEFINES #include #include #undef _USE_MATH_DEFINES -#include #include #include @@ -236,6 +237,73 @@ class UniformDistribution { uint64 range_; }; +// Similar to `UniformDistribution`, except that instead of generating numbers +// in the range [low, high), it generates numbers covering the whole range of +// the integer type. +template +class UniformFullIntDistribution; + +template +class UniformFullIntDistribution32 { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount; + // Cost of generation of a single element (in cycles). + static const int kElementCost = 3; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array ResultType; + typedef IntType ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; ++i) { + result[i] = sample[i]; + } + return result; + } +}; + +template +class UniformFullIntDistribution64 { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount / 2; + // Cost of generation of a single element (in cycles). + static const int kElementCost = 3; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array ResultType; + typedef IntType ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; ++i) { + result[i] = sample[2 * i] | static_cast(sample[2 * i + 1]) << 32; + } + return result; + } +}; + +template +class UniformFullIntDistribution + : public UniformFullIntDistribution32 {}; +template +class UniformFullIntDistribution + : public UniformFullIntDistribution32 {}; +template +class UniformFullIntDistribution + : public UniformFullIntDistribution64 {}; +template +class UniformFullIntDistribution + : public UniformFullIntDistribution64 {}; + // A class that adapts the underlying native multiple samples to return a single // sample at a time. template diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 8d658faa268..2ab68f3d744 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2992,6 +2992,7 @@ cuda_py_test( ":client_testlib", ":logging_ops", ":random_ops_gen", + "//tensorflow/python/kernel_tests/random:util", ], xla_enable_strict_auto_jit = True, ) diff --git a/tensorflow/python/kernel_tests/random/util.py b/tensorflow/python/kernel_tests/random/util.py index 67805c7f262..84e3df4278c 100644 --- a/tensorflow/python/kernel_tests/random/util.py +++ b/tensorflow/python/kernel_tests/random/util.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + import numpy as np @@ -70,3 +72,26 @@ def test_moment_matching( total_variance))) return z_test_scores + +def chi_squared(x, bins): + """Pearson's Chi-squared test.""" + x = np.ravel(x) + n = len(x) + histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) + expected = n / float(bins) + return np.sum(np.square(histogram - expected) / expected) + + +def normal_cdf(x): + """Cumulative distribution function for a standard normal distribution.""" + return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) + + +def anderson_darling(x): + """Anderson-Darling test for a standard normal distribution.""" + x = np.sort(np.ravel(x)) + n = len(x) + i = np.linspace(1, n, n) + z = np.sum((2 * i - 1) * np.log(normal_cdf(x)) + + (2 * (n - i) + 1) * np.log(1 - normal_cdf(x))) + return -n - z / n diff --git a/tensorflow/python/ops/stateful_random_ops.py b/tensorflow/python/ops/stateful_random_ops.py index 91625ff968e..4d985fc18a5 100644 --- a/tensorflow/python/ops/stateful_random_ops.py +++ b/tensorflow/python/ops/stateful_random_ops.py @@ -160,7 +160,19 @@ def _shape_tensor(shape): class Generator(tracking.AutoTrackable): """Random-number generator. - It uses Variable to manage its internal state. + It uses Variable to manage its internal state, and allows choosing an + Random-Number-Generation (RNG) algorithm. + + CPU and GPU with the same algorithm and seed will generate the same integer + random numbers. Float-point results (such as the output of `normal`) may have + small numerical discrepancies between CPU and GPU. + + Because of different counter-reservation schemes, TPU's integer random numbers + will be different from CPU/GPU even with the same algorithm and seed. + Also, TPU uses different sampling algorithms for some distributions + (e.g. using reverse CDF for sampling normal distribution instead of + Box-Muller used by CPU/GPU). Harmonizing TPU's RNG behavior with CPU/GPU is + work in progress. """ def __init__(self, copy_from=None, seed=None, algorithm=None): diff --git a/tensorflow/python/ops/stateful_random_ops_test.py b/tensorflow/python/ops/stateful_random_ops_test.py index ea1cebd18fe..4b04906ab82 100644 --- a/tensorflow/python/ops/stateful_random_ops_test.py +++ b/tensorflow/python/ops/stateful_random_ops_test.py @@ -23,9 +23,11 @@ import numpy as np from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.kernel_tests.random import util as \ +random_test_util from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import gen_stateful_random_ops from tensorflow.python.ops import logging_ops @@ -189,33 +191,60 @@ class StatefulRandomOpsTest(test.TestCase): compare(True, True) compare(True, False) + def _sameAsOldRandomOps(self, device): + def compare(dtype, old, new): + seed1, seed2 = 79, 25 + # note how the two seeds for the old op correspond to the seed for the new + # op + with ops.device(device): + gen = random.Generator(seed=[0, seed2, seed1]) + + # create a graph for the old op in order to call it many times + @def_function.function + def run_old(): + with ops.device(device): + return old(dtype, seed1, seed2) + + def run_new(): + with ops.device(device): + return new(dtype, gen) + + for _ in range(100): + self.assertAllEqual(run_old(), run_new()) + + shape = constant_op.constant([4, 7]) + minval = 128 + maxval = 256 + + # passing `dtype` around to compress go/gpylint-faq#cell-var-from-loop and + # go/gpylint-faq#undefined-loop-variable + def old_normal(dtype, seed1, seed2): + return gen_random_ops.random_standard_normal( + shape, dtype=dtype, seed=seed1, seed2=seed2) + def new_normal(dtype, gen): + return gen._standard_normal(shape, dtype=dtype) + def old_uniform(dtype, seed1, seed2): + minval2 = constant_op.constant(minval, dtype=dtype) + maxval2 = constant_op.constant(maxval, dtype=dtype) + return gen_random_ops.random_uniform_int( + shape, minval=minval2, maxval=maxval2, seed=seed1, seed2=seed2) + def new_uniform(dtype, gen): + return gen.uniform( + shape, minval=minval, maxval=maxval, dtype=dtype) + + for dtype in (dtypes.float16, dtypes.bfloat16, dtypes.float32, + dtypes.float64): + compare(dtype, old_normal, new_normal) + for dtype in [dtypes.int32, dtypes.int64]: + compare(dtype, old_uniform, new_uniform) + @test_util.run_v2_only def testCPUSameAsOldRandomOps(self): """Tests that the generated numbers are the same as the old random_ops.py. The CPU version. """ - seed1, seed2 = 79, 25 - # note how the two seeds for the old op correspond to the seed for the new - # op - with ops.device("/device:CPU:0"): - random.reset_global_generator([0, seed2, seed1]) - shape = constant_op.constant([4, 7]) - dtype = dtypes.float64 - - # create a graph for the old op in order to call it many times - @def_function.function - def old(): - with ops.device("/device:CPU:0"): - return gen_random_ops.random_standard_normal( - shape, dtype=dtype, seed=seed1, seed2=seed2) - - def new(): - with ops.device("/device:CPU:0"): - return random.get_global_generator().normal(shape, dtype=dtype) - - for _ in range(100): - self.assertAllEqual(old(), new()) + self._sameAsOldRandomOps("/device:CPU:0") @test_util.run_v2_only @test_util.run_cuda_only @@ -224,28 +253,103 @@ class StatefulRandomOpsTest(test.TestCase): The GPU version. """ - seed1, seed2 = 79, 25 - with ops.device(test_util.gpu_device_name()): - random.reset_global_generator([0, seed2, seed1]) - shape = constant_op.constant([4, 7]) - dtype = dtypes.float64 + self._sameAsOldRandomOps(test_util.gpu_device_name()) - @def_function.function - def old(): - with ops.device(test_util.gpu_device_name()): - return gen_random_ops.random_standard_normal( - shape, dtype=dtype, seed=seed1, seed2=seed2) + @test_util.run_v2_only + def testUniformIntIsInRange(self): + minval = 2 + maxval = 33 + size = 1000 + gen = random.Generator(seed=1234) + for dtype in [dtypes.int32, dtypes.int64]: + x = gen.uniform( + shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy() + self.assertTrue(np.all(x >= minval)) + self.assertTrue(np.all(x < maxval)) - def new(): - with ops.device(test_util.gpu_device_name()): - return random.get_global_generator().normal(shape, dtype=dtype) + @test_util.run_v2_only + def testNormalIsFinite(self): + gen = random.Generator(seed=1234) + for dtype in [dtypes.float32]: + x = gen.normal(shape=[10000], dtype=dtype).numpy() + self.assertTrue(np.all(np.isfinite(x))) - for _ in range(100): - self.assertAllEqual(old(), new()) + @test_util.run_v2_only + def testDistributionOfUniform(self): + """Use Pearson's Chi-squared test to test for uniformity.""" + n = 1000 + seed = 12 + for dtype in [dtypes.int32, dtypes.int64]: + gen = random.Generator(seed=seed) + maxval = 1 + if dtype.is_integer: + maxval = 100 + x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy() + if maxval > 1: + # Normalize y to range [0, 1). + x = x.astype(float) / maxval + # Tests that the values are distributed amongst 10 bins with equal + # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with + # p=0.05. This test is probabilistic and would be flaky if the random + # seed were not fixed. + val = random_test_util.chi_squared(x, 10) + self.assertLess(val, 16.92) + + @test_util.run_v2_only + def testDistributionOfNormal(self): + """Use Anderson-Darling test to test distribution appears normal.""" + n = 1000 + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + gen = random.Generator(seed=1234) + x = gen.normal(shape=[n], dtype=dtype).numpy() + # The constant 2.492 is the 5% critical value for the Anderson-Darling + # test where the mean and variance are known. This test is probabilistic + # so to avoid flakiness the seed is fixed. + self.assertLess( + random_test_util.anderson_darling(x.astype(float)), 2.492) + + @test_util.run_v2_only + def testErrors(self): + """Tests that proper errors are raised. + """ + shape = [2, 3] + gen = random.Generator(seed=1234) + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + r"algorithm must be of shape \[\], not"): + gen_stateful_random_ops.stateful_standard_normal_v2( + gen.state.handle, [0, 0], shape) + with self.assertRaisesWithPredicateMatch( + TypeError, "Requested dtype: int64"): + gen_stateful_random_ops.stateful_standard_normal_v2( + gen.state.handle, 1.1, shape) + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "Unsupported algorithm id"): + gen_stateful_random_ops.stateful_standard_normal_v2( + gen.state.handle, 123, shape) + var = variables.Variable([0, 0], dtype=dtypes.int32) + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "dtype of RNG state variable must be int64, not"): + gen_stateful_random_ops.stateful_standard_normal_v2( + var.handle, random.RNG_ALG_PHILOX, shape) + var = variables.Variable([[0]], dtype=dtypes.int64) + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "RNG state must have one and only one dimension, not"): + gen_stateful_random_ops.stateful_standard_normal_v2( + var.handle, random.RNG_ALG_PHILOX, shape) + var = variables.Variable([0], dtype=dtypes.int64) + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "For the Philox algorithm, the size of state must be at least"): + gen_stateful_random_ops.stateful_standard_normal_v2( + var.handle, random.RNG_ALG_PHILOX, shape) @test_util.run_v2_only def testStatefulStandardNormal(self): - """Tests that op 'StatefulStandardNormal' still works. + """Tests that the deprecated op 'StatefulStandardNormal' still works. """ shape = constant_op.constant([4, 7]) dtype = dtypes.float64 @@ -277,7 +381,7 @@ class StatefulRandomOpsTest(test.TestCase): random.reset_global_generator(50) with self.assertRaisesWithPredicateMatch( - errors_impl.NotFoundError, "Resource .+ does not exist"): + errors.NotFoundError, "Resource .+ does not exist"): a = f() random.reset_global_generator(50) b = f()