A continuing partial implementation of RFC "Random numbers in TensorFlow 2.0" (https://github.com/tensorflow/community/pull/38):
In this change: - XLA kernel (TF-XLA bridge) for op `StatefulStandardNormalV2` with ThreeFry algorithm. To be done: - ops for other distributions; - other RNG algorithms; - batch seeds; - initializers ('RandomUniform', etc.); - changing `non_deterministic_seed`'implementation to be an op, and other changes to make most code in stateful_random_ops.py tf.function-compatible. PiperOrigin-RevId: 236699104
This commit is contained in:
parent
55240a22f1
commit
12606ff846
@ -835,6 +835,20 @@ tf_xla_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_xla_py_test(
|
||||||
|
name = "stateful_random_ops_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["stateful_random_ops_test.py"],
|
||||||
|
tags = ["optonly"],
|
||||||
|
deps = [
|
||||||
|
":xla_test",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:standard_ops",
|
||||||
|
"//tensorflow/python:stateful_random_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "stateless_random_ops_test",
|
name = "stateless_random_ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
282
tensorflow/compiler/tests/stateful_random_ops_test.py
Normal file
282
tensorflow/compiler/tests/stateful_random_ops_test.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for stateful random-number generation ops."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.compiler.tests import xla_test
|
||||||
|
from tensorflow.python.client import device_lib
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import gen_stateful_random_ops
|
||||||
|
from tensorflow.python.ops import stateful_random_ops as \
|
||||||
|
random
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
def xla_device_name():
|
||||||
|
devices = device_lib.list_local_devices()
|
||||||
|
def find_type(device_type):
|
||||||
|
for d in devices:
|
||||||
|
if d.device_type == device_type:
|
||||||
|
return d.name
|
||||||
|
return None
|
||||||
|
name = find_type("TPU") or find_type("XLA_GPU") or find_type("XLA_CPU")
|
||||||
|
if name is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Can't find any XLA device. Available devices:\n%s" % devices)
|
||||||
|
return str(name)
|
||||||
|
|
||||||
|
|
||||||
|
class StatefulRandomOpsTest(xla_test.XLATestCase):
|
||||||
|
"""Test cases for stateful random-number generator operators."""
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testSimple(self):
|
||||||
|
"""A simple test.
|
||||||
|
"""
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
gen.normal(shape=(3,))
|
||||||
|
gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32)
|
||||||
|
gen.uniform_full_int(shape=(3,))
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testDefun(self):
|
||||||
|
"""Test for defun.
|
||||||
|
"""
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
@def_function.function
|
||||||
|
def f():
|
||||||
|
x = gen.normal(shape=(3,))
|
||||||
|
y = gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32)
|
||||||
|
z = gen.uniform_full_int(shape=(3,))
|
||||||
|
return (x, y, z)
|
||||||
|
f()
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testThreefry2x32(self):
|
||||||
|
"""Tests ThreeFry2x32 conforms to known results.
|
||||||
|
"""
|
||||||
|
# Based on
|
||||||
|
# https://github.com/google/jax/blob/8565a3486adf16beb388b2364c9cd930d7a0d92d/tests/random_test.py#L65-L85
|
||||||
|
# which is in turn based on
|
||||||
|
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
|
||||||
|
|
||||||
|
def uint32s_to_uint64(a, b):
|
||||||
|
return b << 32 | a
|
||||||
|
|
||||||
|
def verify(counter1, counter2, key1, key2, expect1, expect2):
|
||||||
|
counter = uint32s_to_uint64(counter1, counter2)
|
||||||
|
key = uint32s_to_uint64(key1, key2)
|
||||||
|
random.get_global_generator().reset([counter, key])
|
||||||
|
got = random.get_global_generator().uniform_full_int(
|
||||||
|
shape=(2,), dtype=dtypes.uint32)
|
||||||
|
expect = [expect1, expect2]
|
||||||
|
self.assertAllEqual(expect, got)
|
||||||
|
random.get_global_generator().reset([counter, key])
|
||||||
|
got = random.get_global_generator().uniform_full_int(
|
||||||
|
shape=(), dtype=dtypes.uint64)
|
||||||
|
self.assertAllEqual(uint32s_to_uint64(*expect), got)
|
||||||
|
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
verify(0x00000000, 0x00000000, 0x00000000, 0x00000000,
|
||||||
|
0x6b200159, 0x99ba4efe)
|
||||||
|
verify(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
|
||||||
|
0x1cb996fc, 0xbb002be7)
|
||||||
|
verify(0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344,
|
||||||
|
0xc4923a9c, 0x483df7a0)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testNewState(self):
|
||||||
|
"""Tests that the new state is correct.
|
||||||
|
"""
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
counter = 57
|
||||||
|
key = 0x1234
|
||||||
|
size = 46
|
||||||
|
seed = [counter, key]
|
||||||
|
gen = random.Generator(
|
||||||
|
seed=seed, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
|
||||||
|
self.assertAllEqual([counter+(size+1)//2, key], gen.state.read_value())
|
||||||
|
gen.reset(seed=seed)
|
||||||
|
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
|
||||||
|
self.assertAllEqual([counter+size, key], gen.state.read_value())
|
||||||
|
|
||||||
|
def _testRngIsNotConstant(self, rng, dtype):
|
||||||
|
# Tests that 'rng' does not always return the same value.
|
||||||
|
# The random-number generator, if working correctly, should produce the
|
||||||
|
# same output multiple times with low probability.
|
||||||
|
x = rng(dtype).numpy()
|
||||||
|
y = rng(dtype).numpy()
|
||||||
|
self.assertFalse(np.array_equal(x, y))
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testUniformIsNotConstant(self):
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
def rng(dtype):
|
||||||
|
maxval = dtype.max
|
||||||
|
# Workaround for b/125364959
|
||||||
|
if dtype == dtypes.uint64:
|
||||||
|
maxval = 10000000
|
||||||
|
return gen.uniform(shape=[2], dtype=dtype, maxval=maxval)
|
||||||
|
|
||||||
|
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
|
||||||
|
self._testRngIsNotConstant(rng, dtype)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testNormalIsNotConstant(self):
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
def rng(dtype):
|
||||||
|
return gen.normal(shape=[2], dtype=dtype)
|
||||||
|
|
||||||
|
for dtype in {dtypes.float32}:
|
||||||
|
self._testRngIsNotConstant(rng, dtype)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testUniformIntIsInRange(self):
|
||||||
|
minval = 2
|
||||||
|
maxval = 33
|
||||||
|
size = 1000
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
|
||||||
|
x = gen.uniform(
|
||||||
|
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
|
||||||
|
self.assertTrue(np.all(x >= minval))
|
||||||
|
self.assertTrue(np.all(x < maxval))
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testNormalIsFinite(self):
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
for dtype in {dtypes.float32}:
|
||||||
|
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
||||||
|
self.assertTrue(np.all(np.isfinite(x)))
|
||||||
|
|
||||||
|
def _chi_squared(self, x, bins):
|
||||||
|
"""Pearson's Chi-squared test."""
|
||||||
|
x = np.ravel(x)
|
||||||
|
n = len(x)
|
||||||
|
histogram, _ = np.histogram(x, bins=bins, range=(0, 1))
|
||||||
|
expected = n / float(bins)
|
||||||
|
return np.sum(np.square(histogram - expected) / expected)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testDistributionOfUniform(self):
|
||||||
|
"""Use Pearson's Chi-squared test to test for uniformity."""
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
n = 1000
|
||||||
|
seed = 12
|
||||||
|
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
|
||||||
|
gen = random.Generator(seed=seed, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
maxval = 1
|
||||||
|
if dtype.is_integer:
|
||||||
|
maxval = 100
|
||||||
|
x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy()
|
||||||
|
if maxval > 1:
|
||||||
|
# Normalize y to range [0, 1).
|
||||||
|
x = x.astype(float) / maxval
|
||||||
|
# Tests that the values are distributed amongst 10 bins with equal
|
||||||
|
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
|
||||||
|
# p=0.05. This test is probabilistic and would be flaky if the random
|
||||||
|
# seed were not fixed.
|
||||||
|
val = self._chi_squared(x, 10)
|
||||||
|
self.assertLess(val, 16.92)
|
||||||
|
|
||||||
|
def _normal_cdf(self, x):
|
||||||
|
"""Cumulative distribution function for a standard normal distribution."""
|
||||||
|
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
|
||||||
|
|
||||||
|
def _anderson_darling(self, x):
|
||||||
|
"""Anderson-Darling test for a standard normal distribution."""
|
||||||
|
x = np.sort(np.ravel(x))
|
||||||
|
n = len(x)
|
||||||
|
i = np.linspace(1, n, n)
|
||||||
|
z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) +
|
||||||
|
(2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x)))
|
||||||
|
return -n - z / n
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testDistributionOfNormal(self):
|
||||||
|
"""Use Anderson-Darling test to test distribution appears normal."""
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
n = 1000
|
||||||
|
for dtype in {dtypes.float32}:
|
||||||
|
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
x = gen.normal(shape=[n], dtype=dtype).numpy()
|
||||||
|
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
||||||
|
# test where the mean and variance are known. This test is probabilistic
|
||||||
|
# so to avoid flakiness the seed is fixed.
|
||||||
|
self.assertLess(self._anderson_darling(x.astype(float)), 2.492)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testErrors(self):
|
||||||
|
"""Tests that proper errors are raised.
|
||||||
|
"""
|
||||||
|
shape = [2, 3]
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors_impl.InvalidArgumentError,
|
||||||
|
r"algorithm must be of shape \[\], not"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
gen.state.handle, [0, 0], shape)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
TypeError, "Requested dtype: int64"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
gen.state.handle, 1.1, shape)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors_impl.InvalidArgumentError,
|
||||||
|
"Unsupported algorithm id"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
gen.state.handle, 123, shape)
|
||||||
|
var = variables.Variable([0, 0], dtype=dtypes.uint32)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors_impl.InvalidArgumentError,
|
||||||
|
"Type mismatch for read of variable .* Expected int64; got"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
var.handle, random.RNG_ALG_THREEFRY, shape)
|
||||||
|
var = variables.Variable([[0]], dtype=dtypes.int64)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors_impl.InvalidArgumentError,
|
||||||
|
"RNG state must have one and only one dimension, not"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
var.handle, random.RNG_ALG_THREEFRY, shape)
|
||||||
|
var = variables.Variable([0], dtype=dtypes.int64)
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors_impl.InvalidArgumentError,
|
||||||
|
"For the ThreeFry algorithm, the size of state must be at least"):
|
||||||
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
var.handle, random.RNG_ALG_THREEFRY, shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -64,6 +64,7 @@ tf_kernel_library(
|
|||||||
"qr_op.cc",
|
"qr_op.cc",
|
||||||
"quantize_and_dequantize_op.cc",
|
"quantize_and_dequantize_op.cc",
|
||||||
"random_ops.cc",
|
"random_ops.cc",
|
||||||
|
"random_ops_util.h",
|
||||||
"reduce_window_op.cc",
|
"reduce_window_op.cc",
|
||||||
"reduction_ops.cc",
|
"reduction_ops.cc",
|
||||||
"reduction_ops.h",
|
"reduction_ops.h",
|
||||||
@ -89,6 +90,7 @@ tf_kernel_library(
|
|||||||
"sparse_to_dense_op.cc",
|
"sparse_to_dense_op.cc",
|
||||||
"split_op.cc",
|
"split_op.cc",
|
||||||
"stack_ops.cc",
|
"stack_ops.cc",
|
||||||
|
"stateful_random_ops.cc",
|
||||||
"stateless_random_ops.cc",
|
"stateless_random_ops.cc",
|
||||||
"strided_slice_op.cc",
|
"strided_slice_op.cc",
|
||||||
"tensor_array_ops.cc",
|
"tensor_array_ops.cc",
|
||||||
@ -167,6 +169,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:sparse_ops_op_lib",
|
"//tensorflow/core:sparse_ops_op_lib",
|
||||||
"//tensorflow/core:spectral_ops_op_lib",
|
"//tensorflow/core:spectral_ops_op_lib",
|
||||||
"//tensorflow/core:state_ops_op_lib",
|
"//tensorflow/core:state_ops_op_lib",
|
||||||
|
"//tensorflow/core:stateful_random_ops_op_lib",
|
||||||
"//tensorflow/core:stateless_random_ops_op_lib",
|
"//tensorflow/core:stateless_random_ops_op_lib",
|
||||||
"//tensorflow/core:training_ops_op_lib",
|
"//tensorflow/core:training_ops_op_lib",
|
||||||
"//tensorflow/core/kernels:constant_op",
|
"//tensorflow/core/kernels:constant_op",
|
||||||
@ -179,6 +182,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core/kernels:sendrecv_ops",
|
"//tensorflow/core/kernels:sendrecv_ops",
|
||||||
"//tensorflow/core/kernels:sparse_to_dense_op",
|
"//tensorflow/core/kernels:sparse_to_dense_op",
|
||||||
"//tensorflow/core/kernels:stack_ops",
|
"//tensorflow/core/kernels:stack_ops",
|
||||||
|
"//tensorflow/core/kernels:stateful_random_ops",
|
||||||
"//tensorflow/core/kernels:training_ops",
|
"//tensorflow/core/kernels:training_ops",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
34
tensorflow/compiler/tf2xla/kernels/random_ops_util.h
Normal file
34
tensorflow/compiler/tf2xla/kernels/random_ops_util.h
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_
|
||||||
|
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.
|
||||||
|
// It masks the last 16 bit. With normal rounding, values near "maxval" would be
|
||||||
|
// converted to "maxval" which is out of range ["minval", "maxval"). In
|
||||||
|
// addition, the distribution near the limit is not uniform.
|
||||||
|
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_
|
362
tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc
Normal file
362
tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc
Normal file
@ -0,0 +1,362 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
||||||
|
#include "tensorflow/core/lib/math/math_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::pair<xla::ThreeFry2x32State, xla::XlaOp> GetInputsFromCounter(
|
||||||
|
xla::XlaOp counter, const int64 size) {
|
||||||
|
auto builder = counter.builder();
|
||||||
|
auto input_u64 = Iota(builder, xla::U64, size);
|
||||||
|
input_u64 = input_u64 + counter;
|
||||||
|
counter = counter + xla::ConstantR0<uint64>(builder, size);
|
||||||
|
return std::make_pair(xla::Uint64ToUint32s(input_u64), counter);
|
||||||
|
}
|
||||||
|
|
||||||
|
// `StatelessRngUniformU32` uses ThreeFry2x32’s counter space too
|
||||||
|
// wastefully, only able to generate 2^32*2 int32 numbers for each key, while
|
||||||
|
// the real capacity is 2^64*2. Counter-space efficiency is important for
|
||||||
|
// stateful ops, hence the following 2 new functions.
|
||||||
|
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU32(
|
||||||
|
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||||
|
auto builder = key.builder();
|
||||||
|
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
||||||
|
const int64 half_size = xla::CeilOfRatio<int64>(size, 2);
|
||||||
|
const bool size_is_odd = (half_size * 2 != size);
|
||||||
|
auto inputs_counter = GetInputsFromCounter(counter, half_size);
|
||||||
|
auto inputs = inputs_counter.first;
|
||||||
|
counter = inputs_counter.second;
|
||||||
|
auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key));
|
||||||
|
if (size_is_odd) {
|
||||||
|
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||||
|
}
|
||||||
|
auto result = ConcatInDim(builder, outputs, 0);
|
||||||
|
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
||||||
|
counter);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU64(
|
||||||
|
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||||
|
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
||||||
|
auto inputs_counter = GetInputsFromCounter(counter, size);
|
||||||
|
auto inputs = inputs_counter.first;
|
||||||
|
counter = inputs_counter.second;
|
||||||
|
auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||||
|
auto result = Uint32sToUint64(outputs);
|
||||||
|
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
||||||
|
counter);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
|
||||||
|
xla::XlaOp counter,
|
||||||
|
const xla::Shape& shape,
|
||||||
|
xla::XlaOp minval,
|
||||||
|
xla::XlaOp maxval) {
|
||||||
|
auto builder = key.builder();
|
||||||
|
xla::PrimitiveType type = shape.element_type();
|
||||||
|
switch (type) {
|
||||||
|
case xla::F32: {
|
||||||
|
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
||||||
|
auto bits = bits_counter.first;
|
||||||
|
counter = bits_counter.second;
|
||||||
|
return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval),
|
||||||
|
counter);
|
||||||
|
}
|
||||||
|
case xla::U32: // fall through
|
||||||
|
case xla::S32: {
|
||||||
|
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
||||||
|
auto bits = bits_counter.first;
|
||||||
|
counter = bits_counter.second;
|
||||||
|
return std::make_pair(
|
||||||
|
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32),
|
||||||
|
counter);
|
||||||
|
}
|
||||||
|
case xla::U64: // fall through
|
||||||
|
case xla::S64: {
|
||||||
|
auto bits_counter = StatefulRngUniformU64(key, counter, shape);
|
||||||
|
auto bits = bits_counter.first;
|
||||||
|
counter = bits_counter.second;
|
||||||
|
return std::make_pair(
|
||||||
|
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64),
|
||||||
|
counter);
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return std::make_pair(builder->ReportError(xla::Unimplemented(
|
||||||
|
"Types other than F32, U32, S32, U64 and S64 "
|
||||||
|
"are not implemented by "
|
||||||
|
"StatefulRngUniform.")),
|
||||||
|
counter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename A, typename B, typename A2>
|
||||||
|
std::pair<A2, B> map_first(std::function<A2(A)> f, std::pair<A, B> p) {
|
||||||
|
return std::make_pair(f(p.first), p.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformFullInt(
|
||||||
|
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||||
|
xla::PrimitiveType type = shape.element_type();
|
||||||
|
switch (type) {
|
||||||
|
case xla::U32:
|
||||||
|
return StatefulRngUniformU32(key, counter, shape);
|
||||||
|
case xla::S32: {
|
||||||
|
// Needs explicit function type because of type-inference failure.
|
||||||
|
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
||||||
|
return BitcastConvertType(x, xla::S32);
|
||||||
|
};
|
||||||
|
return map_first(f, StatefulRngUniformU32(key, counter, shape));
|
||||||
|
}
|
||||||
|
case xla::U64:
|
||||||
|
return StatefulRngUniformU64(key, counter, shape);
|
||||||
|
case xla::S64: {
|
||||||
|
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
||||||
|
return BitcastConvertType(x, xla::S64);
|
||||||
|
};
|
||||||
|
return map_first(f, StatefulRngUniformU64(key, counter, shape));
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
auto builder = key.builder();
|
||||||
|
return std::make_pair(
|
||||||
|
builder->ReportError(xla::Unimplemented(
|
||||||
|
"Types other than U32, S32, U64 and S64 are not implemented by "
|
||||||
|
"StatefulRngUniformFullInt; got: %s",
|
||||||
|
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||||
|
counter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ListB, typename ListA, typename F>
|
||||||
|
ListB Map(F f, ListA const& list_a) {
|
||||||
|
ListB list_b;
|
||||||
|
for (auto a : list_a) {
|
||||||
|
list_b.push_back(f(a));
|
||||||
|
}
|
||||||
|
return list_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
|
||||||
|
absl::Span<const xla::XlaOp> scalars) {
|
||||||
|
return ConcatInDim(
|
||||||
|
builder,
|
||||||
|
Map<std::vector<xla::XlaOp>>(
|
||||||
|
[](xla::XlaOp x) { return xla::Reshape(x, {1}); }, scalars),
|
||||||
|
0);
|
||||||
|
}
|
||||||
|
|
||||||
|
using sampler_return_type = xla::StatusOr<std::pair<xla::XlaOp, xla::XlaOp>>;
|
||||||
|
|
||||||
|
// A helper function containing the common part of several kernels below.
|
||||||
|
// Precondition: 'algorithm' and 'shape' are compile-time constants.
|
||||||
|
Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
||||||
|
int alg_input_idx, int shape_input_idx,
|
||||||
|
std::function<sampler_return_type(xla::XlaOp, xla::XlaOp,
|
||||||
|
TensorShape)> const&
|
||||||
|
sample_with_threefry) {
|
||||||
|
auto alg_shape = ctx->InputShape(alg_input_idx);
|
||||||
|
if (alg_shape.dims() != 0) {
|
||||||
|
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||||
|
alg_shape.DebugString());
|
||||||
|
}
|
||||||
|
xla::Literal alg_literal;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||||
|
auto alg = alg_literal.Get<Algorithm>({});
|
||||||
|
|
||||||
|
if (alg == RNG_ALG_THREEFRY) {
|
||||||
|
xla::XlaOp var;
|
||||||
|
TensorShape var_shape;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->ReadVariableInput(
|
||||||
|
state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var));
|
||||||
|
if (var_shape.dims() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"RNG state must have one and only one dimension, not ",
|
||||||
|
var_shape.dims());
|
||||||
|
}
|
||||||
|
auto state_size = var_shape.dim_size(0);
|
||||||
|
if (state_size < THREEFRY_MIN_STATE_SIZE) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"For the ThreeFry algorithm, the size of state"
|
||||||
|
" must be at least ",
|
||||||
|
THREEFRY_MIN_STATE_SIZE, "; got ", state_size);
|
||||||
|
}
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
|
||||||
|
|
||||||
|
static constexpr int COUNTER_SIZE = 1;
|
||||||
|
auto counter = BitcastConvertType(
|
||||||
|
xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64);
|
||||||
|
auto key = BitcastConvertType(
|
||||||
|
xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}),
|
||||||
|
{}),
|
||||||
|
xla::U64);
|
||||||
|
|
||||||
|
auto status_or_value = sample_with_threefry(counter, key, shape);
|
||||||
|
if (!status_or_value.ok()) {
|
||||||
|
return status_or_value.status();
|
||||||
|
}
|
||||||
|
auto output_counter = status_or_value.ConsumeValueOrDie();
|
||||||
|
auto output = output_counter.first;
|
||||||
|
counter = output_counter.second;
|
||||||
|
ctx->SetOutput(0, output);
|
||||||
|
auto builder = ctx->builder();
|
||||||
|
var = ConcatScalars(builder, {counter, key});
|
||||||
|
xla::PrimitiveType state_element_type;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
|
||||||
|
var = BitcastConvertType(var, state_element_type);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Unsupported algorithm id: ", alg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class StatefulStandardNormalOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatefulStandardNormalOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
auto builder = ctx->builder();
|
||||||
|
auto sample_with_threefry =
|
||||||
|
// Needs explicit lambda return type because it fails to be inferred.
|
||||||
|
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
||||||
|
TensorShape shape) -> sampler_return_type {
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||||
|
|
||||||
|
auto uniform_counter = StatefulRngUniform(
|
||||||
|
key, counter, xla_shape,
|
||||||
|
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||||
|
xla::ConstantR0<float>(builder, 1.0));
|
||||||
|
auto uniform = uniform_counter.first;
|
||||||
|
counter = uniform_counter.second;
|
||||||
|
// Convert uniform distribution to normal distribution by computing
|
||||||
|
// sqrt(2) * erfinv(x)
|
||||||
|
auto normal =
|
||||||
|
xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
|
||||||
|
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
||||||
|
return {{normal, counter}};
|
||||||
|
};
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||||
|
/*shape_input_idx=*/2, sample_with_threefry));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(StatefulStandardNormalOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(wangpeng): Support plain float16 and float64 to get rid of the
|
||||||
|
// `TypeConstraint`.
|
||||||
|
REGISTER_XLA_OP(Name("StatefulStandardNormalV2")
|
||||||
|
.CompileTimeConstantInput("algorithm")
|
||||||
|
.CompileTimeConstantInput("shape")
|
||||||
|
.TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}),
|
||||||
|
StatefulStandardNormalOp);
|
||||||
|
|
||||||
|
class StatefulUniformIntOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
xla::XlaOp minval = ctx->Input(3);
|
||||||
|
xla::XlaOp maxval = ctx->Input(4);
|
||||||
|
auto sample_with_threefry = [minval, maxval, this](
|
||||||
|
xla::XlaOp counter, xla::XlaOp key,
|
||||||
|
TensorShape shape) -> sampler_return_type {
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||||
|
return StatefulRngUniform(key, counter, xla_shape, minval, maxval);
|
||||||
|
};
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||||
|
/*shape_input_idx=*/2, sample_with_threefry));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformIntOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("StatefulUniformInt")
|
||||||
|
.CompileTimeConstantInput("algorithm")
|
||||||
|
.CompileTimeConstantInput("shape")
|
||||||
|
.TypeConstraint("dtype",
|
||||||
|
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
|
||||||
|
StatefulUniformIntOp);
|
||||||
|
|
||||||
|
class StatefulUniformFullIntOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
auto sample_with_threefry = [this](
|
||||||
|
xla::XlaOp counter, xla::XlaOp key,
|
||||||
|
TensorShape shape) -> sampler_return_type {
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||||
|
return StatefulRngUniformFullInt(key, counter, xla_shape);
|
||||||
|
};
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||||
|
/*shape_input_idx=*/2, sample_with_threefry));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformFullIntOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("StatefulUniformFullInt")
|
||||||
|
.CompileTimeConstantInput("algorithm")
|
||||||
|
.CompileTimeConstantInput("shape")
|
||||||
|
.TypeConstraint("dtype",
|
||||||
|
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
|
||||||
|
StatefulUniformFullIntOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
@ -31,12 +32,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/math/math_util.h"
|
#include "tensorflow/core/lib/math/math_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
|
||||||
|
|
||||||
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
||||||
// Mask the last 16 bit. With normal rounding, values near "maxval" would be
|
|
||||||
// converted to "maxval" which is out of range ["minval", "maxval"). In
|
|
||||||
// addition, the distribution near the limit is not uniform.
|
|
||||||
if (dtype == DT_BFLOAT16) {
|
if (dtype == DT_BFLOAT16) {
|
||||||
xla::XlaBuilder* builder = input.builder();
|
xla::XlaBuilder* builder = input.builder();
|
||||||
auto output = xla::BitcastConvertType(input, xla::U32) &
|
auto output = xla::BitcastConvertType(input, xla::U32) &
|
||||||
@ -48,6 +45,8 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
class StatelessRandomUniformOp : public XlaOpKernel {
|
class StatelessRandomUniformOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
|
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
|
||||||
|
@ -82,6 +82,9 @@ CreateResourceOpInfoMap() {
|
|||||||
add("ResourceScatterSub" , kReadWrite, kVariable);
|
add("ResourceScatterSub" , kReadWrite, kVariable);
|
||||||
add("ResourceScatterUpdate" , kReadWrite, kVariable);
|
add("ResourceScatterUpdate" , kReadWrite, kVariable);
|
||||||
add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
|
add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
|
||||||
|
add("StatefulStandardNormalV2" , kReadWrite, kVariable);
|
||||||
|
add("StatefulUniformFullInt" , kReadWrite, kVariable);
|
||||||
|
add("StatefulUniformInt" , kReadWrite, kVariable);
|
||||||
add("VarIsInitializedOp" , kRead, kVariable);
|
add("VarIsInitializedOp" , kRead, kVariable);
|
||||||
add("VariableShape" , kRead, kVariable);
|
add("VariableShape" , kRead, kVariable);
|
||||||
|
|
||||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "absl/base/casts.h"
|
#include "absl/base/casts.h"
|
||||||
@ -30,11 +32,8 @@ XlaOp RotateLeftU32(XlaOp v, int distance) {
|
|||||||
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
|
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
|
||||||
}
|
}
|
||||||
|
|
||||||
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
} // namespace
|
||||||
|
|
||||||
// Implements the ThreeFry counter-based PRNG algorithm.
|
|
||||||
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
|
||||||
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
|
||||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
||||||
XlaBuilder* builder = input[0].builder();
|
XlaBuilder* builder = input[0].builder();
|
||||||
key[0] = BitcastConvertType(key[0], U32);
|
key[0] = BitcastConvertType(key[0], U32);
|
||||||
@ -127,15 +126,28 @@ XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
|
|||||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
|
||||||
|
auto builder = u64.builder();
|
||||||
|
auto const32 = ConstantR0WithType(builder, U64, 32);
|
||||||
|
auto fst = ConvertElementType(u64, U32);
|
||||||
|
auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
|
||||||
|
return {fst, snd};
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
|
||||||
|
auto builder = u32s[0].builder();
|
||||||
|
return ConvertElementType(u32s[0], U64) |
|
||||||
|
ShiftLeft(ConvertElementType(u32s[1], U64),
|
||||||
|
ConstantR0WithType(builder, U64, 32));
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
|
XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
|
||||||
XlaBuilder* builder = key[0].builder();
|
XlaBuilder* builder = key[0].builder();
|
||||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||||
ThreeFry2x32State inputs = GetInputs(size, builder);
|
ThreeFry2x32State inputs = GetInputs(size, builder);
|
||||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
||||||
// low 32 bit: outputs[0], high 32 bit: outputs[1]
|
// low 32 bit: outputs[0], high 32 bit: outputs[1]
|
||||||
auto result = ConvertElementType(outputs[0], U64) |
|
auto result = Uint32sToUint64(outputs);
|
||||||
ShiftLeft(ConvertElementType(outputs[1], U64),
|
|
||||||
ConstantR0WithType(builder, U64, 32));
|
|
||||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,10 +173,6 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
|||||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||||
PrimitiveType type, PrimitiveType unsigned_type) {
|
PrimitiveType type, PrimitiveType unsigned_type) {
|
||||||
XlaBuilder* builder = bits.builder();
|
XlaBuilder* builder = bits.builder();
|
||||||
// TODO(b/72573764): Generate real uniform integer distribution.
|
|
||||||
// The following algorithm is the same one that TF uses right now, but it's
|
|
||||||
// uniform only when maxval - minval is a divisor of the range that bits is
|
|
||||||
// generated from.
|
|
||||||
auto range = BitcastConvertType(maxval, unsigned_type) -
|
auto range = BitcastConvertType(maxval, unsigned_type) -
|
||||||
BitcastConvertType(minval, unsigned_type);
|
BitcastConvertType(minval, unsigned_type);
|
||||||
auto dist = Rem(bits, range);
|
auto dist = Rem(bits, range);
|
||||||
@ -175,8 +183,6 @@ XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
|||||||
BitcastConvertType(dist - dist_div_2, type);
|
BitcastConvertType(dist - dist_div_2, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
||||||
XlaOp minval, XlaOp maxval) {
|
XlaOp minval, XlaOp maxval) {
|
||||||
XlaBuilder* builder = seeds[0].builder();
|
XlaBuilder* builder = seeds[0].builder();
|
||||||
|
@ -23,12 +23,38 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||||
|
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||||
|
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||||
|
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
||||||
|
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key);
|
||||||
|
|
||||||
// Returns a tensor containing 'shape' random values uniformly distributed in
|
// Returns a tensor containing 'shape' random values uniformly distributed in
|
||||||
// the range [minval, maxval). Requires 2 32-bit integer seeds.
|
// the range [minval, maxval). Requires 2 32-bit integer seeds.
|
||||||
// Currently only 'shape's of type F32, S32 and S64 are implemented.
|
// Currently only 'shape's of type F32, S32 and S64 are implemented.
|
||||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
||||||
XlaOp minval, XlaOp maxval);
|
XlaOp minval, XlaOp maxval);
|
||||||
|
|
||||||
|
// Converts a 32-bit (signed or unsigned) integer random number `bits` into a
|
||||||
|
// float32 in the range [minval, maxval).
|
||||||
|
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval);
|
||||||
|
|
||||||
|
// Converts an integer random number 'bits' of type 'type' to a random number
|
||||||
|
// in the range [minval, maxval), of the same type. 'unsigned_type' is the
|
||||||
|
// unsigned version of 'type' (could be the same) with the same bit width.
|
||||||
|
// The algorithm is the same one that TF uses right now, but it's
|
||||||
|
// uniform only when maxval - minval is a divisor of the range that bits is
|
||||||
|
// generated from.
|
||||||
|
// TODO(b/72573764): Generate real uniform integer distribution.
|
||||||
|
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||||
|
PrimitiveType type, PrimitiveType unsigned_type);
|
||||||
|
|
||||||
|
// The following 2 functions, for converting between one uint64 and two uint32s,
|
||||||
|
// use the contract "lower 32 bits for the first uint32, higher 32 bits for the
|
||||||
|
// second".
|
||||||
|
ThreeFry2x32State Uint64ToUint32s(XlaOp u64);
|
||||||
|
XlaOp Uint32sToUint64(ThreeFry2x32State u32s);
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_
|
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_
|
||||||
|
@ -0,0 +1,38 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatefulUniformFullInt"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "resource"
|
||||||
|
description: <<END
|
||||||
|
The handle of the resource variable that stores the state of the RNG.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "algorithm"
|
||||||
|
description: <<END
|
||||||
|
The RNG algorithm.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "shape"
|
||||||
|
description: <<END
|
||||||
|
The shape of the output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
Random values with specified shape.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The type of the output.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Outputs random integers from a uniform distribution."
|
||||||
|
description: <<END
|
||||||
|
The generated values are uniform integers covering the whole range of `dtype`.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,56 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatefulUniformInt"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "resource"
|
||||||
|
description: <<END
|
||||||
|
The handle of the resource variable that stores the state of the RNG.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "algorithm"
|
||||||
|
description: <<END
|
||||||
|
The RNG algorithm.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "shape"
|
||||||
|
description: <<END
|
||||||
|
The shape of the output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "minval"
|
||||||
|
description: <<END
|
||||||
|
Minimum value (inclusive, scalar).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "maxval"
|
||||||
|
description: <<END
|
||||||
|
Maximum value (exclusive, scalar).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
Random values with specified shape.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The type of the output.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Outputs random integers from a uniform distribution."
|
||||||
|
description: <<END
|
||||||
|
The generated values are uniform integers in the range `[minval, maxval)`.
|
||||||
|
The lower bound `minval` is included in the range, while the upper bound
|
||||||
|
`maxval` is excluded.
|
||||||
|
|
||||||
|
The random integers are slightly biased unless `maxval - minval` is an exact
|
||||||
|
power of two. The bias is small for values of `maxval - minval` significantly
|
||||||
|
smaller than the range of the output (either `2^32` or `2^64`).
|
||||||
|
END
|
||||||
|
}
|
@ -95,6 +95,31 @@ class Var : public ResourceBase {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(Var);
|
TF_DISALLOW_COPY_AND_ASSIGN(Var);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Does unlock and unref automatically when going out of scope, and also
|
||||||
|
// supports early manual release.
|
||||||
|
class ScopedUnlockUnrefVar {
|
||||||
|
public:
|
||||||
|
explicit ScopedUnlockUnrefVar(Var* var) : var_(var) {
|
||||||
|
if (var_) {
|
||||||
|
var_->mu()->lock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Release() {
|
||||||
|
if (var_) {
|
||||||
|
var_->mu()->unlock();
|
||||||
|
var_->Unref();
|
||||||
|
var_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
~ScopedUnlockUnrefVar() { Release(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Var* var_;
|
||||||
|
|
||||||
|
ScopedUnlockUnrefVar(const ScopedUnlockUnrefVar&) = delete;
|
||||||
|
void operator=(const ScopedUnlockUnrefVar&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_
|
#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_
|
||||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
|
||||||
#include "tensorflow/core/kernels/random_op.h"
|
#include "tensorflow/core/kernels/random_op.h"
|
||||||
|
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -25,7 +25,7 @@ template <typename Distribution>
|
|||||||
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
||||||
void operator()(OpKernelContext* ctx, const CPUDevice& device,
|
void operator()(OpKernelContext* ctx, const CPUDevice& device,
|
||||||
int64 output_size, int64 alg_tag_skip,
|
int64 output_size, int64 alg_tag_skip,
|
||||||
ScopedUnlockUnref* state_var_guard, Tensor* state_tensor,
|
ScopedUnlockUnrefVar* state_var_guard, Tensor* state_tensor,
|
||||||
typename Distribution::ResultElementType* output_data) {
|
typename Distribution::ResultElementType* output_data) {
|
||||||
auto state_tensor_flat = state_tensor->flat<StateElementType>();
|
auto state_tensor_flat = state_tensor->flat<StateElementType>();
|
||||||
auto state_data = state_tensor_flat.data();
|
auto state_data = state_tensor_flat.data();
|
||||||
@ -47,11 +47,11 @@ Status UpdateVariableAndFill(
|
|||||||
Var* var = nullptr;
|
Var* var = nullptr;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
|
LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
|
||||||
// Use `ScopedUnlockUnref` here instead of `mutex_lock` and `ScopedUnref`
|
// Use `ScopedUnlockUnrefVar` here instead of `mutex_lock` and `ScopedUnref`
|
||||||
// because the former supports early releasing which is needed by
|
// because the former supports early releasing which is needed by
|
||||||
// `UpdateVariableAndFill_Philox<CPU>` to avoid holding the lock while
|
// `UpdateVariableAndFill_Philox<CPU>` to avoid holding the lock while
|
||||||
// filling.
|
// filling.
|
||||||
ScopedUnlockUnref state_var_guard(var);
|
ScopedUnlockUnrefVar state_var_guard(var);
|
||||||
Tensor* var_tensor = var->tensor();
|
Tensor* var_tensor = var->tensor();
|
||||||
if (var_tensor->dtype() != STATE_ELEMENT_DTYPE) {
|
if (var_tensor->dtype() != STATE_ELEMENT_DTYPE) {
|
||||||
return errors::InvalidArgument("dtype of RNG state variable must be ",
|
return errors::InvalidArgument("dtype of RNG state variable must be ",
|
||||||
@ -80,7 +80,7 @@ Status UpdateVariableAndFill(
|
|||||||
"PhiloxRandom::ResultElementType must be uint32");
|
"PhiloxRandom::ResultElementType must be uint32");
|
||||||
if (var_tensor_flat.size() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) {
|
if (var_tensor_flat.size() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"For Philox algorithm, the size of state"
|
"For the Philox algorithm, the size of state"
|
||||||
" must be at least ",
|
" must be at least ",
|
||||||
alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ",
|
alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ",
|
||||||
var_tensor_flat.size());
|
var_tensor_flat.size());
|
||||||
@ -132,6 +132,11 @@ class StatefulRandomOpV2 : public OpKernel {
|
|||||||
OP_REQUIRES(ctx, alg_tensor.dims() == 0,
|
OP_REQUIRES(ctx, alg_tensor.dims() == 0,
|
||||||
errors::InvalidArgument("algorithm must be of shape [], not ",
|
errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||||
alg_tensor.shape().DebugString()));
|
alg_tensor.shape().DebugString()));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, alg_tensor.dtype() == ALGORITHM_DTYPE,
|
||||||
|
errors::InvalidArgument("algorithm's dtype must be ",
|
||||||
|
DataTypeString(ALGORITHM_DTYPE), ", not ",
|
||||||
|
DataTypeString(alg_tensor.dtype())));
|
||||||
auto alg = alg_tensor.flat<Algorithm>()(0);
|
auto alg = alg_tensor.flat<Algorithm>()(0);
|
||||||
ComputeImpl<Device, Distribution>(ctx, 0, 2, false, alg);
|
ComputeImpl<Device, Distribution>(ctx, 0, 2, false, alg);
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_
|
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/resource_var.h"
|
// #include "tensorflow/core/framework/resource_var.h"
|
||||||
#include "tensorflow/core/lib/random/philox_random.h"
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -28,7 +28,9 @@ using StateElementType = int64;
|
|||||||
static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64;
|
static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64;
|
||||||
|
|
||||||
using Algorithm = StateElementType;
|
using Algorithm = StateElementType;
|
||||||
|
static constexpr DataType ALGORITHM_DTYPE = STATE_ELEMENT_DTYPE;
|
||||||
static constexpr Algorithm RNG_ALG_PHILOX = 1;
|
static constexpr Algorithm RNG_ALG_PHILOX = 1;
|
||||||
|
static constexpr Algorithm RNG_ALG_THREEFRY = 2;
|
||||||
|
|
||||||
using random::PhiloxRandom;
|
using random::PhiloxRandom;
|
||||||
|
|
||||||
@ -36,109 +38,7 @@ static constexpr int64 PHILOX_MIN_STATE_SIZE =
|
|||||||
(PhiloxRandom::ResultType::kElementCount +
|
(PhiloxRandom::ResultType::kElementCount +
|
||||||
PhiloxRandom::Key::kElementCount) /
|
PhiloxRandom::Key::kElementCount) /
|
||||||
2;
|
2;
|
||||||
|
static constexpr int64 THREEFRY_MIN_STATE_SIZE = 2;
|
||||||
// The following 5 functions are made templates to avoid duplicate symbols when
|
|
||||||
// linking.
|
|
||||||
|
|
||||||
// The following two functions use the contract "lower 32 bits for the first
|
|
||||||
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
|
|
||||||
// unlike a direct memory copy `memcpy(output, &input, 8)`.
|
|
||||||
template <typename INT64>
|
|
||||||
PHILOX_DEVICE_FUNC void Int64ToUint32s(INT64 input, uint32* output1,
|
|
||||||
uint32* output2) {
|
|
||||||
auto u64 = static_cast<uint64>(input);
|
|
||||||
*output1 = static_cast<uint32>(u64);
|
|
||||||
*output2 = static_cast<uint32>(u64 >> 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename UINT32>
|
|
||||||
PHILOX_DEVICE_FUNC int64 Uint32sToInt64(UINT32 input1, UINT32 input2) {
|
|
||||||
auto u64_1 = static_cast<uint64>(input1);
|
|
||||||
auto u64_2 = static_cast<uint64>(input2);
|
|
||||||
return static_cast<int64>(u64_1 | (u64_2 << 32));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename STATE_ELEMENT_TYPE>
|
|
||||||
PHILOX_DEVICE_FUNC PhiloxRandom
|
|
||||||
GetPhiloxRandomFromMem(STATE_ELEMENT_TYPE const* ptr) {
|
|
||||||
PhiloxRandom::ResultType counter;
|
|
||||||
PhiloxRandom::Key key;
|
|
||||||
Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
|
|
||||||
Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
|
|
||||||
Int64ToUint32s(ptr[2], &key[0], &key[1]);
|
|
||||||
return PhiloxRandom(counter, key);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PHILOX_RANDOM>
|
|
||||||
PHILOX_DEVICE_FUNC void WritePhiloxRandomToMem(PHILOX_RANDOM const& philox,
|
|
||||||
StateElementType* ptr) {
|
|
||||||
PhiloxRandom::ResultType const& counter = philox.counter();
|
|
||||||
PhiloxRandom::Key const& key = philox.key();
|
|
||||||
ptr[0] = Uint32sToInt64(counter[0], counter[1]);
|
|
||||||
ptr[1] = Uint32sToInt64(counter[2], counter[3]);
|
|
||||||
ptr[2] = Uint32sToInt64(key[0], key[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PHILOX_RANDOM>
|
|
||||||
PHILOX_DEVICE_FUNC void UpdateMemWithPhiloxRandom(PHILOX_RANDOM const& philox,
|
|
||||||
int64 output_size,
|
|
||||||
StateElementType* ptr) {
|
|
||||||
auto new_philox = philox;
|
|
||||||
// Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
|
|
||||||
// it just here.
|
|
||||||
auto delta = output_size * 256;
|
|
||||||
new_philox.Skip(delta); // do the actual increasing
|
|
||||||
WritePhiloxRandomToMem(new_philox, ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Does unlock and unref automatically when going out of scope, and also
|
|
||||||
// supports early manual release.
|
|
||||||
class ScopedUnlockUnref {
|
|
||||||
public:
|
|
||||||
explicit ScopedUnlockUnref(Var* var) : var_(var) {
|
|
||||||
if (var_) {
|
|
||||||
var_->mu()->lock();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void Release() {
|
|
||||||
if (var_) {
|
|
||||||
var_->mu()->unlock();
|
|
||||||
var_->Unref();
|
|
||||||
var_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
~ScopedUnlockUnref() { Release(); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
Var* var_;
|
|
||||||
|
|
||||||
ScopedUnlockUnref(const ScopedUnlockUnref&) = delete;
|
|
||||||
void operator=(const ScopedUnlockUnref&) = delete;
|
|
||||||
};
|
|
||||||
|
|
||||||
// A per-device helper function that does the actual work for
|
|
||||||
// `UpdateVariableAndFill`.
|
|
||||||
// Reason to use functor: C++ doesn't allow function-template partial
|
|
||||||
// specialization.
|
|
||||||
template <typename Device, typename Distribution>
|
|
||||||
struct UpdateVariableAndFill_Philox;
|
|
||||||
|
|
||||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
|
||||||
|
|
||||||
using GPUDevice = Eigen::GpuDevice;
|
|
||||||
|
|
||||||
// Declares the partially GPU-specialized functor struct.
|
|
||||||
template <typename Distribution>
|
|
||||||
struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
|
|
||||||
void operator()(OpKernelContext* ctx, const GPUDevice& device,
|
|
||||||
int64 output_size, int64 alg_tag_skip,
|
|
||||||
ScopedUnlockUnref* not_used, Tensor* state_tensor,
|
|
||||||
typename Distribution::ResultElementType* output_data);
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
104
tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h
Normal file
104
tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/resource_var.h"
|
||||||
|
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// The following 5 functions are made templates to avoid duplicate symbols when
|
||||||
|
// linking.
|
||||||
|
|
||||||
|
// The following 2 functions use the contract "lower 32 bits for the first
|
||||||
|
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
|
||||||
|
// unlike a direct memory copy `memcpy(output, &input, 8)`.
|
||||||
|
template <typename INT64>
|
||||||
|
PHILOX_DEVICE_FUNC void Int64ToUint32s(INT64 input, uint32* output1,
|
||||||
|
uint32* output2) {
|
||||||
|
auto u64 = static_cast<uint64>(input);
|
||||||
|
*output1 = static_cast<uint32>(u64);
|
||||||
|
*output2 = static_cast<uint32>(u64 >> 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename UINT32>
|
||||||
|
PHILOX_DEVICE_FUNC int64 Uint32sToInt64(UINT32 input1, UINT32 input2) {
|
||||||
|
auto u64_1 = static_cast<uint64>(input1);
|
||||||
|
auto u64_2 = static_cast<uint64>(input2);
|
||||||
|
return static_cast<int64>(u64_1 | (u64_2 << 32));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename STATE_ELEMENT_TYPE>
|
||||||
|
PHILOX_DEVICE_FUNC PhiloxRandom
|
||||||
|
GetPhiloxRandomFromMem(STATE_ELEMENT_TYPE const* ptr) {
|
||||||
|
PhiloxRandom::ResultType counter;
|
||||||
|
PhiloxRandom::Key key;
|
||||||
|
Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
|
||||||
|
Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
|
||||||
|
Int64ToUint32s(ptr[2], &key[0], &key[1]);
|
||||||
|
return PhiloxRandom(counter, key);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PHILOX_RANDOM>
|
||||||
|
PHILOX_DEVICE_FUNC void WritePhiloxRandomToMem(PHILOX_RANDOM const& philox,
|
||||||
|
StateElementType* ptr) {
|
||||||
|
PhiloxRandom::ResultType const& counter = philox.counter();
|
||||||
|
PhiloxRandom::Key const& key = philox.key();
|
||||||
|
ptr[0] = Uint32sToInt64(counter[0], counter[1]);
|
||||||
|
ptr[1] = Uint32sToInt64(counter[2], counter[3]);
|
||||||
|
ptr[2] = Uint32sToInt64(key[0], key[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PHILOX_RANDOM>
|
||||||
|
PHILOX_DEVICE_FUNC void UpdateMemWithPhiloxRandom(PHILOX_RANDOM const& philox,
|
||||||
|
int64 output_size,
|
||||||
|
StateElementType* ptr) {
|
||||||
|
auto new_philox = philox;
|
||||||
|
// Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
|
||||||
|
// it just here.
|
||||||
|
auto delta = output_size * 256;
|
||||||
|
new_philox.Skip(delta); // do the actual increasing
|
||||||
|
WritePhiloxRandomToMem(new_philox, ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// A per-device helper function that does the actual work for
|
||||||
|
// `UpdateVariableAndFill`.
|
||||||
|
// Reason to use functor: C++ doesn't allow function-template partial
|
||||||
|
// specialization.
|
||||||
|
template <typename Device, typename Distribution>
|
||||||
|
struct UpdateVariableAndFill_Philox;
|
||||||
|
|
||||||
|
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
|
// Declares the partially GPU-specialized functor struct.
|
||||||
|
template <typename Distribution>
|
||||||
|
struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
|
||||||
|
void operator()(OpKernelContext* ctx, const GPUDevice& device,
|
||||||
|
int64 output_size, int64 alg_tag_skip,
|
||||||
|
ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
|
||||||
|
typename Distribution::ResultElementType* output_data);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/random_op_gpu.h"
|
#include "tensorflow/core/kernels/random_op_gpu.h"
|
||||||
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||||
#include "tensorflow/core/util/cuda_launch_config.h"
|
#include "tensorflow/core/util/cuda_launch_config.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -54,7 +54,7 @@ __global__ void FillKernel(
|
|||||||
template <typename Distribution>
|
template <typename Distribution>
|
||||||
void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
||||||
OpKernelContext* ctx, const GPUDevice& d, int64 output_size,
|
OpKernelContext* ctx, const GPUDevice& d, int64 output_size,
|
||||||
int64 alg_tag_skip, ScopedUnlockUnref* not_used, Tensor* state_tensor,
|
int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor,
|
||||||
typename Distribution::ResultElementType* output_data) {
|
typename Distribution::ResultElementType* output_data) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, alg_tag_skip == 0,
|
ctx, alg_tag_skip == 0,
|
||||||
|
@ -20,11 +20,9 @@ namespace tensorflow {
|
|||||||
|
|
||||||
Status StatefulRandomShape(shape_inference::InferenceContext* c) {
|
Status StatefulRandomShape(shape_inference::InferenceContext* c) {
|
||||||
using shape_inference::ShapeHandle;
|
using shape_inference::ShapeHandle;
|
||||||
|
|
||||||
// Check algorithm shape
|
// Check algorithm shape
|
||||||
ShapeHandle unused;
|
ShapeHandle unused;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
|
||||||
// Set output shape
|
// Set output shape
|
||||||
ShapeHandle out;
|
ShapeHandle out;
|
||||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));
|
||||||
@ -32,16 +30,45 @@ Status StatefulRandomShape(shape_inference::InferenceContext* c) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_OP("StatefulStandardNormalV2")
|
#define REGISTER_STATEFUL_OP(name, default_dtype) \
|
||||||
|
REGISTER_OP(name) \
|
||||||
|
.Input("resource: resource") \
|
||||||
|
.Input("algorithm: int64") \
|
||||||
|
.Input("shape: shape_dtype") \
|
||||||
|
.Output("output: dtype") \
|
||||||
|
.Attr("dtype : type = " #default_dtype) \
|
||||||
|
.Attr("shape_dtype : type = DT_INT64") \
|
||||||
|
.SetShapeFn(StatefulRandomShape);
|
||||||
|
|
||||||
|
REGISTER_STATEFUL_OP("StatefulUniformFullInt", DT_UINT64);
|
||||||
|
REGISTER_STATEFUL_OP("StatefulStandardNormalV2", DT_FLOAT);
|
||||||
|
|
||||||
|
REGISTER_OP("StatefulUniformInt")
|
||||||
.Input("resource: resource")
|
.Input("resource: resource")
|
||||||
.Input("algorithm: int64")
|
.Input("algorithm: int64")
|
||||||
.Input("shape: shape_dtype")
|
.Input("shape: shape_dtype")
|
||||||
|
.Input("minval: dtype")
|
||||||
|
.Input("maxval: dtype")
|
||||||
.Output("output: dtype")
|
.Output("output: dtype")
|
||||||
.Attr("dtype : type = DT_FLOAT")
|
.Attr("dtype : type = DT_INT64")
|
||||||
.Attr("shape_dtype : type = DT_INT64")
|
.Attr("shape_dtype : type = DT_INT64")
|
||||||
.SetShapeFn(StatefulRandomShape);
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
using shape_inference::ShapeHandle;
|
||||||
|
// Check inputs
|
||||||
|
ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||||
|
// Set output
|
||||||
|
ShapeHandle out;
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));
|
||||||
|
c->set_output(0, out);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
// Register the old 'StatefulStandardNormal' op
|
// Register the old 'StatefulStandardNormal' op. This op is a short-lived
|
||||||
|
// version where the 'resource' variable also contains the algorithm tag.
|
||||||
|
// It is deprecated in favor of 'StatefulStandardNormalV2'.
|
||||||
REGISTER_OP("StatefulStandardNormal")
|
REGISTER_OP("StatefulStandardNormal")
|
||||||
.Input("resource: resource")
|
.Input("resource: resource")
|
||||||
.Input("shape: shape_dtype")
|
.Input("shape: shape_dtype")
|
||||||
|
@ -53,9 +53,14 @@ SEED_SIZE = 16 # in units of SEED_TYPE
|
|||||||
STATE_TYPE = SEED_TYPE
|
STATE_TYPE = SEED_TYPE
|
||||||
ALGORITHM_TYPE = STATE_TYPE
|
ALGORITHM_TYPE = STATE_TYPE
|
||||||
RNG_ALG_PHILOX = 1
|
RNG_ALG_PHILOX = 1
|
||||||
|
RNG_ALG_THREEFRY = 2
|
||||||
DEFAULT_ALGORITHM = RNG_ALG_PHILOX
|
DEFAULT_ALGORITHM = RNG_ALG_PHILOX
|
||||||
|
|
||||||
|
|
||||||
|
PHILOX_STATE_SIZE = 3
|
||||||
|
THREEFRY_STATE_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
def non_deterministic_seed():
|
def non_deterministic_seed():
|
||||||
"""Makes a non-deterministic seed.
|
"""Makes a non-deterministic seed.
|
||||||
|
|
||||||
@ -75,7 +80,40 @@ def _uint_to_int(n):
|
|||||||
return n
|
return n
|
||||||
|
|
||||||
|
|
||||||
PHILOX_STATE_SIZE = 3
|
def _make_1d_state(state_size, seed):
|
||||||
|
"""Makes a 1-D RNG state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_size: an integer.
|
||||||
|
seed: an integer or 1-D tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
|
||||||
|
"""
|
||||||
|
int_types = (int,) if sys.version_info >= (3, 0) else (int, long)
|
||||||
|
if isinstance(seed, int_types):
|
||||||
|
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
|
||||||
|
ls = []
|
||||||
|
for _ in range(state_size):
|
||||||
|
ls.append(seed & SEED_BIT_MASK)
|
||||||
|
seed >>= SEED_TYPE_BITS
|
||||||
|
seed = ls
|
||||||
|
# to avoid overflow error from np.asarray
|
||||||
|
seed = list(map(_uint_to_int, seed))
|
||||||
|
seed = np.asarray(seed, dtype=STATE_TYPE)
|
||||||
|
if len(seed.shape) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"seed should only have one dimension; got shape: %s" % seed.shape)
|
||||||
|
seed = seed[0:state_size]
|
||||||
|
# Padding with zeros on the right if too short
|
||||||
|
seed_size = seed.shape[0]
|
||||||
|
if seed_size < state_size:
|
||||||
|
seed = np.pad(
|
||||||
|
seed, [(0, state_size - seed_size)],
|
||||||
|
mode="constant",
|
||||||
|
constant_values=0)
|
||||||
|
assert seed.shape == (state_size,), "Wrong seed.shape: %s" % seed.shape
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
def _make_philox_state(seed):
|
def _make_philox_state(seed):
|
||||||
@ -87,35 +125,26 @@ def _make_philox_state(seed):
|
|||||||
Returns:
|
Returns:
|
||||||
a 1-D tensor.
|
a 1-D tensor.
|
||||||
"""
|
"""
|
||||||
int_types = (int,) if sys.version_info >= (3, 0) else (int, long)
|
return _make_1d_state(PHILOX_STATE_SIZE, seed)
|
||||||
if isinstance(seed, int_types):
|
|
||||||
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
|
|
||||||
ls = []
|
def _make_threefry_state(seed):
|
||||||
for _ in range(PHILOX_STATE_SIZE):
|
"""Makes a RNG state for ThreeFry algorithm.
|
||||||
ls.append(seed & SEED_BIT_MASK)
|
|
||||||
seed >>= SEED_TYPE_BITS
|
Args:
|
||||||
seed = ls
|
seed: an integer or 1-D tensor.
|
||||||
# to avoid overflow error from np.asarray
|
|
||||||
seed = list(map(_uint_to_int, seed))
|
Returns:
|
||||||
seed = np.asarray(seed, dtype=STATE_TYPE)
|
a 1-D tensor.
|
||||||
if len(seed.shape) != 1:
|
"""
|
||||||
raise ValueError(
|
return _make_1d_state(THREEFRY_STATE_SIZE, seed)
|
||||||
"seed should only have one dimension; got shape: %s" % seed.shape)
|
|
||||||
seed = seed[0:PHILOX_STATE_SIZE]
|
|
||||||
# Padding with zeros on the right if too short
|
|
||||||
seed_size = seed.shape[0]
|
|
||||||
if seed_size < PHILOX_STATE_SIZE:
|
|
||||||
seed = np.pad(
|
|
||||||
seed, [(0, PHILOX_STATE_SIZE - seed_size)],
|
|
||||||
mode="constant",
|
|
||||||
constant_values=0)
|
|
||||||
assert seed.shape == (PHILOX_STATE_SIZE,), "Wrong seed.shape: %s" % seed.shape
|
|
||||||
return seed
|
|
||||||
|
|
||||||
|
|
||||||
def _make_state_from_seed(seed, algorithm):
|
def _make_state_from_seed(seed, algorithm):
|
||||||
if algorithm == RNG_ALG_PHILOX:
|
if algorithm == RNG_ALG_PHILOX:
|
||||||
return _make_philox_state(seed)
|
return _make_philox_state(seed)
|
||||||
|
elif algorithm == RNG_ALG_THREEFRY:
|
||||||
|
return _make_threefry_state(seed)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported algorithm id: %s" % algorithm)
|
raise ValueError("Unsupported algorithm id: %s" % algorithm)
|
||||||
|
|
||||||
@ -168,26 +197,19 @@ class Generator(tracking.AutoTrackable):
|
|||||||
algorithm = DEFAULT_ALGORITHM
|
algorithm = DEFAULT_ALGORITHM
|
||||||
state = create_rng_state(seed, algorithm)
|
state = create_rng_state(seed, algorithm)
|
||||||
self._state_var = variables.Variable(state, dtype=STATE_TYPE)
|
self._state_var = variables.Variable(state, dtype=STATE_TYPE)
|
||||||
self._alg_var = variables.Variable(initial_value=algorithm,
|
self._alg_var = algorithm
|
||||||
dtype=ALGORITHM_TYPE)
|
|
||||||
else:
|
else:
|
||||||
assert seed is None
|
assert seed is None
|
||||||
self._state_var = variables.Variable(copy_from.state, dtype=STATE_TYPE)
|
self._state_var = variables.Variable(copy_from.state, dtype=STATE_TYPE)
|
||||||
self._alg_var = variables.Variable(initial_value=copy_from.algorithm,
|
self._alg_var = copy_from.algorithm
|
||||||
dtype=ALGORITHM_TYPE)
|
|
||||||
|
|
||||||
def reset(self, seed):
|
def reset(self, seed):
|
||||||
"""Resets the generator.
|
"""Resets the generator.
|
||||||
|
|
||||||
This function is not thread-safe: if it is run concurrently with a call to
|
|
||||||
sampling, the latter might see the new algorithm but the old state or vice
|
|
||||||
versa.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: the seed to reset the RNG to.
|
seed: the seed to reset the RNG to.
|
||||||
"""
|
"""
|
||||||
algorithm = int(self.algorithm)
|
state = create_rng_state(seed, self.algorithm)
|
||||||
state = create_rng_state(seed, algorithm)
|
|
||||||
self._state_var.assign(state)
|
self._state_var.assign(state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -200,21 +222,59 @@ class Generator(tracking.AutoTrackable):
|
|||||||
|
|
||||||
# The following functions return a tensor and as a side effect update
|
# The following functions return a tensor and as a side effect update
|
||||||
# self._state_var.
|
# self._state_var.
|
||||||
def standard_normal(self, shape, dtype=dtypes.float32):
|
|
||||||
output = gen_stateful_random_ops.stateful_standard_normal_v2(
|
|
||||||
self.state.handle, self.algorithm, shape, dtype)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
|
def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
|
||||||
name=None):
|
name=None):
|
||||||
with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name:
|
with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name:
|
||||||
shape = _shape_tensor(shape)
|
shape = _shape_tensor(shape)
|
||||||
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||||
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||||
rnd = self.standard_normal(shape, dtype)
|
rnd = gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
|
self.state.handle, self.algorithm, shape, dtype=dtype)
|
||||||
return math_ops.add(rnd * stddev, mean, name=name)
|
return math_ops.add(rnd * stddev, mean, name=name)
|
||||||
|
|
||||||
# TODO(wangpeng): implement other distributions (`uniform`,
|
def uniform(self, shape, minval=0, maxval=None,
|
||||||
|
dtype=dtypes.float32, name=None):
|
||||||
|
dtype = dtypes.as_dtype(dtype)
|
||||||
|
if maxval is None:
|
||||||
|
if dtype.is_integer:
|
||||||
|
raise ValueError("Must specify maxval for integer dtype %r" % dtype)
|
||||||
|
maxval = 1
|
||||||
|
with ops.name_scope(name, "stateful_uniform",
|
||||||
|
[shape, minval, maxval]) as name:
|
||||||
|
shape = _shape_tensor(shape)
|
||||||
|
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
||||||
|
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||||
|
if dtype.is_integer:
|
||||||
|
return gen_stateful_random_ops.stateful_uniform_int(
|
||||||
|
self.state.handle, self.algorithm, shape=shape,
|
||||||
|
minval=minval, maxval=maxval, name=name)
|
||||||
|
else:
|
||||||
|
# TODO(wangpeng): implement uniform for floats
|
||||||
|
raise ValueError("uniform for floats not implemented yet")
|
||||||
|
|
||||||
|
def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
|
||||||
|
"""Uniform distribution on an integer type's entire range.
|
||||||
|
|
||||||
|
The other method `uniform` only covers the range [minval, maxval), which
|
||||||
|
cannot be `dtype`'s full range because `maxval` is of type `dtype`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: the shape of the output.
|
||||||
|
dtype: (optional) the integer type, default to uint64.
|
||||||
|
name: (optional) the name of the node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of random numbers of the required shape.
|
||||||
|
"""
|
||||||
|
dtype = dtypes.as_dtype(dtype)
|
||||||
|
with ops.name_scope(name, "stateful_uniform_full_int",
|
||||||
|
[shape]) as name:
|
||||||
|
shape = _shape_tensor(shape)
|
||||||
|
return gen_stateful_random_ops.stateful_uniform_full_int(
|
||||||
|
self.state.handle, self.algorithm, shape=shape,
|
||||||
|
dtype=dtype, name=name)
|
||||||
|
|
||||||
|
# TODO(wangpeng): implement other distributions (
|
||||||
# `truncated_normal`, etc.)
|
# `truncated_normal`, etc.)
|
||||||
# TODO(wangpeng): implement `make_seeds`
|
# TODO(wangpeng): implement `make_seeds`
|
||||||
# TODO(wangpeng): implement `make_generators`
|
# TODO(wangpeng): implement `make_generators`
|
||||||
|
@ -23,6 +23,7 @@ import numpy as np
|
|||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gen_random_ops
|
from tensorflow.python.ops import gen_random_ops
|
||||||
@ -138,7 +139,7 @@ class StatefulRandomOpsTest(test.TestCase):
|
|||||||
|
|
||||||
def new():
|
def new():
|
||||||
with ops.device("/device:CPU:0"):
|
with ops.device("/device:CPU:0"):
|
||||||
return random.get_global_generator().standard_normal(shape, dtype=dtype)
|
return random.get_global_generator().normal(shape, dtype=dtype)
|
||||||
|
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
self.assertAllEqual(old(), new())
|
self.assertAllEqual(old(), new())
|
||||||
@ -164,7 +165,7 @@ class StatefulRandomOpsTest(test.TestCase):
|
|||||||
|
|
||||||
def new():
|
def new():
|
||||||
with ops.device(test_util.gpu_device_name()):
|
with ops.device(test_util.gpu_device_name()):
|
||||||
return random.get_global_generator().standard_normal(shape, dtype=dtype)
|
return random.get_global_generator().normal(shape, dtype=dtype)
|
||||||
|
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
self.assertAllEqual(old(), new())
|
self.assertAllEqual(old(), new())
|
||||||
@ -203,7 +204,7 @@ class StatefulRandomOpsTest(test.TestCase):
|
|||||||
|
|
||||||
random.reset_global_generator(50)
|
random.reset_global_generator(50)
|
||||||
with self.assertRaisesWithPredicateMatch(
|
with self.assertRaisesWithPredicateMatch(
|
||||||
AssertionError, "variable.*deleted"):
|
errors_impl.NotFoundError, "Resource .+ does not exist"):
|
||||||
a = f()
|
a = f()
|
||||||
random.reset_global_generator(50)
|
random.reset_global_generator(50)
|
||||||
b = f()
|
b = f()
|
||||||
|
Loading…
Reference in New Issue
Block a user