A continuing partial implementation of RFC "Random numbers in TensorFlow 2.0" (https://github.com/tensorflow/community/pull/38):

In this change:
- XLA kernel (TF-XLA bridge) for op `StatefulStandardNormalV2` with ThreeFry algorithm.

To be done:
- ops for other distributions;
- other RNG algorithms;
- batch seeds;
- initializers ('RandomUniform', etc.);
- changing `non_deterministic_seed`'implementation to be an op, and other changes to make most code in stateful_random_ops.py tf.function-compatible.

PiperOrigin-RevId: 236699104
This commit is contained in:
Peng Wang 2019-03-04 12:06:05 -08:00 committed by TensorFlower Gardener
parent 55240a22f1
commit 12606ff846
19 changed files with 1125 additions and 179 deletions

View File

@ -835,6 +835,20 @@ tf_xla_py_test(
], ],
) )
tf_xla_py_test(
name = "stateful_random_ops_test",
size = "small",
srcs = ["stateful_random_ops_test.py"],
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:standard_ops",
"//tensorflow/python:stateful_random_ops",
],
)
tf_xla_py_test( tf_xla_py_test(
name = "stateless_random_ops_test", name = "stateless_random_ops_test",
size = "small", size = "small",

View File

@ -0,0 +1,282 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for stateful random-number generation ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.client import device_lib
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_stateful_random_ops
from tensorflow.python.ops import stateful_random_ops as \
random
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
def xla_device_name():
devices = device_lib.list_local_devices()
def find_type(device_type):
for d in devices:
if d.device_type == device_type:
return d.name
return None
name = find_type("TPU") or find_type("XLA_GPU") or find_type("XLA_CPU")
if name is None:
raise ValueError(
"Can't find any XLA device. Available devices:\n%s" % devices)
return str(name)
class StatefulRandomOpsTest(xla_test.XLATestCase):
"""Test cases for stateful random-number generator operators."""
@test_util.run_v2_only
def testSimple(self):
"""A simple test.
"""
with ops.device(xla_device_name()):
gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
gen.normal(shape=(3,))
gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32)
gen.uniform_full_int(shape=(3,))
@test_util.run_v2_only
def testDefun(self):
"""Test for defun.
"""
with ops.device(xla_device_name()):
gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
@def_function.function
def f():
x = gen.normal(shape=(3,))
y = gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32)
z = gen.uniform_full_int(shape=(3,))
return (x, y, z)
f()
@test_util.run_v2_only
def testThreefry2x32(self):
"""Tests ThreeFry2x32 conforms to known results.
"""
# Based on
# https://github.com/google/jax/blob/8565a3486adf16beb388b2364c9cd930d7a0d92d/tests/random_test.py#L65-L85
# which is in turn based on
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
def uint32s_to_uint64(a, b):
return b << 32 | a
def verify(counter1, counter2, key1, key2, expect1, expect2):
counter = uint32s_to_uint64(counter1, counter2)
key = uint32s_to_uint64(key1, key2)
random.get_global_generator().reset([counter, key])
got = random.get_global_generator().uniform_full_int(
shape=(2,), dtype=dtypes.uint32)
expect = [expect1, expect2]
self.assertAllEqual(expect, got)
random.get_global_generator().reset([counter, key])
got = random.get_global_generator().uniform_full_int(
shape=(), dtype=dtypes.uint64)
self.assertAllEqual(uint32s_to_uint64(*expect), got)
with ops.device(xla_device_name()):
random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
verify(0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x6b200159, 0x99ba4efe)
verify(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
0x1cb996fc, 0xbb002be7)
verify(0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344,
0xc4923a9c, 0x483df7a0)
@test_util.run_v2_only
def testNewState(self):
"""Tests that the new state is correct.
"""
with ops.device(xla_device_name()):
counter = 57
key = 0x1234
size = 46
seed = [counter, key]
gen = random.Generator(
seed=seed, algorithm=random.RNG_ALG_THREEFRY)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
self.assertAllEqual([counter+(size+1)//2, key], gen.state.read_value())
gen.reset(seed=seed)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
self.assertAllEqual([counter+size, key], gen.state.read_value())
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
# The random-number generator, if working correctly, should produce the
# same output multiple times with low probability.
x = rng(dtype).numpy()
y = rng(dtype).numpy()
self.assertFalse(np.array_equal(x, y))
@test_util.run_v2_only
def testUniformIsNotConstant(self):
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
def rng(dtype):
maxval = dtype.max
# Workaround for b/125364959
if dtype == dtypes.uint64:
maxval = 10000000
return gen.uniform(shape=[2], dtype=dtype, maxval=maxval)
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
self._testRngIsNotConstant(rng, dtype)
@test_util.run_v2_only
def testNormalIsNotConstant(self):
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
def rng(dtype):
return gen.normal(shape=[2], dtype=dtype)
for dtype in {dtypes.float32}:
self._testRngIsNotConstant(rng, dtype)
@test_util.run_v2_only
def testUniformIntIsInRange(self):
minval = 2
maxval = 33
size = 1000
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
x = gen.uniform(
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
self.assertTrue(np.all(x >= minval))
self.assertTrue(np.all(x < maxval))
@test_util.run_v2_only
def testNormalIsFinite(self):
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
for dtype in {dtypes.float32}:
x = gen.normal(shape=[10000], dtype=dtype).numpy()
self.assertTrue(np.all(np.isfinite(x)))
def _chi_squared(self, x, bins):
"""Pearson's Chi-squared test."""
x = np.ravel(x)
n = len(x)
histogram, _ = np.histogram(x, bins=bins, range=(0, 1))
expected = n / float(bins)
return np.sum(np.square(histogram - expected) / expected)
@test_util.run_v2_only
def testDistributionOfUniform(self):
"""Use Pearson's Chi-squared test to test for uniformity."""
with ops.device(xla_device_name()):
n = 1000
seed = 12
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
gen = random.Generator(seed=seed, algorithm=random.RNG_ALG_THREEFRY)
maxval = 1
if dtype.is_integer:
maxval = 100
x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy()
if maxval > 1:
# Normalize y to range [0, 1).
x = x.astype(float) / maxval
# Tests that the values are distributed amongst 10 bins with equal
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
# p=0.05. This test is probabilistic and would be flaky if the random
# seed were not fixed.
val = self._chi_squared(x, 10)
self.assertLess(val, 16.92)
def _normal_cdf(self, x):
"""Cumulative distribution function for a standard normal distribution."""
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
def _anderson_darling(self, x):
"""Anderson-Darling test for a standard normal distribution."""
x = np.sort(np.ravel(x))
n = len(x)
i = np.linspace(1, n, n)
z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) +
(2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x)))
return -n - z / n
@test_util.run_v2_only
def testDistributionOfNormal(self):
"""Use Anderson-Darling test to test distribution appears normal."""
with ops.device(xla_device_name()):
n = 1000
for dtype in {dtypes.float32}:
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
x = gen.normal(shape=[n], dtype=dtype).numpy()
# The constant 2.492 is the 5% critical value for the Anderson-Darling
# test where the mean and variance are known. This test is probabilistic
# so to avoid flakiness the seed is fixed.
self.assertLess(self._anderson_darling(x.astype(float)), 2.492)
@test_util.run_v2_only
def testErrors(self):
"""Tests that proper errors are raised.
"""
shape = [2, 3]
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
r"algorithm must be of shape \[\], not"):
gen_stateful_random_ops.stateful_standard_normal_v2(
gen.state.handle, [0, 0], shape)
with self.assertRaisesWithPredicateMatch(
TypeError, "Requested dtype: int64"):
gen_stateful_random_ops.stateful_standard_normal_v2(
gen.state.handle, 1.1, shape)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
"Unsupported algorithm id"):
gen_stateful_random_ops.stateful_standard_normal_v2(
gen.state.handle, 123, shape)
var = variables.Variable([0, 0], dtype=dtypes.uint32)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
"Type mismatch for read of variable .* Expected int64; got"):
gen_stateful_random_ops.stateful_standard_normal_v2(
var.handle, random.RNG_ALG_THREEFRY, shape)
var = variables.Variable([[0]], dtype=dtypes.int64)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
"RNG state must have one and only one dimension, not"):
gen_stateful_random_ops.stateful_standard_normal_v2(
var.handle, random.RNG_ALG_THREEFRY, shape)
var = variables.Variable([0], dtype=dtypes.int64)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
"For the ThreeFry algorithm, the size of state must be at least"):
gen_stateful_random_ops.stateful_standard_normal_v2(
var.handle, random.RNG_ALG_THREEFRY, shape)
if __name__ == "__main__":
test.main()

