diff --git a/tensorflow/compiler/tests/stateful_random_ops_test.py b/tensorflow/compiler/tests/stateful_random_ops_test.py index fd1f69789ae..1992a6e9c0a 100644 --- a/tensorflow/compiler/tests/stateful_random_ops_test.py +++ b/tensorflow/compiler/tests/stateful_random_ops_test.py @@ -53,6 +53,9 @@ def xla_device_name(): class StatefulRandomOpsTest(xla_test.XLATestCase): """Test cases for stateful random-number generator operators.""" + _ints = [dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64] + _floats = [dtypes.bfloat16, dtypes.float32] + @test_util.run_v2_only def testSimple(self): """A simple test. @@ -147,7 +150,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): maxval = 10000000 return gen.uniform(shape=[2], dtype=dtype, maxval=maxval) - for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + for dtype in self._ints + self._floats: self._testRngIsNotConstant(rng, dtype) @test_util.run_v2_only @@ -157,27 +160,27 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): def rng(dtype): return gen.normal(shape=[2], dtype=dtype) - for dtype in {dtypes.float32}: + for dtype in self._floats: self._testRngIsNotConstant(rng, dtype) @test_util.run_v2_only - def testUniformIntIsInRange(self): + def testUniformIsInRange(self): minval = 2 maxval = 33 size = 1000 with ops.device(xla_device_name()): - gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) - for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + for dtype in self._ints + self._floats: + gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) x = gen.uniform( shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy() self.assertTrue(np.all(x >= minval)) - self.assertTrue(np.all(x < maxval)) + self.assertTrue(np.all(x <= maxval)) @test_util.run_v2_only def testNormalIsFinite(self): with ops.device(xla_device_name()): gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) - for dtype in {dtypes.float32}: + for dtype in self._floats: x = gen.normal(shape=[10000], dtype=dtype).numpy() self.assertTrue(np.all(np.isfinite(x))) @@ -187,7 +190,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): with ops.device(xla_device_name()): n = 1000 seed = 12 - for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + for dtype in self._ints + self._floats: gen = random.Generator(seed=seed, algorithm=random.RNG_ALG_THREEFRY) maxval = 1 if dtype.is_integer: @@ -208,7 +211,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): """Use Anderson-Darling test to test distribution appears normal.""" with ops.device(xla_device_name()): n = 1000 - for dtype in {dtypes.float32}: + for dtype in self._floats: gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) x = gen.normal(shape=[n], dtype=dtype).numpy() # The constant 2.492 is the 5% critical value for the Anderson-Darling @@ -217,6 +220,15 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): self.assertLess( random_test_util.anderson_darling(x.astype(float)), 2.492) + @test_util.run_v2_only + def testTruncatedNormal(self): + for dtype in self._floats: + gen = random.Generator(seed=123) + n = 10000000 + y = gen.truncated_normal(shape=[n], dtype=dtype).numpy() + random_test_util.test_truncated_normal( + self.assertEqual, self.assertAllClose, dtype, n, y) + @test_util.run_v2_only def testErrors(self): """Tests that proper errors are raised. diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 3fb3176ee00..93c7d7fbf09 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as np from tensorflow.compiler.tests import xla_test @@ -28,7 +26,6 @@ from tensorflow.python.kernel_tests.random import util as \ random_test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import stateless_random_ops as stateless -from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test @@ -122,7 +119,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertLess( random_test_util.anderson_darling(y.astype(float)), 2.492) - def testTruncatedNormalIsInRange(self): + def testTruncatedNormal(self): for dtype in self._random_types(): with self.cached_session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) @@ -130,47 +127,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): x = stateless.stateless_truncated_normal( shape=[n], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) - - def normal_cdf(x): - return .5 * math.erfc(-x / math.sqrt(2)) - - def normal_pdf(x): - return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) - - def probit(x): - return self.evaluate(special_math.ndtri(x)) - - a = -2. - b = 2. - mu = 0. - sigma = 1. - - alpha = (a - mu) / sigma - beta = (b - mu) / sigma - z = normal_cdf(beta) - normal_cdf(alpha) - - self.assertEqual((y >= a).sum(), n) - self.assertEqual((y <= b).sum(), n) - - # For more information on these calculations, see: - # Burkardt, John. "The Truncated Normal Distribution". - # Department of Scientific Computing website. Florida State University. - expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma - y = y.astype(float) - actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=5e-4) - - expected_median = mu + probit( - (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma - actual_median = np.median(y) - self.assertAllClose(actual_median, expected_median, atol=8e-4) - - expected_variance = sigma**2 * (1 + ( - (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( - (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) - actual_variance = np.var(y) - self.assertAllClose(actual_variance, expected_variance, - rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3) + random_test_util.test_truncated_normal( + self.assertEqual, self.assertAllClose, dtype, n, y) if __name__ == '__main__': diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index f1d68835e12..ccdd1194916 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -112,11 +112,13 @@ std::pair StatefulRngUniform(xla::XlaOp key, counter); } default: - return std::make_pair(builder->ReportError(xla::Unimplemented( - "Types other than F32, U32, S32, U64 and S64 " - "are not implemented by " - "StatefulRngUniform.")), - counter); + return std::make_pair( + builder->ReportError(xla::Unimplemented( + "Types other than F32, U32, S32, U64 and S64 " + "are not implemented by " + "StatefulRngUniform; got: %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))), + counter); } } @@ -243,6 +245,46 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx, } } +class StatefulUniformOp : public XlaOpKernel { + public: + explicit StatefulUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + auto sample_with_threefry = [builder, this]( + xla::XlaOp counter, xla::XlaOp key, + TensorShape shape) -> sampler_return_type { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + auto uniform_counter = StatefulRngUniform( + key, counter, xla_shape, xla::ConstantR0(builder, 0.0), + xla::ConstantR0(builder, 1.0)); + auto uniform = uniform_counter.first; + counter = uniform_counter.second; + uniform = MaybeConvertF32ToBF16(uniform, dtype_); + return {{uniform, counter}}; + }; + OP_REQUIRES_OK(ctx, + CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, + /*shape_input_idx=*/2, sample_with_threefry)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformOp); +}; + +// TODO(wangpeng): Support plain float16 and float64 to get rid of the +// `TypeConstraint`. +REGISTER_XLA_OP(Name("StatefulUniform") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}), + StatefulUniformOp); + class StatefulStandardNormalOp : public XlaOpKernel { public: explicit StatefulStandardNormalOp(OpKernelConstruction* ctx) @@ -291,6 +333,51 @@ REGISTER_XLA_OP(Name("StatefulStandardNormalV2") .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}), StatefulStandardNormalOp); +class StatefulTruncatedNormalOp : public XlaOpKernel { + public: + explicit StatefulTruncatedNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + auto sample_with_threefry = + // Needs explicit lambda return type because it fails to be inferred. + [builder, this](xla::XlaOp counter, xla::XlaOp key, + TensorShape shape) -> sampler_return_type { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto uniform_counter = StatefulRngUniform( + key, counter, xla_shape, + xla::MinPositiveNormalValue(builder, xla_shape.element_type()), + xla::One(builder, xla_shape.element_type())); + auto uniform = uniform_counter.first; + counter = uniform_counter.second; + xla::XlaOp truncated_normal = TruncatedNormal(uniform); + truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_); + return {{truncated_normal, counter}}; + }; + OP_REQUIRES_OK(ctx, + CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, + /*shape_input_idx=*/2, sample_with_threefry)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatefulTruncatedNormalOp); +}; + +// TODO(wangpeng): Support plain float16 and float64 to get rid of the +// `TypeConstraint`. +REGISTER_XLA_OP(Name("StatefulTruncatedNormal") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}), + StatefulTruncatedNormalOp); + class StatefulUniformIntOp : public XlaOpKernel { public: explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index e143a711730..91230de0029 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -55,7 +55,7 @@ xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) { // values with uniform distribution in the range [minval, maxval) for the given // shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and // S64 are implemented. -xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType dtype, +xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType unused, xla::XlaOp seed, xla::XlaOp minval, xla::XlaOp maxval) { xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 29ebf46e4bf..1243e31a047 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -83,6 +83,8 @@ CreateResourceOpInfoMap() { add("ResourceScatterUpdate" , kReadWrite, kVariable); add("ResourceStridedSliceAssign" , kReadWrite, kVariable); add("StatefulStandardNormalV2" , kReadWrite, kVariable); + add("StatefulTruncatedNormal" , kReadWrite, kVariable); + add("StatefulUniform" , kReadWrite, kVariable); add("StatefulUniformFullInt" , kReadWrite, kVariable); add("StatefulUniformInt" , kReadWrite, kVariable); add("VarIsInitializedOp" , kRead, kVariable); diff --git a/tensorflow/python/kernel_tests/random/util.py b/tensorflow/python/kernel_tests/random/util.py index 84e3df4278c..d8ece405cf5 100644 --- a/tensorflow/python/kernel_tests/random/util.py +++ b/tensorflow/python/kernel_tests/random/util.py @@ -22,6 +22,9 @@ import math import numpy as np +from tensorflow.python.framework import dtypes +from tensorflow.python.ops.distributions import special_math + def test_moment_matching( samples, @@ -95,3 +98,47 @@ def anderson_darling(x): z = np.sum((2 * i - 1) * np.log(normal_cdf(x)) + (2 * (n - i) + 1) * np.log(1 - normal_cdf(x))) return -n - z / n + + +def test_truncated_normal(assert_equal, assert_all_close, dtype, n, y): + """Tests truncated normal distribution's statistics.""" + def _normal_cdf(x): + return .5 * math.erfc(-x / math.sqrt(2)) + + def normal_pdf(x): + return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) + + def probit(x): + return special_math.ndtri(x) + + a = -2. + b = 2. + mu = 0. + sigma = 1. + + alpha = (a - mu) / sigma + beta = (b - mu) / sigma + z = _normal_cdf(beta) - _normal_cdf(alpha) + + assert_equal((y >= a).sum(), n) + assert_equal((y <= b).sum(), n) + + # For more information on these calculations, see: + # Burkardt, John. "The Truncated Normal Distribution". + # Department of Scientific Computing website. Florida State University. + expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma + y = y.astype(float) + actual_mean = np.mean(y) + assert_all_close(actual_mean, expected_mean, atol=5e-4) + + expected_median = mu + probit( + (_normal_cdf(alpha) + _normal_cdf(beta)) / 2.) * sigma + actual_median = np.median(y) + assert_all_close(actual_median, expected_median, atol=8e-4) + + expected_variance = sigma**2 * (1 + ( + (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( + (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) + actual_variance = np.var(y) + assert_all_close(actual_variance, expected_variance, + rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3)