From 12606ff846566dd842444f27f7fed4f911a58580 Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Mon, 4 Mar 2019 12:06:05 -0800 Subject: [PATCH] A continuing partial implementation of RFC "Random numbers in TensorFlow 2.0" (https://github.com/tensorflow/community/pull/38): In this change: - XLA kernel (TF-XLA bridge) for op `StatefulStandardNormalV2` with ThreeFry algorithm. To be done: - ops for other distributions; - other RNG algorithms; - batch seeds; - initializers ('RandomUniform', etc.); - changing `non_deterministic_seed`'implementation to be an op, and other changes to make most code in stateful_random_ops.py tf.function-compatible. PiperOrigin-RevId: 236699104 --- tensorflow/compiler/tests/BUILD | 14 + .../tests/stateful_random_ops_test.py | 282 ++++++++++++++ tensorflow/compiler/tf2xla/kernels/BUILD | 4 + .../compiler/tf2xla/kernels/random_ops_util.h | 34 ++ .../tf2xla/kernels/stateful_random_ops.cc | 362 ++++++++++++++++++ .../tf2xla/kernels/stateless_random_ops.cc | 7 +- .../tf2xla/resource_operation_table.cc | 3 + tensorflow/compiler/xla/client/lib/prng.cc | 32 +- tensorflow/compiler/xla/client/lib/prng.h | 26 ++ .../api_def_StatefulUniformFullInt.pbtxt | 38 ++ .../base_api/api_def_StatefulUniformInt.pbtxt | 56 +++ tensorflow/core/framework/resource_var.h | 25 ++ .../core/kernels/stateful_random_ops.cc | 15 +- tensorflow/core/kernels/stateful_random_ops.h | 108 +----- .../kernels/stateful_random_ops_cpu_gpu.h | 104 +++++ .../kernels/stateful_random_ops_gpu.cu.cc | 4 +- tensorflow/core/ops/stateful_random_ops.cc | 39 +- tensorflow/python/ops/stateful_random_ops.py | 144 +++++-- .../python/ops/stateful_random_ops_test.py | 7 +- 19 files changed, 1125 insertions(+), 179 deletions(-) create mode 100644 tensorflow/compiler/tests/stateful_random_ops_test.py create mode 100644 tensorflow/compiler/tf2xla/kernels/random_ops_util.h create mode 100644 tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc create mode 100644 tensorflow/core/api_def/base_api/api_def_StatefulUniformFullInt.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_StatefulUniformInt.pbtxt create mode 100644 tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 00c7d74e7d9..65f5fba269c 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -835,6 +835,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "stateful_random_ops_test", + size = "small", + srcs = ["stateful_random_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:standard_ops", + "//tensorflow/python:stateful_random_ops", + ], +) + tf_xla_py_test( name = "stateless_random_ops_test", size = "small", diff --git a/tensorflow/compiler/tests/stateful_random_ops_test.py b/tensorflow/compiler/tests/stateful_random_ops_test.py new file mode 100644 index 00000000000..f0535579bf2 --- /dev/null +++ b/tensorflow/compiler/tests/stateful_random_ops_test.py @@ -0,0 +1,282 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for stateful random-number generation ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.client import device_lib +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gen_stateful_random_ops +from tensorflow.python.ops import stateful_random_ops as \ +random +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def xla_device_name(): + devices = device_lib.list_local_devices() + def find_type(device_type): + for d in devices: + if d.device_type == device_type: + return d.name + return None + name = find_type("TPU") or find_type("XLA_GPU") or find_type("XLA_CPU") + if name is None: + raise ValueError( + "Can't find any XLA device. Available devices:\n%s" % devices) + return str(name) + + +class StatefulRandomOpsTest(xla_test.XLATestCase): + """Test cases for stateful random-number generator operators.""" + + @test_util.run_v2_only + def testSimple(self): + """A simple test. + """ + with ops.device(xla_device_name()): + gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY) + gen.normal(shape=(3,)) + gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32) + gen.uniform_full_int(shape=(3,)) + + @test_util.run_v2_only + def testDefun(self): + """Test for defun. + """ + with ops.device(xla_device_name()): + gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY) + @def_function.function + def f(): + x = gen.normal(shape=(3,)) + y = gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32) + z = gen.uniform_full_int(shape=(3,)) + return (x, y, z) + f() + + @test_util.run_v2_only + def testThreefry2x32(self): + """Tests ThreeFry2x32 conforms to known results. + """ + # Based on + # https://github.com/google/jax/blob/8565a3486adf16beb388b2364c9cd930d7a0d92d/tests/random_test.py#L65-L85 + # which is in turn based on + # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32 + + def uint32s_to_uint64(a, b): + return b << 32 | a + + def verify(counter1, counter2, key1, key2, expect1, expect2): + counter = uint32s_to_uint64(counter1, counter2) + key = uint32s_to_uint64(key1, key2) + random.get_global_generator().reset([counter, key]) + got = random.get_global_generator().uniform_full_int( + shape=(2,), dtype=dtypes.uint32) + expect = [expect1, expect2] + self.assertAllEqual(expect, got) + random.get_global_generator().reset([counter, key]) + got = random.get_global_generator().uniform_full_int( + shape=(), dtype=dtypes.uint64) + self.assertAllEqual(uint32s_to_uint64(*expect), got) + + with ops.device(xla_device_name()): + random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY) + verify(0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x6b200159, 0x99ba4efe) + verify(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0x1cb996fc, 0xbb002be7) + verify(0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344, + 0xc4923a9c, 0x483df7a0) + + @test_util.run_v2_only + def testNewState(self): + """Tests that the new state is correct. + """ + with ops.device(xla_device_name()): + counter = 57 + key = 0x1234 + size = 46 + seed = [counter, key] + gen = random.Generator( + seed=seed, algorithm=random.RNG_ALG_THREEFRY) + gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32) + self.assertAllEqual([counter+(size+1)//2, key], gen.state.read_value()) + gen.reset(seed=seed) + gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64) + self.assertAllEqual([counter+size, key], gen.state.read_value()) + + def _testRngIsNotConstant(self, rng, dtype): + # Tests that 'rng' does not always return the same value. + # The random-number generator, if working correctly, should produce the + # same output multiple times with low probability. + x = rng(dtype).numpy() + y = rng(dtype).numpy() + self.assertFalse(np.array_equal(x, y)) + + @test_util.run_v2_only + def testUniformIsNotConstant(self): + with ops.device(xla_device_name()): + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + def rng(dtype): + maxval = dtype.max + # Workaround for b/125364959 + if dtype == dtypes.uint64: + maxval = 10000000 + return gen.uniform(shape=[2], dtype=dtype, maxval=maxval) + + for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + self._testRngIsNotConstant(rng, dtype) + + @test_util.run_v2_only + def testNormalIsNotConstant(self): + with ops.device(xla_device_name()): + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + def rng(dtype): + return gen.normal(shape=[2], dtype=dtype) + + for dtype in {dtypes.float32}: + self._testRngIsNotConstant(rng, dtype) + + @test_util.run_v2_only + def testUniformIntIsInRange(self): + minval = 2 + maxval = 33 + size = 1000 + with ops.device(xla_device_name()): + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + x = gen.uniform( + shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy() + self.assertTrue(np.all(x >= minval)) + self.assertTrue(np.all(x < maxval)) + + @test_util.run_v2_only + def testNormalIsFinite(self): + with ops.device(xla_device_name()): + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + for dtype in {dtypes.float32}: + x = gen.normal(shape=[10000], dtype=dtype).numpy() + self.assertTrue(np.all(np.isfinite(x))) + + def _chi_squared(self, x, bins): + """Pearson's Chi-squared test.""" + x = np.ravel(x) + n = len(x) + histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) + expected = n / float(bins) + return np.sum(np.square(histogram - expected) / expected) + + @test_util.run_v2_only + def testDistributionOfUniform(self): + """Use Pearson's Chi-squared test to test for uniformity.""" + with ops.device(xla_device_name()): + n = 1000 + seed = 12 + for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + gen = random.Generator(seed=seed, algorithm=random.RNG_ALG_THREEFRY) + maxval = 1 + if dtype.is_integer: + maxval = 100 + x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy() + if maxval > 1: + # Normalize y to range [0, 1). + x = x.astype(float) / maxval + # Tests that the values are distributed amongst 10 bins with equal + # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with + # p=0.05. This test is probabilistic and would be flaky if the random + # seed were not fixed. + val = self._chi_squared(x, 10) + self.assertLess(val, 16.92) + + def _normal_cdf(self, x): + """Cumulative distribution function for a standard normal distribution.""" + return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) + + def _anderson_darling(self, x): + """Anderson-Darling test for a standard normal distribution.""" + x = np.sort(np.ravel(x)) + n = len(x) + i = np.linspace(1, n, n) + z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) + + (2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x))) + return -n - z / n + + @test_util.run_v2_only + def testDistributionOfNormal(self): + """Use Anderson-Darling test to test distribution appears normal.""" + with ops.device(xla_device_name()): + n = 1000 + for dtype in {dtypes.float32}: + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + x = gen.normal(shape=[n], dtype=dtype).numpy() + # The constant 2.492 is the 5% critical value for the Anderson-Darling + # test where the mean and variance are known. This test is probabilistic + # so to avoid flakiness the seed is fixed. + self.assertLess(self._anderson_darling(x.astype(float)), 2.492) + + @test_util.run_v2_only + def testErrors(self): + """Tests that proper errors are raised. + """ + shape = [2, 3] + with ops.device(xla_device_name()): + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + r"algorithm must be of shape \[\], not"): + gen_stateful_random_ops.stateful_standard_normal_v2( + gen.state.handle, [0, 0], shape) + with self.assertRaisesWithPredicateMatch( + TypeError, "Requested dtype: int64"): + gen_stateful_random_ops.stateful_standard_normal_v2( + gen.state.handle, 1.1, shape) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + "Unsupported algorithm id"): + gen_stateful_random_ops.stateful_standard_normal_v2( + gen.state.handle, 123, shape) + var = variables.Variable([0, 0], dtype=dtypes.uint32) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + "Type mismatch for read of variable .* Expected int64; got"): + gen_stateful_random_ops.stateful_standard_normal_v2( + var.handle, random.RNG_ALG_THREEFRY, shape) + var = variables.Variable([[0]], dtype=dtypes.int64) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + "RNG state must have one and only one dimension, not"): + gen_stateful_random_ops.stateful_standard_normal_v2( + var.handle, random.RNG_ALG_THREEFRY, shape) + var = variables.Variable([0], dtype=dtypes.int64) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + "For the ThreeFry algorithm, the size of state must be at least"): + gen_stateful_random_ops.stateful_standard_normal_v2( + var.handle, random.RNG_ALG_THREEFRY, shape) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 934b28cd688..c790c4c6723 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -64,6 +64,7 @@ tf_kernel_library( "qr_op.cc", "quantize_and_dequantize_op.cc", "random_ops.cc", + "random_ops_util.h", "reduce_window_op.cc", "reduction_ops.cc", "reduction_ops.h", @@ -89,6 +90,7 @@ tf_kernel_library( "sparse_to_dense_op.cc", "split_op.cc", "stack_ops.cc", + "stateful_random_ops.cc", "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", @@ -167,6 +169,7 @@ tf_kernel_library( "//tensorflow/core:sparse_ops_op_lib", "//tensorflow/core:spectral_ops_op_lib", "//tensorflow/core:state_ops_op_lib", + "//tensorflow/core:stateful_random_ops_op_lib", "//tensorflow/core:stateless_random_ops_op_lib", "//tensorflow/core:training_ops_op_lib", "//tensorflow/core/kernels:constant_op", @@ -179,6 +182,7 @@ tf_kernel_library( "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:stack_ops", + "//tensorflow/core/kernels:stateful_random_ops", "//tensorflow/core/kernels:training_ops", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h new file mode 100644 index 00000000000..d107be6f13c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise. +// It masks the last 16 bit. With normal rounding, values near "maxval" would be +// converted to "maxval" which is out of range ["minval", "maxval"). In +// addition, the distribution near the limit is not uniform. +xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc new file mode 100644 index 00000000000..f1d68835e12 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -0,0 +1,362 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" +#include "tensorflow/compiler/tf2xla/lib/random.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/stateful_random_ops.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { +namespace { + +std::pair GetInputsFromCounter( + xla::XlaOp counter, const int64 size) { + auto builder = counter.builder(); + auto input_u64 = Iota(builder, xla::U64, size); + input_u64 = input_u64 + counter; + counter = counter + xla::ConstantR0(builder, size); + return std::make_pair(xla::Uint64ToUint32s(input_u64), counter); +} + +// `StatelessRngUniformU32` uses ThreeFry2x32’s counter space too +// wastefully, only able to generate 2^32*2 int32 numbers for each key, while +// the real capacity is 2^64*2. Counter-space efficiency is important for +// stateful ops, hence the following 2 new functions. +std::pair StatefulRngUniformU32( + xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { + auto builder = key.builder(); + const int64 size = xla::ShapeUtil::ElementsIn(shape); + const int64 half_size = xla::CeilOfRatio(size, 2); + const bool size_is_odd = (half_size * 2 != size); + auto inputs_counter = GetInputsFromCounter(counter, half_size); + auto inputs = inputs_counter.first; + counter = inputs_counter.second; + auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key)); + if (size_is_odd) { + outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + auto result = ConcatInDim(builder, outputs, 0); + return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())), + counter); +} + +std::pair StatefulRngUniformU64( + xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { + const int64 size = xla::ShapeUtil::ElementsIn(shape); + auto inputs_counter = GetInputsFromCounter(counter, size); + auto inputs = inputs_counter.first; + counter = inputs_counter.second; + auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); + auto result = Uint32sToUint64(outputs); + return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())), + counter); +} + +std::pair StatefulRngUniform(xla::XlaOp key, + xla::XlaOp counter, + const xla::Shape& shape, + xla::XlaOp minval, + xla::XlaOp maxval) { + auto builder = key.builder(); + xla::PrimitiveType type = shape.element_type(); + switch (type) { + case xla::F32: { + auto bits_counter = StatefulRngUniformU32(key, counter, shape); + auto bits = bits_counter.first; + counter = bits_counter.second; + return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval), + counter); + } + case xla::U32: // fall through + case xla::S32: { + auto bits_counter = StatefulRngUniformU32(key, counter, shape); + auto bits = bits_counter.first; + counter = bits_counter.second; + return std::make_pair( + xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32), + counter); + } + case xla::U64: // fall through + case xla::S64: { + auto bits_counter = StatefulRngUniformU64(key, counter, shape); + auto bits = bits_counter.first; + counter = bits_counter.second; + return std::make_pair( + xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64), + counter); + } + default: + return std::make_pair(builder->ReportError(xla::Unimplemented( + "Types other than F32, U32, S32, U64 and S64 " + "are not implemented by " + "StatefulRngUniform.")), + counter); + } +} + +template +std::pair map_first(std::function f, std::pair p) { + return std::make_pair(f(p.first), p.second); +} + +std::pair StatefulRngUniformFullInt( + xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { + xla::PrimitiveType type = shape.element_type(); + switch (type) { + case xla::U32: + return StatefulRngUniformU32(key, counter, shape); + case xla::S32: { + // Needs explicit function type because of type-inference failure. + std::function f = [](xla::XlaOp x) { + return BitcastConvertType(x, xla::S32); + }; + return map_first(f, StatefulRngUniformU32(key, counter, shape)); + } + case xla::U64: + return StatefulRngUniformU64(key, counter, shape); + case xla::S64: { + std::function f = [](xla::XlaOp x) { + return BitcastConvertType(x, xla::S64); + }; + return map_first(f, StatefulRngUniformU64(key, counter, shape)); + } + default: + auto builder = key.builder(); + return std::make_pair( + builder->ReportError(xla::Unimplemented( + "Types other than U32, S32, U64 and S64 are not implemented by " + "StatefulRngUniformFullInt; got: %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))), + counter); + } +} + +template +ListB Map(F f, ListA const& list_a) { + ListB list_b; + for (auto a : list_a) { + list_b.push_back(f(a)); + } + return list_b; +} + +xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, + absl::Span scalars) { + return ConcatInDim( + builder, + Map>( + [](xla::XlaOp x) { return xla::Reshape(x, {1}); }, scalars), + 0); +} + +using sampler_return_type = xla::StatusOr>; + +// A helper function containing the common part of several kernels below. +// Precondition: 'algorithm' and 'shape' are compile-time constants. +Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx, + int alg_input_idx, int shape_input_idx, + std::function const& + sample_with_threefry) { + auto alg_shape = ctx->InputShape(alg_input_idx); + if (alg_shape.dims() != 0) { + return errors::InvalidArgument("algorithm must be of shape [], not ", + alg_shape.DebugString()); + } + xla::Literal alg_literal; + TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal)); + auto alg = alg_literal.Get({}); + + if (alg == RNG_ALG_THREEFRY) { + xla::XlaOp var; + TensorShape var_shape; + TF_RETURN_IF_ERROR(ctx->ReadVariableInput( + state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var)); + if (var_shape.dims() != 1) { + return errors::InvalidArgument( + "RNG state must have one and only one dimension, not ", + var_shape.dims()); + } + auto state_size = var_shape.dim_size(0); + if (state_size < THREEFRY_MIN_STATE_SIZE) { + return errors::InvalidArgument( + "For the ThreeFry algorithm, the size of state" + " must be at least ", + THREEFRY_MIN_STATE_SIZE, "; got ", state_size); + } + TensorShape shape; + TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape)); + + static constexpr int COUNTER_SIZE = 1; + auto counter = BitcastConvertType( + xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64); + auto key = BitcastConvertType( + xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}), + {}), + xla::U64); + + auto status_or_value = sample_with_threefry(counter, key, shape); + if (!status_or_value.ok()) { + return status_or_value.status(); + } + auto output_counter = status_or_value.ConsumeValueOrDie(); + auto output = output_counter.first; + counter = output_counter.second; + ctx->SetOutput(0, output); + auto builder = ctx->builder(); + var = ConcatScalars(builder, {counter, key}); + xla::PrimitiveType state_element_type; + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type)); + var = BitcastConvertType(var, state_element_type); + TF_RETURN_IF_ERROR( + ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var)); + return Status::OK(); + } else { + return errors::InvalidArgument("Unsupported algorithm id: ", alg); + } +} + +class StatefulStandardNormalOp : public XlaOpKernel { + public: + explicit StatefulStandardNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + auto sample_with_threefry = + // Needs explicit lambda return type because it fails to be inferred. + [builder, this](xla::XlaOp counter, xla::XlaOp key, + TensorShape shape) -> sampler_return_type { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto uniform_counter = StatefulRngUniform( + key, counter, xla_shape, + xla::ConstantR0(builder, std::nextafter(-1.0f, 0.0f)), + xla::ConstantR0(builder, 1.0)); + auto uniform = uniform_counter.first; + counter = uniform_counter.second; + // Convert uniform distribution to normal distribution by computing + // sqrt(2) * erfinv(x) + auto normal = + xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); + normal = MaybeConvertF32ToBF16(normal, dtype_); + return {{normal, counter}}; + }; + OP_REQUIRES_OK(ctx, + CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, + /*shape_input_idx=*/2, sample_with_threefry)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatefulStandardNormalOp); +}; + +// TODO(wangpeng): Support plain float16 and float64 to get rid of the +// `TypeConstraint`. +REGISTER_XLA_OP(Name("StatefulStandardNormalV2") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}), + StatefulStandardNormalOp); + +class StatefulUniformIntOp : public XlaOpKernel { + public: + explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp minval = ctx->Input(3); + xla::XlaOp maxval = ctx->Input(4); + auto sample_with_threefry = [minval, maxval, this]( + xla::XlaOp counter, xla::XlaOp key, + TensorShape shape) -> sampler_return_type { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + return StatefulRngUniform(key, counter, xla_shape, minval, maxval); + }; + OP_REQUIRES_OK(ctx, + CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, + /*shape_input_idx=*/2, sample_with_threefry)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformIntOp); +}; + +REGISTER_XLA_OP(Name("StatefulUniformInt") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", + {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}), + StatefulUniformIntOp); + +class StatefulUniformFullIntOp : public XlaOpKernel { + public: + explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + auto sample_with_threefry = [this]( + xla::XlaOp counter, xla::XlaOp key, + TensorShape shape) -> sampler_return_type { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + return StatefulRngUniformFullInt(key, counter, xla_shape); + }; + OP_REQUIRES_OK(ctx, + CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, + /*shape_input_idx=*/2, sample_with_threefry)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformFullIntOp); +}; + +REGISTER_XLA_OP(Name("StatefulUniformFullInt") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", + {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}), + StatefulUniformFullIntOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 17f067e0dfc..43255452cc3 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -31,12 +32,8 @@ limitations under the License. #include "tensorflow/core/lib/math/math_util.h" namespace tensorflow { -namespace { xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { - // Mask the last 16 bit. With normal rounding, values near "maxval" would be - // converted to "maxval" which is out of range ["minval", "maxval"). In - // addition, the distribution near the limit is not uniform. if (dtype == DT_BFLOAT16) { xla::XlaBuilder* builder = input.builder(); auto output = xla::BitcastConvertType(input, xla::U32) & @@ -48,6 +45,8 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { } } +namespace { + class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index c20d6a5fd1f..29ebf46e4bf 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -82,6 +82,9 @@ CreateResourceOpInfoMap() { add("ResourceScatterSub" , kReadWrite, kVariable); add("ResourceScatterUpdate" , kReadWrite, kVariable); add("ResourceStridedSliceAssign" , kReadWrite, kVariable); + add("StatefulStandardNormalV2" , kReadWrite, kVariable); + add("StatefulUniformFullInt" , kReadWrite, kVariable); + add("StatefulUniformInt" , kReadWrite, kVariable); add("VarIsInitializedOp" , kRead, kVariable); add("VariableShape" , kRead, kVariable); diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 85b9e1827dc..63b3b07ddc2 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/lib/prng.h" + #include #include "absl/base/casts.h" @@ -30,11 +32,8 @@ XlaOp RotateLeftU32(XlaOp v, int distance) { ShiftRightLogical(v, ConstantR0(v.builder(), 32 - distance)); } -using ThreeFry2x32State = std::array; +} // namespace -// Implements the ThreeFry counter-based PRNG algorithm. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { XlaBuilder* builder = input[0].builder(); key[0] = BitcastConvertType(key[0], U32); @@ -127,15 +126,28 @@ XlaOp StatelessRngUniformU32(std::array key, const Shape& shape) { return Reshape(result, AsInt64Slice(shape.dimensions())); } +ThreeFry2x32State Uint64ToUint32s(XlaOp u64) { + auto builder = u64.builder(); + auto const32 = ConstantR0WithType(builder, U64, 32); + auto fst = ConvertElementType(u64, U32); + auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32); + return {fst, snd}; +} + +XlaOp Uint32sToUint64(ThreeFry2x32State u32s) { + auto builder = u32s[0].builder(); + return ConvertElementType(u32s[0], U64) | + ShiftLeft(ConvertElementType(u32s[1], U64), + ConstantR0WithType(builder, U64, 32)); +} + XlaOp StatelessRngUniformU64(std::array key, const Shape& shape) { XlaBuilder* builder = key[0].builder(); const int64 size = ShapeUtil::ElementsIn(shape); ThreeFry2x32State inputs = GetInputs(size, builder); ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); // low 32 bit: outputs[0], high 32 bit: outputs[1] - auto result = ConvertElementType(outputs[0], U64) | - ShiftLeft(ConvertElementType(outputs[1], U64), - ConstantR0WithType(builder, U64, 32)); + auto result = Uint32sToUint64(outputs); return Reshape(result, AsInt64Slice(shape.dimensions())); } @@ -161,10 +173,6 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) { XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, PrimitiveType type, PrimitiveType unsigned_type) { XlaBuilder* builder = bits.builder(); - // TODO(b/72573764): Generate real uniform integer distribution. - // The following algorithm is the same one that TF uses right now, but it's - // uniform only when maxval - minval is a divisor of the range that bits is - // generated from. auto range = BitcastConvertType(maxval, unsigned_type) - BitcastConvertType(minval, unsigned_type); auto dist = Rem(bits, range); @@ -175,8 +183,6 @@ XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, BitcastConvertType(dist - dist_div_2, type); } -} // namespace - XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, XlaOp minval, XlaOp maxval) { XlaBuilder* builder = seeds[0].builder(); diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h index 2603818de26..7b0b4c2439e 100644 --- a/tensorflow/compiler/xla/client/lib/prng.h +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -23,12 +23,38 @@ limitations under the License. namespace xla { +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +using ThreeFry2x32State = std::array; +ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key); + // Returns a tensor containing 'shape' random values uniformly distributed in // the range [minval, maxval). Requires 2 32-bit integer seeds. // Currently only 'shape's of type F32, S32 and S64 are implemented. XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, XlaOp minval, XlaOp maxval); +// Converts a 32-bit (signed or unsigned) integer random number `bits` into a +// float32 in the range [minval, maxval). +XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval); + +// Converts an integer random number 'bits' of type 'type' to a random number +// in the range [minval, maxval), of the same type. 'unsigned_type' is the +// unsigned version of 'type' (could be the same) with the same bit width. +// The algorithm is the same one that TF uses right now, but it's +// uniform only when maxval - minval is a divisor of the range that bits is +// generated from. +// TODO(b/72573764): Generate real uniform integer distribution. +XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, + PrimitiveType type, PrimitiveType unsigned_type); + +// The following 2 functions, for converting between one uint64 and two uint32s, +// use the contract "lower 32 bits for the first uint32, higher 32 bits for the +// second". +ThreeFry2x32State Uint64ToUint32s(XlaOp u64); +XlaOp Uint32sToUint64(ThreeFry2x32State u32s); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ diff --git a/tensorflow/core/api_def/base_api/api_def_StatefulUniformFullInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatefulUniformFullInt.pbtxt new file mode 100644 index 00000000000..6d576052c0a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatefulUniformFullInt.pbtxt @@ -0,0 +1,38 @@ +op { + graph_op_name: "StatefulUniformFullInt" + visibility: HIDDEN + in_arg { + name: "resource" + description: <