View File

@ -64,6 +64,7 @@ tf_kernel_library(
"qr_op.cc", "qr_op.cc",
"quantize_and_dequantize_op.cc", "quantize_and_dequantize_op.cc",
"random_ops.cc", "random_ops.cc",
"random_ops_util.h",
"reduce_window_op.cc", "reduce_window_op.cc",
"reduction_ops.cc", "reduction_ops.cc",
"reduction_ops.h", "reduction_ops.h",
@ -89,6 +90,7 @@ tf_kernel_library(
"sparse_to_dense_op.cc", "sparse_to_dense_op.cc",
"split_op.cc", "split_op.cc",
"stack_ops.cc", "stack_ops.cc",
"stateful_random_ops.cc",
"stateless_random_ops.cc", "stateless_random_ops.cc",
"strided_slice_op.cc", "strided_slice_op.cc",
"tensor_array_ops.cc", "tensor_array_ops.cc",
@ -167,6 +169,7 @@ tf_kernel_library(
"//tensorflow/core:sparse_ops_op_lib", "//tensorflow/core:sparse_ops_op_lib",
"//tensorflow/core:spectral_ops_op_lib", "//tensorflow/core:spectral_ops_op_lib",
"//tensorflow/core:state_ops_op_lib", "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stateful_random_ops_op_lib",
"//tensorflow/core:stateless_random_ops_op_lib", "//tensorflow/core:stateless_random_ops_op_lib",
"//tensorflow/core:training_ops_op_lib", "//tensorflow/core:training_ops_op_lib",
"//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:constant_op",
@ -179,6 +182,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:sparse_to_dense_op",
"//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:stack_ops",
"//tensorflow/core/kernels:stateful_random_ops",
"//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:training_ops",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -0,0 +1,34 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_
#include <cmath>
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.
// It masks the last 16 bit. With normal rounding, values near "maxval" would be
// converted to "maxval" which is out of range ["minval", "maxval"). In
// addition, the distribution near the limit is not uniform.
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_

View File

@ -0,0 +1,362 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
#include "tensorflow/compiler/tf2xla/lib/random.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/stateful_random_ops.h"
#include "tensorflow/core/lib/math/math_util.h"
namespace tensorflow {
namespace {
std::pair<xla::ThreeFry2x32State, xla::XlaOp> GetInputsFromCounter(
xla::XlaOp counter, const int64 size) {
auto builder = counter.builder();
auto input_u64 = Iota(builder, xla::U64, size);
input_u64 = input_u64 + counter;
counter = counter + xla::ConstantR0<uint64>(builder, size);
return std::make_pair(xla::Uint64ToUint32s(input_u64), counter);
}
// `StatelessRngUniformU32` uses ThreeFry2x32s counter space too
// wastefully, only able to generate 2^32*2 int32 numbers for each key, while
// the real capacity is 2^64*2. Counter-space efficiency is important for
// stateful ops, hence the following 2 new functions.
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU32(
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
auto builder = key.builder();
const int64 size = xla::ShapeUtil::ElementsIn(shape);
const int64 half_size = xla::CeilOfRatio<int64>(size, 2);
const bool size_is_odd = (half_size * 2 != size);
auto inputs_counter = GetInputsFromCounter(counter, half_size);
auto inputs = inputs_counter.first;
counter = inputs_counter.second;
auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key));
if (size_is_odd) {
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
}
auto result = ConcatInDim(builder, outputs, 0);
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
counter);
}
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU64(
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
const int64 size = xla::ShapeUtil::ElementsIn(shape);
auto inputs_counter = GetInputsFromCounter(counter, size);
auto inputs = inputs_counter.first;
counter = inputs_counter.second;
auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
auto result = Uint32sToUint64(outputs);
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
counter);
}
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
xla::XlaOp counter,
const xla::Shape& shape,
xla::XlaOp minval,
xla::XlaOp maxval) {
auto builder = key.builder();
xla::PrimitiveType type = shape.element_type();
switch (type) {
case xla::F32: {
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
auto bits = bits_counter.first;
counter = bits_counter.second;
return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval),
counter);
}
case xla::U32: // fall through
case xla::S32: {
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
auto bits = bits_counter.first;
counter = bits_counter.second;
return std::make_pair(
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32),
counter);
}
case xla::U64: // fall through
case xla::S64: {
auto bits_counter = StatefulRngUniformU64(key, counter, shape);
auto bits = bits_counter.first;
counter = bits_counter.second;
return std::make_pair(
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64),
counter);
}
default:
return std::make_pair(builder->ReportError(xla::Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by "
"StatefulRngUniform.")),
counter);
}
}
template <typename A, typename B, typename A2>
std::pair<A2, B> map_first(std::function<A2(A)> f, std::pair<A, B> p) {
return std::make_pair(f(p.first), p.second);
}
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformFullInt(
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
xla::PrimitiveType type = shape.element_type();
switch (type) {
case xla::U32:
return StatefulRngUniformU32(key, counter, shape);
case xla::S32: {
// Needs explicit function type because of type-inference failure.
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
return BitcastConvertType(x, xla::S32);
};
return map_first(f, StatefulRngUniformU32(key, counter, shape));
}
case xla::U64:
return StatefulRngUniformU64(key, counter, shape);
case xla::S64: {
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
return BitcastConvertType(x, xla::S64);
};
return map_first(f, StatefulRngUniformU64(key, counter, shape));
}
default:
auto builder = key.builder();
return std::make_pair(
builder->ReportError(xla::Unimplemented(
"Types other than U32, S32, U64 and S64 are not implemented by "
"StatefulRngUniformFullInt; got: %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
counter);
}
}
template <typename ListB, typename ListA, typename F>
ListB Map(F f, ListA const& list_a) {
ListB list_b;
for (auto a : list_a) {
list_b.push_back(f(a));
}
return list_b;
}
xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
absl::Span<const xla::XlaOp> scalars) {
return ConcatInDim(
builder,
Map<std::vector<xla::XlaOp>>(
[](xla::XlaOp x) { return xla::Reshape(x, {1}); }, scalars),
0);
}
using sampler_return_type = xla::StatusOr<std::pair<xla::XlaOp, xla::XlaOp>>;
// A helper function containing the common part of several kernels below.
// Precondition: 'algorithm' and 'shape' are compile-time constants.
Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
int alg_input_idx, int shape_input_idx,
std::function<sampler_return_type(xla::XlaOp, xla::XlaOp,
TensorShape)> const&
sample_with_threefry) {
auto alg_shape = ctx->InputShape(alg_input_idx);
if (alg_shape.dims() != 0) {
return errors::InvalidArgument("algorithm must be of shape [], not ",
alg_shape.DebugString());
}
xla::Literal alg_literal;
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
auto alg = alg_literal.Get<Algorithm>({});
if (alg == RNG_ALG_THREEFRY) {
xla::XlaOp var;
TensorShape var_shape;
TF_RETURN_IF_ERROR(ctx->ReadVariableInput(
state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var));
if (var_shape.dims() != 1) {
return errors::InvalidArgument(
"RNG state must have one and only one dimension, not ",
var_shape.dims());
}
auto state_size = var_shape.dim_size(0);
if (state_size < THREEFRY_MIN_STATE_SIZE) {
return errors::InvalidArgument(
"For the ThreeFry algorithm, the size of state"
" must be at least ",
THREEFRY_MIN_STATE_SIZE, "; got ", state_size);
}
TensorShape shape;
TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
static constexpr int COUNTER_SIZE = 1;
auto counter = BitcastConvertType(
xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64);
auto key = BitcastConvertType(
xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}),
{}),
xla::U64);
auto status_or_value = sample_with_threefry(counter, key, shape);
if (!status_or_value.ok()) {
return status_or_value.status();
}
auto output_counter = status_or_value.ConsumeValueOrDie();
auto output = output_counter.first;
counter = output_counter.second;
ctx->SetOutput(0, output);
auto builder = ctx->builder();
var = ConcatScalars(builder, {counter, key});
xla::PrimitiveType state_element_type;
TF_RETURN_IF_ERROR(
DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
var = BitcastConvertType(var, state_element_type);
TF_RETURN_IF_ERROR(
ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
return Status::OK();
} else {
return errors::InvalidArgument("Unsupported algorithm id: ", alg);
}
}
class StatefulStandardNormalOp : public XlaOpKernel {
public:
explicit StatefulStandardNormalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
auto builder = ctx->builder();
auto sample_with_threefry =
// Needs explicit lambda return type because it fails to be inferred.
[builder, this](xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
auto uniform_counter = StatefulRngUniform(
key, counter, xla_shape,
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
xla::ConstantR0<float>(builder, 1.0));
auto uniform = uniform_counter.first;
counter = uniform_counter.second;
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
auto normal =
xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
normal = MaybeConvertF32ToBF16(normal, dtype_);
return {{normal, counter}};
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StatefulStandardNormalOp);
};
// TODO(wangpeng): Support plain float16 and float64 to get rid of the
// `TypeConstraint`.
REGISTER_XLA_OP(Name("StatefulStandardNormalV2")
.CompileTimeConstantInput("algorithm")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}),
StatefulStandardNormalOp);
class StatefulUniformIntOp : public XlaOpKernel {
public:
explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp minval = ctx->Input(3);
xla::XlaOp maxval = ctx->Input(4);
auto sample_with_threefry = [minval, maxval, this](
xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
return StatefulRngUniform(key, counter, xla_shape, minval, maxval);
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformIntOp);
};
REGISTER_XLA_OP(Name("StatefulUniformInt")
.CompileTimeConstantInput("algorithm")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype",
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
StatefulUniformIntOp);
class StatefulUniformFullIntOp : public XlaOpKernel {
public:
explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
auto sample_with_threefry = [this](
xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
return StatefulRngUniformFullInt(key, counter, xla_shape);
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformFullIntOp);
};
REGISTER_XLA_OP(Name("StatefulUniformFullInt")
.CompileTimeConstantInput("algorithm")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype",
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
StatefulUniformFullIntOp);
} // namespace
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <cmath> #include <cmath>
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
#include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/random.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
@ -31,12 +32,8 @@ limitations under the License.
#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/math/math_util.h"
namespace tensorflow { namespace tensorflow {
namespace {
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
// Mask the last 16 bit. With normal rounding, values near "maxval" would be
// converted to "maxval" which is out of range ["minval", "maxval"). In
// addition, the distribution near the limit is not uniform.
if (dtype == DT_BFLOAT16) { if (dtype == DT_BFLOAT16) {
xla::XlaBuilder* builder = input.builder(); xla::XlaBuilder* builder = input.builder();
auto output = xla::BitcastConvertType(input, xla::U32) & auto output = xla::BitcastConvertType(input, xla::U32) &
@ -48,6 +45,8 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
} }
} }
namespace {
class StatelessRandomUniformOp : public XlaOpKernel { class StatelessRandomUniformOp : public XlaOpKernel {
public: public:
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)

View File

@ -82,6 +82,9 @@ CreateResourceOpInfoMap() {
add("ResourceScatterSub" , kReadWrite, kVariable); add("ResourceScatterSub" , kReadWrite, kVariable);
add("ResourceScatterUpdate" , kReadWrite, kVariable); add("ResourceScatterUpdate" , kReadWrite, kVariable);
add("ResourceStridedSliceAssign" , kReadWrite, kVariable); add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
add("StatefulStandardNormalV2" , kReadWrite, kVariable);
add("StatefulUniformFullInt" , kReadWrite, kVariable);
add("StatefulUniformInt" , kReadWrite, kVariable);
add("VarIsInitializedOp" , kRead, kVariable); add("VarIsInitializedOp" , kRead, kVariable);
add("VariableShape" , kRead, kVariable); add("VariableShape" , kRead, kVariable);

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include <cmath> #include <cmath>
#include "absl/base/casts.h" #include "absl/base/casts.h"
@ -30,11 +32,8 @@ XlaOp RotateLeftU32(XlaOp v, int distance) {
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance)); ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
} }
using ThreeFry2x32State = std::array<XlaOp, 2>; } // namespace
// Implements the ThreeFry counter-based PRNG algorithm.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
XlaBuilder* builder = input[0].builder(); XlaBuilder* builder = input[0].builder();
key[0] = BitcastConvertType(key[0], U32); key[0] = BitcastConvertType(key[0], U32);
@ -127,15 +126,28 @@ XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
return Reshape(result, AsInt64Slice(shape.dimensions())); return Reshape(result, AsInt64Slice(shape.dimensions()));
} }
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
auto builder = u64.builder();
auto const32 = ConstantR0WithType(builder, U64, 32);
auto fst = ConvertElementType(u64, U32);
auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
return {fst, snd};
}
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
auto builder = u32s[0].builder();
return ConvertElementType(u32s[0], U64) |
ShiftLeft(ConvertElementType(u32s[1], U64),
ConstantR0WithType(builder, U64, 32));
}
XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) { XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
XlaBuilder* builder = key[0].builder(); XlaBuilder* builder = key[0].builder();
const int64 size = ShapeUtil::ElementsIn(shape); const int64 size = ShapeUtil::ElementsIn(shape);
ThreeFry2x32State inputs = GetInputs(size, builder); ThreeFry2x32State inputs = GetInputs(size, builder);
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
// low 32 bit: outputs[0], high 32 bit: outputs[1] // low 32 bit: outputs[0], high 32 bit: outputs[1]
auto result = ConvertElementType(outputs[0], U64) | auto result = Uint32sToUint64(outputs);
ShiftLeft(ConvertElementType(outputs[1], U64),
ConstantR0WithType(builder, U64, 32));
return Reshape(result, AsInt64Slice(shape.dimensions())); return Reshape(result, AsInt64Slice(shape.dimensions()));
} }
@ -161,10 +173,6 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
PrimitiveType type, PrimitiveType unsigned_type) { PrimitiveType type, PrimitiveType unsigned_type) {
XlaBuilder* builder = bits.builder(); XlaBuilder* builder = bits.builder();
// TODO(b/72573764): Generate real uniform integer distribution.
// The following algorithm is the same one that TF uses right now, but it's
// uniform only when maxval - minval is a divisor of the range that bits is
// generated from.
auto range = BitcastConvertType(maxval, unsigned_type) - auto range = BitcastConvertType(maxval, unsigned_type) -
BitcastConvertType(minval, unsigned_type); BitcastConvertType(minval, unsigned_type);
auto dist = Rem(bits, range); auto dist = Rem(bits, range);
@ -175,8 +183,6 @@ XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
BitcastConvertType(dist - dist_div_2, type); BitcastConvertType(dist - dist_div_2, type);
} }
} // namespace
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape, XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
XlaOp minval, XlaOp maxval) { XlaOp minval, XlaOp maxval) {
XlaBuilder* builder = seeds[0].builder(); XlaBuilder* builder = seeds[0].builder();

View File

@ -23,12 +23,38 @@ limitations under the License.
namespace xla { namespace xla {
// Implements the ThreeFry counter-based PRNG algorithm.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
using ThreeFry2x32State = std::array<XlaOp, 2>;
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key);
// Returns a tensor containing 'shape' random values uniformly distributed in // Returns a tensor containing 'shape' random values uniformly distributed in
// the range [minval, maxval). Requires 2 32-bit integer seeds. // the range [minval, maxval). Requires 2 32-bit integer seeds.
// Currently only 'shape's of type F32, S32 and S64 are implemented. // Currently only 'shape's of type F32, S32 and S64 are implemented.
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape, XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
XlaOp minval, XlaOp maxval); XlaOp minval, XlaOp maxval);
// Converts a 32-bit (signed or unsigned) integer random number `bits` into a
// float32 in the range [minval, maxval).
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval);
// Converts an integer random number 'bits' of type 'type' to a random number
// in the range [minval, maxval), of the same type. 'unsigned_type' is the
// unsigned version of 'type' (could be the same) with the same bit width.
// The algorithm is the same one that TF uses right now, but it's
// uniform only when maxval - minval is a divisor of the range that bits is
// generated from.
// TODO(b/72573764): Generate real uniform integer distribution.
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
PrimitiveType type, PrimitiveType unsigned_type);
// The following 2 functions, for converting between one uint64 and two uint32s,
// use the contract "lower 32 bits for the first uint32, higher 32 bits for the
// second".
ThreeFry2x32State Uint64ToUint32s(XlaOp u64);
XlaOp Uint32sToUint64(ThreeFry2x32State u32s);
} // namespace xla } // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_

