A continuing partial implementation of RFC "Random numbers in TensorFlow 2.0" (https://github.com/tensorflow/community/blob/master/rfcs/20181217-tf2-random-numbers.md):
In this change: - XLA kernels for op 'StatefulUniform' and 'StatefulTruncatedNormal'. To be done: - ops for other distributions; - other RNG algorithms; - batch seeds; - initializers ('RandomUniform', etc.); PiperOrigin-RevId: 240658747
This commit is contained in:
parent
b3be250ccf
commit
99c4d2ae1a
@ -53,6 +53,9 @@ def xla_device_name():
|
||||
class StatefulRandomOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for stateful random-number generator operators."""
|
||||
|
||||
_ints = [dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64]
|
||||
_floats = [dtypes.bfloat16, dtypes.float32]
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testSimple(self):
|
||||
"""A simple test.
|
||||
@ -147,7 +150,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
|
||||
maxval = 10000000
|
||||
return gen.uniform(shape=[2], dtype=dtype, maxval=maxval)
|
||||
|
||||
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
|
||||
for dtype in self._ints + self._floats:
|
||||
self._testRngIsNotConstant(rng, dtype)
|
||||
|
||||
@test_util.run_v2_only
|
||||
@ -157,27 +160,27 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
|
||||
def rng(dtype):
|
||||
return gen.normal(shape=[2], dtype=dtype)
|
||||
|
||||
for dtype in {dtypes.float32}:
|
||||
for dtype in self._floats:
|
||||
self._testRngIsNotConstant(rng, dtype)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testUniformIntIsInRange(self):
|
||||
def testUniformIsInRange(self):
|
||||
minval = 2
|
||||
maxval = 33
|
||||
size = 1000
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
|
||||
for dtype in self._ints + self._floats:
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
x = gen.uniform(
|
||||
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
|
||||
self.assertTrue(np.all(x >= minval))
|
||||
self.assertTrue(np.all(x < maxval))
|
||||
self.assertTrue(np.all(x <= maxval))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testNormalIsFinite(self):
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
for dtype in {dtypes.float32}:
|
||||
for dtype in self._floats:
|
||||
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
||||
self.assertTrue(np.all(np.isfinite(x)))
|
||||
|
||||
@ -187,7 +190,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
|
||||
with ops.device(xla_device_name()):
|
||||
n = 1000
|
||||
seed = 12
|
||||
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
|
||||
for dtype in self._ints + self._floats:
|
||||
gen = random.Generator(seed=seed, algorithm=random.RNG_ALG_THREEFRY)
|
||||
maxval = 1
|
||||
if dtype.is_integer:
|
||||
@ -208,7 +211,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
|
||||
"""Use Anderson-Darling test to test distribution appears normal."""
|
||||
with ops.device(xla_device_name()):
|
||||
n = 1000
|
||||
for dtype in {dtypes.float32}:
|
||||
for dtype in self._floats:
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
x = gen.normal(shape=[n], dtype=dtype).numpy()
|
||||
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
||||
@ -217,6 +220,15 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
|
||||
self.assertLess(
|
||||
random_test_util.anderson_darling(x.astype(float)), 2.492)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTruncatedNormal(self):
|
||||
for dtype in self._floats:
|
||||
gen = random.Generator(seed=123)
|
||||
n = 10000000
|
||||
y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
|
||||
random_test_util.test_truncated_normal(
|
||||
self.assertEqual, self.assertAllClose, dtype, n, y)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testErrors(self):
|
||||
"""Tests that proper errors are raised.
|
||||
|
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
@ -28,7 +26,6 @@ from tensorflow.python.kernel_tests.random import util as \
|
||||
random_test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import stateless_random_ops as stateless
|
||||
from tensorflow.python.ops.distributions import special_math
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -122,7 +119,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
||||
self.assertLess(
|
||||
random_test_util.anderson_darling(y.astype(float)), 2.492)
|
||||
|
||||
def testTruncatedNormalIsInRange(self):
|
||||
def testTruncatedNormal(self):
|
||||
for dtype in self._random_types():
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
|
||||
@ -130,47 +127,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
||||
x = stateless.stateless_truncated_normal(
|
||||
shape=[n], seed=seed_t, dtype=dtype)
|
||||
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
|
||||
|
||||
def normal_cdf(x):
|
||||
return .5 * math.erfc(-x / math.sqrt(2))
|
||||
|
||||
def normal_pdf(x):
|
||||
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
|
||||
|
||||
def probit(x):
|
||||
return self.evaluate(special_math.ndtri(x))
|
||||
|
||||
a = -2.
|
||||
b = 2.
|
||||
mu = 0.
|
||||
sigma = 1.
|
||||
|
||||
alpha = (a - mu) / sigma
|
||||
beta = (b - mu) / sigma
|
||||
z = normal_cdf(beta) - normal_cdf(alpha)
|
||||
|
||||
self.assertEqual((y >= a).sum(), n)
|
||||
self.assertEqual((y <= b).sum(), n)
|
||||
|
||||
# For more information on these calculations, see:
|
||||
# Burkardt, John. "The Truncated Normal Distribution".
|
||||
# Department of Scientific Computing website. Florida State University.
|
||||
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
|
||||
y = y.astype(float)
|
||||
actual_mean = np.mean(y)
|
||||
self.assertAllClose(actual_mean, expected_mean, atol=5e-4)
|
||||
|
||||
expected_median = mu + probit(
|
||||
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
|
||||
actual_median = np.median(y)
|
||||
self.assertAllClose(actual_median, expected_median, atol=8e-4)
|
||||
|
||||
expected_variance = sigma**2 * (1 + (
|
||||
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
|
||||
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
|
||||
actual_variance = np.var(y)
|
||||
self.assertAllClose(actual_variance, expected_variance,
|
||||
rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3)
|
||||
random_test_util.test_truncated_normal(
|
||||
self.assertEqual, self.assertAllClose, dtype, n, y)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -112,11 +112,13 @@ std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
|
||||
counter);
|
||||
}
|
||||
default:
|
||||
return std::make_pair(builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by "
|
||||
"StatefulRngUniform.")),
|
||||
counter);
|
||||
return std::make_pair(
|
||||
builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by "
|
||||
"StatefulRngUniform; got: %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
counter);
|
||||
}
|
||||
}
|
||||
|
||||
@ -243,6 +245,46 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
||||
}
|
||||
}
|
||||
|
||||
class StatefulUniformOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatefulUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry = [builder, this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
||||
return {{uniform, counter}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformOp);
|
||||
};
|
||||
|
||||
// TODO(wangpeng): Support plain float16 and float64 to get rid of the
|
||||
// `TypeConstraint`.
|
||||
REGISTER_XLA_OP(Name("StatefulUniform")
|
||||
.CompileTimeConstantInput("algorithm")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}),
|
||||
StatefulUniformOp);
|
||||
|
||||
class StatefulStandardNormalOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatefulStandardNormalOp(OpKernelConstruction* ctx)
|
||||
@ -291,6 +333,51 @@ REGISTER_XLA_OP(Name("StatefulStandardNormalV2")
|
||||
.TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}),
|
||||
StatefulStandardNormalOp);
|
||||
|
||||
class StatefulTruncatedNormalOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatefulTruncatedNormalOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry =
|
||||
// Needs explicit lambda return type because it fails to be inferred.
|
||||
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape,
|
||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||
xla::One(builder, xla_shape.element_type()));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
||||
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
|
||||
return {{truncated_normal, counter}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatefulTruncatedNormalOp);
|
||||
};
|
||||
|
||||
// TODO(wangpeng): Support plain float16 and float64 to get rid of the
|
||||
// `TypeConstraint`.
|
||||
REGISTER_XLA_OP(Name("StatefulTruncatedNormal")
|
||||
.CompileTimeConstantInput("algorithm")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}),
|
||||
StatefulTruncatedNormalOp);
|
||||
|
||||
class StatefulUniformIntOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
|
@ -55,7 +55,7 @@ xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) {
|
||||
// values with uniform distribution in the range [minval, maxval) for the given
|
||||
// shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and
|
||||
// S64 are implemented.
|
||||
xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType dtype,
|
||||
xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType unused,
|
||||
xla::XlaOp seed, xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
|
@ -83,6 +83,8 @@ CreateResourceOpInfoMap() {
|
||||
add("ResourceScatterUpdate" , kReadWrite, kVariable);
|
||||
add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
|
||||
add("StatefulStandardNormalV2" , kReadWrite, kVariable);
|
||||
add("StatefulTruncatedNormal" , kReadWrite, kVariable);
|
||||
add("StatefulUniform" , kReadWrite, kVariable);
|
||||
add("StatefulUniformFullInt" , kReadWrite, kVariable);
|
||||
add("StatefulUniformInt" , kReadWrite, kVariable);
|
||||
add("VarIsInitializedOp" , kRead, kVariable);
|
||||
|
@ -22,6 +22,9 @@ import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops.distributions import special_math
|
||||
|
||||
|
||||
def test_moment_matching(
|
||||
samples,
|
||||
@ -95,3 +98,47 @@ def anderson_darling(x):
|
||||
z = np.sum((2 * i - 1) * np.log(normal_cdf(x)) +
|
||||
(2 * (n - i) + 1) * np.log(1 - normal_cdf(x)))
|
||||
return -n - z / n
|
||||
|
||||
|
||||
def test_truncated_normal(assert_equal, assert_all_close, dtype, n, y):
|
||||
"""Tests truncated normal distribution's statistics."""
|
||||
def _normal_cdf(x):
|
||||
return .5 * math.erfc(-x / math.sqrt(2))
|
||||
|
||||
def normal_pdf(x):
|
||||
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
|
||||
|
||||
def probit(x):
|
||||
return special_math.ndtri(x)
|
||||
|
||||
a = -2.
|
||||
b = 2.
|
||||
mu = 0.
|
||||
sigma = 1.
|
||||
|
||||
alpha = (a - mu) / sigma
|
||||
beta = (b - mu) / sigma
|
||||
z = _normal_cdf(beta) - _normal_cdf(alpha)
|
||||
|
||||
assert_equal((y >= a).sum(), n)
|
||||
assert_equal((y <= b).sum(), n)
|
||||
|
||||
# For more information on these calculations, see:
|
||||
# Burkardt, John. "The Truncated Normal Distribution".
|
||||
# Department of Scientific Computing website. Florida State University.
|
||||
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
|
||||
y = y.astype(float)
|
||||
actual_mean = np.mean(y)
|
||||
assert_all_close(actual_mean, expected_mean, atol=5e-4)
|
||||
|
||||
expected_median = mu + probit(
|
||||
(_normal_cdf(alpha) + _normal_cdf(beta)) / 2.) * sigma
|
||||
actual_median = np.median(y)
|
||||
assert_all_close(actual_median, expected_median, atol=8e-4)
|
||||
|
||||
expected_variance = sigma**2 * (1 + (
|
||||
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
|
||||
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
|
||||
actual_variance = np.var(y)
|
||||
assert_all_close(actual_variance, expected_variance,
|
||||
rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3)
|
||||
|
Loading…
Reference in New Issue
Block a user