diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 145b9495ff4..b2c641f8ab3 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -204,6 +204,24 @@ cuda_py_test( ], ) +cuda_py_test( + name = "half_normal_test", + size = "medium", + srcs = ["python/kernel_tests/half_normal_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "inverse_gamma_test", srcs = ["python/kernel_tests/inverse_gamma_test.py"], diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 0d12d838932..66827179e9f 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -36,6 +36,7 @@ from tensorflow.contrib.distributions.python.ops.distribution_util import softpl from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag from tensorflow.contrib.distributions.python.ops.estimator import * from tensorflow.contrib.distributions.python.ops.geometric import * +from tensorflow.contrib.distributions.python.ops.half_normal import * from tensorflow.contrib.distributions.python.ops.independent import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * from tensorflow.contrib.distributions.python.ops.logistic import * @@ -107,6 +108,7 @@ _allowed_symbols = [ 'Gamma', 'GammaWithSoftplusConcentrationRate', 'Geometric', + 'HalfNormal', 'Independent', 'InverseGamma', 'InverseGammaWithSoftplusConcentrationRate', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py new file mode 100644 index 00000000000..a7571806f29 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py @@ -0,0 +1,320 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variables +from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +stats = try_import("scipy.stats") + + +class HalfNormalTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(123) + + def assertAllFinite(self, tensor): + is_finite = np.isfinite(tensor.eval()) + all_true = np.ones_like(is_finite, dtype=np.bool) + self.assertAllEqual(all_true, is_finite) + + def _testParamShapes(self, sample_shape, expected): + with self.test_session(): + param_shapes = hn_lib.HalfNormal.param_shapes(sample_shape) + scale_shape = param_shapes["scale"] + self.assertAllEqual(expected, scale_shape.eval()) + scale = array_ops.ones(scale_shape) + self.assertAllEqual( + expected, + array_ops.shape(hn_lib.HalfNormal(scale).sample()).eval()) + + def _testParamStaticShapes(self, sample_shape, expected): + param_shapes = hn_lib.HalfNormal.param_static_shapes(sample_shape) + scale_shape = param_shapes["scale"] + self.assertEqual(expected, scale_shape) + + def _testBatchShapes(self, dist, tensor): + self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.shape) + self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.eval().shape) + self.assertAllEqual(dist.batch_shape, tensor.shape) + self.assertAllEqual(dist.batch_shape, tensor.eval().shape) + + def testParamShapes(self): + sample_shape = [10, 3, 4] + self._testParamShapes(sample_shape, sample_shape) + self._testParamShapes(constant_op.constant(sample_shape), sample_shape) + + def testParamStaticShapes(self): + sample_shape = [10, 3, 4] + self._testParamStaticShapes(sample_shape, sample_shape) + self._testParamStaticShapes( + tensor_shape.TensorShape(sample_shape), sample_shape) + + def testHalfNormalLogPDF(self): + with self.test_session(): + batch_size = 6 + scale = constant_op.constant([3.0] * batch_size) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) + halfnorm = hn_lib.HalfNormal(scale=scale) + + log_pdf = halfnorm.log_prob(x) + self._testBatchShapes(halfnorm, log_pdf) + + pdf = halfnorm.prob(x) + self._testBatchShapes(halfnorm, pdf) + + if not stats: + return + expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testHalfNormalLogPDFMultidimensional(self): + with self.test_session(): + batch_size = 6 + scale = constant_op.constant([[3.0, 1.0]] * batch_size) + x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T + halfnorm = hn_lib.HalfNormal(scale=scale) + + log_pdf = halfnorm.log_prob(x) + self._testBatchShapes(halfnorm, log_pdf) + + pdf = halfnorm.prob(x) + self._testBatchShapes(halfnorm, pdf) + + if not stats: + return + expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testHalfNormalCDF(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + halfnorm = hn_lib.HalfNormal(scale=scale) + + cdf = halfnorm.cdf(x) + self._testBatchShapes(halfnorm, cdf) + + log_cdf = halfnorm.log_cdf(x) + self._testBatchShapes(halfnorm, log_cdf) + + if not stats: + return + expected_logcdf = stats.halfnorm(scale=scale).logcdf(x) + self.assertAllClose(expected_logcdf, log_cdf.eval(), atol=0) + self.assertAllClose(np.exp(expected_logcdf), cdf.eval(), atol=0) + + def testHalfNormalSurvivalFunction(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sf = halfnorm.survival_function(x) + self._testBatchShapes(halfnorm, sf) + + log_sf = halfnorm.log_survival_function(x) + self._testBatchShapes(halfnorm, log_sf) + + if not stats: + return + expected_logsf = stats.halfnorm(scale=scale).logsf(x) + self.assertAllClose(expected_logsf, log_sf.eval(), atol=0) + self.assertAllClose(np.exp(expected_logsf), sf.eval(), atol=0) + + def testHalfNormalQuantile(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + p = np.linspace(0., 1.0, batch_size).astype(np.float64) + + halfnorm = hn_lib.HalfNormal(scale=scale) + x = halfnorm.quantile(p) + self._testBatchShapes(halfnorm, x) + + if not stats: + return + expected_x = stats.halfnorm(scale=scale).ppf(p) + self.assertAllClose(expected_x, x.eval(), atol=0) + + def testFiniteGradients(self): + for dtype in [np.float32, np.float64]: + g = ops.Graph() + with g.as_default(): + scale = variables.Variable(dtype(3.0)) + dist = hn_lib.HalfNormal(scale=scale) + x = np.array([0.01, 0.1, 1., 5., 10.]).astype(dtype) + for func in [ + dist.cdf, dist.log_cdf, dist.survival_function, + dist.log_prob, dist.prob, dist.log_survival_function, + ]: + print(func.__name__) + value = func(x) + grads = gradients_impl.gradients(value, [scale]) + with self.test_session(graph=g): + variables.global_variables_initializer().run() + self.assertAllFinite(value) + self.assertAllFinite(grads[0]) + + def testHalfNormalEntropy(self): + with self.test_session(): + scale = np.array([[1.0, 2.0, 3.0]]) + halfnorm = hn_lib.HalfNormal(scale=scale) + + # See https://en.wikipedia.org/wiki/Half-normal_distribution for the + # entropy formula used here. + expected_entropy = 0.5 * np.log(np.pi * scale ** 2.0 / 2.0) + 0.5 + + entropy = halfnorm.entropy() + self._testBatchShapes(halfnorm, entropy) + self.assertAllClose(expected_entropy, entropy.eval()) + + def testHalfNormalMeanAndMode(self): + with self.test_session(): + scale = np.array([11., 12., 13.]) + + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_mean = scale * np.sqrt(2.0) / np.sqrt(np.pi) + + self.assertAllEqual((3,), halfnorm.mean().eval().shape) + self.assertAllEqual(expected_mean, halfnorm.mean().eval()) + + self.assertAllEqual((3,), halfnorm.mode().eval().shape) + self.assertAllEqual([0., 0., 0.], halfnorm.mode().eval()) + + def testHalfNormalVariance(self): + with self.test_session(): + scale = np.array([7., 7., 7.]) + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi) + + self.assertAllEqual((3,), halfnorm.variance().eval().shape) + self.assertAllEqual(expected_variance, halfnorm.variance().eval()) + + def testHalfNormalStandardDeviation(self): + with self.test_session(): + scale = np.array([7., 7., 7.]) + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi) + + self.assertAllEqual((3,), halfnorm.stddev().shape) + self.assertAllEqual(np.sqrt(expected_variance), halfnorm.stddev().eval()) + + def testHalfNormalSample(self): + with self.test_session(): + scale = constant_op.constant(3.0) + n = constant_op.constant(100000) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sample = halfnorm.sample(n) + + self.assertEqual(sample.eval().shape, (100000,)) + self.assertAllClose(sample.eval().mean(), + 3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, sample.shape) + self.assertAllEqual(expected_shape, sample.eval().shape) + + expected_shape_static = (tensor_shape.TensorShape( + [n.eval()]).concatenate(halfnorm.batch_shape)) + self.assertAllEqual(expected_shape_static, sample.shape) + self.assertAllEqual(expected_shape_static, sample.eval().shape) + + def testHalfNormalSampleMultiDimensional(self): + with self.test_session(): + batch_size = 2 + scale = constant_op.constant([[2.0, 3.0]] * batch_size) + n = constant_op.constant(100000) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sample = halfnorm.sample(n) + self.assertEqual(sample.shape, (100000, batch_size, 2)) + self.assertAllClose(sample.eval()[:, 0, 0].mean(), + 2.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + self.assertAllClose(sample.eval()[:, 0, 1].mean(), + 3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, sample.shape) + self.assertAllEqual(expected_shape, sample.eval().shape) + + expected_shape_static = (tensor_shape.TensorShape( + [n.eval()]).concatenate(halfnorm.batch_shape)) + self.assertAllEqual(expected_shape_static, sample.shape) + self.assertAllEqual(expected_shape_static, sample.eval().shape) + + def testNegativeSigmaFails(self): + with self.test_session(): + halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G") + with self.assertRaisesOpError("Condition x > 0 did not hold"): + halfnorm.mean().eval() + + def testHalfNormalShape(self): + with self.test_session(): + scale = constant_op.constant([6.0] * 5) + halfnorm = hn_lib.HalfNormal(scale=scale) + + self.assertEqual(halfnorm.batch_shape_tensor().eval(), [5]) + self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(halfnorm.event_shape_tensor().eval(), []) + self.assertEqual(halfnorm.event_shape, tensor_shape.TensorShape([])) + + def testHalfNormalShapeWithPlaceholders(self): + scale = array_ops.placeholder(dtype=dtypes.float32) + halfnorm = hn_lib.HalfNormal(scale=scale) + + with self.test_session() as sess: + # get_batch_shape should return an "<unknown>" tensor. + self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape(None)) + self.assertEqual(halfnorm.event_shape, ()) + self.assertAllEqual(halfnorm.event_shape_tensor().eval(), []) + self.assertAllEqual( + sess.run(halfnorm.batch_shape_tensor(), + feed_dict={scale: [1.0, 2.0]}), [2]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py new file mode 100644 index 00000000000..12059b6a9e1 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -0,0 +1,170 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Half Normal distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import special_math + + +__all__ = [ + "HalfNormal", +] + + +class HalfNormal(distribution.Distribution): + """The Half Normal distribution with scale `scale`. + + #### Mathematical details + + The half normal is a transformation of a centered normal distribution. + If some random variable `X` has normal distribution, + ```none + X ~ Normal(0.0, scale) + Y = |X| + ``` + Then `Y` will have half normal distribution. The probability density + function (pdf) is: + + ```none + pdf(x; scale, x > 0) = sqrt(2) / (scale * sqrt(pi)) * + exp(- 1/2 * (x / scale) ** 2) + ) + ``` + Where `scale = sigma` is the standard deviation of the underlying normal + distribution. + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + # Define a single scalar HalfNormal distribution. + dist = tf.contrib.distributions.HalfNormal(scale=3.0) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1.) + + # Define a batch of two scalar valued HalfNormals. + # The first has scale 11.0, the second 22.0 + dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0]) + + # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5, + # returning a length two tensor. + dist.prob([1.0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + ``` + + """ + def __init__(self, + scale, + validate_args=False, + allow_nan_stats=True, + name="HalfNormal"): + """Construct HalfNormals with scale `scale`. + + Args: + scale: Floating point tensor; the scales of the distribution(s). + Must contain only positive values. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + parameters = locals() + with ops.name_scope(name, values=[scale]): + with ops.control_dependencies([check_ops.assert_positive(scale)] if + validate_args else []): + self._scale = array_ops.identity(scale, name="scale") + super(HalfNormal, self).__init__( + dtype=self._scale.dtype, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._scale], + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return {'scale': ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} + + @property + def scale(self): + """Distribution parameter for the scale.""" + return self._scale + + def _batch_shape_tensor(self): + return array_ops.shape(self.scale) + + def _batch_shape(self): + return self.scale.shape + + def _event_shape_tensor(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + sampled = random_ops.random_normal( + shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) + return math_ops.abs(sampled * self.scale) + + def _prob(self, x): + coeff = np.sqrt(2) / self.scale / np.sqrt(np.pi) + pdf = coeff * math_ops.exp(- 0.5 * (x / self.scale) ** 2) + return pdf * math_ops.cast(x >= 0, self.dtype) + + def _cdf(self, x): + truncated_x = nn.relu(x) + return math_ops.erf(truncated_x / self.scale / np.sqrt(2.0)) + + def _entropy(self): + return 0.5 * math_ops.log(np.pi * self.scale ** 2.0 / 2.0) + 0.5 + + def _mean(self): + return self.scale * np.sqrt(2.0) / np.sqrt(np.pi) + + def _quantile(self, p): + return np.sqrt(2.0) * self.scale * special_math.erfinv(p) + + def _mode(self): + return array_ops.zeros(self.batch_shape_tensor()) + + def _variance(self): + return self.scale ** 2.0 * (1.0 - 2.0 / np.pi) diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py index 9441cdbe39e..2d434a39c29 100644 --- a/tensorflow/python/kernel_tests/distributions/special_math_test.py +++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py @@ -332,6 +332,32 @@ class LogNdtrGradientTest(NdtrGradientTest): _use_log = True +class ErfInvTest(test.TestCase): + + def testErfInvValues(self): + with self.test_session(): + if not special: + return + + x = np.linspace(0., 1.0, 50).astype(np.float64) + + expected_x = special.erfinv(x) + x = special_math.erfinv(x) + self.assertAllClose(expected_x, x.eval(), atol=0.) + + def testErfInvIntegerInput(self): + with self.test_session(): + + with self.assertRaises(TypeError): + x = np.array([1, 2, 3]).astype(np.int32) + special_math.erfinv(x) + + with self.assertRaises(TypeError): + x = np.array([1, 2, 3]).astype(np.int64) + special_math.erfinv(x) + + + class LogCDFLaplaceTest(test.TestCase): # Note that scipy.stats.laplace does not have a stable Log CDF, so we cannot # rely on scipy to cross check the extreme values. diff --git a/tensorflow/python/ops/distributions/special_math.py b/tensorflow/python/ops/distributions/special_math.py index 222a39ad828..bed4cbb2c1a 100644 --- a/tensorflow/python/ops/distributions/special_math.py +++ b/tensorflow/python/ops/distributions/special_math.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops __all__ = [ + "erfinv", "ndtr", "ndtri", "log_ndtr", @@ -350,6 +351,29 @@ def _log_ndtr_asymptotic_series(x, series_order): return 1. + even_sum - odd_sum +def erfinv(x, name="erfinv"): + """The inverse function for erf, the error function. + + Args: + x: `Tensor` of type `float32`, `float64`. + name: Python string. A name for the operation (default="erfinv"). + + Returns: + x: `Tensor` with `dtype=x.dtype`. + + Raises: + TypeError: if `x` is not floating-type. + """ + + with ops.name_scope(name, values=[x]): + x = ops.convert_to_tensor(x, name="x") + if x.dtype.as_numpy_dtype not in [np.float32, np.float64]: + raise TypeError( + "x.dtype=%s is not handled, see docstring for supported types." + % x.dtype) + return ndtri((x + 1.0) / 2.0) / np.sqrt(2) + + def _double_factorial(n): """The double factorial function for small Python integer `n`.""" return np.prod(np.arange(n, 1, -2))