Added Binomial and Multinomial distributions.
- Refactored some common asserts into a distribution_util library. - Changed some documentation for distributions (in particular providing more helpful error messages, properly escaping values in comments, etc.). Change: 129280447
This commit is contained in:
parent
bdad5cdcbe
commit
ed4300da87
tensorflow/contrib/distributions
@ -99,7 +99,16 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/beta_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "binomial_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/binomial_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
tags = ["notsan"],
|
||||
@ -179,9 +188,8 @@ cuda_py_tests(
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "kullback_leibler_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
|
||||
name = "laplace_test",
|
||||
srcs = ["python/kernel_tests/laplace_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -190,13 +198,14 @@ cuda_py_tests(
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "laplace_test",
|
||||
srcs = ["python/kernel_tests/laplace_test.py"],
|
||||
name = "multinomial_test",
|
||||
srcs = ["python/kernel_tests/multinomial_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
@ -239,6 +248,15 @@ cuda_py_tests(
|
||||
srcs = ["python/kernel_tests/uniform_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "kullback_leibler_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
@ -25,6 +25,7 @@ initialized with parameters that define the distributions.
|
||||
|
||||
### Univariate (scalar) distributions
|
||||
|
||||
@@Binomial
|
||||
@@Bernoulli
|
||||
@@Beta
|
||||
@@Categorical
|
||||
@ -50,6 +51,7 @@ initialized with parameters that define the distributions.
|
||||
|
||||
@@Dirichlet
|
||||
@@DirichletMultinomial
|
||||
@@Multinomial
|
||||
|
||||
### Transformed distributions
|
||||
|
||||
@ -79,6 +81,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops.bernoulli import *
|
||||
from tensorflow.contrib.distributions.python.ops.beta import *
|
||||
from tensorflow.contrib.distributions.python.ops.binomial import *
|
||||
from tensorflow.contrib.distributions.python.ops.categorical import *
|
||||
from tensorflow.contrib.distributions.python.ops.chi2 import *
|
||||
from tensorflow.contrib.distributions.python.ops.dirichlet import *
|
||||
@ -89,6 +92,7 @@ from tensorflow.contrib.distributions.python.ops.gamma import *
|
||||
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
|
||||
from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
|
||||
from tensorflow.contrib.distributions.python.ops.laplace import *
|
||||
from tensorflow.contrib.distributions.python.ops.multinomial import *
|
||||
from tensorflow.contrib.distributions.python.ops.mvn import *
|
||||
from tensorflow.contrib.distributions.python.ops.normal import *
|
||||
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
|
||||
|
@ -57,10 +57,17 @@ class BernoulliTest(tf.test.TestCase):
|
||||
self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
|
||||
|
||||
def testInvalidP(self):
|
||||
invalid_ps = [1.01, -0.01, 2., -3.]
|
||||
invalid_ps = [1.01, 2.]
|
||||
for p in invalid_ps:
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError("x <= y"):
|
||||
with self.assertRaisesOpError("p has components greater than 1"):
|
||||
dist = tf.contrib.distributions.Bernoulli(p=p)
|
||||
dist.p.eval()
|
||||
|
||||
invalid_ps = [-0.01, -3.]
|
||||
for p in invalid_ps:
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError("Condition x >= 0"):
|
||||
dist = tf.contrib.distributions.Bernoulli(p=p)
|
||||
dist.p.eval()
|
||||
|
||||
|
@ -0,0 +1,173 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class BinomialTest(tf.test.TestCase):
|
||||
|
||||
def testSimpleShapes(self):
|
||||
with self.test_session():
|
||||
p = np.float32(np.random.beta(1, 1))
|
||||
binom = tf.contrib.distributions.Binomial(n=1., p=p)
|
||||
self.assertAllEqual([], binom.event_shape().eval())
|
||||
self.assertAllEqual([], binom.batch_shape().eval())
|
||||
self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
|
||||
self.assertEqual(tf.TensorShape([]), binom.get_batch_shape())
|
||||
|
||||
def testComplexShapes(self):
|
||||
with self.test_session():
|
||||
p = np.random.beta(1, 1, size=(3, 2)).astype(np.float32)
|
||||
n = [[3., 2], [4, 5], [6, 7]]
|
||||
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||
self.assertAllEqual([], binom.event_shape().eval())
|
||||
self.assertAllEqual([3, 2], binom.batch_shape().eval())
|
||||
self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
|
||||
self.assertEqual(tf.TensorShape([3, 2]), binom.get_batch_shape())
|
||||
|
||||
def testNProperty(self):
|
||||
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
|
||||
n = [[3.], [4]]
|
||||
with self.test_session():
|
||||
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||
self.assertEqual((2, 1), binom.n.get_shape())
|
||||
self.assertAllClose(n, binom.n.eval())
|
||||
|
||||
def testPProperty(self):
|
||||
p = [[0.1, 0.2, 0.7]]
|
||||
with self.test_session():
|
||||
binom = tf.contrib.distributions.Binomial(n=3., p=p)
|
||||
self.assertEqual((1, 3), binom.p.get_shape())
|
||||
self.assertEqual((1, 3), binom.logits.get_shape())
|
||||
self.assertAllClose(p, binom.p.eval())
|
||||
|
||||
def testLogitsProperty(self):
|
||||
logits = [[0., 9., -0.5]]
|
||||
with self.test_session():
|
||||
binom = tf.contrib.distributions.Binomial(n=3., logits=logits)
|
||||
self.assertEqual((1, 3), binom.p.get_shape())
|
||||
self.assertEqual((1, 3), binom.logits.get_shape())
|
||||
self.assertAllClose(logits, binom.logits.eval())
|
||||
|
||||
def testPmfNandCountsAgree(self):
|
||||
p = [[0.1, 0.2, 0.7]]
|
||||
n = [[5.]]
|
||||
with self.test_session():
|
||||
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||
binom.pmf([2., 3, 2]).eval()
|
||||
binom.pmf([3., 1, 2]).eval()
|
||||
with self.assertRaisesOpError('Condition x >= 0.*'):
|
||||
binom.pmf([-1., 4, 2]).eval()
|
||||
with self.assertRaisesOpError('Condition x <= y.*'):
|
||||
binom.pmf([7., 3, 0]).eval()
|
||||
|
||||
def testPmf_non_integer_counts(self):
|
||||
p = [[0.1, 0.2, 0.7]]
|
||||
n = [[5.]]
|
||||
with self.test_session():
|
||||
# No errors with integer n.
|
||||
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||
binom.pmf([2., 3, 2]).eval()
|
||||
binom.pmf([3., 1, 2]).eval()
|
||||
# Both equality and integer checking fail.
|
||||
with self.assertRaisesOpError('Condition x == y.*'):
|
||||
binom.pmf([1.0, 2.5, 1.5]).eval()
|
||||
|
||||
binom = tf.contrib.distributions.Binomial(n=n, p=p, validate_args=False)
|
||||
binom.pmf([1., 2., 3.]).eval()
|
||||
# Non-integer arguments work.
|
||||
binom.pmf([1.0, 2.5, 1.5]).eval()
|
||||
|
||||
def testPmfBothZeroBatches(self):
|
||||
with self.test_session():
|
||||
# Both zero-batches. No broadcast
|
||||
p = 0.5
|
||||
counts = 1.
|
||||
pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
|
||||
self.assertAllClose(0.5, pmf.eval())
|
||||
self.assertEqual((), pmf.get_shape())
|
||||
|
||||
def testPmfBothZeroBatchesNontrivialN(self):
|
||||
with self.test_session():
|
||||
# Both zero-batches. No broadcast
|
||||
p = 0.1
|
||||
counts = 3.
|
||||
binom = tf.contrib.distributions.Binomial(n=5., p=p)
|
||||
pmf = binom.pmf(counts)
|
||||
self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval())
|
||||
self.assertEqual((), pmf.get_shape())
|
||||
|
||||
def testPmfPStretchedInBroadcastWhenSameRank(self):
|
||||
with self.test_session():
|
||||
p = [[0.1, 0.9]]
|
||||
counts = [[1., 2.]]
|
||||
pmf = tf.contrib.distributions.Binomial(n=3., p=p).pmf(counts)
|
||||
self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval())
|
||||
self.assertEqual((1, 2), pmf.get_shape())
|
||||
|
||||
def testPmfPStretchedInBroadcastWhenLowerRank(self):
|
||||
with self.test_session():
|
||||
p = [0.1, 0.4]
|
||||
counts = [[1.], [0.]]
|
||||
pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
|
||||
self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval())
|
||||
self.assertEqual((2, 2), pmf.get_shape())
|
||||
|
||||
def testBinomialMean(self):
|
||||
with self.test_session():
|
||||
n = 5.
|
||||
p = [0.1, 0.2, 0.7]
|
||||
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||
expected_means = stats.binom.mean(n, p)
|
||||
self.assertEqual((3,), binom.mean().get_shape())
|
||||
self.assertAllClose(expected_means, binom.mean().eval())
|
||||
|
||||
def testBinomialVariance(self):
|
||||
with self.test_session():
|
||||
n = 5.
|
||||
p = [0.1, 0.2, 0.7]
|
||||
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||
expected_variances = stats.binom.var(n, p)
|
||||
self.assertEqual((3,), binom.variance().get_shape())
|
||||
self.assertAllClose(expected_variances, binom.variance().eval())
|
||||
|
||||
def testBinomialMode(self):
|
||||
with self.test_session():
|
||||
n = 5.
|
||||
p = [0.1, 0.2, 0.7]
|
||||
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||
expected_modes = [0., 1, 4]
|
||||
self.assertEqual((3,), binom.mode().get_shape())
|
||||
self.assertAllClose(expected_modes, binom.mode().eval())
|
||||
|
||||
def testBinomialMultipleMode(self):
|
||||
with self.test_session():
|
||||
n = 9.
|
||||
p = [0.1, 0.2, 0.7]
|
||||
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||
# For the case where (n + 1) * p is an integer, the modes are:
|
||||
# (n + 1) * p and (n + 1) * p - 1. In this case, we get back
|
||||
# the larger of the two modes.
|
||||
expected_modes = [1., 2, 7]
|
||||
self.assertEqual((3,), binom.mode().get_shape())
|
||||
self.assertAllClose(expected_modes, binom.mode().eval())
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -65,7 +65,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
||||
dist.pmf([3., 0, 2]).eval()
|
||||
with self.assertRaisesOpError('Condition x >= 0.*'):
|
||||
dist.pmf([-1., 4, 2]).eval()
|
||||
with self.assertRaisesOpError('Condition x == y.*'):
|
||||
with self.assertRaisesOpError('counts do not sum to n'):
|
||||
dist.pmf([3., 3, 0]).eval()
|
||||
|
||||
def testPmf_non_integer_counts(self):
|
||||
|
@ -0,0 +1,226 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class MultinomialTest(tf.test.TestCase):
|
||||
|
||||
def testSimpleShapes(self):
|
||||
with self.test_session():
|
||||
p = [.1, .3, .6]
|
||||
dist = tf.contrib.distributions.Multinomial(n=1., p=p)
|
||||
self.assertEqual(3, dist.event_shape().eval())
|
||||
self.assertAllEqual([], dist.batch_shape().eval())
|
||||
self.assertEqual(tf.TensorShape([3]), dist.get_event_shape())
|
||||
self.assertEqual(tf.TensorShape([]), dist.get_batch_shape())
|
||||
|
||||
def testComplexShapes(self):
|
||||
with self.test_session():
|
||||
p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
|
||||
n = [[3., 2], [4, 5], [6, 7]]
|
||||
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||
self.assertEqual(2, dist.event_shape().eval())
|
||||
self.assertAllEqual([3, 2], dist.batch_shape().eval())
|
||||
self.assertEqual(tf.TensorShape([2]), dist.get_event_shape())
|
||||
self.assertEqual(tf.TensorShape([3, 2]), dist.get_batch_shape())
|
||||
|
||||
def testNProperty(self):
|
||||
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
|
||||
n = [[3.], [4]]
|
||||
with self.test_session():
|
||||
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||
self.assertEqual((2, 1), dist.n.get_shape())
|
||||
self.assertAllClose(n, dist.n.eval())
|
||||
|
||||
def testPProperty(self):
|
||||
p = [[0.1, 0.2, 0.7]]
|
||||
with self.test_session():
|
||||
dist = tf.contrib.distributions.Multinomial(n=3., p=p)
|
||||
self.assertEqual((1, 3), dist.p.get_shape())
|
||||
self.assertEqual((1, 3), dist.logits.get_shape())
|
||||
self.assertAllClose(p, dist.p.eval())
|
||||
|
||||
def testLogitsProperty(self):
|
||||
logits = [[0., 9., -0.5]]
|
||||
with self.test_session():
|
||||
multinom = tf.contrib.distributions.Multinomial(n=3., logits=logits)
|
||||
self.assertEqual((1, 3), multinom.p.get_shape())
|
||||
self.assertEqual((1, 3), multinom.logits.get_shape())
|
||||
self.assertAllClose(logits, multinom.logits.eval())
|
||||
|
||||
def testPmfNandCountsAgree(self):
|
||||
p = [[0.1, 0.2, 0.7]]
|
||||
n = [[5.]]
|
||||
with self.test_session():
|
||||
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||
dist.pmf([2., 3, 0]).eval()
|
||||
dist.pmf([3., 0, 2]).eval()
|
||||
with self.assertRaisesOpError('Condition x >= 0.*'):
|
||||
dist.pmf([-1., 4, 2]).eval()
|
||||
with self.assertRaisesOpError('counts do not sum to n'):
|
||||
dist.pmf([3., 3, 0]).eval()
|
||||
|
||||
def testPmf_non_integer_counts(self):
|
||||
p = [[0.1, 0.2, 0.7]]
|
||||
n = [[5.]]
|
||||
with self.test_session():
|
||||
# No errors with integer n.
|
||||
multinom = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||
multinom.pmf([2., 1, 2]).eval()
|
||||
multinom.pmf([3., 0, 2]).eval()
|
||||
# Counts don't sum to n.
|
||||
with self.assertRaisesOpError('counts do not sum to n'):
|
||||
multinom.pmf([2., 3, 2]).eval()
|
||||
# Counts are non-integers.
|
||||
with self.assertRaisesOpError('Condition x == y.*'):
|
||||
multinom.pmf([1.0, 2.5, 1.5]).eval()
|
||||
|
||||
multinom = tf.contrib.distributions.Multinomial(
|
||||
n=n, p=p, validate_args=False)
|
||||
multinom.pmf([1., 2., 2.]).eval()
|
||||
# Non-integer arguments work.
|
||||
multinom.pmf([1.0, 2.5, 1.5]).eval()
|
||||
|
||||
def testPmfBothZeroBatches(self):
|
||||
with self.test_session():
|
||||
# Both zero-batches. No broadcast
|
||||
p = [0.5, 0.5]
|
||||
counts = [1., 0]
|
||||
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||
self.assertAllClose(0.5, pmf.eval())
|
||||
self.assertEqual((), pmf.get_shape())
|
||||
|
||||
def testPmfBothZeroBatchesNontrivialN(self):
|
||||
with self.test_session():
|
||||
# Both zero-batches. No broadcast
|
||||
p = [0.1, 0.9]
|
||||
counts = [3., 2]
|
||||
dist = tf.contrib.distributions.Multinomial(n=5., p=p)
|
||||
pmf = dist.pmf(counts)
|
||||
# 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000.
|
||||
self.assertAllClose(81./10000, pmf.eval())
|
||||
self.assertEqual((), pmf.get_shape())
|
||||
|
||||
def testPmfPStretchedInBroadcastWhenSameRank(self):
|
||||
with self.test_session():
|
||||
p = [[0.1, 0.9]]
|
||||
counts = [[1., 0], [0, 1]]
|
||||
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||
self.assertAllClose([0.1, 0.9], pmf.eval())
|
||||
self.assertEqual((2), pmf.get_shape())
|
||||
|
||||
def testPmfPStretchedInBroadcastWhenLowerRank(self):
|
||||
with self.test_session():
|
||||
p = [0.1, 0.9]
|
||||
counts = [[1., 0], [0, 1]]
|
||||
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||
self.assertAllClose([0.1, 0.9], pmf.eval())
|
||||
self.assertEqual((2), pmf.get_shape())
|
||||
|
||||
def testPmfCountsStretchedInBroadcastWhenSameRank(self):
|
||||
with self.test_session():
|
||||
p = [[0.1, 0.9], [0.7, 0.3]]
|
||||
counts = [[1., 0]]
|
||||
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||
self.assertAllClose(pmf.eval(), [0.1, 0.7])
|
||||
self.assertEqual((2), pmf.get_shape())
|
||||
|
||||
def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
|
||||
with self.test_session():
|
||||
p = [[0.1, 0.9], [0.7, 0.3]]
|
||||
counts = [1., 0]
|
||||
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||
self.assertAllClose(pmf.eval(), [0.1, 0.7])
|
||||
self.assertEqual(pmf.get_shape(), (2))
|
||||
|
||||
def testPmfShapeCountsStretched_N(self):
|
||||
with self.test_session():
|
||||
# [2, 2, 2]
|
||||
p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]]
|
||||
# [2, 2]
|
||||
n = [[3., 3], [3, 3]]
|
||||
# [2]
|
||||
counts = [2., 1]
|
||||
pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
|
||||
pmf.eval()
|
||||
self.assertEqual(pmf.get_shape(), (2, 2))
|
||||
|
||||
def testPmfShapeCountsPStretched_N(self):
|
||||
with self.test_session():
|
||||
p = [0.1, 0.9]
|
||||
counts = [3., 2]
|
||||
n = np.full([4, 3], 5., dtype=np.float32)
|
||||
pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
|
||||
pmf.eval()
|
||||
self.assertEqual((4, 3), pmf.get_shape())
|
||||
|
||||
def testMultinomialMean(self):
|
||||
with self.test_session():
|
||||
n = 5.
|
||||
p = [0.1, 0.2, 0.7]
|
||||
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||
expected_means = 5 * np.array(p, dtype=np.float32)
|
||||
self.assertEqual((3,), dist.mean().get_shape())
|
||||
self.assertAllClose(expected_means, dist.mean().eval())
|
||||
|
||||
def testMultinomialVariance(self):
|
||||
with self.test_session():
|
||||
n = 5.
|
||||
p = [0.1, 0.2, 0.7]
|
||||
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||
expected_variances = [
|
||||
[9./20, -1/10, -7/20], [-1/10, 4/5, -7/10], [-7/20, -7/10, 21/20]]
|
||||
self.assertEqual((3, 3), dist.variance().get_shape())
|
||||
self.assertAllClose(expected_variances, dist.variance().eval())
|
||||
|
||||
def testMultinomialVariance_batch(self):
|
||||
with self.test_session():
|
||||
# Shape [2]
|
||||
n = [5.] * 2
|
||||
# Shape [4, 1, 2]
|
||||
p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2
|
||||
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||
# Shape [2, 2]
|
||||
inner_var = [[9./20, -9/20], [-9/20, 9/20]]
|
||||
# Shape [4, 2, 2, 2]
|
||||
expected_variances = [[inner_var, inner_var]] * 4
|
||||
self.assertEqual((4, 2, 2, 2), dist.variance().get_shape())
|
||||
self.assertAllClose(expected_variances, dist.variance().eval())
|
||||
|
||||
def testVariance_multidimensional(self):
|
||||
# Shape [3, 5, 4]
|
||||
p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
|
||||
# Shape [6, 3, 3]
|
||||
p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32)
|
||||
|
||||
ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32)
|
||||
ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
|
||||
|
||||
with self.test_session():
|
||||
dist = tf.contrib.distributions.Multinomial(ns, p)
|
||||
dist2 = tf.contrib.distributions.Multinomial(ns2, p2)
|
||||
|
||||
variance = dist.variance()
|
||||
variance2 = dist2.variance()
|
||||
self.assertEqual((3, 5, 4, 4), variance.get_shape())
|
||||
self.assertEqual((6, 3, 3, 3), variance2.get_shape())
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -19,15 +19,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import random_ops
|
||||
@ -38,10 +36,6 @@ class Bernoulli(distribution.Distribution):
|
||||
|
||||
The Bernoulli distribution is parameterized by p, the probability of a
|
||||
positive event.
|
||||
|
||||
Note, the following methods of the base class aren't implemented:
|
||||
* cdf
|
||||
* log_cdf
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -64,10 +58,10 @@ class Bernoulli(distribution.Distribution):
|
||||
dtype: dtype for samples.
|
||||
validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args,
|
||||
`log_pmf` may return nans.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: A name for this distribution.
|
||||
|
||||
Raises:
|
||||
@ -77,27 +71,8 @@ class Bernoulli(distribution.Distribution):
|
||||
self._name = name
|
||||
self._dtype = dtype
|
||||
self._validate_args = validate_args
|
||||
check_op = check_ops.assert_less_equal
|
||||
if p is None and logits is None:
|
||||
raise ValueError("Must pass p or logits.")
|
||||
elif p is not None and logits is not None:
|
||||
raise ValueError("Must pass either p or logits, not both.")
|
||||
elif p is None:
|
||||
with ops.op_scope([logits], name):
|
||||
self._logits = array_ops.identity(logits, name="logits")
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope("p"):
|
||||
self._p = math_ops.sigmoid(self._logits)
|
||||
elif logits is None:
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope("p"):
|
||||
p = array_ops.identity(p)
|
||||
one = constant_op.constant(1., p.dtype)
|
||||
zero = constant_op.constant(0., p.dtype)
|
||||
self._p = control_flow_ops.with_dependencies(
|
||||
[check_op(p, one), check_op(zero, p)] if validate_args else [], p)
|
||||
with ops.name_scope("logits"):
|
||||
self._logits = math_ops.log(self._p) - math_ops.log(1. - self._p)
|
||||
self._logits, self._p = distribution_util.get_logits_and_prob(
|
||||
name=name, logits=logits, p=p, validate_args=validate_args)
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope("q"):
|
||||
self._q = 1. - self._p
|
||||
@ -184,8 +159,12 @@ class Bernoulli(distribution.Distribution):
|
||||
event = ops.convert_to_tensor(event, name="event")
|
||||
event = math_ops.cast(event, self.logits.dtype)
|
||||
logits = self.logits
|
||||
if ((event.get_shape().ndims is not None) or
|
||||
(logits.get_shape().ndims is not None) or
|
||||
# sigmoid_cross_entropy_with_logits doesn't broadcast shape,
|
||||
# so we do this here.
|
||||
# TODO(b/30637701): Check dynamic shape, and don't broadcast if the
|
||||
# dynamic shapes are the same.
|
||||
if (not event.get_shape().is_fully_defined() or
|
||||
not logits.get_shape().is_fully_defined() or
|
||||
event.get_shape() != logits.get_shape()):
|
||||
logits = array_ops.ones_like(event) * logits
|
||||
event = array_ops.ones_like(logits) * event
|
||||
@ -206,8 +185,7 @@ class Bernoulli(distribution.Distribution):
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self.p, n], name):
|
||||
n = ops.convert_to_tensor(n, name="n")
|
||||
new_shape = array_ops.concat(
|
||||
0, [array_ops.expand_dims(n, 0), self.batch_shape()])
|
||||
new_shape = array_ops.concat(0, ([n], self.batch_shape()))
|
||||
uniform = random_ops.random_uniform(
|
||||
new_shape, seed=seed, dtype=dtypes.float32)
|
||||
sample = math_ops.less(uniform, self.p)
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Beta distribution class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -95,6 +96,7 @@ class Beta(distribution.Distribution):
|
||||
x = [.2, .3, .9]
|
||||
dist.pdf(x) # Shape [2]
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, a, b, validate_args=True, allow_nan_stats=False,
|
||||
@ -102,20 +104,20 @@ class Beta(distribution.Distribution):
|
||||
"""Initialize a batch of Beta distributions.
|
||||
|
||||
Args:
|
||||
a: Positive `float` or `double` tensor with shape broadcastable to
|
||||
a: Positive floating point tensor with shape broadcastable to
|
||||
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
||||
different Beta distributions. This also defines the
|
||||
dtype of the distribution.
|
||||
b: Positive `float` or `double` tensor with shape broadcastable to
|
||||
b: Positive floating point tensor with shape broadcastable to
|
||||
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
||||
different Beta distributions.
|
||||
validate_args: Whether to assert valid values for parameters `a` and `b`,
|
||||
and `x` in `prob` and `log_prob`. If False, correct behavior is not
|
||||
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||
guaranteed.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to prefix Ops created by this distribution class.
|
||||
|
||||
Examples:
|
||||
@ -127,6 +129,7 @@ class Beta(distribution.Distribution):
|
||||
# Define a 2-batch.
|
||||
dist = Beta([1.0, 2.0], [4.0, 5.0])
|
||||
```
|
||||
|
||||
"""
|
||||
with ops.op_scope([a, b], name):
|
||||
with ops.control_dependencies([
|
||||
@ -276,8 +279,14 @@ class Beta(distribution.Distribution):
|
||||
array_ops.ones_like(a_b_sum, dtype=self.dtype)))
|
||||
else:
|
||||
return control_flow_ops.with_dependencies([
|
||||
check_ops.assert_less(one, a),
|
||||
check_ops.assert_less(one, b)], mode)
|
||||
check_ops.assert_less(
|
||||
one, a,
|
||||
message="mode not defined for components of a <= 1"
|
||||
),
|
||||
check_ops.assert_less(
|
||||
one, b,
|
||||
message="mode not defined for components of b <= 1"
|
||||
)], mode)
|
||||
|
||||
def entropy(self, name="entropy"):
|
||||
"""Entropy of the distribution in nats."""
|
||||
@ -306,7 +315,7 @@ class Beta(distribution.Distribution):
|
||||
"""`Log(P[counts])`, computed for every batch member.
|
||||
|
||||
Args:
|
||||
x: Non-negative `float` or `double`, tensor whose shape can
|
||||
x: Non-negative floating point tensor whose shape can
|
||||
be broadcast with `self.a` and `self.b`. For fixed leading
|
||||
dimensions, the last dimension represents counts for the corresponding
|
||||
Beta distribution in `self.a` and `self.b`. `x` is only legal if
|
||||
@ -334,7 +343,7 @@ class Beta(distribution.Distribution):
|
||||
"""`P[x]`, computed for every batch member.
|
||||
|
||||
Args:
|
||||
x: Non-negative `float`, `double` tensor whose shape can
|
||||
x: Non-negative floating point tensor whose shape can
|
||||
be broadcast with `self.a` and `self.b`. For fixed leading
|
||||
dimensions, the last dimension represents x for the corresponding Beta
|
||||
distribution in `self.a` and `self.b`. `x` is only legal if is
|
||||
|
340
tensorflow/contrib/distributions/python/ops/binomial.py
Normal file
340
tensorflow/contrib/distributions/python/ops/binomial.py
Normal file
@ -0,0 +1,340 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Binomial distribution class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
class Binomial(distribution.Distribution):
|
||||
"""Binomial distribution.
|
||||
|
||||
This distribution is parameterized by a vector `p` of probabilities and `n`,
|
||||
the total counts.
|
||||
|
||||
#### Mathematical details
|
||||
|
||||
The Binomial is a distribution over the number of successes in `n` independent
|
||||
trials, with each trial having the same probability of success `p`.
|
||||
The probability mass function (pmf):
|
||||
|
||||
```pmf(k) = n! / (k! * (n - k)!) * (p)^k * (1 - p)^(n - k)```
|
||||
|
||||
#### Examples
|
||||
|
||||
Create a single distribution, corresponding to 5 coin flips.
|
||||
|
||||
```python
|
||||
dist = Binomial(n=5., p=.5)
|
||||
```
|
||||
|
||||
Create a single distribution (using logits), corresponding to 5 coin flips.
|
||||
|
||||
```python
|
||||
dist = Binomial(n=5., logits=0.)
|
||||
```
|
||||
|
||||
Creates 3 distributions with the third distribution most likely to have
|
||||
successes.
|
||||
|
||||
```python
|
||||
p = [.2, .3, .8]
|
||||
# n will be broadcast to [4., 4., 4.], to match p.
|
||||
dist = Binomial(n=4., p=p)
|
||||
```
|
||||
|
||||
The distribution functions can be evaluated on counts.
|
||||
|
||||
```python
|
||||
# counts same shape as p.
|
||||
counts = [1., 2, 3]
|
||||
dist.prob(counts) # Shape [3]
|
||||
|
||||
# p will be broadcast to [[.2, .3, .8], [.2, .3, .8]] to match counts.
|
||||
counts = [[1., 2, 1], [2, 2, 4]]
|
||||
dist.prob(counts) # Shape [2, 3]
|
||||
|
||||
# p will be broadcast to shape [5, 7, 3] to match counts.
|
||||
counts = [[...]] # Shape [5, 7, 3]
|
||||
dist.prob(counts) # Shape [5, 7, 3]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n,
|
||||
logits=None,
|
||||
p=None,
|
||||
validate_args=True,
|
||||
allow_nan_stats=False,
|
||||
name="Binomial"):
|
||||
"""Initialize a batch of Binomial distributions.
|
||||
|
||||
Args:
|
||||
n: Non-negative floating point tensor with shape broadcastable to
|
||||
`[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`.
|
||||
Defines this as a batch of `N1 x ... x Nm` different Binomial
|
||||
distributions. Its components should be equal to integer values.
|
||||
logits: Floating point tensor representing the log-odds of a
|
||||
positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
|
||||
the same dtype as `n`. Each entry represents logits for the probability
|
||||
of success for independent Binomial distributions.
|
||||
p: Positive floating point tensor with shape broadcastable to
|
||||
`[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the
|
||||
probability of success for independent Binomial distributions.
|
||||
validate_args: Whether to assert valid values for parameters `n` and `p`,
|
||||
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||
guaranteed.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to prefix Ops created by this distribution class.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Define 1-batch of a binomial distribution.
|
||||
dist = Binomial(n=2., p=.9)
|
||||
|
||||
# Define a 2-batch.
|
||||
dist = Binomial(n=[4., 5], p=[.1, .3])
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
self._logits, self._p = distribution_util.get_logits_and_prob(
|
||||
name=name, logits=logits, p=p, validate_args=validate_args)
|
||||
|
||||
with ops.op_scope([n], name):
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_non_negative(
|
||||
n, message="n has negative components."),
|
||||
distribution_util.assert_integer_form(
|
||||
n, message="n has non-integer components."
|
||||
)] if validate_args else []):
|
||||
self._n = array_ops.identity(n, name="convert_n")
|
||||
|
||||
self._name = name
|
||||
self._validate_args = validate_args
|
||||
self._allow_nan_stats = allow_nan_stats
|
||||
|
||||
self._mean = self._n * self._p
|
||||
self._get_batch_shape = self._mean.get_shape()
|
||||
self._get_event_shape = tensor_shape.TensorShape([])
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Name to prepend to all ops."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""dtype of samples from this distribution."""
|
||||
return self._p.dtype
|
||||
|
||||
@property
|
||||
def validate_args(self):
|
||||
"""Boolean describing behavior on invalid input."""
|
||||
return self._validate_args
|
||||
|
||||
@property
|
||||
def allow_nan_stats(self):
|
||||
"""Boolean describing behavior when a stat is undefined for batch member."""
|
||||
return self._allow_nan_stats
|
||||
|
||||
def batch_shape(self, name="batch_shape"):
|
||||
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
||||
|
||||
The product of the dimensions of the `batch_shape` is the number of
|
||||
independent distributions of this kind the instance represents.
|
||||
|
||||
Args:
|
||||
name: name to give to the op
|
||||
|
||||
Returns:
|
||||
`Tensor` `batch_shape`
|
||||
"""
|
||||
return array_ops.shape(self._mean)
|
||||
|
||||
def get_batch_shape(self):
|
||||
"""`TensorShape` available at graph construction time.
|
||||
|
||||
Same meaning as `batch_shape`. May be only partially defined.
|
||||
|
||||
Returns:
|
||||
batch shape
|
||||
"""
|
||||
return self._get_batch_shape
|
||||
|
||||
def event_shape(self, name="event_shape"):
|
||||
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
||||
|
||||
Args:
|
||||
name: name to give to the op
|
||||
|
||||
Returns:
|
||||
`Tensor` `event_shape`
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([], name):
|
||||
return constant_op.constant([], name=name, dtype=dtypes.int32)
|
||||
|
||||
def get_event_shape(self):
|
||||
"""`TensorShape` available at graph construction time.
|
||||
|
||||
Same meaning as `event_shape`. May be only partially defined.
|
||||
|
||||
Returns:
|
||||
event shape
|
||||
"""
|
||||
return self._get_event_shape
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
"""Number of trials."""
|
||||
return self._n
|
||||
|
||||
@property
|
||||
def logits(self):
|
||||
"""Log-odds."""
|
||||
return self._logits
|
||||
|
||||
@property
|
||||
def p(self):
|
||||
"""Probability of success."""
|
||||
return self._p
|
||||
|
||||
def mean(self, name="mean"):
|
||||
"""Mean of the distribution."""
|
||||
with ops.name_scope(self.name):
|
||||
return array_ops.identity(self._mean, name=name)
|
||||
|
||||
def variance(self, name="variance"):
|
||||
"""Variance of the distribution."""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._n, self._p], name):
|
||||
return self._n * self._p * (1 - self._p)
|
||||
|
||||
def std(self, name="std"):
|
||||
"""Standard deviation of the distribution."""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._n, self._p], name):
|
||||
return math_ops.sqrt(self.variance())
|
||||
|
||||
def mode(self, name="mode"):
|
||||
"""Mode of the distribution.
|
||||
|
||||
Note that when `(n + 1) * p` is an integer, there are actually two modes.
|
||||
Namely, `(n + 1) * p` and `(n + 1) * p - 1` are both modes. Here we return
|
||||
only the larger of the two modes.
|
||||
|
||||
Args:
|
||||
name: The name for this op.
|
||||
|
||||
Returns:
|
||||
The mode of the Binomial distribution.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._n, self._p], name):
|
||||
return math_ops.floor((self._n + 1) * self._p)
|
||||
|
||||
def log_prob(self, counts, name="log_prob"):
|
||||
"""`Log(P[counts])`, computed for every batch member.
|
||||
|
||||
For each batch member of counts `k`, `P[counts]` is the probability that
|
||||
after sampling `n` draws from this Binomial distribution, the number of
|
||||
successes is `k`. Note that different sequences of draws can result in the
|
||||
same counts, thus the probability includes a combinatorial coefficient.
|
||||
|
||||
Args:
|
||||
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
|
||||
less than or equal to `n` and its components are equal to integer
|
||||
values.
|
||||
name: Name to give this Op, defaults to "log_prob".
|
||||
|
||||
Returns:
|
||||
Log probabilities for each record, shape `[N1,...,Nm]`.
|
||||
"""
|
||||
n = self._n
|
||||
p = self._p
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._n, self._p, counts], name):
|
||||
counts = self._check_counts(counts)
|
||||
|
||||
prob_prob = counts * math_ops.log(p) + (
|
||||
n - counts) * math_ops.log(1 - p)
|
||||
|
||||
combinations = math_ops.lgamma(n + 1) - math_ops.lgamma(
|
||||
counts + 1) - math_ops.lgamma(n - counts + 1)
|
||||
log_prob = prob_prob + combinations
|
||||
return log_prob
|
||||
|
||||
def prob(self, counts, name="prob"):
|
||||
"""`P[counts]`, computed for every batch member.
|
||||
|
||||
|
||||
For each batch member of counts `k`, `P[counts]` is the probability that
|
||||
after sampling `n` draws from this Binomial distribution, the number of
|
||||
successes is `k`. Note that different sequences of draws can result in the
|
||||
same counts, thus the probability includes a combinatorial coefficient.
|
||||
|
||||
Args:
|
||||
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
|
||||
less than or equal to `n` and its components are equal to integer
|
||||
values.
|
||||
name: Name to give this Op, defaults to "prob".
|
||||
|
||||
Returns:
|
||||
Probabilities for each record, shape `[N1,...,Nm]`.
|
||||
"""
|
||||
return super(Binomial, self).prob(counts, name=name)
|
||||
|
||||
@property
|
||||
def is_continuous(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_reparameterized(self):
|
||||
return False
|
||||
|
||||
def _check_counts(self, counts):
|
||||
"""Check counts for proper shape, values, then return tensor version."""
|
||||
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
|
||||
if not self.validate_args:
|
||||
return counts
|
||||
return control_flow_ops.with_dependencies([
|
||||
check_ops.assert_non_negative(
|
||||
counts, message="counts has negative components."),
|
||||
check_ops.assert_less_equal(
|
||||
counts, self._n, message="counts are not less than or equal to n."),
|
||||
distribution_util.assert_integer_form(
|
||||
counts, message="counts have non-integer components.")], counts)
|
@ -34,11 +34,6 @@ class Categorical(distribution.Distribution):
|
||||
|
||||
The categorical distribution is parameterized by the log-probabilities
|
||||
of a set of classes.
|
||||
|
||||
Note, the following methods of the base class aren't implemented:
|
||||
* mean
|
||||
* cdf
|
||||
* log_cdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -57,10 +52,10 @@ class Categorical(distribution.Distribution):
|
||||
indexes into the classes.
|
||||
dtype: The type of the event samples (default: int32).
|
||||
validate_args: Unused in this distribution.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: A name for this distribution (optional).
|
||||
"""
|
||||
self._allow_nan_stats = allow_nan_stats
|
||||
@ -177,8 +172,7 @@ class Categorical(distribution.Distribution):
|
||||
samples = math_ops.cast(samples, self._dtype)
|
||||
ret = array_ops.reshape(
|
||||
array_ops.transpose(samples),
|
||||
array_ops.concat(
|
||||
0, [array_ops.expand_dims(n, 0), self.batch_shape()]))
|
||||
array_ops.concat(0, ([n], self.batch_shape())))
|
||||
ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
|
||||
.concatenate(self.get_batch_shape()))
|
||||
return ret
|
||||
|
@ -42,15 +42,15 @@ class Chi2(gamma.Gamma):
|
||||
"""Construct Chi2 distributions with parameter `df`.
|
||||
|
||||
Args:
|
||||
df: `float` or `double` tensor, the degrees of freedom of the
|
||||
df: Floating point tensor, the degrees of freedom of the
|
||||
distribution(s). `df` must contain only positive values.
|
||||
validate_args: Whether to assert that `df > 0`, and that `x > 0` in the
|
||||
methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
||||
methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||
and the inputs are invalid, correct behavior is not guaranteed.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to prepend to all ops created by this distribution.
|
||||
"""
|
||||
# Even though all stats of chi2 are defined for valid parameters, this is
|
||||
|
@ -19,9 +19,8 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -29,7 +28,6 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import special_math_ops
|
||||
@ -37,24 +35,6 @@ from tensorflow.python.ops import special_math_ops
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
def _assert_close(x, y, data=None, summarize=None, name=None):
|
||||
if x.dtype.is_integer:
|
||||
return check_ops.assert_equal(
|
||||
x, y, data=data, summarize=summarize, name=name)
|
||||
|
||||
with ops.op_scope([x, y, data], name, "assert_close"):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
y = ops.convert_to_tensor(y, name="y")
|
||||
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
|
||||
if data is None:
|
||||
data = [
|
||||
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
|
||||
y.name, y
|
||||
]
|
||||
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
|
||||
return logging_ops.Assert(condition, data, summarize=summarize)
|
||||
|
||||
|
||||
class Dirichlet(distribution.Distribution):
|
||||
"""Dirichlet distribution.
|
||||
|
||||
@ -117,6 +97,7 @@ class Dirichlet(distribution.Distribution):
|
||||
x = [.2, .3, .5]
|
||||
dist.prob(x) # Shape [2]
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -127,16 +108,16 @@ class Dirichlet(distribution.Distribution):
|
||||
"""Initialize a batch of Dirichlet distributions.
|
||||
|
||||
Args:
|
||||
alpha: Positive `float` or `double` tensor with shape broadcastable to
|
||||
alpha: Positive floating point tensor with shape broadcastable to
|
||||
`[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
||||
different `k` class Dirichlet distributions.
|
||||
validate_args: Whether to assert valid values for parameters `alpha` and
|
||||
`x` in `prob` and `log_prob`. If False, correct behavior is not
|
||||
`x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||
guaranteed.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to prefix Ops created by this distribution class.
|
||||
|
||||
Examples:
|
||||
@ -149,6 +130,7 @@ class Dirichlet(distribution.Distribution):
|
||||
# Define a 2-batch of 3-class distributions.
|
||||
dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
```
|
||||
|
||||
"""
|
||||
with ops.op_scope([alpha], name):
|
||||
alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps")
|
||||
@ -302,7 +284,9 @@ class Dirichlet(distribution.Distribution):
|
||||
array_ops.ones_like(self._alpha, dtype=self.dtype)))
|
||||
else:
|
||||
return control_flow_ops.with_dependencies([
|
||||
check_ops.assert_less(one, self._alpha)
|
||||
check_ops.assert_less(
|
||||
one, self._alpha,
|
||||
message="mode not defined for components of alpha <= 1")
|
||||
], mode)
|
||||
|
||||
def entropy(self, name="entropy"):
|
||||
@ -334,7 +318,7 @@ class Dirichlet(distribution.Distribution):
|
||||
"""`Log(P[counts])`, computed for every batch member.
|
||||
|
||||
Args:
|
||||
x: Non-negative `float` or `double`, tensor whose shape can
|
||||
x: Non-negative tensor with dtype `dtype` and whose shape can
|
||||
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||
dimension represents counts for the corresponding Dirichlet distribution
|
||||
in `self.alpha`. `x` is only legal if it sums up to one.
|
||||
@ -359,7 +343,7 @@ class Dirichlet(distribution.Distribution):
|
||||
"""`P[x]`, computed for every batch member.
|
||||
|
||||
Args:
|
||||
x: Non-negative `float`, `double` tensor whose shape can
|
||||
x: Non-negative tensor with dtype `dtype` and whose shape can
|
||||
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||
dimension represents x for the corresponding Dirichlet distribution in
|
||||
`self.alpha` and `self.beta`. `x` is only legal if it sums up to one.
|
||||
@ -407,7 +391,8 @@ class Dirichlet(distribution.Distribution):
|
||||
x = ops.convert_to_tensor(x, name="x_before_deps")
|
||||
candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1])
|
||||
one = constant_op.constant(1., self.dtype)
|
||||
dependencies = [check_ops.assert_positive(x), check_ops.assert_less(x, one),
|
||||
_assert_close(one, candidate_one)
|
||||
dependencies = [check_ops.assert_positive(x), check_ops.assert_less(
|
||||
x, one, message="x has components greater than or equal to 1"),
|
||||
distribution_util.assert_close(one, candidate_one)
|
||||
] if self.validate_args else []
|
||||
return control_flow_ops.with_dependencies(dependencies, x)
|
||||
|
@ -13,13 +13,15 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Dirichlet Multinomial distribution class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.distributions.python.ops import distribution
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
@ -30,34 +32,6 @@ from tensorflow.python.ops import special_math_ops
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
def _assert_integer_form(x):
|
||||
"""Check x for integer components (or floats that are equal to integers)."""
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
casted_x = math_ops.to_int64(x)
|
||||
return check_ops.assert_equal(x, math_ops.cast(
|
||||
math_ops.round(casted_x), x.dtype))
|
||||
|
||||
|
||||
def _log_combinations(n, counts, name='log_combinations'):
|
||||
"""Log number of ways counts could have come in."""
|
||||
# First a bit about the number of ways counts could have come in:
|
||||
# E.g. if counts = [1, 2], then this is 3 choose 2.
|
||||
# In general, this is (sum counts)! / sum(counts!)
|
||||
# The sum should be along the last dimension of counts. This is the
|
||||
# "distribution" dimension. Here n a priori represents the sum of counts.
|
||||
with ops.op_scope([counts], name):
|
||||
# To compute factorials, use the fact that Gamma(n + 1) = n!
|
||||
# Compute two terms, each a sum over counts. Compute each for each
|
||||
# batch member.
|
||||
# Log Gamma((sum counts) + 1) = Log((sum counts)!)
|
||||
total_permutations = math_ops.lgamma(n + 1)
|
||||
# sum(Log Gamma(counts + 1)) = Log sum(counts!)
|
||||
counts_factorial = math_ops.lgamma(counts + 1)
|
||||
redundant_permutations = math_ops.reduce_sum(counts_factorial,
|
||||
reduction_indices=[-1])
|
||||
return total_permutations - redundant_permutations
|
||||
|
||||
|
||||
class DirichletMultinomial(distribution.Distribution):
|
||||
"""DirichletMultinomial mixture distribution.
|
||||
|
||||
@ -126,6 +100,7 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
counts = [2, 1, 0]
|
||||
dist.pmf(counts) # Shape [2]
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
# TODO(b/27419586) Change docstring for dtype of alpha once int allowed.
|
||||
@ -134,26 +109,26 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
alpha,
|
||||
validate_args=True,
|
||||
allow_nan_stats=False,
|
||||
name='DirichletMultinomial'):
|
||||
name="DirichletMultinomial"):
|
||||
"""Initialize a batch of DirichletMultinomial distributions.
|
||||
|
||||
Args:
|
||||
n: Non-negative `float` or `double` tensor, whose dtype is the same as
|
||||
n: Non-negative floating point tensor, whose dtype is the same as
|
||||
`alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`.
|
||||
Defines this as a batch of `N1 x ... x Nm` different Dirichlet
|
||||
multinomial distributions. Its components should be equal to integral
|
||||
multinomial distributions. Its components should be equal to integer
|
||||
values.
|
||||
alpha: Positive `float` or `double` tensor, whose dtype is the same as
|
||||
alpha: Positive floating point tensor, whose dtype is the same as
|
||||
`n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines
|
||||
this as a batch of `N1 x ... x Nm` different `k` class Dirichlet
|
||||
multinomial distributions.
|
||||
validate_args: Whether to assert valid values for parameters `alpha` and
|
||||
`n`, and `x` in `prob` and `log_prob`. If False, correct behavior is
|
||||
`n`, and `x` in `prob` and `log_prob`. If `False`, correct behavior is
|
||||
not guaranteed.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to prefix Ops created by this distribution class.
|
||||
|
||||
Examples:
|
||||
@ -166,6 +141,7 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
# Define a 2-batch of 3-class distributions.
|
||||
dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
```
|
||||
|
||||
"""
|
||||
self._allow_nan_stats = allow_nan_stats
|
||||
self._validate_args = validate_args
|
||||
@ -221,7 +197,7 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
"""dtype of samples from this distribution."""
|
||||
return self._alpha.dtype
|
||||
|
||||
def mean(self, name='mean'):
|
||||
def mean(self, name="mean"):
|
||||
"""Class means for every batch member."""
|
||||
alpha = self._alpha
|
||||
alpha_sum = self._alpha_sum
|
||||
@ -231,7 +207,7 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1)
|
||||
return array_ops.expand_dims(n, -1) * mean_no_n
|
||||
|
||||
def variance(self, name='mean'):
|
||||
def variance(self, name="mean"):
|
||||
"""Class variances for every batch member.
|
||||
|
||||
The variance for each batch member is defined as the following:
|
||||
@ -273,7 +249,7 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
variance *= array_ops.expand_dims(shared_factor, -1)
|
||||
return variance
|
||||
|
||||
def batch_shape(self, name='batch_shape'):
|
||||
def batch_shape(self, name="batch_shape"):
|
||||
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
||||
|
||||
The product of the dimensions of the `batch_shape` is the number of
|
||||
@ -299,7 +275,7 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
"""
|
||||
return self._get_batch_shape
|
||||
|
||||
def event_shape(self, name='event_shape'):
|
||||
def event_shape(self, name="event_shape"):
|
||||
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
||||
|
||||
Args:
|
||||
@ -322,15 +298,15 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
"""
|
||||
return self._get_event_shape
|
||||
|
||||
def cdf(self, x, name='cdf'):
|
||||
def cdf(self, x, name="cdf"):
|
||||
raise NotImplementedError(
|
||||
'DirichletMultinomial does not have a well-defined cdf.')
|
||||
"DirichletMultinomial does not have a well-defined cdf.")
|
||||
|
||||
def log_cdf(self, x, name='log_cdf'):
|
||||
def log_cdf(self, x, name="log_cdf"):
|
||||
raise NotImplementedError(
|
||||
'DirichletMultinomial does not have a well-defined cdf.')
|
||||
"DirichletMultinomial does not have a well-defined cdf.")
|
||||
|
||||
def log_prob(self, counts, name='log_prob'):
|
||||
def log_prob(self, counts, name="log_prob"):
|
||||
"""`Log(P[counts])`, computed for every batch member.
|
||||
|
||||
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
||||
@ -340,12 +316,11 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
probability includes a combinatorial coefficient.
|
||||
|
||||
Args:
|
||||
counts: Non-negative `float` or `double` tensor whose dtype is the same
|
||||
`self` and whose shape can be broadcast with `self.alpha`. For fixed
|
||||
leading dimensions, the last dimension represents counts for the
|
||||
corresponding Dirichlet Multinomial distribution in `self.alpha`.
|
||||
`counts` is only legal if it sums up to `n` and its components are
|
||||
equal to integral values.
|
||||
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||
broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||
dimension represents counts for the corresponding Dirichlet Multinomial
|
||||
distribution in `self.alpha`. `counts` is only legal if it sums up to
|
||||
`n` and its components are equal to integer values.
|
||||
name: Name to give this Op, defaults to "log_prob".
|
||||
|
||||
Returns:
|
||||
@ -359,20 +334,11 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
|
||||
ordered_prob = (special_math_ops.lbeta(alpha + counts) -
|
||||
special_math_ops.lbeta(alpha))
|
||||
log_prob = ordered_prob + _log_combinations(n, counts)
|
||||
# If alpha = counts = [[]], ordered_prob carries the right shape, which
|
||||
# is []. However, since reduce_sum([[]]) = [0], log_combinations = [0],
|
||||
# which is not correct. Luckily, [] + [0] = [], so the sum is fine, but
|
||||
# shape must be inferred from ordered_prob. We must also make this
|
||||
# broadcastable with n, so this is multiplied by n to ensure the shape
|
||||
# is correctly inferred.
|
||||
# Note also that tf.constant([]).get_shape() =
|
||||
# TensorShape([Dimension(0)])
|
||||
broadcasted_tensor = ordered_prob * n
|
||||
log_prob.set_shape(broadcasted_tensor.get_shape())
|
||||
log_prob = ordered_prob + distribution_util.log_combinations(
|
||||
n, counts)
|
||||
return log_prob
|
||||
|
||||
def prob(self, counts, name='prob'):
|
||||
def prob(self, counts, name="prob"):
|
||||
"""`P[counts]`, computed for every batch member.
|
||||
|
||||
For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability
|
||||
@ -382,12 +348,11 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
probability includes a combinatorial coefficient.
|
||||
|
||||
Args:
|
||||
counts: Non-negative `float` or `double` tensor whose dtype is the same
|
||||
`self` and whose shape can be broadcast with `self.alpha`. For fixed
|
||||
leading dimensions, the last dimension represents counts for the
|
||||
corresponding Dirichlet Multinomial distribution in `self.alpha`.
|
||||
`counts` is only legal if it sums up to `n` and its components are
|
||||
equal to integral values.
|
||||
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||
broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||
dimension represents counts for the corresponding Dirichlet Multinomial
|
||||
distribution in `self.alpha`. `counts` is only legal if it sums up to
|
||||
`n` and its components are equal to integer values.
|
||||
name: Name to give this Op, defaults to "prob".
|
||||
|
||||
Returns:
|
||||
@ -397,18 +362,21 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
|
||||
def _check_counts(self, counts):
|
||||
"""Check counts for proper shape, values, then return tensor version."""
|
||||
counts = ops.convert_to_tensor(counts, name='counts')
|
||||
counts = ops.convert_to_tensor(counts, name="counts")
|
||||
if not self.validate_args:
|
||||
return counts
|
||||
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
|
||||
|
||||
return control_flow_ops.with_dependencies([
|
||||
check_ops.assert_non_negative(counts),
|
||||
check_ops.assert_equal(self._n, candidate_n),
|
||||
_assert_integer_form(counts)], counts)
|
||||
check_ops.assert_equal(
|
||||
self._n, candidate_n,
|
||||
message="counts do not sum to n"
|
||||
),
|
||||
distribution_util.assert_integer_form(counts)], counts)
|
||||
|
||||
def _check_alpha(self, alpha):
|
||||
alpha = ops.convert_to_tensor(alpha, name='alpha')
|
||||
alpha = ops.convert_to_tensor(alpha, name="alpha")
|
||||
if not self.validate_args:
|
||||
return alpha
|
||||
return control_flow_ops.with_dependencies(
|
||||
@ -416,11 +384,12 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
check_ops.assert_positive(alpha)], alpha)
|
||||
|
||||
def _check_n(self, n):
|
||||
n = ops.convert_to_tensor(n, name='n')
|
||||
n = ops.convert_to_tensor(n, name="n")
|
||||
if not self.validate_args:
|
||||
return n
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_non_negative(n), _assert_integer_form(n)], n)
|
||||
[check_ops.assert_non_negative(n),
|
||||
distribution_util.assert_integer_form(n)], n)
|
||||
|
||||
@property
|
||||
def is_continuous(self):
|
||||
|
177
tensorflow/contrib/distributions/python/ops/distribution_util.py
Normal file
177
tensorflow/contrib/distributions/python/ops/distribution_util.py
Normal file
@ -0,0 +1,177 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for probability distributions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def assert_close(
|
||||
x, y, data=None, summarize=None, message=None, name="assert_close"):
|
||||
"""Assert that that x and y are within machine epsilon of each other.
|
||||
|
||||
Args:
|
||||
x: Numeric `Tensor`
|
||||
y: Numeric `Tensor`
|
||||
data: The tensors to print out if the condition is `False`. Defaults to
|
||||
error message and first few entries of `x` and `y`.
|
||||
summarize: Print this many entries of each tensor.
|
||||
message: A string to prefix to the default message.
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
|
||||
"""
|
||||
message = message or ""
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
y = ops.convert_to_tensor(y, name="y")
|
||||
|
||||
if x.dtype.is_integer:
|
||||
return check_ops.assert_equal(
|
||||
x, y, data=data, summarize=summarize, message=message, name=name)
|
||||
|
||||
with ops.op_scope([x, y, data], name, "assert_close"):
|
||||
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
|
||||
if data is None:
|
||||
data = [
|
||||
message,
|
||||
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
|
||||
y.name, y
|
||||
]
|
||||
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
|
||||
return logging_ops.Assert(
|
||||
condition, data, summarize=summarize)
|
||||
|
||||
|
||||
def assert_integer_form(
|
||||
x, data=None, summarize=None, message=None, name="assert_integer_form"):
|
||||
"""Assert that x has integer components (or floats equal to integers).
|
||||
|
||||
Args:
|
||||
x: Numeric `Tensor`
|
||||
data: The tensors to print out if the condition is `False`. Defaults to
|
||||
error message and first few entries of `x` and `y`.
|
||||
summarize: Print this many entries of each tensor.
|
||||
message: A string to prefix to the default message.
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
Op raising `InvalidArgumentError` if round(x) != x.
|
||||
"""
|
||||
|
||||
message = message or "x has non-integer components"
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
casted_x = math_ops.to_int64(x)
|
||||
return check_ops.assert_equal(
|
||||
x, math_ops.cast(math_ops.round(casted_x), x.dtype),
|
||||
data=data, summarize=summarize, message=message, name=name)
|
||||
|
||||
|
||||
def get_logits_and_prob(
|
||||
logits=None, p=None, multidimensional=False, validate_args=True, name=None):
|
||||
"""Converts logits to probabilities and vice-versa, and returns both.
|
||||
|
||||
Args:
|
||||
logits: Numeric `Tensor` representing log-odds.
|
||||
p: Numeric `Tensor` representing probabilities.
|
||||
multidimensional: Given `p` a [N1, N2, ... k] dimensional tensor,
|
||||
whether the last dimension represents the probability between k classes.
|
||||
This will additionally assert that the values in the last dimension
|
||||
sum to one. If `False`, will instead assert that each value is in
|
||||
`[0, 1]`.
|
||||
validate_args: Whether to assert `0 <= p <= 1` if multidimensional is
|
||||
`False`, otherwise that the last dimension of `p` sums to one.
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
Tuple with `logits` and `p`. If `p` has an entry that is `0` or `1`, then
|
||||
the corresponding entry in the returned logits will be `-Inf` and `Inf`
|
||||
respectively.
|
||||
|
||||
Raises:
|
||||
ValueError: if neither `p` nor `logits` were passed in, or both were.
|
||||
"""
|
||||
if p is None and logits is None:
|
||||
raise ValueError("Must pass p or logits.")
|
||||
elif p is not None and logits is not None:
|
||||
raise ValueError("Must pass either p or logits, not both.")
|
||||
elif p is None:
|
||||
with ops.op_scope([logits], name):
|
||||
logits = array_ops.identity(logits, name="logits")
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope("p"):
|
||||
p = math_ops.sigmoid(logits)
|
||||
elif logits is None:
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope("p"):
|
||||
p = array_ops.identity(p)
|
||||
if validate_args:
|
||||
one = constant_op.constant(1., p.dtype)
|
||||
dependencies = [check_ops.assert_non_negative(p)]
|
||||
if multidimensional:
|
||||
dependencies += [assert_close(
|
||||
math_ops.reduce_sum(p, reduction_indices=[-1]),
|
||||
one, message="p does not sum to 1.")]
|
||||
else:
|
||||
dependencies += [check_ops.assert_less_equal(
|
||||
p, one, message="p has components greater than 1.")]
|
||||
p = control_flow_ops.with_dependencies(dependencies, p)
|
||||
with ops.name_scope("logits"):
|
||||
logits = math_ops.log(p) - math_ops.log(1. - p)
|
||||
return (logits, p)
|
||||
|
||||
|
||||
def log_combinations(n, counts, name="log_combinations"):
|
||||
"""Multinomial coefficient.
|
||||
|
||||
Given `n` and `counts`, where `counts` has last dimension `k`, we compute
|
||||
the multinomial coefficient as:
|
||||
|
||||
```n! / sum_i n_i!```
|
||||
|
||||
where `i` runs over all `k` classes.
|
||||
|
||||
Args:
|
||||
n: Numeric `Tensor` broadcastable with `counts`. This represents `n`
|
||||
outcomes.
|
||||
counts: Numeric `Tensor` broadcastable with `n`. This represents counts
|
||||
in `k` classes, where `k` is the last dimension of the tensor.
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
`Tensor` representing the multinomial coefficient between `n` and `counts`.
|
||||
"""
|
||||
# First a bit about the number of ways counts could have come in:
|
||||
# E.g. if counts = [1, 2], then this is 3 choose 2.
|
||||
# In general, this is (sum counts)! / sum(counts!)
|
||||
# The sum should be along the last dimension of counts. This is the
|
||||
# "distribution" dimension. Here n a priori represents the sum of counts.
|
||||
with ops.op_scope([n, counts], name):
|
||||
total_permutations = math_ops.lgamma(n + 1)
|
||||
counts_factorial = math_ops.lgamma(counts + 1)
|
||||
redundant_permutations = math_ops.reduce_sum(counts_factorial,
|
||||
reduction_indices=[-1])
|
||||
return total_permutations - redundant_permutations
|
@ -46,15 +46,15 @@ class Exponential(gamma.Gamma):
|
||||
"""Construct Exponential distribution with parameter `lam`.
|
||||
|
||||
Args:
|
||||
lam: `float` or `double` tensor, the rate of the distribution(s).
|
||||
lam: Floating point tensor, the rate of the distribution(s).
|
||||
`lam` must contain only positive values.
|
||||
validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the
|
||||
methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
||||
methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||
and the inputs are invalid, correct behavior is not guaranteed.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to prepend to all ops created by this distribution.
|
||||
"""
|
||||
# Even though all statistics of are defined for valid inputs, this is not
|
||||
@ -95,8 +95,7 @@ class Exponential(gamma.Gamma):
|
||||
broadcast_shape = self._lam.get_shape()
|
||||
with ops.op_scope([self.lam, n], name, "ExponentialSample"):
|
||||
n = ops.convert_to_tensor(n, name="n")
|
||||
shape = array_ops.concat(
|
||||
0, [array_ops.pack([n]), array_ops.shape(self._lam)])
|
||||
shape = array_ops.concat(0, ([n], array_ops.shape(self._lam)))
|
||||
# Sample uniformly-at-random from the open-interval (0, 1).
|
||||
sampled = random_ops.random_uniform(
|
||||
shape, minval=np.nextafter(
|
||||
|
@ -69,19 +69,19 @@ class Gamma(distribution.Distribution):
|
||||
broadcasting (e.g. `alpha + beta` is a valid operation).
|
||||
|
||||
Args:
|
||||
alpha: `float` or `double` tensor, the shape params of the
|
||||
alpha: Floating point tensor, the shape params of the
|
||||
distribution(s).
|
||||
alpha must contain only positive values.
|
||||
beta: `float` or `double` tensor, the inverse scale params of the
|
||||
beta: Floating point tensor, the inverse scale params of the
|
||||
distribution(s).
|
||||
beta must contain only positive values.
|
||||
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
|
||||
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
||||
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||
and the inputs are invalid, correct behavior is not guaranteed.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to prepend to all ops created by this distribution.
|
||||
|
||||
Raises:
|
||||
@ -213,9 +213,12 @@ class Gamma(distribution.Distribution):
|
||||
nan = np.nan * self._ones()
|
||||
return math_ops.select(alpha_ge_1, mode_if_defined, nan)
|
||||
else:
|
||||
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
|
||||
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_less(one, alpha)], mode_if_defined)
|
||||
[check_ops.assert_less(
|
||||
one, alpha,
|
||||
message="mode not defined for components of alpha <= 1"
|
||||
)], mode_if_defined)
|
||||
|
||||
def variance(self, name="variance"):
|
||||
"""Variance of each batch member."""
|
||||
|
@ -69,18 +69,18 @@ class InverseGamma(distribution.Distribution):
|
||||
broadcasting (e.g. `alpha + beta` is a valid operation).
|
||||
|
||||
Args:
|
||||
alpha: `float` or `double` tensor, the shape params of the
|
||||
alpha: Floating point tensor, the shape params of the
|
||||
distribution(s).
|
||||
alpha must contain only positive values.
|
||||
beta: `float` or `double` tensor, the scale params of the distribution(s).
|
||||
beta: Floating point tensor, the scale params of the distribution(s).
|
||||
beta must contain only positive values.
|
||||
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
|
||||
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
||||
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||
and the inputs are invalid, correct behavior is not guaranteed.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to prepend to all ops created by this distribution.
|
||||
|
||||
Raises:
|
||||
@ -206,9 +206,12 @@ class InverseGamma(distribution.Distribution):
|
||||
nan = np.nan * self._ones()
|
||||
return math_ops.select(alpha_gt_1, mean_if_defined, nan)
|
||||
else:
|
||||
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
|
||||
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_less(one, alpha)], mean_if_defined)
|
||||
[check_ops.assert_less(
|
||||
one, alpha,
|
||||
message="mean not defined for components of alpha <= 1")],
|
||||
mean_if_defined)
|
||||
|
||||
def mode(self, name="mode"):
|
||||
"""Mode of each batch member.
|
||||
@ -250,9 +253,12 @@ class InverseGamma(distribution.Distribution):
|
||||
nan = np.nan * self._ones()
|
||||
return math_ops.select(alpha_gt_2, var_if_defined, nan)
|
||||
else:
|
||||
two = ops.convert_to_tensor(2.0, dtype=self.dtype)
|
||||
two = constant_op.constant(2.0, dtype=self.dtype)
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_less(two, alpha)], var_if_defined)
|
||||
[check_ops.assert_less(
|
||||
two, alpha,
|
||||
message="variance not defined for components of alpha <= 2")],
|
||||
var_if_defined)
|
||||
|
||||
def log_prob(self, x, name="log_prob"):
|
||||
"""Log prob of observations in `x` under these InverseGamma distribution(s).
|
||||
|
@ -34,9 +34,9 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
|
||||
Args:
|
||||
dist_a: instance of distributions.Distribution.
|
||||
dist_b: instance of distributions.Distribution.
|
||||
allow_nan: If False (default), a runtime error is raised
|
||||
allow_nan: If `False` (default), a runtime error is raised
|
||||
if the KL returns NaN values for any batch entry of the given
|
||||
distributions. If True, the KL may return a NaN for the given entry.
|
||||
distributions. If `True`, the KL may return a NaN for the given entry.
|
||||
name: (optional) Name scope to use for created operations.
|
||||
|
||||
Returns:
|
||||
|
@ -60,17 +60,17 @@ class Laplace(distribution.Distribution):
|
||||
broadcasting (e.g., `loc / scale` is a valid operation).
|
||||
|
||||
Args:
|
||||
loc: `float` or `double` tensor which characterizes the location (center)
|
||||
loc: Floating point tensor which characterizes the location (center)
|
||||
of the distribution.
|
||||
scale: `float` or `double`, positive-valued tensor which characterzes the
|
||||
spread of the distribution.
|
||||
scale: Positive floating point tensor which characterizes the spread of
|
||||
the distribution.
|
||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||
is `False`, and the inputs are invalid, correct behavior is not
|
||||
guaranteed.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to give Ops created by the initializer.
|
||||
|
||||
Raises:
|
||||
@ -294,8 +294,7 @@ class Laplace(distribution.Distribution):
|
||||
with ops.op_scope([self._loc, self._scale, n], name):
|
||||
n = ops.convert_to_tensor(n)
|
||||
n_val = tensor_util.constant_value(n)
|
||||
shape = array_ops.concat(
|
||||
0, [array_ops.pack([n]), self.batch_shape()])
|
||||
shape = array_ops.concat(0, ([n], self.batch_shape()))
|
||||
# Sample uniformly-at-random from the open-interval (-1, 1).
|
||||
uniform_samples = random_ops.random_uniform(
|
||||
shape=shape,
|
||||
|
343
tensorflow/contrib/distributions/python/ops/multinomial.py
Normal file
343
tensorflow/contrib/distributions/python/ops/multinomial.py
Normal file
@ -0,0 +1,343 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Multinomial distribution class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
class Multinomial(distribution.Distribution):
|
||||
"""Multinomial distribution.
|
||||
|
||||
This distribution is parameterized by a vector `p` of probability
|
||||
parameters for `k` classes and `n`, the counts per each class..
|
||||
|
||||
#### Mathematical details
|
||||
|
||||
The Multinomial is a distribution over k-class count data, meaning
|
||||
for each k-tuple of non-negative integer `counts = [n_1,...,n_k]`, we have a
|
||||
probability of these draws being made from the distribution. The distribution
|
||||
has hyperparameters `p = (p_1,...,p_k)`, and probability mass
|
||||
function (pmf):
|
||||
|
||||
```pmf(counts) = n! / (n_1!...n_k!) * (p_1)^n_1*(p_2)^n_2*...(p_k)^n_k```
|
||||
|
||||
where above `n = sum_j n_j`, `n!` is `n` factorial.
|
||||
|
||||
#### Examples
|
||||
|
||||
Create a 3-class distribution, with the 3rd class is most likely to be drawn,
|
||||
using logits..
|
||||
|
||||
```python
|
||||
logits = [-50., -43, 0]
|
||||
dist = Multinomial(n=4., logits=logits)
|
||||
```
|
||||
|
||||
Create a 3-class distribution, with the 3rd class is most likely to be drawn.
|
||||
|
||||
```python
|
||||
p = [.2, .3, .5]
|
||||
dist = Multinomial(n=4., p=p)
|
||||
```
|
||||
|
||||
The distribution functions can be evaluated on counts.
|
||||
|
||||
```python
|
||||
# counts same shape as p.
|
||||
counts = [1., 0, 3]
|
||||
dist.prob(counts) # Shape []
|
||||
|
||||
# p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
|
||||
counts = [[1., 2, 1], [2, 2, 0]]
|
||||
dist.prob(counts) # Shape [2]
|
||||
|
||||
# p will be broadcast to shape [5, 7, 3] to match counts.
|
||||
counts = [[...]] # Shape [5, 7, 3]
|
||||
dist.prob(counts) # Shape [5, 7]
|
||||
```
|
||||
|
||||
Create a 2-batch of 3-class distributions.
|
||||
|
||||
```python
|
||||
p = [[.1, .2, .7], [.3, .3, .4]] # Shape [2, 3]
|
||||
dist = Multinomial(n=[4., 5], p=p)
|
||||
|
||||
counts = [[2., 1, 1], [3, 1, 1]]
|
||||
dist.prob(counts) # Shape [2]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n,
|
||||
logits=None,
|
||||
p=None,
|
||||
validate_args=True,
|
||||
allow_nan_stats=False,
|
||||
name="Multinomial"):
|
||||
"""Initialize a batch of Multinomial distributions.
|
||||
|
||||
Args:
|
||||
n: Non-negative floating point tensor with shape broadcastable to
|
||||
`[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
|
||||
`N1 x ... x Nm` different Multinomial distributions. Its components
|
||||
should be equal to integer values.
|
||||
logits: Floating point tensor representing the log-odds of a
|
||||
positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
|
||||
and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm`
|
||||
different `k` class Multinomial distributions.
|
||||
p: Positive floating point tensor with shape broadcastable to
|
||||
`[N1,..., Nm, k]` `m >= 0` and same dtype as `n`. Defines this as
|
||||
a batch of `N1 x ... x Nm` different `k` class Multinomial
|
||||
distributions. `p`'s components in the last portion of its shape should
|
||||
sum up to 1.
|
||||
validate_args: Whether to assert valid values for parameters `n` and `p`,
|
||||
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||
guaranteed.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to prefix Ops created by this distribution class.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Define 1-batch of 2-class multinomial distribution,
|
||||
# also known as a Binomial distribution.
|
||||
dist = Multinomial(n=2., p=[.1, .9])
|
||||
|
||||
# Define a 2-batch of 3-class distributions.
|
||||
dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]])
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
self._logits, self._p = distribution_util.get_logits_and_prob(
|
||||
name=name, logits=logits, p=p, validate_args=validate_args,
|
||||
multidimensional=True)
|
||||
with ops.op_scope([n, self._p], name):
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_non_negative(
|
||||
n, message="n has negative components."),
|
||||
distribution_util.assert_integer_form(
|
||||
n, message="n has non-integer components."
|
||||
)] if validate_args else []):
|
||||
self._n = array_ops.identity(n, name="convert_n")
|
||||
self._name = name
|
||||
|
||||
self._validate_args = validate_args
|
||||
self._allow_nan_stats = allow_nan_stats
|
||||
|
||||
self._mean = array_ops.expand_dims(n, -1) * self._p
|
||||
# Only used for inferring shape.
|
||||
self._broadcast_shape = math_ops.reduce_sum(self._mean,
|
||||
reduction_indices=[-1],
|
||||
keep_dims=False)
|
||||
|
||||
self._get_batch_shape = self._broadcast_shape.get_shape()
|
||||
self._get_event_shape = (
|
||||
self._mean.get_shape().with_rank_at_least(1)[-1:])
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
"""Number of trials."""
|
||||
return self._n
|
||||
|
||||
@property
|
||||
def p(self):
|
||||
"""Event probabilities."""
|
||||
return self._p
|
||||
|
||||
@property
|
||||
def logits(self):
|
||||
"""Log-odds."""
|
||||
return self._logits
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Name to prepend to all ops."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""dtype of samples from this distribution."""
|
||||
return self._p.dtype
|
||||
|
||||
@property
|
||||
def validate_args(self):
|
||||
"""Boolean describing behavior on invalid input."""
|
||||
return self._validate_args
|
||||
|
||||
@property
|
||||
def allow_nan_stats(self):
|
||||
"""Boolean describing behavior when a stat is undefined for batch member."""
|
||||
return self._allow_nan_stats
|
||||
|
||||
def batch_shape(self, name="batch_shape"):
|
||||
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
||||
|
||||
The product of the dimensions of the `batch_shape` is the number of
|
||||
independent distributions of this kind the instance represents.
|
||||
|
||||
Args:
|
||||
name: name to give to the op
|
||||
|
||||
Returns:
|
||||
`Tensor` `batch_shape`
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._broadcast_shape], name):
|
||||
return array_ops.shape(self._broadcast_shape)
|
||||
|
||||
def get_batch_shape(self):
|
||||
"""`TensorShape` available at graph construction time.
|
||||
|
||||
Same meaning as `batch_shape`. May be only partially defined.
|
||||
|
||||
Returns:
|
||||
batch shape
|
||||
"""
|
||||
return self._get_batch_shape
|
||||
|
||||
def event_shape(self, name="event_shape"):
|
||||
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
||||
|
||||
Args:
|
||||
name: name to give to the op
|
||||
|
||||
Returns:
|
||||
`Tensor` `event_shape`
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._mean], name):
|
||||
return array_ops.gather(array_ops.shape(self._mean),
|
||||
[array_ops.rank(self._mean) - 1])
|
||||
|
||||
def get_event_shape(self):
|
||||
"""`TensorShape` available at graph construction time.
|
||||
|
||||
Same meaning as `event_shape`. May be only partially defined.
|
||||
|
||||
Returns:
|
||||
event shape
|
||||
"""
|
||||
return self._get_event_shape
|
||||
|
||||
def mean(self, name="mean"):
|
||||
"""Mean of the distribution."""
|
||||
with ops.name_scope(self.name):
|
||||
return array_ops.identity(self._mean, name=name)
|
||||
|
||||
def variance(self, name="variance"):
|
||||
"""Variance of the distribution."""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._n, self._p, self._mean], name):
|
||||
p = array_ops.expand_dims(
|
||||
self._p * array_ops.expand_dims(
|
||||
array_ops.ones_like(self._n), -1), -1)
|
||||
variance = -math_ops.batch_matmul(
|
||||
array_ops.expand_dims(self._mean, -1), p, adj_y=True)
|
||||
variance += array_ops.batch_matrix_diag(self._mean)
|
||||
return variance
|
||||
|
||||
def log_prob(self, counts, name="log_prob"):
|
||||
"""`Log(P[counts])`, computed for every batch member.
|
||||
|
||||
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
||||
that after sampling `n` draws from this Multinomial distribution, the
|
||||
number of draws falling in class `j` is `n_j`. Note that different
|
||||
sequences of draws can result in the same counts, thus the probability
|
||||
includes a combinatorial coefficient.
|
||||
|
||||
Args:
|
||||
counts: Non-negative tensor with dtype `dtype` and whose shape can
|
||||
be broadcast with `self.p` and `self.n`. For fixed leading dimensions,
|
||||
the last dimension represents counts for the corresponding Multinomial
|
||||
distribution in `self.p`. `counts` is only legal if it sums up to `n`
|
||||
and its components are equal to integer values.
|
||||
name: Name to give this Op, defaults to "log_prob".
|
||||
|
||||
Returns:
|
||||
Log probabilities for each record, shape `[N1,...,Nm]`.
|
||||
"""
|
||||
n = self._n
|
||||
p = self._p
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([n, p, counts], name):
|
||||
counts = self._check_counts(counts)
|
||||
|
||||
prob_prob = math_ops.reduce_sum(counts * math_ops.log(self._p),
|
||||
reduction_indices=[-1])
|
||||
log_prob = prob_prob + distribution_util.log_combinations(
|
||||
n, counts)
|
||||
return log_prob
|
||||
|
||||
def prob(self, counts, name="prob"):
|
||||
"""`P[counts]`, computed for every batch member.
|
||||
|
||||
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
||||
that after sampling `n` draws from this Multinomial distribution, the
|
||||
number of draws falling in class `j` is `n_j`. Note that different
|
||||
sequences of draws can result in the same counts, thus the probability
|
||||
includes a combinatorial coefficient.
|
||||
|
||||
Args:
|
||||
counts: Non-negative tensor with dtype `dtype` and whose shape can
|
||||
be broadcast with `self.p` and `self.n`. For fixed leading dimensions,
|
||||
the last dimension represents counts for the corresponding Multinomial
|
||||
distribution in `self.p`. `counts` is only legal if it sums up to `n`
|
||||
and its components are equal to integer values.
|
||||
name: Name to give this Op, defaults to "prob".
|
||||
|
||||
Returns:
|
||||
Probabilities for each record, shape `[N1,...,Nm]`.
|
||||
"""
|
||||
return super(Multinomial, self).prob(counts, name=name)
|
||||
|
||||
@property
|
||||
def is_continuous(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_reparameterized(self):
|
||||
return False
|
||||
|
||||
def _check_counts(self, counts):
|
||||
"""Check counts for proper shape, values, then return tensor version."""
|
||||
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
|
||||
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
|
||||
if not self.validate_args:
|
||||
return counts
|
||||
|
||||
return control_flow_ops.with_dependencies([
|
||||
check_ops.assert_non_negative(
|
||||
counts, message="counts has negative components."),
|
||||
check_ops.assert_equal(
|
||||
self._n, candidate_n, message="counts do not sum to n."),
|
||||
distribution_util.assert_integer_form(
|
||||
counts, message="counts have non-integer components.")], counts)
|
@ -105,9 +105,9 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
|
||||
which determines the covariance.
|
||||
|
||||
Args:
|
||||
mu: `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
|
||||
cov: `float` or `double` instance of `OperatorPDBase` with same `dtype`
|
||||
as `mu` and shape `[N1,...,Nb, k, k]`.
|
||||
mu: Floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
|
||||
cov: Instance of `OperatorPDBase` with same `dtype` as `mu` and shape
|
||||
`[N1,...,Nb, k, k]`.
|
||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||
is `False`, and the inputs are invalid, correct behavior is not
|
||||
guaranteed.
|
||||
@ -466,7 +466,7 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
|
||||
The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stdev[i]`.
|
||||
|
||||
Args:
|
||||
mu: Rank `N + 1` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
||||
mu: Rank `N + 1` floating point tensor with shape `[N1,...,Nb, k]`,
|
||||
`b >= 0`.
|
||||
diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`,
|
||||
representing the standard deviations. Must be positive.
|
||||
@ -581,13 +581,13 @@ class MultivariateNormalDiagPlusVDVT(MultivariateNormalOperatorPD):
|
||||
```
|
||||
|
||||
Args:
|
||||
mu: Rank `n + 1` `float` or `double` tensor with shape `[N1,...,Nn, k]`,
|
||||
mu: Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`,
|
||||
`n >= 0`. The means.
|
||||
diag_large: Optional rank `n + 1` `float` or `double` tensor, shape
|
||||
diag_large: Optional rank `n + 1` floating point tensor, shape
|
||||
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`.
|
||||
v: Rank `n + 1` `float` or `double` tensor, shape `[N1,...,Nn, k, r]`
|
||||
v: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]`
|
||||
`n >= 0`. Defines the matrix `V`.
|
||||
diag_small: Rank `n + 1` `float` or `double` tensor, shape
|
||||
diag_small: Rank `n + 1` floating point tensor, shape
|
||||
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default
|
||||
is `None`, which means `D` will be the identity matrix.
|
||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||
@ -670,7 +670,7 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
|
||||
factors, such that the covariance of each batch member is `chol chol^T`.
|
||||
|
||||
Args:
|
||||
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
||||
mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
|
||||
`b >= 0`.
|
||||
chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
||||
`[N1,...,Nb, k, k]`. The upper triangular part is ignored (treated as
|
||||
@ -750,7 +750,7 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
|
||||
User must provide means `mu` and `sigma`, the mean and covariance.
|
||||
|
||||
Args:
|
||||
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
||||
mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
|
||||
`b >= 0`.
|
||||
sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
||||
`[N1,...,Nb, k, k]`. Each batch member must be positive definite.
|
||||
|
@ -92,15 +92,15 @@ class Normal(distribution.Distribution):
|
||||
broadcasting (e.g. `mu + sigma` is a valid operation).
|
||||
|
||||
Args:
|
||||
mu: `float` or `double` tensor, the means of the distribution(s).
|
||||
sigma: `float` or `double` tensor, the stddevs of the distribution(s).
|
||||
mu: Floating point tensor, the means of the distribution(s).
|
||||
sigma: Floating point tensor, the stddevs of the distribution(s).
|
||||
sigma must contain only positive values.
|
||||
validate_args: Whether to assert that `sigma > 0`. If `validate_args` is
|
||||
False, correct output is not guaranteed when input is invalid.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
`False`, correct output is not guaranteed when input is invalid.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to give Ops created by the initializer.
|
||||
|
||||
Raises:
|
||||
@ -321,8 +321,7 @@ class Normal(distribution.Distribution):
|
||||
with ops.op_scope([self._mu, self._sigma, n], name):
|
||||
broadcast_shape = (self._mu + self._sigma).get_shape()
|
||||
n = ops.convert_to_tensor(n)
|
||||
shape = array_ops.concat(
|
||||
0, [array_ops.pack([n]), array_ops.shape(self.mean())])
|
||||
shape = array_ops.concat(0, ([n], array_ops.shape(self.mean())))
|
||||
sampled = random_ops.random_normal(
|
||||
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
|
||||
|
||||
|
@ -82,6 +82,7 @@ class StudentT(distribution.Distribution):
|
||||
# returning a length 2 tensor.
|
||||
dist.pdf(3.0)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -99,19 +100,19 @@ class StudentT(distribution.Distribution):
|
||||
broadcasting (e.g. `df + mu + sigma` is a valid operation).
|
||||
|
||||
Args:
|
||||
df: `float` or `double` tensor, the degrees of freedom of the
|
||||
df: Floating point tensor, the degrees of freedom of the
|
||||
distribution(s). `df` must contain only positive values.
|
||||
mu: `float` or `double` tensor, the means of the distribution(s).
|
||||
sigma: `float` or `double` tensor, the scaling factor for the
|
||||
mu: Floating point tensor, the means of the distribution(s).
|
||||
sigma: Floating point tensor, the scaling factor for the
|
||||
distribution(s). `sigma` must contain only positive values.
|
||||
Note that `sigma` is not the standard deviation of this distribution.
|
||||
validate_args: Whether to assert that `df > 0, sigma > 0`. If
|
||||
`validate_args` is False and inputs are invalid, correct behavior is not
|
||||
guaranteed.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
`validate_args` is `False` and inputs are invalid, correct behavior is
|
||||
not guaranteed.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to give Ops created by the initializer.
|
||||
|
||||
Raises:
|
||||
@ -185,9 +186,12 @@ class StudentT(distribution.Distribution):
|
||||
nan = np.nan + self._zeros()
|
||||
return math_ops.select(df_gt_1, result_if_defined, nan)
|
||||
else:
|
||||
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
|
||||
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_less(one, self._df)], result_if_defined)
|
||||
[check_ops.assert_less(
|
||||
one, self._df,
|
||||
message="mean not defined for components of df <= 1"
|
||||
)], result_if_defined)
|
||||
|
||||
def mode(self, name="mode"):
|
||||
with ops.name_scope(self.name):
|
||||
@ -232,9 +236,12 @@ class StudentT(distribution.Distribution):
|
||||
result_where_defined,
|
||||
self._zeros() + np.nan)
|
||||
else:
|
||||
one = ops.convert_to_tensor(1.0, self.dtype)
|
||||
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||
return control_flow_ops.with_dependencies(
|
||||
[check_ops.assert_less(one, self._df)], result_where_defined)
|
||||
[check_ops.assert_less(
|
||||
one, self._df,
|
||||
message="variance not defined for components of df <= 1"
|
||||
)], result_where_defined)
|
||||
|
||||
def std(self, name="std"):
|
||||
with ops.name_scope(self.name):
|
||||
@ -348,8 +355,7 @@ class StudentT(distribution.Distribution):
|
||||
# Let X = R*cos(theta), and let Y = R*sin(theta).
|
||||
# Then X ~ t_df and Y ~ t_df.
|
||||
# The variates X and Y are not independent.
|
||||
shape = array_ops.concat(0, [array_ops.pack([2, n]),
|
||||
self.batch_shape()])
|
||||
shape = array_ops.concat(0, ([2, n], self.batch_shape()))
|
||||
uniform = random_ops.random_uniform(shape=shape,
|
||||
dtype=self.dtype,
|
||||
seed=seed)
|
||||
|
@ -57,6 +57,7 @@ class TransformedDistribution(distribution.Distribution):
|
||||
name="LogitNormalTransformedDistribution"
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -67,14 +67,14 @@ class Uniform(distribution.Distribution):
|
||||
```
|
||||
|
||||
Args:
|
||||
a: `float` or `double` tensor, the minimum endpoint.
|
||||
b: `float` or `double` tensor, the maximum endpoint. Must be > `a`.
|
||||
validate_args: Whether to assert that `a > b`. If `validate_args` is False
|
||||
and inputs are invalid, correct behavior is not guaranteed.
|
||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
||||
If True, batch members with valid parameters leading to undefined
|
||||
statistics will return NaN for this statistic.
|
||||
a: Floating point tensor, the minimum endpoint.
|
||||
b: Floating point tensor, the maximum endpoint. Must be > `a`.
|
||||
validate_args: Whether to assert that `a > b`. If `validate_args` is
|
||||
`False` and inputs are invalid, correct behavior is not guaranteed.
|
||||
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
batch member. If `True`, batch members with valid parameters leading to
|
||||
undefined statistics will return NaN for this statistic.
|
||||
name: The name to prefix Ops created by this distribution class.
|
||||
|
||||
Raises:
|
||||
@ -83,8 +83,9 @@ class Uniform(distribution.Distribution):
|
||||
self._allow_nan_stats = allow_nan_stats
|
||||
self._validate_args = validate_args
|
||||
with ops.op_scope([a, b], name):
|
||||
with ops.control_dependencies([check_ops.assert_less(a, b)] if
|
||||
validate_args else []):
|
||||
with ops.control_dependencies([check_ops.assert_less(
|
||||
a, b, message="uniform not defined when a > b.")] if validate_args
|
||||
else []):
|
||||
a = array_ops.identity(a, name="a")
|
||||
b = array_ops.identity(b, name="b")
|
||||
|
||||
@ -228,7 +229,7 @@ class Uniform(distribution.Distribution):
|
||||
n = ops.convert_to_tensor(n, name="n")
|
||||
n_val = tensor_util.constant_value(n)
|
||||
|
||||
shape = array_ops.concat(0, [array_ops.pack([n]), self.batch_shape()])
|
||||
shape = array_ops.concat(0, ([n], self.batch_shape()))
|
||||
samples = random_ops.random_uniform(shape=shape,
|
||||
dtype=self.dtype,
|
||||
seed=seed)
|
||||
|
Loading…
Reference in New Issue
Block a user