A continuing partial implementation of RFC "Random numbers in TensorFlow 2.0" (https://github.com/tensorflow/community/blob/master/rfcs/20181217-tf2-random-numbers.md):

In this change:
- XLA kernels for op 'StatefulUniform' and 'StatefulTruncatedNormal'.

To be done:
- ops for other distributions;
- other RNG algorithms;
- batch seeds;
- initializers ('RandomUniform', etc.);

PiperOrigin-RevId: 240658747
This commit is contained in:
Peng Wang 2019-03-27 15:51:44 -07:00 committed by TensorFlower Gardener
parent b3be250ccf
commit 99c4d2ae1a
6 changed files with 166 additions and 60 deletions

View File

@ -53,6 +53,9 @@ def xla_device_name():
class StatefulRandomOpsTest(xla_test.XLATestCase):
"""Test cases for stateful random-number generator operators."""
_ints = [dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64]
_floats = [dtypes.bfloat16, dtypes.float32]
@test_util.run_v2_only
def testSimple(self):
"""A simple test.
@ -147,7 +150,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
maxval = 10000000
return gen.uniform(shape=[2], dtype=dtype, maxval=maxval)
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
for dtype in self._ints + self._floats:
self._testRngIsNotConstant(rng, dtype)
@test_util.run_v2_only
@ -157,27 +160,27 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return gen.normal(shape=[2], dtype=dtype)
for dtype in {dtypes.float32}:
for dtype in self._floats:
self._testRngIsNotConstant(rng, dtype)
@test_util.run_v2_only
def testUniformIntIsInRange(self):
def testUniformIsInRange(self):
minval = 2
maxval = 33
size = 1000
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
for dtype in self._ints + self._floats:
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
x = gen.uniform(
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
self.assertTrue(np.all(x >= minval))
self.assertTrue(np.all(x < maxval))
self.assertTrue(np.all(x <= maxval))
@test_util.run_v2_only
def testNormalIsFinite(self):
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
for dtype in {dtypes.float32}:
for dtype in self._floats:
x = gen.normal(shape=[10000], dtype=dtype).numpy()
self.assertTrue(np.all(np.isfinite(x)))
@ -187,7 +190,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
with ops.device(xla_device_name()):
n = 1000
seed = 12
for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}:
for dtype in self._ints + self._floats:
gen = random.Generator(seed=seed, algorithm=random.RNG_ALG_THREEFRY)
maxval = 1
if dtype.is_integer:
@ -208,7 +211,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
"""Use Anderson-Darling test to test distribution appears normal."""
with ops.device(xla_device_name()):
n = 1000
for dtype in {dtypes.float32}:
for dtype in self._floats:
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
x = gen.normal(shape=[n], dtype=dtype).numpy()
# The constant 2.492 is the 5% critical value for the Anderson-Darling
@ -217,6 +220,15 @@ class StatefulRandomOpsTest(xla_test.XLATestCase):
self.assertLess(
random_test_util.anderson_darling(x.astype(float)), 2.492)
@test_util.run_v2_only
def testTruncatedNormal(self):
for dtype in self._floats:
gen = random.Generator(seed=123)
n = 10000000
y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
random_test_util.test_truncated_normal(
self.assertEqual, self.assertAllClose, dtype, n, y)
@test_util.run_v2_only
def testErrors(self):
"""Tests that proper errors are raised.

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
from tensorflow.compiler.tests import xla_test
@ -28,7 +26,6 @@ from tensorflow.python.kernel_tests.random import util as \
random_test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import stateless_random_ops as stateless
from tensorflow.python.ops.distributions import special_math
from tensorflow.python.platform import test
@ -122,7 +119,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertLess(
random_test_util.anderson_darling(y.astype(float)), 2.492)
def testTruncatedNormalIsInRange(self):
def testTruncatedNormal(self):
for dtype in self._random_types():
with self.cached_session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
@ -130,47 +127,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
x = stateless.stateless_truncated_normal(
shape=[n], seed=seed_t, dtype=dtype)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
def normal_cdf(x):
return .5 * math.erfc(-x / math.sqrt(2))
def normal_pdf(x):
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x):
return self.evaluate(special_math.ndtri(x))
a = -2.
b = 2.
mu = 0.
sigma = 1.
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
z = normal_cdf(beta) - normal_cdf(alpha)
self.assertEqual((y >= a).sum(), n)
self.assertEqual((y <= b).sum(), n)
# For more information on these calculations, see:
# Burkardt, John. "The Truncated Normal Distribution".
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
y = y.astype(float)
actual_mean = np.mean(y)
self.assertAllClose(actual_mean, expected_mean, atol=5e-4)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
actual_median = np.median(y)
self.assertAllClose(actual_median, expected_median, atol=8e-4)
expected_variance = sigma**2 * (1 + (
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
actual_variance = np.var(y)
self.assertAllClose(actual_variance, expected_variance,
rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3)
random_test_util.test_truncated_normal(
self.assertEqual, self.assertAllClose, dtype, n, y)
if __name__ == '__main__':

View File

@ -112,11 +112,13 @@ std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
counter);
}
default:
return std::make_pair(builder->ReportError(xla::Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by "
"StatefulRngUniform.")),
counter);
return std::make_pair(
builder->ReportError(xla::Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by "
"StatefulRngUniform; got: %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
counter);
}
}
@ -243,6 +245,46 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
}
}
class StatefulUniformOp : public XlaOpKernel {
public:
explicit StatefulUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
auto builder = ctx->builder();
auto sample_with_threefry = [builder, this](
xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
auto uniform_counter = StatefulRngUniform(
key, counter, xla_shape, xla::ConstantR0<float>(builder, 0.0),
xla::ConstantR0<float>(builder, 1.0));
auto uniform = uniform_counter.first;
counter = uniform_counter.second;
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
return {{uniform, counter}};
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformOp);
};
// TODO(wangpeng): Support plain float16 and float64 to get rid of the
// `TypeConstraint`.
REGISTER_XLA_OP(Name("StatefulUniform")
.CompileTimeConstantInput("algorithm")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}),
StatefulUniformOp);
class StatefulStandardNormalOp : public XlaOpKernel {
public:
explicit StatefulStandardNormalOp(OpKernelConstruction* ctx)
@ -291,6 +333,51 @@ REGISTER_XLA_OP(Name("StatefulStandardNormalV2")
.TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}),
StatefulStandardNormalOp);
class StatefulTruncatedNormalOp : public XlaOpKernel {
public:
explicit StatefulTruncatedNormalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
auto builder = ctx->builder();
auto sample_with_threefry =
// Needs explicit lambda return type because it fails to be inferred.
[builder, this](xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
auto uniform_counter = StatefulRngUniform(
key, counter, xla_shape,
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
xla::One(builder, xla_shape.element_type()));
auto uniform = uniform_counter.first;
counter = uniform_counter.second;
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
return {{truncated_normal, counter}};
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StatefulTruncatedNormalOp);
};
// TODO(wangpeng): Support plain float16 and float64 to get rid of the
// `TypeConstraint`.
REGISTER_XLA_OP(Name("StatefulTruncatedNormal")
.CompileTimeConstantInput("algorithm")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}),
StatefulTruncatedNormalOp);
class StatefulUniformIntOp : public XlaOpKernel {
public:
explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {

View File

@ -55,7 +55,7 @@ xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) {
// values with uniform distribution in the range [minval, maxval) for the given
// shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and
// S64 are implemented.
xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType dtype,
xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType unused,
xla::XlaOp seed, xla::XlaOp minval,
xla::XlaOp maxval) {
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});