View File

@ -0,0 +1,38 @@
op {
graph_op_name: "StatefulUniformFullInt"
visibility: HIDDEN
in_arg {
name: "resource"
description: <<END
The handle of the resource variable that stores the state of the RNG.
END
}
in_arg {
name: "algorithm"
description: <<END
The RNG algorithm.
END
}
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
out_arg {
name: "output"
description: <<END
Random values with specified shape.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
summary: "Outputs random integers from a uniform distribution."
description: <<END
The generated values are uniform integers covering the whole range of `dtype`.
END
}

View File

@ -0,0 +1,56 @@
op {
graph_op_name: "StatefulUniformInt"
visibility: HIDDEN
in_arg {
name: "resource"
description: <<END
The handle of the resource variable that stores the state of the RNG.
END
}
in_arg {
name: "algorithm"
description: <<END
The RNG algorithm.
END
}
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
in_arg {
name: "minval"
description: <<END
Minimum value (inclusive, scalar).
END
}
in_arg {
name: "maxval"
description: <<END
Maximum value (exclusive, scalar).
END
}
out_arg {
name: "output"
description: <<END
Random values with specified shape.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
summary: "Outputs random integers from a uniform distribution."
description: <<END
The generated values are uniform integers in the range `[minval, maxval)`.
The lower bound `minval` is included in the range, while the upper bound
`maxval` is excluded.
The random integers are slightly biased unless `maxval - minval` is an exact
power of two. The bias is small for values of `maxval - minval` significantly
smaller than the range of the output (either `2^32` or `2^64`).
END
}

View File

@ -95,6 +95,31 @@ class Var : public ResourceBase {
TF_DISALLOW_COPY_AND_ASSIGN(Var); TF_DISALLOW_COPY_AND_ASSIGN(Var);
}; };
// Does unlock and unref automatically when going out of scope, and also
// supports early manual release.
class ScopedUnlockUnrefVar {
public:
explicit ScopedUnlockUnrefVar(Var* var) : var_(var) {
if (var_) {
var_->mu()->lock();
}
}
void Release() {
if (var_) {
var_->mu()->unlock();
var_->Unref();
var_ = nullptr;
}
}
~ScopedUnlockUnrefVar() { Release(); }
private:
Var* var_;
ScopedUnlockUnrefVar(const ScopedUnlockUnrefVar&) = delete;
void operator=(const ScopedUnlockUnrefVar&) = delete;
};
} // end namespace tensorflow } // end namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_ #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_

