A continuing partial implementation of RFC "Random numbers in TensorFlow 2.0" (https://github.com/tensorflow/community/pull/38):
In this change: - XLA kernel (TF-XLA bridge) for op `StatefulStandardNormalV2` with ThreeFry algorithm. To be done: - ops for other distributions; - other RNG algorithms; - batch seeds; - initializers ('RandomUniform', etc.); - changing `non_deterministic_seed`'implementation to be an op, and other changes to make most code in stateful_random_ops.py tf.function-compatible. PiperOrigin-RevId: 236699104
This commit is contained in:
parent
55240a22f1
commit
12606ff846
@ -835,6 +835,20 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "stateful_random_ops_test",
|
||||
size = "small",
|
||||
srcs = ["stateful_random_ops_test.py"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:standard_ops",
|
||||
"//tensorflow/python:stateful_random_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "stateless_random_ops_test",
|
||||
size = "small",
|
||||
|
282
tensorflow/compiler/tests/stateful_random_ops_test.py
Normal file
282
tensorflow/compiler/tests/stateful_random_ops_test.py
Normal file
@ -0,0 +1,282 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for stateful random-number generation ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_stateful_random_ops
|
||||
from tensorflow.python.ops import stateful_random_ops as \
|
||||
random
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def xla_device_name():
|
||||
devices = device_lib.list_local_devices()
|
||||
def find_type(device_type):
|
||||
for d in devices:
|
||||
if d.device_type == device_type:
|
||||
return d.name
|
||||
return None
|
||||
name = find_type("TPU") or find_type("XLA_GPU") or find_type("XLA_CPU")
|
||||
if name is None:
|
||||
raise ValueError(
|
||||
"Can't find any XLA device. Available devices:\n%s" % devices)
|
||||
return str(name)
|
||||
|
||||
|
||||
class StatefulRandomOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for stateful random-number generator operators."""
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testSimple(self):
|
||||
"""A simple test.
|
||||
"""
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
|
||||
gen.normal(shape=(3,))
|
||||
gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32)
|
||||
gen.uniform_full_int(shape=(3,))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testDefun(self):
|
||||
"""Test for defun.
|
||||
"""
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
|
||||
@def_function.function
|
||||
def f():
|
||||
x = gen.normal(shape=(3,))
|
||||
y = gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32)
|
||||
z = gen.uniform_full_int(shape=(3,))
|
||||
return (x, y, z)
|
||||
f()
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testThreefry2x32(self):
|
||||
"""Tests ThreeFry2x32 conforms to known results.
|
||||
"""
|
||||
# Based on
|
||||
# https://github.com/google/jax/blob/8565a3486adf16beb388b2364c9cd930d7a0d92d/tests/random_test.py#L65-L85
|
||||
# which is in turn based on
|
||||
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
|
||||
|
||||
def uint32s_to_uint64(a, b):
|
||||
return b << 32 | a
|
||||
|
||||
def verify(counter1, counter2, key1, key2, expect1, expect2):
|
||||
counter = uint32s_to_uint64(counter1, counter2)
|
||||
key = uint32s_to_uint64(key1, key2)
|
||||
random.get_global_generator().reset([counter, key])
|
||||
got = random.get_global_generator().uniform_full_int(
|
||||
shape=(2,), dtype=dtypes.uint32)
|
||||
expect = [expect1, expect2]
|
||||
self.assertAllEqual(expect, got)
|
||||
random.get_global_generator().reset([counter, key])
|
||||
got = random.get_global_generator().uniform_full_int(
|
||||
shape=(), dtype=dtypes.uint64)
|
||||
self.assertAllEqual(uint32s_to_uint64(*expect), got)
|
||||
|
||||
with ops.device(xla_device_name()):
|
||||
random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
|
||||
verify(0x00000000, 0x00000000, 0x00000000, 0x00000000,
|
||||
0x6b200159, 0x99ba4efe)
|
||||
verify(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
|
||||
0x1cb996fc, 0xbb002be7)
|
||||
verify(0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344,
|
||||
0xc4923a9c, 0x483df7a0)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testNewState(self):
|
||||
"""Tests that the new state is correct.
|
||||
"""
|
||||
with ops.device(xla_device_name()):
|
||||
counter = 57
|
||||
key = 0x1234
|
||||
size = 46
|
||||
seed = [counter, key]
|
||||
gen = random.Generator(
|
||||
seed=seed, algorithm=random.RNG_ALG_THREEFRY)
|
||||
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
|
||||
self.assertAllEqual([counter+(size+1)//2, key], gen.state.read_value())
|
||||
gen.reset(seed=seed)
|
||||
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
|
||||
self.assertAllEqual([counter+size, key], gen.state.read_value())
|
||||
|
||||
def _testRngIsNotConstant(self, rng, dtype):
|
||||
# Tests that 'rng' does not always return the same value.
|
||||
# The random-number generator, if working correctly, should produce the
|
||||
# same output multiple times with low probability.
|
||||
x = rng(dtype).numpy()
|
||||
y = rng(dtype).numpy()
|
||||
self.assertFalse(np.array_equal(x, y))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testUniformIsNotConstant(self):
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
def rng(dtype):
|
||||
maxval = dtype.max
|
||||
# Workaround for b/125364959
|
||||
if dtype == dtypes.uint64:
|
||||
maxval = 10000000
|
||||
return gen.uniform(shape=[2], dtype=dtype, maxval=maxval)
|
||||
|
||||
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
|
||||
self._testRngIsNotConstant(rng, dtype)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testNormalIsNotConstant(self):
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
def rng(dtype):
|
||||
return gen.normal(shape=[2], dtype=dtype)
|
||||
|
||||
for dtype in {dtypes.float32}:
|
||||
self._testRngIsNotConstant(rng, dtype)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testUniformIntIsInRange(self):
|
||||
minval = 2
|
||||
maxval = 33
|
||||
size = 1000
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
|
||||
x = gen.uniform(
|
||||
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
|
||||
self.assertTrue(np.all(x >= minval))
|
||||
self.assertTrue(np.all(x < maxval))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testNormalIsFinite(self):
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
for dtype in {dtypes.float32}:
|
||||
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
||||
self.assertTrue(np.all(np.isfinite(x)))
|
||||
|
||||
def _chi_squared(self, x, bins):
|
||||
"""Pearson's Chi-squared test."""
|
||||
x = np.ravel(x)
|
||||
n = len(x)
|
||||
histogram, _ = np.histogram(x, bins=bins, range=(0, 1))
|
||||
expected = n / float(bins)
|
||||
return np.sum(np.square(histogram - expected) / expected)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testDistributionOfUniform(self):
|
||||
"""Use Pearson's Chi-squared test to test for uniformity."""
|
||||
with ops.device(xla_device_name()):
|
||||
n = 1000
|
||||
seed = 12
|
||||
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
|
||||
gen = random.Generator(seed=seed, algorithm=random.RNG_ALG_THREEFRY)
|
||||
maxval = 1
|
||||
if dtype.is_integer:
|
||||
maxval = 100
|
||||
x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy()
|
||||
if maxval > 1:
|
||||
# Normalize y to range [0, 1).
|
||||
x = x.astype(float) / maxval
|
||||
# Tests that the values are distributed amongst 10 bins with equal
|
||||
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
|
||||
# p=0.05. This test is probabilistic and would be flaky if the random
|
||||
# seed were not fixed.
|
||||
val = self._chi_squared(x, 10)
|
||||
self.assertLess(val, 16.92)
|
||||
|
||||
def _normal_cdf(self, x):
|
||||
"""Cumulative distribution function for a standard normal distribution."""
|
||||
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
|
||||
|
||||
def _anderson_darling(self, x):
|
||||
"""Anderson-Darling test for a standard normal distribution."""
|
||||
x = np.sort(np.ravel(x))
|
||||
n = len(x)
|
||||
i = np.linspace(1, n, n)
|
||||
z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) +
|
||||
(2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x)))
|
||||
return -n - z / n
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testDistributionOfNormal(self):
|
||||
"""Use Anderson-Darling test to test distribution appears normal."""
|
||||
with ops.device(xla_device_name()):
|
||||
n = 1000
|
||||
for dtype in {dtypes.float32}:
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
x = gen.normal(shape=[n], dtype=dtype).numpy()
|
||||
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
||||
# test where the mean and variance are known. This test is probabilistic
|
||||
# so to avoid flakiness the seed is fixed.
|
||||
self.assertLess(self._anderson_darling(x.astype(float)), 2.492)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testErrors(self):
|
||||
"""Tests that proper errors are raised.
|
||||
"""
|
||||
shape = [2, 3]
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError,
|
||||
r"algorithm must be of shape \[\], not"):
|
||||
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
gen.state.handle, [0, 0], shape)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
TypeError, "Requested dtype: int64"):
|
||||
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
gen.state.handle, 1.1, shape)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError,
|
||||
"Unsupported algorithm id"):
|
||||
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
gen.state.handle, 123, shape)
|
||||
var = variables.Variable([0, 0], dtype=dtypes.uint32)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError,
|
||||
"Type mismatch for read of variable .* Expected int64; got"):
|
||||
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
var.handle, random.RNG_ALG_THREEFRY, shape)
|
||||
var = variables.Variable([[0]], dtype=dtypes.int64)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError,
|
||||
"RNG state must have one and only one dimension, not"):
|
||||
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
var.handle, random.RNG_ALG_THREEFRY, shape)
|
||||
var = variables.Variable([0], dtype=dtypes.int64)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError,
|
||||
"For the ThreeFry algorithm, the size of state must be at least"):
|
||||
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
var.handle, random.RNG_ALG_THREEFRY, shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -64,6 +64,7 @@ tf_kernel_library(
|
||||
"qr_op.cc",
|
||||
"quantize_and_dequantize_op.cc",
|
||||
"random_ops.cc",
|
||||
"random_ops_util.h",
|
||||
"reduce_window_op.cc",
|
||||
"reduction_ops.cc",
|
||||
"reduction_ops.h",
|
||||
@ -89,6 +90,7 @@ tf_kernel_library(
|
||||
"sparse_to_dense_op.cc",
|
||||
"split_op.cc",
|
||||
"stack_ops.cc",
|
||||
"stateful_random_ops.cc",
|
||||
"stateless_random_ops.cc",
|
||||
"strided_slice_op.cc",
|
||||
"tensor_array_ops.cc",
|
||||
@ -167,6 +169,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:sparse_ops_op_lib",
|
||||
"//tensorflow/core:spectral_ops_op_lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stateful_random_ops_op_lib",
|
||||
"//tensorflow/core:stateless_random_ops_op_lib",
|
||||
"//tensorflow/core:training_ops_op_lib",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
@ -179,6 +182,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core/kernels:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:sparse_to_dense_op",
|
||||
"//tensorflow/core/kernels:stack_ops",
|
||||
"//tensorflow/core/kernels:stateful_random_ops",
|
||||
"//tensorflow/core/kernels:training_ops",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
34
tensorflow/compiler/tf2xla/kernels/random_ops_util.h
Normal file
34
tensorflow/compiler/tf2xla/kernels/random_ops_util.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.
|
||||
// It masks the last 16 bit. With normal rounding, values near "maxval" would be
|
||||
// converted to "maxval" which is out of range ["minval", "maxval"). In
|
||||
// addition, the distribution near the limit is not uniform.
|
||||
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_
|
362
tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc
Normal file
362
tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc
Normal file
@ -0,0 +1,362 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
std::pair<xla::ThreeFry2x32State, xla::XlaOp> GetInputsFromCounter(
|
||||
xla::XlaOp counter, const int64 size) {
|
||||
auto builder = counter.builder();
|
||||
auto input_u64 = Iota(builder, xla::U64, size);
|
||||
input_u64 = input_u64 + counter;
|
||||
counter = counter + xla::ConstantR0<uint64>(builder, size);
|
||||
return std::make_pair(xla::Uint64ToUint32s(input_u64), counter);
|
||||
}
|
||||
|
||||
// `StatelessRngUniformU32` uses ThreeFry2x32’s counter space too
|
||||
// wastefully, only able to generate 2^32*2 int32 numbers for each key, while
|
||||
// the real capacity is 2^64*2. Counter-space efficiency is important for
|
||||
// stateful ops, hence the following 2 new functions.
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU32(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
auto builder = key.builder();
|
||||
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = xla::CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
auto inputs_counter = GetInputsFromCounter(counter, half_size);
|
||||
auto inputs = inputs_counter.first;
|
||||
counter = inputs_counter.second;
|
||||
auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key));
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
auto result = ConcatInDim(builder, outputs, 0);
|
||||
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
||||
counter);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU64(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
||||
auto inputs_counter = GetInputsFromCounter(counter, size);
|
||||
auto inputs = inputs_counter.first;
|
||||
counter = inputs_counter.second;
|
||||
auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
auto result = Uint32sToUint64(outputs);
|
||||
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
||||
counter);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
|
||||
xla::XlaOp counter,
|
||||
const xla::Shape& shape,
|
||||
xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
auto builder = key.builder();
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case xla::F32: {
|
||||
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval),
|
||||
counter);
|
||||
}
|
||||
case xla::U32: // fall through
|
||||
case xla::S32: {
|
||||
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(
|
||||
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32),
|
||||
counter);
|
||||
}
|
||||
case xla::U64: // fall through
|
||||
case xla::S64: {
|
||||
auto bits_counter = StatefulRngUniformU64(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(
|
||||
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64),
|
||||
counter);
|
||||
}
|
||||
default:
|
||||
return std::make_pair(builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by "
|
||||
"StatefulRngUniform.")),
|
||||
counter);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename A, typename B, typename A2>
|
||||
std::pair<A2, B> map_first(std::function<A2(A)> f, std::pair<A, B> p) {
|
||||
return std::make_pair(f(p.first), p.second);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformFullInt(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case xla::U32:
|
||||
return StatefulRngUniformU32(key, counter, shape);
|
||||
case xla::S32: {
|
||||
// Needs explicit function type because of type-inference failure.
|
||||
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
||||
return BitcastConvertType(x, xla::S32);
|
||||
};
|
||||
return map_first(f, StatefulRngUniformU32(key, counter, shape));
|
||||
}
|
||||
case xla::U64:
|
||||
return StatefulRngUniformU64(key, counter, shape);
|
||||
case xla::S64: {
|
||||
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
||||
return BitcastConvertType(x, xla::S64);
|
||||
};
|
||||
return map_first(f, StatefulRngUniformU64(key, counter, shape));
|
||||
}
|
||||
default:
|
||||
auto builder = key.builder();
|
||||
return std::make_pair(
|
||||
builder->ReportError(xla::Unimplemented(
|
||||
"Types other than U32, S32, U64 and S64 are not implemented by "
|
||||
"StatefulRngUniformFullInt; got: %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
counter);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ListB, typename ListA, typename F>
|
||||
ListB Map(F f, ListA const& list_a) {
|
||||
ListB list_b;
|
||||
for (auto a : list_a) {
|
||||
list_b.push_back(f(a));
|
||||
}
|
||||
return list_b;
|
||||
}
|
||||
|
||||
xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
|
||||
absl::Span<const xla::XlaOp> scalars) {
|
||||
return ConcatInDim(
|
||||
builder,
|
||||
Map<std::vector<xla::XlaOp>>(
|
||||
[](xla::XlaOp x) { return xla::Reshape(x, {1}); }, scalars),
|
||||
0);
|
||||
}
|
||||
|
||||
using sampler_return_type = xla::StatusOr<std::pair<xla::XlaOp, xla::XlaOp>>;
|
||||
|
||||
// A helper function containing the common part of several kernels below.
|
||||
// Precondition: 'algorithm' and 'shape' are compile-time constants.
|
||||
Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
||||
int alg_input_idx, int shape_input_idx,
|
||||
std::function<sampler_return_type(xla::XlaOp, xla::XlaOp,
|
||||
TensorShape)> const&
|
||||
sample_with_threefry) {
|
||||
auto alg_shape = ctx->InputShape(alg_input_idx);
|
||||
if (alg_shape.dims() != 0) {
|
||||
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||
alg_shape.DebugString());
|
||||
}
|
||||
xla::Literal alg_literal;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||
auto alg = alg_literal.Get<Algorithm>({});
|
||||
|
||||
if (alg == RNG_ALG_THREEFRY) {
|
||||
xla::XlaOp var;
|
||||
TensorShape var_shape;
|
||||
TF_RETURN_IF_ERROR(ctx->ReadVariableInput(
|
||||
state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var));
|
||||
if (var_shape.dims() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"RNG state must have one and only one dimension, not ",
|
||||
var_shape.dims());
|
||||
}
|
||||
auto state_size = var_shape.dim_size(0);
|
||||
if (state_size < THREEFRY_MIN_STATE_SIZE) {
|
||||
return errors::InvalidArgument(
|
||||
"For the ThreeFry algorithm, the size of state"
|
||||
" must be at least ",
|
||||
THREEFRY_MIN_STATE_SIZE, "; got ", state_size);
|
||||
}
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
|
||||
|
||||
static constexpr int COUNTER_SIZE = 1;
|
||||
auto counter = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64);
|
||||
auto key = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}),
|
||||
{}),
|
||||
xla::U64);
|
||||
|
||||
auto status_or_value = sample_with_threefry(counter, key, shape);
|
||||
if (!status_or_value.ok()) {
|
||||
return status_or_value.status();
|
||||
}
|
||||
auto output_counter = status_or_value.ConsumeValueOrDie();
|
||||
auto output = output_counter.first;
|
||||
counter = output_counter.second;
|
||||
ctx->SetOutput(0, output);
|
||||
auto builder = ctx->builder();
|
||||
var = ConcatScalars(builder, {counter, key});
|
||||
xla::PrimitiveType state_element_type;
|
||||
TF_RETURN_IF_ERROR(
|
||||
DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
|
||||
var = BitcastConvertType(var, state_element_type);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::InvalidArgument("Unsupported algorithm id: ", alg);
|
||||
}
|
||||
}
|
||||
|
||||
class StatefulStandardNormalOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatefulStandardNormalOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry =
|
||||
// Needs explicit lambda return type because it fails to be inferred.
|
||||
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape,
|
||||
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
// Convert uniform distribution to normal distribution by computing
|
||||
// sqrt(2) * erfinv(x)
|
||||
auto normal =
|
||||
xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
|
||||
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
||||
return {{normal, counter}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatefulStandardNormalOp);
|
||||
};
|
||||
|
||||
// TODO(wangpeng): Support plain float16 and float64 to get rid of the
|
||||
// `TypeConstraint`.
|
||||
REGISTER_XLA_OP(Name("StatefulStandardNormalV2")
|
||||
.CompileTimeConstantInput("algorithm")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}),
|
||||
StatefulStandardNormalOp);
|
||||
|
||||
class StatefulUniformIntOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp minval = ctx->Input(3);
|
||||
xla::XlaOp maxval = ctx->Input(4);
|
||||
auto sample_with_threefry = [minval, maxval, this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
return StatefulRngUniform(key, counter, xla_shape, minval, maxval);
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformIntOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatefulUniformInt")
|
||||
.CompileTimeConstantInput("algorithm")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
|
||||
StatefulUniformIntOp);
|
||||
|
||||
class StatefulUniformFullIntOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto sample_with_threefry = [this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
return StatefulRngUniformFullInt(key, counter, xla_shape);
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformFullIntOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatefulUniformFullInt")
|
||||
.CompileTimeConstantInput("algorithm")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
|
||||
StatefulUniformFullIntOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
@ -31,12 +32,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
||||
// Mask the last 16 bit. With normal rounding, values near "maxval" would be
|
||||
// converted to "maxval" which is out of range ["minval", "maxval"). In
|
||||
// addition, the distribution near the limit is not uniform.
|
||||
if (dtype == DT_BFLOAT16) {
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
auto output = xla::BitcastConvertType(input, xla::U32) &
|
||||
@ -48,6 +45,8 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class StatelessRandomUniformOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
|
||||
|
@ -82,6 +82,9 @@ CreateResourceOpInfoMap() {
|
||||
add("ResourceScatterSub" , kReadWrite, kVariable);
|
||||
add("ResourceScatterUpdate" , kReadWrite, kVariable);
|
||||
add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
|
||||
add("StatefulStandardNormalV2" , kReadWrite, kVariable);
|
||||
add("StatefulUniformFullInt" , kReadWrite, kVariable);
|
||||
add("StatefulUniformInt" , kReadWrite, kVariable);
|
||||
add("VarIsInitializedOp" , kRead, kVariable);
|
||||
add("VariableShape" , kRead, kVariable);
|
||||
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
@ -30,11 +32,8 @@ XlaOp RotateLeftU32(XlaOp v, int distance) {
|
||||
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
|
||||
}
|
||||
|
||||
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
||||
} // namespace
|
||||
|
||||
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
||||
XlaBuilder* builder = input[0].builder();
|
||||
key[0] = BitcastConvertType(key[0], U32);
|
||||
@ -127,15 +126,28 @@ XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
|
||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
||||
}
|
||||
|
||||
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
|
||||
auto builder = u64.builder();
|
||||
auto const32 = ConstantR0WithType(builder, U64, 32);
|
||||
auto fst = ConvertElementType(u64, U32);
|
||||
auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
|
||||
return {fst, snd};
|
||||
}
|
||||
|
||||
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
|
||||
auto builder = u32s[0].builder();
|
||||
return ConvertElementType(u32s[0], U64) |
|
||||
ShiftLeft(ConvertElementType(u32s[1], U64),
|
||||
ConstantR0WithType(builder, U64, 32));
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
|
||||
XlaBuilder* builder = key[0].builder();
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
ThreeFry2x32State inputs = GetInputs(size, builder);
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
||||
// low 32 bit: outputs[0], high 32 bit: outputs[1]
|
||||
auto result = ConvertElementType(outputs[0], U64) |
|
||||
ShiftLeft(ConvertElementType(outputs[1], U64),
|
||||
ConstantR0WithType(builder, U64, 32));
|
||||
auto result = Uint32sToUint64(outputs);
|
||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
||||
}
|
||||
|
||||
@ -161,10 +173,6 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type, PrimitiveType unsigned_type) {
|
||||
XlaBuilder* builder = bits.builder();
|
||||
// TODO(b/72573764): Generate real uniform integer distribution.
|
||||
// The following algorithm is the same one that TF uses right now, but it's
|
||||
// uniform only when maxval - minval is a divisor of the range that bits is
|
||||
// generated from.
|
||||
auto range = BitcastConvertType(maxval, unsigned_type) -
|
||||
BitcastConvertType(minval, unsigned_type);
|
||||
auto dist = Rem(bits, range);
|
||||
@ -175,8 +183,6 @@ XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
BitcastConvertType(dist - dist_div_2, type);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
||||
XlaOp minval, XlaOp maxval) {
|
||||
XlaBuilder* builder = seeds[0].builder();
|
||||
|
@ -23,12 +23,38 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key);
|
||||
|
||||
// Returns a tensor containing 'shape' random values uniformly distributed in
|
||||
// the range [minval, maxval). Requires 2 32-bit integer seeds.
|
||||
// Currently only 'shape's of type F32, S32 and S64 are implemented.
|
||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
||||
XlaOp minval, XlaOp maxval);
|
||||
|
||||
// Converts a 32-bit (signed or unsigned) integer random number `bits` into a
|
||||
// float32 in the range [minval, maxval).
|
||||
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval);
|
||||
|
||||
// Converts an integer random number 'bits' of type 'type' to a random number
|
||||
// in the range [minval, maxval), of the same type. 'unsigned_type' is the
|
||||
// unsigned version of 'type' (could be the same) with the same bit width.
|
||||
// The algorithm is the same one that TF uses right now, but it's
|
||||
// uniform only when maxval - minval is a divisor of the range that bits is
|
||||
// generated from.
|
||||
// TODO(b/72573764): Generate real uniform integer distribution.
|
||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type, PrimitiveType unsigned_type);
|
||||
|
||||
// The following 2 functions, for converting between one uint64 and two uint32s,
|
||||
// use the contract "lower 32 bits for the first uint32, higher 32 bits for the
|
||||
// second".
|
||||
ThreeFry2x32State Uint64ToUint32s(XlaOp u64);
|
||||
XlaOp Uint32sToUint64(ThreeFry2x32State u32s);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_
|
||||
|
@ -0,0 +1,38 @@
|
||||
op {
|
||||
graph_op_name: "StatefulUniformFullInt"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "resource"
|
||||
description: <<END
|
||||
The handle of the resource variable that stores the state of the RNG.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "algorithm"
|
||||
description: <<END
|
||||
The RNG algorithm.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs random integers from a uniform distribution."
|
||||
description: <<END
|
||||
The generated values are uniform integers covering the whole range of `dtype`.
|
||||
END
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
op {
|
||||
graph_op_name: "StatefulUniformInt"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "resource"
|
||||
description: <<END
|
||||
The handle of the resource variable that stores the state of the RNG.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "algorithm"
|
||||
description: <<END
|
||||
The RNG algorithm.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "minval"
|
||||
description: <<END
|
||||
Minimum value (inclusive, scalar).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "maxval"
|
||||
description: <<END
|
||||
Maximum value (exclusive, scalar).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs random integers from a uniform distribution."
|
||||
description: <<END
|
||||
The generated values are uniform integers in the range `[minval, maxval)`.
|
||||
The lower bound `minval` is included in the range, while the upper bound
|
||||
`maxval` is excluded.
|
||||
|
||||
The random integers are slightly biased unless `maxval - minval` is an exact
|
||||
power of two. The bias is small for values of `maxval - minval` significantly
|
||||
smaller than the range of the output (either `2^32` or `2^64`).
|
||||
END
|
||||
}
|
@ -95,6 +95,31 @@ class Var : public ResourceBase {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Var);
|
||||
};
|
||||
|
||||
// Does unlock and unref automatically when going out of scope, and also
|
||||
// supports early manual release.
|
||||
class ScopedUnlockUnrefVar {
|
||||
public:
|
||||
explicit ScopedUnlockUnrefVar(Var* var) : var_(var) {
|
||||
if (var_) {
|
||||
var_->mu()->lock();
|
||||
}
|
||||
}
|
||||
void Release() {
|
||||
if (var_) {
|
||||
var_->mu()->unlock();
|
||||
var_->Unref();
|
||||
var_ = nullptr;
|
||||
}
|
||||
}
|
||||
~ScopedUnlockUnrefVar() { Release(); }
|
||||
|
||||
private:
|
||||
Var* var_;
|
||||
|
||||
ScopedUnlockUnrefVar(const ScopedUnlockUnrefVar&) = delete;
|
||||
void operator=(const ScopedUnlockUnrefVar&) = delete;
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -25,7 +25,7 @@ template <typename Distribution>
|
||||
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& device,
|
||||
int64 output_size, int64 alg_tag_skip,
|
||||
ScopedUnlockUnref* state_var_guard, Tensor* state_tensor,
|
||||
ScopedUnlockUnrefVar* state_var_guard, Tensor* state_tensor,
|
||||
typename Distribution::ResultElementType* output_data) {
|
||||
auto state_tensor_flat = state_tensor->flat<StateElementType>();
|
||||
auto state_data = state_tensor_flat.data();
|
||||
@ -47,11 +47,11 @@ Status UpdateVariableAndFill(
|
||||
Var* var = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
|
||||
// Use `ScopedUnlockUnref` here instead of `mutex_lock` and `ScopedUnref`
|
||||
// Use `ScopedUnlockUnrefVar` here instead of `mutex_lock` and `ScopedUnref`
|
||||
// because the former supports early releasing which is needed by
|
||||
// `UpdateVariableAndFill_Philox<CPU>` to avoid holding the lock while
|
||||
// filling.
|
||||
ScopedUnlockUnref state_var_guard(var);
|
||||
ScopedUnlockUnrefVar state_var_guard(var);
|
||||
Tensor* var_tensor = var->tensor();
|
||||
if (var_tensor->dtype() != STATE_ELEMENT_DTYPE) {
|
||||
return errors::InvalidArgument("dtype of RNG state variable must be ",
|
||||
@ -80,7 +80,7 @@ Status UpdateVariableAndFill(
|
||||
"PhiloxRandom::ResultElementType must be uint32");
|
||||
if (var_tensor_flat.size() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) {
|
||||
return errors::InvalidArgument(
|
||||
"For Philox algorithm, the size of state"
|
||||
"For the Philox algorithm, the size of state"
|
||||
" must be at least ",
|
||||
alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ",
|
||||
var_tensor_flat.size());
|
||||
@ -132,6 +132,11 @@ class StatefulRandomOpV2 : public OpKernel {
|
||||
OP_REQUIRES(ctx, alg_tensor.dims() == 0,
|
||||
errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||
alg_tensor.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, alg_tensor.dtype() == ALGORITHM_DTYPE,
|
||||
errors::InvalidArgument("algorithm's dtype must be ",
|
||||
DataTypeString(ALGORITHM_DTYPE), ", not ",
|
||||
DataTypeString(alg_tensor.dtype())));
|
||||
auto alg = alg_tensor.flat<Algorithm>()(0);
|
||||
ComputeImpl<Device, Distribution>(ctx, 0, 2, false, alg);
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_
|
||||
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
// #include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -28,7 +28,9 @@ using StateElementType = int64;
|
||||
static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64;
|
||||
|
||||
using Algorithm = StateElementType;
|
||||
static constexpr DataType ALGORITHM_DTYPE = STATE_ELEMENT_DTYPE;
|
||||
static constexpr Algorithm RNG_ALG_PHILOX = 1;
|
||||
static constexpr Algorithm RNG_ALG_THREEFRY = 2;
|
||||
|
||||
using random::PhiloxRandom;
|
||||
|
||||
@ -36,109 +38,7 @@ static constexpr int64 PHILOX_MIN_STATE_SIZE =
|
||||
(PhiloxRandom::ResultType::kElementCount +
|
||||
PhiloxRandom::Key::kElementCount) /
|
||||
2;
|
||||
|
||||
// The following 5 functions are made templates to avoid duplicate symbols when
|
||||
// linking.
|
||||
|
||||
// The following two functions use the contract "lower 32 bits for the first
|
||||
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
|
||||
// unlike a direct memory copy `memcpy(output, &input, 8)`.
|
||||
template <typename INT64>
|
||||
PHILOX_DEVICE_FUNC void Int64ToUint32s(INT64 input, uint32* output1,
|
||||
uint32* output2) {
|
||||
auto u64 = static_cast<uint64>(input);
|
||||
*output1 = static_cast<uint32>(u64);
|
||||
*output2 = static_cast<uint32>(u64 >> 32);
|
||||
}
|
||||
|
||||
template <typename UINT32>
|
||||
PHILOX_DEVICE_FUNC int64 Uint32sToInt64(UINT32 input1, UINT32 input2) {
|
||||
auto u64_1 = static_cast<uint64>(input1);
|
||||
auto u64_2 = static_cast<uint64>(input2);
|
||||
return static_cast<int64>(u64_1 | (u64_2 << 32));
|
||||
}
|
||||
|
||||
template <typename STATE_ELEMENT_TYPE>
|
||||
PHILOX_DEVICE_FUNC PhiloxRandom
|
||||
GetPhiloxRandomFromMem(STATE_ELEMENT_TYPE const* ptr) {
|
||||
PhiloxRandom::ResultType counter;
|
||||
PhiloxRandom::Key key;
|
||||
Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
|
||||
Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
|
||||
Int64ToUint32s(ptr[2], &key[0], &key[1]);
|
||||
return PhiloxRandom(counter, key);
|
||||
}
|
||||
|
||||
template <typename PHILOX_RANDOM>
|
||||
PHILOX_DEVICE_FUNC void WritePhiloxRandomToMem(PHILOX_RANDOM const& philox,
|
||||
StateElementType* ptr) {
|
||||
PhiloxRandom::ResultType const& counter = philox.counter();
|
||||
PhiloxRandom::Key const& key = philox.key();
|
||||
ptr[0] = Uint32sToInt64(counter[0], counter[1]);
|
||||
ptr[1] = Uint32sToInt64(counter[2], counter[3]);
|
||||
ptr[2] = Uint32sToInt64(key[0], key[1]);
|
||||
}
|
||||
|
||||
template <typename PHILOX_RANDOM>
|
||||
PHILOX_DEVICE_FUNC void UpdateMemWithPhiloxRandom(PHILOX_RANDOM const& philox,
|
||||
int64 output_size,
|
||||
StateElementType* ptr) {
|
||||
auto new_philox = philox;
|
||||
// Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
|
||||
// it just here.
|
||||
auto delta = output_size * 256;
|
||||
new_philox.Skip(delta); // do the actual increasing
|
||||
WritePhiloxRandomToMem(new_philox, ptr);
|
||||
}
|
||||
|
||||
// Does unlock and unref automatically when going out of scope, and also
|
||||
// supports early manual release.
|
||||
class ScopedUnlockUnref {
|
||||
public:
|
||||
explicit ScopedUnlockUnref(Var* var) : var_(var) {
|
||||
if (var_) {
|
||||
var_->mu()->lock();
|
||||
}
|
||||
}
|
||||
void Release() {
|
||||
if (var_) {
|
||||
var_->mu()->unlock();
|
||||
var_->Unref();
|
||||
var_ = nullptr;
|
||||
}
|
||||
}
|
||||
~ScopedUnlockUnref() { Release(); }
|
||||
|
||||
private:
|
||||
Var* var_;
|
||||
|
||||
ScopedUnlockUnref(const ScopedUnlockUnref&) = delete;
|
||||
void operator=(const ScopedUnlockUnref&) = delete;
|
||||
};
|
||||
|
||||
// A per-device helper function that does the actual work for
|
||||
// `UpdateVariableAndFill`.
|
||||
// Reason to use functor: C++ doesn't allow function-template partial
|
||||
// specialization.
|
||||
template <typename Device, typename Distribution>
|
||||
struct UpdateVariableAndFill_Philox;
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
// Declares the partially GPU-specialized functor struct.
|
||||
template <typename Distribution>
|
||||
struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& device,
|
||||
int64 output_size, int64 alg_tag_skip,
|
||||
ScopedUnlockUnref* not_used, Tensor* state_tensor,
|
||||
typename Distribution::ResultElementType* output_data);
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
static constexpr int64 THREEFRY_MIN_STATE_SIZE = 2;
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
104
tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h
Normal file
104
tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
|
||||
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The following 5 functions are made templates to avoid duplicate symbols when
|
||||
// linking.
|
||||
|
||||
// The following 2 functions use the contract "lower 32 bits for the first
|
||||
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
|
||||
// unlike a direct memory copy `memcpy(output, &input, 8)`.
|
||||
template <typename INT64>
|
||||
PHILOX_DEVICE_FUNC void Int64ToUint32s(INT64 input, uint32* output1,
|
||||
uint32* output2) {
|
||||
auto u64 = static_cast<uint64>(input);
|
||||
*output1 = static_cast<uint32>(u64);
|
||||
*output2 = static_cast<uint32>(u64 >> 32);
|
||||
}
|
||||
|
||||
template <typename UINT32>
|
||||
PHILOX_DEVICE_FUNC int64 Uint32sToInt64(UINT32 input1, UINT32 input2) {
|
||||
auto u64_1 = static_cast<uint64>(input1);
|
||||
auto u64_2 = static_cast<uint64>(input2);
|
||||
return static_cast<int64>(u64_1 | (u64_2 << 32));
|
||||
}
|
||||
|
||||
template <typename STATE_ELEMENT_TYPE>
|
||||
PHILOX_DEVICE_FUNC PhiloxRandom
|
||||
GetPhiloxRandomFromMem(STATE_ELEMENT_TYPE const* ptr) {
|
||||
PhiloxRandom::ResultType counter;
|
||||
PhiloxRandom::Key key;
|
||||
Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
|
||||
Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
|
||||
Int64ToUint32s(ptr[2], &key[0], &key[1]);
|
||||
return PhiloxRandom(counter, key);
|
||||
}
|
||||
|
||||
template <typename PHILOX_RANDOM>
|
||||
PHILOX_DEVICE_FUNC void WritePhiloxRandomToMem(PHILOX_RANDOM const& philox,
|
||||
StateElementType* ptr) {
|
||||
PhiloxRandom::ResultType const& counter = philox.counter();
|
||||
PhiloxRandom::Key const& key = philox.key();
|
||||
ptr[0] = Uint32sToInt64(counter[0], counter[1]);
|
||||
ptr[1] = Uint32sToInt64(counter[2], counter[3]);
|
||||
ptr[2] = Uint32sToInt64(key[0], key[1]);
|
||||
}
|
||||
|
||||
template <typename PHILOX_RANDOM>
|
||||
PHILOX_DEVICE_FUNC void UpdateMemWithPhiloxRandom(PHILOX_RANDOM const& philox,
|
||||
int64 output_size,
|
||||
StateElementType* ptr) {
|
||||
auto new_philox = philox;
|
||||
// Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
|
||||
// it just here.
|
||||
auto delta = output_size * 256;
|
||||
new_philox.Skip(delta); // do the actual increasing
|
||||
WritePhiloxRandomToMem(new_philox, ptr);
|
||||
}
|
||||
|
||||
// A per-device helper function that does the actual work for
|
||||
// `UpdateVariableAndFill`.
|
||||
// Reason to use functor: C++ doesn't allow function-template partial
|
||||
// specialization.
|
||||
template <typename Device, typename Distribution>
|
||||
struct UpdateVariableAndFill_Philox;
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
// Declares the partially GPU-specialized functor struct.
|
||||
template <typename Distribution>
|
||||
struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& device,
|
||||
int64 output_size, int64 alg_tag_skip,
|
||||
ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
|
||||
typename Distribution::ResultElementType* output_data);
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/random_op_gpu.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||
#include "tensorflow/core/util/cuda_launch_config.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -54,7 +54,7 @@ __global__ void FillKernel(
|
||||
template <typename Distribution>
|
||||
void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
||||
OpKernelContext* ctx, const GPUDevice& d, int64 output_size,
|
||||
int64 alg_tag_skip, ScopedUnlockUnref* not_used, Tensor* state_tensor,
|
||||
int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
|
||||
typename Distribution::ResultElementType* output_data) {
|
||||
OP_REQUIRES(
|
||||
ctx, alg_tag_skip == 0,
|
||||
|
@ -20,11 +20,9 @@ namespace tensorflow {
|
||||
|
||||
Status StatefulRandomShape(shape_inference::InferenceContext* c) {
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
// Check algorithm shape
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
|
||||
// Set output shape
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));
|
||||
@ -32,16 +30,45 @@ Status StatefulRandomShape(shape_inference::InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_OP("StatefulStandardNormalV2")
|
||||
#define REGISTER_STATEFUL_OP(name, default_dtype) \
|
||||
REGISTER_OP(name) \
|
||||
.Input("resource: resource") \
|
||||
.Input("algorithm: int64") \
|
||||
.Input("shape: shape_dtype") \
|
||||
.Output("output: dtype") \
|
||||
.Attr("dtype : type = " #default_dtype) \
|
||||
.Attr("shape_dtype : type = DT_INT64") \
|
||||
.SetShapeFn(StatefulRandomShape);
|
||||
|
||||
REGISTER_STATEFUL_OP("StatefulUniformFullInt", DT_UINT64);
|
||||
REGISTER_STATEFUL_OP("StatefulStandardNormalV2", DT_FLOAT);
|
||||
|
||||
REGISTER_OP("StatefulUniformInt")
|
||||
.Input("resource: resource")
|
||||
.Input("algorithm: int64")
|
||||
.Input("shape: shape_dtype")
|
||||
.Input("minval: dtype")
|
||||
.Input("maxval: dtype")
|
||||
.Output("output: dtype")
|
||||
.Attr("dtype : type = DT_FLOAT")
|
||||
.Attr("dtype : type = DT_INT64")
|
||||
.Attr("shape_dtype : type = DT_INT64")
|
||||
.SetShapeFn(StatefulRandomShape);
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
using shape_inference::ShapeHandle;
|
||||
// Check inputs
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||
// Set output
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
// Register the old 'StatefulStandardNormal' op
|
||||
// Register the old 'StatefulStandardNormal' op. This op is a short-lived
|
||||
// version where the 'resource' variable also contains the algorithm tag.
|
||||
// It is deprecated in favor of 'StatefulStandardNormalV2'.
|
||||
REGISTER_OP("StatefulStandardNormal")
|
||||
.Input("resource: resource")
|
||||
.Input("shape: shape_dtype")
|
||||
|
@ -53,9 +53,14 @@ SEED_SIZE = 16 # in units of SEED_TYPE
|
||||
STATE_TYPE = SEED_TYPE
|
||||
ALGORITHM_TYPE = STATE_TYPE
|
||||
RNG_ALG_PHILOX = 1
|
||||
RNG_ALG_THREEFRY = 2
|
||||
DEFAULT_ALGORITHM = RNG_ALG_PHILOX
|
||||
|
||||
|
||||
PHILOX_STATE_SIZE = 3
|
||||
THREEFRY_STATE_SIZE = 2
|
||||
|
||||
|
||||
def non_deterministic_seed():
|
||||
"""Makes a non-deterministic seed.
|
||||
|
||||
@ -75,7 +80,40 @@ def _uint_to_int(n):
|
||||
return n
|
||||
|
||||
|
||||
PHILOX_STATE_SIZE = 3
|
||||
def _make_1d_state(state_size, seed):
|
||||
"""Makes a 1-D RNG state.
|
||||
|
||||
Args:
|
||||
state_size: an integer.
|
||||
seed: an integer or 1-D tensor.
|
||||
|
||||
Returns:
|
||||
a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
|
||||
"""
|
||||
int_types = (int,) if sys.version_info >= (3, 0) else (int, long)
|
||||
if isinstance(seed, int_types):
|
||||
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
|
||||
ls = []
|
||||
for _ in range(state_size):
|
||||
ls.append(seed & SEED_BIT_MASK)
|
||||
seed >>= SEED_TYPE_BITS
|
||||
seed = ls
|
||||
# to avoid overflow error from np.asarray
|
||||
seed = list(map(_uint_to_int, seed))
|
||||
seed = np.asarray(seed, dtype=STATE_TYPE)
|
||||
if len(seed.shape) != 1:
|
||||
raise ValueError(
|
||||
"seed should only have one dimension; got shape: %s" % seed.shape)
|
||||
seed = seed[0:state_size]
|
||||
# Padding with zeros on the right if too short
|
||||
seed_size = seed.shape[0]
|
||||
if seed_size < state_size:
|
||||
seed = np.pad(
|
||||
seed, [(0, state_size - seed_size)],
|
||||
mode="constant",
|
||||
constant_values=0)
|
||||
assert seed.shape == (state_size,), "Wrong seed.shape: %s" % seed.shape
|
||||
return seed
|
||||
|
||||
|
||||
def _make_philox_state(seed):
|
||||
@ -87,35 +125,26 @@ def _make_philox_state(seed):
|
||||
Returns:
|
||||
a 1-D tensor.
|
||||
"""
|
||||
int_types = (int,) if sys.version_info >= (3, 0) else (int, long)
|
||||
if isinstance(seed, int_types):
|
||||
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
|
||||
ls = []
|
||||
for _ in range(PHILOX_STATE_SIZE):
|
||||
ls.append(seed & SEED_BIT_MASK)
|
||||
seed >>= SEED_TYPE_BITS
|
||||
seed = ls
|
||||
# to avoid overflow error from np.asarray
|
||||
seed = list(map(_uint_to_int, seed))
|
||||
seed = np.asarray(seed, dtype=STATE_TYPE)
|
||||
if len(seed.shape) != 1:
|
||||
raise ValueError(
|
||||
"seed should only have one dimension; got shape: %s" % seed.shape)
|
||||
seed = seed[0:PHILOX_STATE_SIZE]
|
||||
# Padding with zeros on the right if too short
|
||||
seed_size = seed.shape[0]
|
||||
if seed_size < PHILOX_STATE_SIZE:
|
||||
seed = np.pad(
|
||||
seed, [(0, PHILOX_STATE_SIZE - seed_size)],
|
||||
mode="constant",
|
||||
constant_values=0)
|
||||
assert seed.shape == (PHILOX_STATE_SIZE,), "Wrong seed.shape: %s" % seed.shape
|
||||
return seed
|
||||
return _make_1d_state(PHILOX_STATE_SIZE, seed)
|
||||
|
||||
|
||||
def _make_threefry_state(seed):
|
||||
"""Makes a RNG state for ThreeFry algorithm.
|
||||
|
||||
Args:
|
||||
seed: an integer or 1-D tensor.
|
||||
|
||||
Returns:
|
||||
a 1-D tensor.
|
||||
"""
|
||||
return _make_1d_state(THREEFRY_STATE_SIZE, seed)
|
||||
|
||||
|
||||
def _make_state_from_seed(seed, algorithm):
|
||||
if algorithm == RNG_ALG_PHILOX:
|
||||
return _make_philox_state(seed)
|
||||
elif algorithm == RNG_ALG_THREEFRY:
|
||||
return _make_threefry_state(seed)
|
||||
else:
|
||||
raise ValueError("Unsupported algorithm id: %s" % algorithm)
|
||||
|
||||
@ -168,26 +197,19 @@ class Generator(tracking.AutoTrackable):
|
||||
algorithm = DEFAULT_ALGORITHM
|
||||
state = create_rng_state(seed, algorithm)
|
||||
self._state_var = variables.Variable(state, dtype=STATE_TYPE)
|
||||
self._alg_var = variables.Variable(initial_value=algorithm,
|
||||
dtype=ALGORITHM_TYPE)
|
||||
self._alg_var = algorithm
|
||||
else:
|
||||
assert seed is None
|
||||
self._state_var = variables.Variable(copy_from.state, dtype=STATE_TYPE)
|
||||
self._alg_var = variables.Variable(initial_value=copy_from.algorithm,
|
||||
dtype=ALGORITHM_TYPE)
|
||||
self._alg_var = copy_from.algorithm
|
||||
|
||||
def reset(self, seed):
|
||||
"""Resets the generator.
|
||||
|
||||
This function is not thread-safe: if it is run concurrently with a call to
|
||||
sampling, the latter might see the new algorithm but the old state or vice
|
||||
versa.
|
||||
|
||||
Args:
|
||||
seed: the seed to reset the RNG to.
|
||||
"""
|
||||
algorithm = int(self.algorithm)
|
||||
state = create_rng_state(seed, algorithm)
|
||||
state = create_rng_state(seed, self.algorithm)
|
||||
self._state_var.assign(state)
|
||||
|
||||
@property
|
||||
@ -200,21 +222,59 @@ class Generator(tracking.AutoTrackable):
|
||||
|
||||
# The following functions return a tensor and as a side effect update
|
||||
# self._state_var.
|
||||
def standard_normal(self, shape, dtype=dtypes.float32):
|
||||
output = gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
self.state.handle, self.algorithm, shape, dtype)
|
||||
return output
|
||||
|
||||
def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
|
||||
name=None):
|
||||
with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name:
|
||||
shape = _shape_tensor(shape)
|
||||
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||
rnd = self.standard_normal(shape, dtype)
|
||||
rnd = gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
self.state.handle, self.algorithm, shape, dtype=dtype)
|
||||
return math_ops.add(rnd * stddev, mean, name=name)
|
||||
|
||||
# TODO(wangpeng): implement other distributions (`uniform`,
|
||||
def uniform(self, shape, minval=0, maxval=None,
|
||||
dtype=dtypes.float32, name=None):
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if maxval is None:
|
||||
if dtype.is_integer:
|
||||
raise ValueError("Must specify maxval for integer dtype %r" % dtype)
|
||||
maxval = 1
|
||||
with ops.name_scope(name, "stateful_uniform",
|
||||
[shape, minval, maxval]) as name:
|
||||
shape = _shape_tensor(shape)
|
||||
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
||||
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||
if dtype.is_integer:
|
||||
return gen_stateful_random_ops.stateful_uniform_int(
|
||||
self.state.handle, self.algorithm, shape=shape,
|
||||
minval=minval, maxval=maxval, name=name)
|
||||
else:
|
||||
# TODO(wangpeng): implement uniform for floats
|
||||
raise ValueError("uniform for floats not implemented yet")
|
||||
|
||||
def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
|
||||
"""Uniform distribution on an integer type's entire range.
|
||||
|
||||
The other method `uniform` only covers the range [minval, maxval), which
|
||||
cannot be `dtype`'s full range because `maxval` is of type `dtype`.
|
||||
|
||||
Args:
|
||||
shape: the shape of the output.
|
||||
dtype: (optional) the integer type, default to uint64.
|
||||
name: (optional) the name of the node.
|
||||
|
||||
Returns:
|
||||
A tensor of random numbers of the required shape.
|
||||
"""
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
with ops.name_scope(name, "stateful_uniform_full_int",
|
||||
[shape]) as name:
|
||||
shape = _shape_tensor(shape)
|
||||
return gen_stateful_random_ops.stateful_uniform_full_int(
|
||||
self.state.handle, self.algorithm, shape=shape,
|
||||
dtype=dtype, name=name)
|
||||
|
||||
# TODO(wangpeng): implement other distributions (
|
||||
# `truncated_normal`, etc.)
|
||||
# TODO(wangpeng): implement `make_seeds`
|
||||
# TODO(wangpeng): implement `make_generators`
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_random_ops
|
||||
@ -138,7 +139,7 @@ class StatefulRandomOpsTest(test.TestCase):
|
||||
|
||||
def new():
|
||||
with ops.device("/device:CPU:0"):
|
||||
return random.get_global_generator().standard_normal(shape, dtype=dtype)
|
||||
return random.get_global_generator().normal(shape, dtype=dtype)
|
||||
|
||||
for _ in range(100):
|
||||
self.assertAllEqual(old(), new())
|
||||
@ -164,7 +165,7 @@ class StatefulRandomOpsTest(test.TestCase):
|
||||
|
||||
def new():
|
||||
with ops.device(test_util.gpu_device_name()):
|
||||
return random.get_global_generator().standard_normal(shape, dtype=dtype)
|
||||
return random.get_global_generator().normal(shape, dtype=dtype)
|
||||
|
||||
for _ in range(100):
|
||||
self.assertAllEqual(old(), new())
|
||||
@ -203,7 +204,7 @@ class StatefulRandomOpsTest(test.TestCase):
|
||||
|
||||
random.reset_global_generator(50)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
AssertionError, "variable.*deleted"):
|
||||
errors_impl.NotFoundError, "Resource .+ does not exist"):
|
||||
a = f()
|
||||
random.reset_global_generator(50)
|
||||
b = f()
|
||||
|
Loading…
Reference in New Issue
Block a user