View File

@ -83,6 +83,8 @@ CreateResourceOpInfoMap() {
add("ResourceScatterUpdate" , kReadWrite, kVariable);
add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
add("StatefulStandardNormalV2" , kReadWrite, kVariable);
add("StatefulTruncatedNormal" , kReadWrite, kVariable);
add("StatefulUniform" , kReadWrite, kVariable);
add("StatefulUniformFullInt" , kReadWrite, kVariable);
add("StatefulUniformInt" , kReadWrite, kVariable);
add("VarIsInitializedOp" , kRead, kVariable);

View File

@ -22,6 +22,9 @@ import math
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.ops.distributions import special_math
def test_moment_matching(
samples,
@ -95,3 +98,47 @@ def anderson_darling(x):
z = np.sum((2 * i - 1) * np.log(normal_cdf(x)) +
(2 * (n - i) + 1) * np.log(1 - normal_cdf(x)))
return -n - z / n
def test_truncated_normal(assert_equal, assert_all_close, dtype, n, y):
"""Tests truncated normal distribution's statistics."""
def _normal_cdf(x):
return .5 * math.erfc(-x / math.sqrt(2))
def normal_pdf(x):
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x):
return special_math.ndtri(x)
a = -2.
b = 2.
mu = 0.
sigma = 1.
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
z = _normal_cdf(beta) - _normal_cdf(alpha)
assert_equal((y >= a).sum(), n)
assert_equal((y <= b).sum(), n)
# For more information on these calculations, see:
# Burkardt, John. "The Truncated Normal Distribution".
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
y = y.astype(float)
actual_mean = np.mean(y)
assert_all_close(actual_mean, expected_mean, atol=5e-4)
expected_median = mu + probit(
(_normal_cdf(alpha) + _normal_cdf(beta)) / 2.) * sigma
actual_median = np.median(y)
assert_all_close(actual_median, expected_median, atol=8e-4)
expected_variance = sigma**2 * (1 + (
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
actual_variance = np.var(y)
assert_all_close(actual_variance, expected_variance,
rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3)