View File

@ -15,8 +15,8 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/stateful_random_ops.h"
#include "tensorflow/core/kernels/random_op.h" #include "tensorflow/core/kernels/random_op.h"
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
#include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/kernels/training_op_helpers.h"
namespace tensorflow { namespace tensorflow {
@ -25,7 +25,7 @@ template <typename Distribution>
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> { struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
void operator()(OpKernelContext* ctx, const CPUDevice& device, void operator()(OpKernelContext* ctx, const CPUDevice& device,
int64 output_size, int64 alg_tag_skip, int64 output_size, int64 alg_tag_skip,
ScopedUnlockUnref* state_var_guard, Tensor* state_tensor, ScopedUnlockUnrefVar* state_var_guard, Tensor* state_tensor,
typename Distribution::ResultElementType* output_data) { typename Distribution::ResultElementType* output_data) {
auto state_tensor_flat = state_tensor->flat<StateElementType>(); auto state_tensor_flat = state_tensor->flat<StateElementType>();
auto state_data = state_tensor_flat.data(); auto state_data = state_tensor_flat.data();
@ -47,11 +47,11 @@ Status UpdateVariableAndFill(
Var* var = nullptr; Var* var = nullptr;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var)); LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
// Use `ScopedUnlockUnref` here instead of `mutex_lock` and `ScopedUnref` // Use `ScopedUnlockUnrefVar` here instead of `mutex_lock` and `ScopedUnref`
// because the former supports early releasing which is needed by // because the former supports early releasing which is needed by
// `UpdateVariableAndFill_Philox<CPU>` to avoid holding the lock while // `UpdateVariableAndFill_Philox<CPU>` to avoid holding the lock while
// filling. // filling.
ScopedUnlockUnref state_var_guard(var); ScopedUnlockUnrefVar state_var_guard(var);
Tensor* var_tensor = var->tensor(); Tensor* var_tensor = var->tensor();
if (var_tensor->dtype() != STATE_ELEMENT_DTYPE) { if (var_tensor->dtype() != STATE_ELEMENT_DTYPE) {
return errors::InvalidArgument("dtype of RNG state variable must be ", return errors::InvalidArgument("dtype of RNG state variable must be ",
@ -80,7 +80,7 @@ Status UpdateVariableAndFill(
"PhiloxRandom::ResultElementType must be uint32"); "PhiloxRandom::ResultElementType must be uint32");
if (var_tensor_flat.size() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) { if (var_tensor_flat.size() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) {
return errors::InvalidArgument( return errors::InvalidArgument(
"For Philox algorithm, the size of state" "For the Philox algorithm, the size of state"
" must be at least ", " must be at least ",
alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ", alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ",
var_tensor_flat.size()); var_tensor_flat.size());
@ -132,6 +132,11 @@ class StatefulRandomOpV2 : public OpKernel {
OP_REQUIRES(ctx, alg_tensor.dims() == 0, OP_REQUIRES(ctx, alg_tensor.dims() == 0,
errors::InvalidArgument("algorithm must be of shape [], not ", errors::InvalidArgument("algorithm must be of shape [], not ",
alg_tensor.shape().DebugString())); alg_tensor.shape().DebugString()));
OP_REQUIRES(
ctx, alg_tensor.dtype() == ALGORITHM_DTYPE,
errors::InvalidArgument("algorithm's dtype must be ",
DataTypeString(ALGORITHM_DTYPE), ", not ",
DataTypeString(alg_tensor.dtype())));
auto alg = alg_tensor.flat<Algorithm>()(0); auto alg = alg_tensor.flat<Algorithm>()(0);
ComputeImpl<Device, Distribution>(ctx, 0, 2, false, alg); ComputeImpl<Device, Distribution>(ctx, 0, 2, false, alg);
} }

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_ #ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_ #define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_
#include "tensorflow/core/framework/resource_var.h" // #include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/philox_random.h"
namespace tensorflow { namespace tensorflow {
@ -28,7 +28,9 @@ using StateElementType = int64;
static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64; static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64;
using Algorithm = StateElementType; using Algorithm = StateElementType;
static constexpr DataType ALGORITHM_DTYPE = STATE_ELEMENT_DTYPE;
static constexpr Algorithm RNG_ALG_PHILOX = 1; static constexpr Algorithm RNG_ALG_PHILOX = 1;
static constexpr Algorithm RNG_ALG_THREEFRY = 2;
using random::PhiloxRandom; using random::PhiloxRandom;
@ -36,109 +38,7 @@ static constexpr int64 PHILOX_MIN_STATE_SIZE =
(PhiloxRandom::ResultType::kElementCount + (PhiloxRandom::ResultType::kElementCount +
PhiloxRandom::Key::kElementCount) / PhiloxRandom::Key::kElementCount) /
2; 2;
static constexpr int64 THREEFRY_MIN_STATE_SIZE = 2;
// The following 5 functions are made templates to avoid duplicate symbols when
// linking.
// The following two functions use the contract "lower 32 bits for the first
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
// unlike a direct memory copy `memcpy(output, &input, 8)`.
template <typename INT64>
PHILOX_DEVICE_FUNC void Int64ToUint32s(INT64 input, uint32* output1,
uint32* output2) {
auto u64 = static_cast<uint64>(input);
*output1 = static_cast<uint32>(u64);
*output2 = static_cast<uint32>(u64 >> 32);
}
template <typename UINT32>
PHILOX_DEVICE_FUNC int64 Uint32sToInt64(UINT32 input1, UINT32 input2) {
auto u64_1 = static_cast<uint64>(input1);
auto u64_2 = static_cast<uint64>(input2);
return static_cast<int64>(u64_1 | (u64_2 << 32));
}
template <typename STATE_ELEMENT_TYPE>
PHILOX_DEVICE_FUNC PhiloxRandom
GetPhiloxRandomFromMem(STATE_ELEMENT_TYPE const* ptr) {
PhiloxRandom::ResultType counter;
PhiloxRandom::Key key;
Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
Int64ToUint32s(ptr[2], &key[0], &key[1]);
return PhiloxRandom(counter, key);
}
template <typename PHILOX_RANDOM>
PHILOX_DEVICE_FUNC void WritePhiloxRandomToMem(PHILOX_RANDOM const& philox,
StateElementType* ptr) {
PhiloxRandom::ResultType const& counter = philox.counter();
PhiloxRandom::Key const& key = philox.key();
ptr[0] = Uint32sToInt64(counter[0], counter[1]);
ptr[1] = Uint32sToInt64(counter[2], counter[3]);
ptr[2] = Uint32sToInt64(key[0], key[1]);
}
template <typename PHILOX_RANDOM>
PHILOX_DEVICE_FUNC void UpdateMemWithPhiloxRandom(PHILOX_RANDOM const& philox,
int64 output_size,
StateElementType* ptr) {
auto new_philox = philox;
// Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
// it just here.
auto delta = output_size * 256;
new_philox.Skip(delta); // do the actual increasing
WritePhiloxRandomToMem(new_philox, ptr);
}
// Does unlock and unref automatically when going out of scope, and also
// supports early manual release.
class ScopedUnlockUnref {
public:
explicit ScopedUnlockUnref(Var* var) : var_(var) {
if (var_) {
var_->mu()->lock();
}
}
void Release() {
if (var_) {
var_->mu()->unlock();
var_->Unref();
var_ = nullptr;
}
}
~ScopedUnlockUnref() { Release(); }
private:
Var* var_;
ScopedUnlockUnref(const ScopedUnlockUnref&) = delete;
void operator=(const ScopedUnlockUnref&) = delete;
};
// A per-device helper function that does the actual work for
// `UpdateVariableAndFill`.
// Reason to use functor: C++ doesn't allow function-template partial
// specialization.
template <typename Device, typename Distribution>
struct UpdateVariableAndFill_Philox;
using CPUDevice = Eigen::ThreadPoolDevice;
#if GOOGLE_CUDA
using GPUDevice = Eigen::GpuDevice;
// Declares the partially GPU-specialized functor struct.
template <typename Distribution>
struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
void operator()(OpKernelContext* ctx, const GPUDevice& device,
int64 output_size, int64 alg_tag_skip,
ScopedUnlockUnref* not_used, Tensor* state_tensor,
typename Distribution::ResultElementType* output_data);
};
#endif // GOOGLE_CUDA
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -0,0 +1,104 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/kernels/stateful_random_ops.h"
namespace tensorflow {
// The following 5 functions are made templates to avoid duplicate symbols when
// linking.
// The following 2 functions use the contract "lower 32 bits for the first
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
// unlike a direct memory copy `memcpy(output, &input, 8)`.
template <typename INT64>
PHILOX_DEVICE_FUNC void Int64ToUint32s(INT64 input, uint32* output1,
uint32* output2) {
auto u64 = static_cast<uint64>(input);
*output1 = static_cast<uint32>(u64);
*output2 = static_cast<uint32>(u64 >> 32);
}
template <typename UINT32>
PHILOX_DEVICE_FUNC int64 Uint32sToInt64(UINT32 input1, UINT32 input2) {
auto u64_1 = static_cast<uint64>(input1);
auto u64_2 = static_cast<uint64>(input2);
return static_cast<int64>(u64_1 | (u64_2 << 32));
}
template <typename STATE_ELEMENT_TYPE>
PHILOX_DEVICE_FUNC PhiloxRandom
GetPhiloxRandomFromMem(STATE_ELEMENT_TYPE const* ptr) {
PhiloxRandom::ResultType counter;
PhiloxRandom::Key key;
Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
Int64ToUint32s(ptr[2], &key[0], &key[1]);
return PhiloxRandom(counter, key);
}
template <typename PHILOX_RANDOM>
PHILOX_DEVICE_FUNC void WritePhiloxRandomToMem(PHILOX_RANDOM const& philox,
StateElementType* ptr) {
PhiloxRandom::ResultType const& counter = philox.counter();
PhiloxRandom::Key const& key = philox.key();
ptr[0] = Uint32sToInt64(counter[0], counter[1]);
ptr[1] = Uint32sToInt64(counter[2], counter[3]);
ptr[2] = Uint32sToInt64(key[0], key[1]);
}
template <typename PHILOX_RANDOM>
PHILOX_DEVICE_FUNC void UpdateMemWithPhiloxRandom(PHILOX_RANDOM const& philox,
int64 output_size,
StateElementType* ptr) {
auto new_philox = philox;
// Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
// it just here.
auto delta = output_size * 256;
new_philox.Skip(delta); // do the actual increasing
WritePhiloxRandomToMem(new_philox, ptr);
}
// A per-device helper function that does the actual work for
// `UpdateVariableAndFill`.
// Reason to use functor: C++ doesn't allow function-template partial
// specialization.
template <typename Device, typename Distribution>
struct UpdateVariableAndFill_Philox;
using CPUDevice = Eigen::ThreadPoolDevice;
#if GOOGLE_CUDA
using GPUDevice = Eigen::GpuDevice;
// Declares the partially GPU-specialized functor struct.
template <typename Distribution>
struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
void operator()(OpKernelContext* ctx, const GPUDevice& device,
int64 output_size, int64 alg_tag_skip,
ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
typename Distribution::ResultElementType* output_data);
};
#endif // GOOGLE_CUDA
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_

View File

@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "tensorflow/core/kernels/random_op_gpu.h" #include "tensorflow/core/kernels/random_op_gpu.h"
#include "tensorflow/core/kernels/stateful_random_ops.h" #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
#include "tensorflow/core/util/cuda_launch_config.h" #include "tensorflow/core/util/cuda_launch_config.h"
namespace tensorflow { namespace tensorflow {
@ -54,7 +54,7 @@ __global__ void FillKernel(
template <typename Distribution> template <typename Distribution>
void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()( void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
OpKernelContext* ctx, const GPUDevice& d, int64 output_size, OpKernelContext* ctx, const GPUDevice& d, int64 output_size,
int64 alg_tag_skip, ScopedUnlockUnref* not_used, Tensor* state_tensor, int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
typename Distribution::ResultElementType* output_data) { typename Distribution::ResultElementType* output_data) {
OP_REQUIRES( OP_REQUIRES(
ctx, alg_tag_skip == 0, ctx, alg_tag_skip == 0,

View File

@ -20,11 +20,9 @@ namespace tensorflow {
Status StatefulRandomShape(shape_inference::InferenceContext* c) { Status StatefulRandomShape(shape_inference::InferenceContext* c) {
using shape_inference::ShapeHandle; using shape_inference::ShapeHandle;
// Check algorithm shape // Check algorithm shape
ShapeHandle unused; ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
// Set output shape // Set output shape
ShapeHandle out; ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out)); TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));
@ -32,16 +30,45 @@ Status StatefulRandomShape(shape_inference::InferenceContext* c) {
return Status::OK(); return Status::OK();
} }
REGISTER_OP("StatefulStandardNormalV2") #define REGISTER_STATEFUL_OP(name, default_dtype) \
REGISTER_OP(name) \
.Input("resource: resource") \
.Input("algorithm: int64") \
.Input("shape: shape_dtype") \
.Output("output: dtype") \
.Attr("dtype : type = " #default_dtype) \
.Attr("shape_dtype : type = DT_INT64") \
.SetShapeFn(StatefulRandomShape);
REGISTER_STATEFUL_OP("StatefulUniformFullInt", DT_UINT64);
REGISTER_STATEFUL_OP("StatefulStandardNormalV2", DT_FLOAT);
REGISTER_OP("StatefulUniformInt")
.Input("resource: resource") .Input("resource: resource")
.Input("algorithm: int64") .Input("algorithm: int64")
.Input("shape: shape_dtype") .Input("shape: shape_dtype")
.Input("minval: dtype")
.Input("maxval: dtype")
.Output("output: dtype") .Output("output: dtype")
.Attr("dtype : type = DT_FLOAT") .Attr("dtype : type = DT_INT64")
.Attr("shape_dtype : type = DT_INT64") .Attr("shape_dtype : type = DT_INT64")
.SetShapeFn(StatefulRandomShape); .SetShapeFn([](shape_inference::InferenceContext* c) {
using shape_inference::ShapeHandle;
// Check inputs
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
// Set output
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));
c->set_output(0, out);
return Status::OK();
});
// Register the old 'StatefulStandardNormal' op // Register the old 'StatefulStandardNormal' op. This op is a short-lived
// version where the 'resource' variable also contains the algorithm tag.
// It is deprecated in favor of 'StatefulStandardNormalV2'.
REGISTER_OP("StatefulStandardNormal") REGISTER_OP("StatefulStandardNormal")
.Input("resource: resource") .Input("resource: resource")
.Input("shape: shape_dtype") .Input("shape: shape_dtype")

View File

@ -53,9 +53,14 @@ SEED_SIZE = 16 # in units of SEED_TYPE
STATE_TYPE = SEED_TYPE STATE_TYPE = SEED_TYPE
ALGORITHM_TYPE = STATE_TYPE ALGORITHM_TYPE = STATE_TYPE
RNG_ALG_PHILOX = 1 RNG_ALG_PHILOX = 1
RNG_ALG_THREEFRY = 2
DEFAULT_ALGORITHM = RNG_ALG_PHILOX DEFAULT_ALGORITHM = RNG_ALG_PHILOX
PHILOX_STATE_SIZE = 3
THREEFRY_STATE_SIZE = 2
def non_deterministic_seed(): def non_deterministic_seed():
"""Makes a non-deterministic seed. """Makes a non-deterministic seed.
@ -75,7 +80,40 @@ def _uint_to_int(n):
return n return n
PHILOX_STATE_SIZE = 3 def _make_1d_state(state_size, seed):
"""Makes a 1-D RNG state.
Args:
state_size: an integer.
seed: an integer or 1-D tensor.
Returns:
a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
"""
int_types = (int,) if sys.version_info >= (3, 0) else (int, long)
if isinstance(seed, int_types):
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
ls = []
for _ in range(state_size):
ls.append(seed & SEED_BIT_MASK)
seed >>= SEED_TYPE_BITS
seed = ls
# to avoid overflow error from np.asarray
seed = list(map(_uint_to_int, seed))
seed = np.asarray(seed, dtype=STATE_TYPE)
if len(seed.shape) != 1:
raise ValueError(
"seed should only have one dimension; got shape: %s" % seed.shape)
seed = seed[0:state_size]
# Padding with zeros on the right if too short
seed_size = seed.shape[0]
if seed_size < state_size:
seed = np.pad(
seed, [(0, state_size - seed_size)],
mode="constant",
constant_values=0)
assert seed.shape == (state_size,), "Wrong seed.shape: %s" % seed.shape
return seed
def _make_philox_state(seed): def _make_philox_state(seed):
@ -87,35 +125,26 @@ def _make_philox_state(seed):
Returns: Returns:
a 1-D tensor. a 1-D tensor.
""" """
int_types = (int,) if sys.version_info >= (3, 0) else (int, long) return _make_1d_state(PHILOX_STATE_SIZE, seed)
if isinstance(seed, int_types):
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
ls = [] def _make_threefry_state(seed):
for _ in range(PHILOX_STATE_SIZE): """Makes a RNG state for ThreeFry algorithm.
ls.append(seed & SEED_BIT_MASK)
seed >>= SEED_TYPE_BITS Args:
seed = ls seed: an integer or 1-D tensor.
# to avoid overflow error from np.asarray
seed = list(map(_uint_to_int, seed)) Returns:
seed = np.asarray(seed, dtype=STATE_TYPE) a 1-D tensor.
if len(seed.shape) != 1: """
raise ValueError( return _make_1d_state(THREEFRY_STATE_SIZE, seed)
"seed should only have one dimension; got shape: %s" % seed.shape)
seed = seed[0:PHILOX_STATE_SIZE]
# Padding with zeros on the right if too short
seed_size = seed.shape[0]
if seed_size < PHILOX_STATE_SIZE:
seed = np.pad(
seed, [(0, PHILOX_STATE_SIZE - seed_size)],
mode="constant",
constant_values=0)
assert seed.shape == (PHILOX_STATE_SIZE,), "Wrong seed.shape: %s" % seed.shape
return seed
def _make_state_from_seed(seed, algorithm): def _make_state_from_seed(seed, algorithm):
if algorithm == RNG_ALG_PHILOX: if algorithm == RNG_ALG_PHILOX:
return _make_philox_state(seed) return _make_philox_state(seed)
elif algorithm == RNG_ALG_THREEFRY:
return _make_threefry_state(seed)
else: else:
raise ValueError("Unsupported algorithm id: %s" % algorithm) raise ValueError("Unsupported algorithm id: %s" % algorithm)
@ -168,26 +197,19 @@ class Generator(tracking.AutoTrackable):
algorithm = DEFAULT_ALGORITHM algorithm = DEFAULT_ALGORITHM
state = create_rng_state(seed, algorithm) state = create_rng_state(seed, algorithm)
self._state_var = variables.Variable(state, dtype=STATE_TYPE) self._state_var = variables.Variable(state, dtype=STATE_TYPE)
self._alg_var = variables.Variable(initial_value=algorithm, self._alg_var = algorithm
dtype=ALGORITHM_TYPE)
else: else:
assert seed is None assert seed is None
self._state_var = variables.Variable(copy_from.state, dtype=STATE_TYPE) self._state_var = variables.Variable(copy_from.state, dtype=STATE_TYPE)
self._alg_var = variables.Variable(initial_value=copy_from.algorithm, self._alg_var = copy_from.algorithm
dtype=ALGORITHM_TYPE)
def reset(self, seed): def reset(self, seed):
"""Resets the generator. """Resets the generator.
This function is not thread-safe: if it is run concurrently with a call to
sampling, the latter might see the new algorithm but the old state or vice
versa.
Args: Args:
seed: the seed to reset the RNG to. seed: the seed to reset the RNG to.
""" """
algorithm = int(self.algorithm) state = create_rng_state(seed, self.algorithm)
state = create_rng_state(seed, algorithm)
self._state_var.assign(state) self._state_var.assign(state)
@property @property
@ -200,21 +222,59 @@ class Generator(tracking.AutoTrackable):
# The following functions return a tensor and as a side effect update # The following functions return a tensor and as a side effect update
# self._state_var. # self._state_var.
def standard_normal(self, shape, dtype=dtypes.float32):
output = gen_stateful_random_ops.stateful_standard_normal_v2(
self.state.handle, self.algorithm, shape, dtype)
return output
def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
name=None): name=None):
with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name: with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name:
shape = _shape_tensor(shape) shape = _shape_tensor(shape)
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
rnd = self.standard_normal(shape, dtype) rnd = gen_stateful_random_ops.stateful_standard_normal_v2(
self.state.handle, self.algorithm, shape, dtype=dtype)
return math_ops.add(rnd * stddev, mean, name=name) return math_ops.add(rnd * stddev, mean, name=name)
# TODO(wangpeng): implement other distributions (`uniform`, def uniform(self, shape, minval=0, maxval=None,
dtype=dtypes.float32, name=None):
dtype = dtypes.as_dtype(dtype)
if maxval is None:
if dtype.is_integer:
raise ValueError("Must specify maxval for integer dtype %r" % dtype)
maxval = 1
with ops.name_scope(name, "stateful_uniform",
[shape, minval, maxval]) as name:
shape = _shape_tensor(shape)
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
if dtype.is_integer:
return gen_stateful_random_ops.stateful_uniform_int(
self.state.handle, self.algorithm, shape=shape,
minval=minval, maxval=maxval, name=name)
else:
# TODO(wangpeng): implement uniform for floats
raise ValueError("uniform for floats not implemented yet")
def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
"""Uniform distribution on an integer type's entire range.
The other method `uniform` only covers the range [minval, maxval), which
cannot be `dtype`'s full range because `maxval` is of type `dtype`.
Args:
shape: the shape of the output.
dtype: (optional) the integer type, default to uint64.
name: (optional) the name of the node.
Returns:
A tensor of random numbers of the required shape.
"""
dtype = dtypes.as_dtype(dtype)
with ops.name_scope(name, "stateful_uniform_full_int",
[shape]) as name:
shape = _shape_tensor(shape)
return gen_stateful_random_ops.stateful_uniform_full_int(
self.state.handle, self.algorithm, shape=shape,
dtype=dtype, name=name)
# TODO(wangpeng): implement other distributions (
# `truncated_normal`, etc.) # `truncated_normal`, etc.)
# TODO(wangpeng): implement `make_seeds` # TODO(wangpeng): implement `make_seeds`
# TODO(wangpeng): implement `make_generators` # TODO(wangpeng): implement `make_generators`

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import gen_random_ops
@ -138,7 +139,7 @@ class StatefulRandomOpsTest(test.TestCase):
def new(): def new():
with ops.device("/device:CPU:0"): with ops.device("/device:CPU:0"):
return random.get_global_generator().standard_normal(shape, dtype=dtype) return random.get_global_generator().normal(shape, dtype=dtype)
for _ in range(100): for _ in range(100):
self.assertAllEqual(old(), new()) self.assertAllEqual(old(), new())
@ -164,7 +165,7 @@ class StatefulRandomOpsTest(test.TestCase):
def new(): def new():
with ops.device(test_util.gpu_device_name()): with ops.device(test_util.gpu_device_name()):
return random.get_global_generator().standard_normal(shape, dtype=dtype) return random.get_global_generator().normal(shape, dtype=dtype)
for _ in range(100): for _ in range(100):
self.assertAllEqual(old(), new()) self.assertAllEqual(old(), new())
@ -203,7 +204,7 @@ class StatefulRandomOpsTest(test.TestCase):
random.reset_global_generator(50) random.reset_global_generator(50)
with self.assertRaisesWithPredicateMatch( with self.assertRaisesWithPredicateMatch(
AssertionError, "variable.*deleted"): errors_impl.NotFoundError, "Resource .+ does not exist"):
a = f() a = f()
random.reset_global_generator(50) random.reset_global_generator(50)
b = f() b = f()