Half Normal Distribution (and inverse error function) (#14056)
* foldednormal docstring * folded __init__ method * prob, log_prob methods * rewrote halfnormal docstring * initial implementation of dist methods * halfnormal unit tests * registered HalfNormal to contrib.distributions * added erfinv function * unit tests for erfinv * registered erfinv symbol * cdf, pdf now deal with x < 0 correctly * pylint fixes * cuda_py test reference in BUILD * erfinv fixes * corrections to scipy reference tests * Added reference to entropy test case.
This commit is contained in:
parent
3bf2f35c71
commit
ec4d31e82c
tensorflow
contrib/distributions
python
@ -204,6 +204,24 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "half_normal_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["python/kernel_tests/half_normal_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":distributions_py",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/python:client",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:gradients",
|
||||||
|
"//tensorflow/python:nn_ops",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:variables",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "inverse_gamma_test",
|
name = "inverse_gamma_test",
|
||||||
srcs = ["python/kernel_tests/inverse_gamma_test.py"],
|
srcs = ["python/kernel_tests/inverse_gamma_test.py"],
|
||||||
|
@ -36,6 +36,7 @@ from tensorflow.contrib.distributions.python.ops.distribution_util import softpl
|
|||||||
from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag
|
from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag
|
||||||
from tensorflow.contrib.distributions.python.ops.estimator import *
|
from tensorflow.contrib.distributions.python.ops.estimator import *
|
||||||
from tensorflow.contrib.distributions.python.ops.geometric import *
|
from tensorflow.contrib.distributions.python.ops.geometric import *
|
||||||
|
from tensorflow.contrib.distributions.python.ops.half_normal import *
|
||||||
from tensorflow.contrib.distributions.python.ops.independent import *
|
from tensorflow.contrib.distributions.python.ops.independent import *
|
||||||
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
|
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
|
||||||
from tensorflow.contrib.distributions.python.ops.logistic import *
|
from tensorflow.contrib.distributions.python.ops.logistic import *
|
||||||
@ -107,6 +108,7 @@ _allowed_symbols = [
|
|||||||
'Gamma',
|
'Gamma',
|
||||||
'GammaWithSoftplusConcentrationRate',
|
'GammaWithSoftplusConcentrationRate',
|
||||||
'Geometric',
|
'Geometric',
|
||||||
|
'HalfNormal',
|
||||||
'Independent',
|
'Independent',
|
||||||
'InverseGamma',
|
'InverseGamma',
|
||||||
'InverseGammaWithSoftplusConcentrationRate',
|
'InverseGammaWithSoftplusConcentrationRate',
|
||||||
|
@ -0,0 +1,320 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for initializers."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
|
|
||||||
|
def try_import(name): # pylint: disable=invalid-name
|
||||||
|
module = None
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(name)
|
||||||
|
except ImportError as e:
|
||||||
|
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
|
||||||
|
return module
|
||||||
|
|
||||||
|
stats = try_import("scipy.stats")
|
||||||
|
|
||||||
|
|
||||||
|
class HalfNormalTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._rng = np.random.RandomState(123)
|
||||||
|
|
||||||
|
def assertAllFinite(self, tensor):
|
||||||
|
is_finite = np.isfinite(tensor.eval())
|
||||||
|
all_true = np.ones_like(is_finite, dtype=np.bool)
|
||||||
|
self.assertAllEqual(all_true, is_finite)
|
||||||
|
|
||||||
|
def _testParamShapes(self, sample_shape, expected):
|
||||||
|
with self.test_session():
|
||||||
|
param_shapes = hn_lib.HalfNormal.param_shapes(sample_shape)
|
||||||
|
scale_shape = param_shapes["scale"]
|
||||||
|
self.assertAllEqual(expected, scale_shape.eval())
|
||||||
|
scale = array_ops.ones(scale_shape)
|
||||||
|
self.assertAllEqual(
|
||||||
|
expected,
|
||||||
|
array_ops.shape(hn_lib.HalfNormal(scale).sample()).eval())
|
||||||
|
|
||||||
|
def _testParamStaticShapes(self, sample_shape, expected):
|
||||||
|
param_shapes = hn_lib.HalfNormal.param_static_shapes(sample_shape)
|
||||||
|
scale_shape = param_shapes["scale"]
|
||||||
|
self.assertEqual(expected, scale_shape)
|
||||||
|
|
||||||
|
def _testBatchShapes(self, dist, tensor):
|
||||||
|
self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.shape)
|
||||||
|
self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.eval().shape)
|
||||||
|
self.assertAllEqual(dist.batch_shape, tensor.shape)
|
||||||
|
self.assertAllEqual(dist.batch_shape, tensor.eval().shape)
|
||||||
|
|
||||||
|
def testParamShapes(self):
|
||||||
|
sample_shape = [10, 3, 4]
|
||||||
|
self._testParamShapes(sample_shape, sample_shape)
|
||||||
|
self._testParamShapes(constant_op.constant(sample_shape), sample_shape)
|
||||||
|
|
||||||
|
def testParamStaticShapes(self):
|
||||||
|
sample_shape = [10, 3, 4]
|
||||||
|
self._testParamStaticShapes(sample_shape, sample_shape)
|
||||||
|
self._testParamStaticShapes(
|
||||||
|
tensor_shape.TensorShape(sample_shape), sample_shape)
|
||||||
|
|
||||||
|
def testHalfNormalLogPDF(self):
|
||||||
|
with self.test_session():
|
||||||
|
batch_size = 6
|
||||||
|
scale = constant_op.constant([3.0] * batch_size)
|
||||||
|
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
|
||||||
|
log_pdf = halfnorm.log_prob(x)
|
||||||
|
self._testBatchShapes(halfnorm, log_pdf)
|
||||||
|
|
||||||
|
pdf = halfnorm.prob(x)
|
||||||
|
self._testBatchShapes(halfnorm, pdf)
|
||||||
|
|
||||||
|
if not stats:
|
||||||
|
return
|
||||||
|
expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x)
|
||||||
|
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||||
|
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
|
||||||
|
|
||||||
|
def testHalfNormalLogPDFMultidimensional(self):
|
||||||
|
with self.test_session():
|
||||||
|
batch_size = 6
|
||||||
|
scale = constant_op.constant([[3.0, 1.0]] * batch_size)
|
||||||
|
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
|
||||||
|
log_pdf = halfnorm.log_prob(x)
|
||||||
|
self._testBatchShapes(halfnorm, log_pdf)
|
||||||
|
|
||||||
|
pdf = halfnorm.prob(x)
|
||||||
|
self._testBatchShapes(halfnorm, pdf)
|
||||||
|
|
||||||
|
if not stats:
|
||||||
|
return
|
||||||
|
expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x)
|
||||||
|
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||||
|
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
|
||||||
|
|
||||||
|
def testHalfNormalCDF(self):
|
||||||
|
with self.test_session():
|
||||||
|
batch_size = 50
|
||||||
|
scale = self._rng.rand(batch_size) + 1.0
|
||||||
|
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
|
||||||
|
cdf = halfnorm.cdf(x)
|
||||||
|
self._testBatchShapes(halfnorm, cdf)
|
||||||
|
|
||||||
|
log_cdf = halfnorm.log_cdf(x)
|
||||||
|
self._testBatchShapes(halfnorm, log_cdf)
|
||||||
|
|
||||||
|
if not stats:
|
||||||
|
return
|
||||||
|
expected_logcdf = stats.halfnorm(scale=scale).logcdf(x)
|
||||||
|
self.assertAllClose(expected_logcdf, log_cdf.eval(), atol=0)
|
||||||
|
self.assertAllClose(np.exp(expected_logcdf), cdf.eval(), atol=0)
|
||||||
|
|
||||||
|
def testHalfNormalSurvivalFunction(self):
|
||||||
|
with self.test_session():
|
||||||
|
batch_size = 50
|
||||||
|
scale = self._rng.rand(batch_size) + 1.0
|
||||||
|
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
|
||||||
|
sf = halfnorm.survival_function(x)
|
||||||
|
self._testBatchShapes(halfnorm, sf)
|
||||||
|
|
||||||
|
log_sf = halfnorm.log_survival_function(x)
|
||||||
|
self._testBatchShapes(halfnorm, log_sf)
|
||||||
|
|
||||||
|
if not stats:
|
||||||
|
return
|
||||||
|
expected_logsf = stats.halfnorm(scale=scale).logsf(x)
|
||||||
|
self.assertAllClose(expected_logsf, log_sf.eval(), atol=0)
|
||||||
|
self.assertAllClose(np.exp(expected_logsf), sf.eval(), atol=0)
|
||||||
|
|
||||||
|
def testHalfNormalQuantile(self):
|
||||||
|
with self.test_session():
|
||||||
|
batch_size = 50
|
||||||
|
scale = self._rng.rand(batch_size) + 1.0
|
||||||
|
p = np.linspace(0., 1.0, batch_size).astype(np.float64)
|
||||||
|
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
x = halfnorm.quantile(p)
|
||||||
|
self._testBatchShapes(halfnorm, x)
|
||||||
|
|
||||||
|
if not stats:
|
||||||
|
return
|
||||||
|
expected_x = stats.halfnorm(scale=scale).ppf(p)
|
||||||
|
self.assertAllClose(expected_x, x.eval(), atol=0)
|
||||||
|
|
||||||
|
def testFiniteGradients(self):
|
||||||
|
for dtype in [np.float32, np.float64]:
|
||||||
|
g = ops.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
scale = variables.Variable(dtype(3.0))
|
||||||
|
dist = hn_lib.HalfNormal(scale=scale)
|
||||||
|
x = np.array([0.01, 0.1, 1., 5., 10.]).astype(dtype)
|
||||||
|
for func in [
|
||||||
|
dist.cdf, dist.log_cdf, dist.survival_function,
|
||||||
|
dist.log_prob, dist.prob, dist.log_survival_function,
|
||||||
|
]:
|
||||||
|
print(func.__name__)
|
||||||
|
value = func(x)
|
||||||
|
grads = gradients_impl.gradients(value, [scale])
|
||||||
|
with self.test_session(graph=g):
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
self.assertAllFinite(value)
|
||||||
|
self.assertAllFinite(grads[0])
|
||||||
|
|
||||||
|
def testHalfNormalEntropy(self):
|
||||||
|
with self.test_session():
|
||||||
|
scale = np.array([[1.0, 2.0, 3.0]])
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
|
||||||
|
# See https://en.wikipedia.org/wiki/Half-normal_distribution for the
|
||||||
|
# entropy formula used here.
|
||||||
|
expected_entropy = 0.5 * np.log(np.pi * scale ** 2.0 / 2.0) + 0.5
|
||||||
|
|
||||||
|
entropy = halfnorm.entropy()
|
||||||
|
self._testBatchShapes(halfnorm, entropy)
|
||||||
|
self.assertAllClose(expected_entropy, entropy.eval())
|
||||||
|
|
||||||
|
def testHalfNormalMeanAndMode(self):
|
||||||
|
with self.test_session():
|
||||||
|
scale = np.array([11., 12., 13.])
|
||||||
|
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
expected_mean = scale * np.sqrt(2.0) / np.sqrt(np.pi)
|
||||||
|
|
||||||
|
self.assertAllEqual((3,), halfnorm.mean().eval().shape)
|
||||||
|
self.assertAllEqual(expected_mean, halfnorm.mean().eval())
|
||||||
|
|
||||||
|
self.assertAllEqual((3,), halfnorm.mode().eval().shape)
|
||||||
|
self.assertAllEqual([0., 0., 0.], halfnorm.mode().eval())
|
||||||
|
|
||||||
|
def testHalfNormalVariance(self):
|
||||||
|
with self.test_session():
|
||||||
|
scale = np.array([7., 7., 7.])
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi)
|
||||||
|
|
||||||
|
self.assertAllEqual((3,), halfnorm.variance().eval().shape)
|
||||||
|
self.assertAllEqual(expected_variance, halfnorm.variance().eval())
|
||||||
|
|
||||||
|
def testHalfNormalStandardDeviation(self):
|
||||||
|
with self.test_session():
|
||||||
|
scale = np.array([7., 7., 7.])
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi)
|
||||||
|
|
||||||
|
self.assertAllEqual((3,), halfnorm.stddev().shape)
|
||||||
|
self.assertAllEqual(np.sqrt(expected_variance), halfnorm.stddev().eval())
|
||||||
|
|
||||||
|
def testHalfNormalSample(self):
|
||||||
|
with self.test_session():
|
||||||
|
scale = constant_op.constant(3.0)
|
||||||
|
n = constant_op.constant(100000)
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
|
||||||
|
sample = halfnorm.sample(n)
|
||||||
|
|
||||||
|
self.assertEqual(sample.eval().shape, (100000,))
|
||||||
|
self.assertAllClose(sample.eval().mean(),
|
||||||
|
3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1)
|
||||||
|
|
||||||
|
expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate(
|
||||||
|
tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval()))
|
||||||
|
self.assertAllEqual(expected_shape, sample.shape)
|
||||||
|
self.assertAllEqual(expected_shape, sample.eval().shape)
|
||||||
|
|
||||||
|
expected_shape_static = (tensor_shape.TensorShape(
|
||||||
|
[n.eval()]).concatenate(halfnorm.batch_shape))
|
||||||
|
self.assertAllEqual(expected_shape_static, sample.shape)
|
||||||
|
self.assertAllEqual(expected_shape_static, sample.eval().shape)
|
||||||
|
|
||||||
|
def testHalfNormalSampleMultiDimensional(self):
|
||||||
|
with self.test_session():
|
||||||
|
batch_size = 2
|
||||||
|
scale = constant_op.constant([[2.0, 3.0]] * batch_size)
|
||||||
|
n = constant_op.constant(100000)
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
|
||||||
|
sample = halfnorm.sample(n)
|
||||||
|
self.assertEqual(sample.shape, (100000, batch_size, 2))
|
||||||
|
self.assertAllClose(sample.eval()[:, 0, 0].mean(),
|
||||||
|
2.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1)
|
||||||
|
self.assertAllClose(sample.eval()[:, 0, 1].mean(),
|
||||||
|
3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1)
|
||||||
|
|
||||||
|
expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate(
|
||||||
|
tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval()))
|
||||||
|
self.assertAllEqual(expected_shape, sample.shape)
|
||||||
|
self.assertAllEqual(expected_shape, sample.eval().shape)
|
||||||
|
|
||||||
|
expected_shape_static = (tensor_shape.TensorShape(
|
||||||
|
[n.eval()]).concatenate(halfnorm.batch_shape))
|
||||||
|
self.assertAllEqual(expected_shape_static, sample.shape)
|
||||||
|
self.assertAllEqual(expected_shape_static, sample.eval().shape)
|
||||||
|
|
||||||
|
def testNegativeSigmaFails(self):
|
||||||
|
with self.test_session():
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G")
|
||||||
|
with self.assertRaisesOpError("Condition x > 0 did not hold"):
|
||||||
|
halfnorm.mean().eval()
|
||||||
|
|
||||||
|
def testHalfNormalShape(self):
|
||||||
|
with self.test_session():
|
||||||
|
scale = constant_op.constant([6.0] * 5)
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
|
||||||
|
self.assertEqual(halfnorm.batch_shape_tensor().eval(), [5])
|
||||||
|
self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape([5]))
|
||||||
|
self.assertAllEqual(halfnorm.event_shape_tensor().eval(), [])
|
||||||
|
self.assertEqual(halfnorm.event_shape, tensor_shape.TensorShape([]))
|
||||||
|
|
||||||
|
def testHalfNormalShapeWithPlaceholders(self):
|
||||||
|
scale = array_ops.placeholder(dtype=dtypes.float32)
|
||||||
|
halfnorm = hn_lib.HalfNormal(scale=scale)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
# get_batch_shape should return an "<unknown>" tensor.
|
||||||
|
self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape(None))
|
||||||
|
self.assertEqual(halfnorm.event_shape, ())
|
||||||
|
self.assertAllEqual(halfnorm.event_shape_tensor().eval(), [])
|
||||||
|
self.assertAllEqual(
|
||||||
|
sess.run(halfnorm.batch_shape_tensor(),
|
||||||
|
feed_dict={scale: [1.0, 2.0]}), [2])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
170
tensorflow/contrib/distributions/python/ops/half_normal.py
Normal file
170
tensorflow/contrib/distributions/python/ops/half_normal.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""The Half Normal distribution class."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops.distributions import distribution
|
||||||
|
from tensorflow.python.ops.distributions import special_math
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"HalfNormal",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class HalfNormal(distribution.Distribution):
|
||||||
|
"""The Half Normal distribution with scale `scale`.
|
||||||
|
|
||||||
|
#### Mathematical details
|
||||||
|
|
||||||
|
The half normal is a transformation of a centered normal distribution.
|
||||||
|
If some random variable `X` has normal distribution,
|
||||||
|
```none
|
||||||
|
X ~ Normal(0.0, scale)
|
||||||
|
Y = |X|
|
||||||
|
```
|
||||||
|
Then `Y` will have half normal distribution. The probability density
|
||||||
|
function (pdf) is:
|
||||||
|
|
||||||
|
```none
|
||||||
|
pdf(x; scale, x > 0) = sqrt(2) / (scale * sqrt(pi)) *
|
||||||
|
exp(- 1/2 * (x / scale) ** 2)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
Where `scale = sigma` is the standard deviation of the underlying normal
|
||||||
|
distribution.
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
Examples of initialization of one or a batch of distributions.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define a single scalar HalfNormal distribution.
|
||||||
|
dist = tf.contrib.distributions.HalfNormal(scale=3.0)
|
||||||
|
|
||||||
|
# Evaluate the cdf at 1, returning a scalar.
|
||||||
|
dist.cdf(1.)
|
||||||
|
|
||||||
|
# Define a batch of two scalar valued HalfNormals.
|
||||||
|
# The first has scale 11.0, the second 22.0
|
||||||
|
dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0])
|
||||||
|
|
||||||
|
# Evaluate the pdf of the first distribution on 1.0, and the second on 1.5,
|
||||||
|
# returning a length two tensor.
|
||||||
|
dist.prob([1.0, 1.5])
|
||||||
|
|
||||||
|
# Get 3 samples, returning a 3 x 2 tensor.
|
||||||
|
dist.sample([3])
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
scale,
|
||||||
|
validate_args=False,
|
||||||
|
allow_nan_stats=True,
|
||||||
|
name="HalfNormal"):
|
||||||
|
"""Construct HalfNormals with scale `scale`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale: Floating point tensor; the scales of the distribution(s).
|
||||||
|
Must contain only positive values.
|
||||||
|
validate_args: Python `bool`, default `False`. When `True` distribution
|
||||||
|
parameters are checked for validity despite possibly degrading runtime
|
||||||
|
performance. When `False` invalid inputs may silently render incorrect
|
||||||
|
outputs.
|
||||||
|
allow_nan_stats: Python `bool`, default `True`. When `True`,
|
||||||
|
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
|
||||||
|
indicate the result is undefined. When `False`, an exception is raised
|
||||||
|
if one or more of the statistic's batch members are undefined.
|
||||||
|
name: Python `str` name prefixed to Ops created by this class.
|
||||||
|
"""
|
||||||
|
parameters = locals()
|
||||||
|
with ops.name_scope(name, values=[scale]):
|
||||||
|
with ops.control_dependencies([check_ops.assert_positive(scale)] if
|
||||||
|
validate_args else []):
|
||||||
|
self._scale = array_ops.identity(scale, name="scale")
|
||||||
|
super(HalfNormal, self).__init__(
|
||||||
|
dtype=self._scale.dtype,
|
||||||
|
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
|
||||||
|
validate_args=validate_args,
|
||||||
|
allow_nan_stats=allow_nan_stats,
|
||||||
|
parameters=parameters,
|
||||||
|
graph_parents=[self._scale],
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _param_shapes(sample_shape):
|
||||||
|
return {'scale': ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self):
|
||||||
|
"""Distribution parameter for the scale."""
|
||||||
|
return self._scale
|
||||||
|
|
||||||
|
def _batch_shape_tensor(self):
|
||||||
|
return array_ops.shape(self.scale)
|
||||||
|
|
||||||
|
def _batch_shape(self):
|
||||||
|
return self.scale.shape
|
||||||
|
|
||||||
|
def _event_shape_tensor(self):
|
||||||
|
return constant_op.constant([], dtype=dtypes.int32)
|
||||||
|
|
||||||
|
def _event_shape(self):
|
||||||
|
return tensor_shape.scalar()
|
||||||
|
|
||||||
|
def _sample_n(self, n, seed=None):
|
||||||
|
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
|
||||||
|
sampled = random_ops.random_normal(
|
||||||
|
shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed)
|
||||||
|
return math_ops.abs(sampled * self.scale)
|
||||||
|
|
||||||
|
def _prob(self, x):
|
||||||
|
coeff = np.sqrt(2) / self.scale / np.sqrt(np.pi)
|
||||||
|
pdf = coeff * math_ops.exp(- 0.5 * (x / self.scale) ** 2)
|
||||||
|
return pdf * math_ops.cast(x >= 0, self.dtype)
|
||||||
|
|
||||||
|
def _cdf(self, x):
|
||||||
|
truncated_x = nn.relu(x)
|
||||||
|
return math_ops.erf(truncated_x / self.scale / np.sqrt(2.0))
|
||||||
|
|
||||||
|
def _entropy(self):
|
||||||
|
return 0.5 * math_ops.log(np.pi * self.scale ** 2.0 / 2.0) + 0.5
|
||||||
|
|
||||||
|
def _mean(self):
|
||||||
|
return self.scale * np.sqrt(2.0) / np.sqrt(np.pi)
|
||||||
|
|
||||||
|
def _quantile(self, p):
|
||||||
|
return np.sqrt(2.0) * self.scale * special_math.erfinv(p)
|
||||||
|
|
||||||
|
def _mode(self):
|
||||||
|
return array_ops.zeros(self.batch_shape_tensor())
|
||||||
|
|
||||||
|
def _variance(self):
|
||||||
|
return self.scale ** 2.0 * (1.0 - 2.0 / np.pi)
|
@ -332,6 +332,32 @@ class LogNdtrGradientTest(NdtrGradientTest):
|
|||||||
_use_log = True
|
_use_log = True
|
||||||
|
|
||||||
|
|
||||||
|
class ErfInvTest(test.TestCase):
|
||||||
|
|
||||||
|
def testErfInvValues(self):
|
||||||
|
with self.test_session():
|
||||||
|
if not special:
|
||||||
|
return
|
||||||
|
|
||||||
|
x = np.linspace(0., 1.0, 50).astype(np.float64)
|
||||||
|
|
||||||
|
expected_x = special.erfinv(x)
|
||||||
|
x = special_math.erfinv(x)
|
||||||
|
self.assertAllClose(expected_x, x.eval(), atol=0.)
|
||||||
|
|
||||||
|
def testErfInvIntegerInput(self):
|
||||||
|
with self.test_session():
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
x = np.array([1, 2, 3]).astype(np.int32)
|
||||||
|
special_math.erfinv(x)
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
x = np.array([1, 2, 3]).astype(np.int64)
|
||||||
|
special_math.erfinv(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LogCDFLaplaceTest(test.TestCase):
|
class LogCDFLaplaceTest(test.TestCase):
|
||||||
# Note that scipy.stats.laplace does not have a stable Log CDF, so we cannot
|
# Note that scipy.stats.laplace does not have a stable Log CDF, so we cannot
|
||||||
# rely on scipy to cross check the extreme values.
|
# rely on scipy to cross check the extreme values.
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"erfinv",
|
||||||
"ndtr",
|
"ndtr",
|
||||||
"ndtri",
|
"ndtri",
|
||||||
"log_ndtr",
|
"log_ndtr",
|
||||||
@ -350,6 +351,29 @@ def _log_ndtr_asymptotic_series(x, series_order):
|
|||||||
return 1. + even_sum - odd_sum
|
return 1. + even_sum - odd_sum
|
||||||
|
|
||||||
|
|
||||||
|
def erfinv(x, name="erfinv"):
|
||||||
|
"""The inverse function for erf, the error function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `Tensor` of type `float32`, `float64`.
|
||||||
|
name: Python string. A name for the operation (default="erfinv").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: `Tensor` with `dtype=x.dtype`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `x` is not floating-type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with ops.name_scope(name, values=[x]):
|
||||||
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
|
if x.dtype.as_numpy_dtype not in [np.float32, np.float64]:
|
||||||
|
raise TypeError(
|
||||||
|
"x.dtype=%s is not handled, see docstring for supported types."
|
||||||
|
% x.dtype)
|
||||||
|
return ndtri((x + 1.0) / 2.0) / np.sqrt(2)
|
||||||
|
|
||||||
|
|
||||||
def _double_factorial(n):
|
def _double_factorial(n):
|
||||||
"""The double factorial function for small Python integer `n`."""
|
"""The double factorial function for small Python integer `n`."""
|
||||||
return np.prod(np.arange(n, 1, -2))
|
return np.prod(np.arange(n, 1, -2))
|
||||||
|
Loading…
Reference in New Issue
Block a user