Added Bernoulli distribution.
Change: 125604684
This commit is contained in:
parent
e231a8b382
commit
cb44b3a001
@ -107,6 +107,16 @@ cuda_py_tests(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "bernoulli_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/bernoulli_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "mvn_test",
|
||||
size = "small",
|
||||
|
@ -27,6 +27,7 @@ initialized with parameters that define the distributions.
|
||||
|
||||
### Univariate (scalar) distributions
|
||||
|
||||
@@Bernoulli
|
||||
@@Categorical
|
||||
@@Chi2
|
||||
@@Exponential
|
||||
@ -62,6 +63,7 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops.bernoulli import *
|
||||
from tensorflow.contrib.distributions.python.ops.categorical import *
|
||||
from tensorflow.contrib.distributions.python.ops.chi2 import *
|
||||
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
|
||||
|
@ -0,0 +1,163 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the Bernoulli distribution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def make_bernoulli(batch_shape, dtype=tf.int32):
|
||||
p = np.random.uniform(size=list(batch_shape))
|
||||
p = tf.constant(p, dtype=tf.float32)
|
||||
return tf.contrib.distributions.Bernoulli(p, dtype=dtype)
|
||||
|
||||
|
||||
def entropy(p):
|
||||
q = 1. - p
|
||||
return -q * np.log(q) - p * np.log(p)
|
||||
|
||||
|
||||
class BernoulliTest(tf.test.TestCase):
|
||||
|
||||
def testP(self):
|
||||
p = [0.2, 0.4]
|
||||
dist = tf.contrib.distributions.Bernoulli(p)
|
||||
with self.test_session():
|
||||
self.assertAllClose(p, dist.p.eval())
|
||||
|
||||
def testInvalidP(self):
|
||||
invalid_ps = [1.01, -0.01, 2., -3.]
|
||||
for p in invalid_ps:
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError("x <= y"):
|
||||
dist = tf.contrib.distributions.Bernoulli(p)
|
||||
dist.p.eval()
|
||||
|
||||
valid_ps = [0.0, 0.5, 1.0]
|
||||
for p in valid_ps:
|
||||
with self.test_session():
|
||||
dist = tf.contrib.distributions.Bernoulli(p)
|
||||
self.assertEqual(p, dist.p.eval()) # Should not fail
|
||||
|
||||
def testShapes(self):
|
||||
with self.test_session():
|
||||
for batch_shape in ([], [1], [2, 3, 4]):
|
||||
dist = make_bernoulli(batch_shape)
|
||||
self.assertAllEqual(batch_shape, dist.get_batch_shape().as_list())
|
||||
self.assertAllEqual(batch_shape, dist.batch_shape().eval())
|
||||
self.assertAllEqual([], dist.get_event_shape().as_list())
|
||||
self.assertAllEqual([], dist.event_shape().eval())
|
||||
|
||||
def testDtype(self):
|
||||
dist = make_bernoulli([])
|
||||
self.assertEqual(dist.dtype, tf.int32)
|
||||
self.assertEqual(dist.dtype, dist.sample(5).dtype)
|
||||
self.assertEqual(dist.dtype, dist.mode().dtype)
|
||||
self.assertEqual(dist.p.dtype, dist.mean().dtype)
|
||||
self.assertEqual(dist.p.dtype, dist.variance().dtype)
|
||||
self.assertEqual(dist.p.dtype, dist.std().dtype)
|
||||
self.assertEqual(dist.p.dtype, dist.entropy().dtype)
|
||||
self.assertEqual(dist.p.dtype, dist.pmf(0).dtype)
|
||||
self.assertEqual(dist.p.dtype, dist.log_pmf(0).dtype)
|
||||
|
||||
dist64 = make_bernoulli([], tf.int64)
|
||||
self.assertEqual(dist64.dtype, tf.int64)
|
||||
self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
|
||||
self.assertEqual(dist64.dtype, dist64.mode().dtype)
|
||||
|
||||
def testPmf(self):
|
||||
p = [[0.2, 0.4], [0.3, 0.6]]
|
||||
dist = tf.contrib.distributions.Bernoulli(p)
|
||||
with self.test_session():
|
||||
# pylint: disable=bad-continuation
|
||||
xs = [
|
||||
0,
|
||||
[1],
|
||||
[1, 0],
|
||||
[[1, 0]],
|
||||
[[1, 0], [1, 1]],
|
||||
]
|
||||
expected_pmfs = [
|
||||
[[0.8, 0.6], [0.7, 0.4]],
|
||||
[[0.2, 0.4], [0.3, 0.6]],
|
||||
[[0.2, 0.6], [0.3, 0.4]],
|
||||
[[0.2, 0.6], [0.3, 0.4]],
|
||||
[[0.2, 0.6], [0.3, 0.6]],
|
||||
]
|
||||
# pylint: enable=bad-continuation
|
||||
|
||||
for x, expected_pmf in zip(xs, expected_pmfs):
|
||||
self.assertAllClose(dist.pmf(x).eval(), expected_pmf)
|
||||
self.assertAllClose(dist.log_pmf(x).eval(), np.log(expected_pmf))
|
||||
|
||||
def testBoundaryConditions(self):
|
||||
with self.test_session():
|
||||
dist = tf.contrib.distributions.Bernoulli(1.0)
|
||||
self.assertEqual(-np.inf, dist.log_pmf(0).eval())
|
||||
self.assertAllClose([0.0], [dist.log_pmf(1).eval()])
|
||||
|
||||
def testEntropyNoBatch(self):
|
||||
p = 0.2
|
||||
dist = tf.contrib.distributions.Bernoulli(p)
|
||||
with self.test_session():
|
||||
self.assertAllClose(dist.entropy().eval(), entropy(p))
|
||||
|
||||
def testEntropyWithBatch(self):
|
||||
p = [[0.0, 0.7], [1.0, 0.6]]
|
||||
dist = tf.contrib.distributions.Bernoulli(p, strict=False)
|
||||
with self.test_session():
|
||||
self.assertAllClose(dist.entropy().eval(), [[0.0, entropy(0.7)],
|
||||
[0.0, entropy(0.6)]])
|
||||
|
||||
def testSample(self):
|
||||
with self.test_session():
|
||||
p = [0.2, 0.6]
|
||||
dist = tf.contrib.distributions.Bernoulli(p)
|
||||
n = 1000
|
||||
samples = dist.sample(n, seed=123)
|
||||
samples.set_shape([n, 2])
|
||||
self.assertEqual(samples.dtype, tf.int32)
|
||||
sample_values = samples.eval()
|
||||
self.assertFalse(np.any(sample_values < 0))
|
||||
self.assertFalse(np.any(sample_values > 1))
|
||||
self.assertAllClose(p, np.mean(sample_values == 1, axis=0), atol=1e-2)
|
||||
self.assertEqual(set([0, 1]), set(sample_values.flatten()))
|
||||
|
||||
def testMean(self):
|
||||
with self.test_session():
|
||||
p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
|
||||
dist = tf.contrib.distributions.Bernoulli(p)
|
||||
self.assertAllEqual(dist.mean().eval(), p)
|
||||
|
||||
def testVarianceAndStd(self):
|
||||
var = lambda p: p * (1. - p)
|
||||
with self.test_session():
|
||||
p = [[0.2, 0.7], [0.5, 0.4]]
|
||||
dist = tf.contrib.distributions.Bernoulli(p)
|
||||
self.assertAllClose(dist.variance().eval(),
|
||||
np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
|
||||
dtype=np.float32))
|
||||
self.assertAllClose(dist.std().eval(),
|
||||
np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
|
||||
[np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
|
||||
dtype=np.float32))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
231
tensorflow/contrib/distributions/python/ops/bernoulli.py
Normal file
231
tensorflow/contrib/distributions/python/ops/bernoulli.py
Normal file
@ -0,0 +1,231 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Bernoulli distribution class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
|
||||
|
||||
class Bernoulli(distribution.DiscreteDistribution):
|
||||
"""Bernoulli distribution.
|
||||
|
||||
The Bernoulli distribution is parameterized by p, the probability of a
|
||||
positive event.
|
||||
|
||||
Note, the following methods of the base class aren't implemented:
|
||||
* cdf
|
||||
* log_cdf
|
||||
"""
|
||||
|
||||
def __init__(self, p, dtype=dtypes.int32, strict=True, name="Bernoulli"):
|
||||
"""Construct Bernoulli distributions.
|
||||
|
||||
Args:
|
||||
p: An N-D `Tensor` representing the probability of a positive
|
||||
event. Each entry in the `Tensor` parameterizes an independent
|
||||
Bernoulli distribution.
|
||||
dtype: dtype for samples. Note that other values will take the dtype of p.
|
||||
strict: Whether to assert that `0 <= p <= 1`. If not strict, `log_pmf` may
|
||||
return nans.
|
||||
name: A name for this distribution.
|
||||
"""
|
||||
self._name = name
|
||||
self._dtype = dtype
|
||||
self._strict = strict
|
||||
check_op = check_ops.assert_less_equal
|
||||
with ops.op_scope([p], name):
|
||||
with ops.control_dependencies(
|
||||
[check_op(p, 1.), check_op(0., p)] if strict else []):
|
||||
p = array_ops.identity(p, name="p")
|
||||
self._p = p
|
||||
self._q = array_ops.identity(1. - p, name="q")
|
||||
self._batch_shape = array_ops.shape(self._p)
|
||||
self._event_shape = array_ops.constant([], dtype=dtypes.int32)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def is_reparameterized(self):
|
||||
return False
|
||||
|
||||
def batch_shape(self, name="batch_shape"):
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._batch_shape], name):
|
||||
return array_ops.identity(self._batch_shape)
|
||||
|
||||
def get_batch_shape(self):
|
||||
return self.p.get_shape()
|
||||
|
||||
def event_shape(self, name="event_shape"):
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self._batch_shape], name):
|
||||
return array_ops.constant([], dtype=self._batch_shape.dtype)
|
||||
|
||||
def get_event_shape(self):
|
||||
return tensor_shape.scalar()
|
||||
|
||||
@property
|
||||
def p(self):
|
||||
return self._p
|
||||
|
||||
@property
|
||||
def q(self):
|
||||
"""1-p."""
|
||||
return self._q
|
||||
|
||||
def pmf(self, event, name="pmf"):
|
||||
"""Probability mass function.
|
||||
|
||||
Args:
|
||||
event: `int32` or `int64` binary Tensor; must be broadcastable with `p`.
|
||||
name: A name for this operation.
|
||||
|
||||
Returns:
|
||||
The probabilities of the events.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self.p, self.q, event], name):
|
||||
event = ops.convert_to_tensor(event, name="event")
|
||||
event = math_ops.cast(event, self.p.dtype)
|
||||
return event * self.p + (1. - event) * self.q
|
||||
|
||||
def log_pmf(self, event, name="log_pmf"):
|
||||
"""Log of the probability mass function.
|
||||
|
||||
Args:
|
||||
event: `int32` or `int64` binary Tensor.
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
The log-probabilities of the events.
|
||||
"""
|
||||
return super(Bernoulli, self).log_pmf(event, name)
|
||||
|
||||
def sample(self, n, seed=None, name="sample"):
|
||||
"""Generate `n` samples.
|
||||
|
||||
Args:
|
||||
n: scalar. Number of samples to draw from each distribution.
|
||||
seed: Python integer seed for RNG.
|
||||
name: name to give to the op.
|
||||
|
||||
Returns:
|
||||
samples: a `Tensor` of shape `(n,) + self.batch_shape` with values of type
|
||||
`self.dtype`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self.p, n], name):
|
||||
n = ops.convert_to_tensor(n, name="n")
|
||||
p_2d = array_ops.reshape(self.p, array_ops.pack([-1, 1]))
|
||||
q_2d = 1. - p_2d
|
||||
probs = array_ops.concat(1, [q_2d, p_2d])
|
||||
samples = random_ops.multinomial(math_ops.log(probs), n, seed=seed)
|
||||
ret = array_ops.reshape(
|
||||
array_ops.transpose(samples),
|
||||
array_ops.concat(0,
|
||||
[array_ops.expand_dims(n, 0), self.batch_shape()]))
|
||||
ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
|
||||
.concatenate(self.get_batch_shape()))
|
||||
return math_ops.cast(ret, self.dtype)
|
||||
|
||||
def entropy(self, name="entropy"):
|
||||
"""Entropy of the distribution.
|
||||
|
||||
Args:
|
||||
name: Name for the op.
|
||||
|
||||
Returns:
|
||||
entropy: `Tensor` of the same type and shape as `p`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self.q, self.p], name):
|
||||
e = array_ops.constant(
|
||||
np.finfo(self.p.dtype.as_numpy_dtype).tiny,
|
||||
dtype=self.p.dtype)
|
||||
return (-self.q * math_ops.log(self.q + e) - self.p *
|
||||
math_ops.log(self.p + e))
|
||||
|
||||
def mean(self, name="mean"):
|
||||
"""Mean of the distribution.
|
||||
|
||||
Args:
|
||||
name: Name for the op.
|
||||
|
||||
Returns:
|
||||
mean: `Tensor` of the same type and shape as `p`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self.p], name):
|
||||
return array_ops.identity(self.p)
|
||||
|
||||
def mode(self, name="mode"):
|
||||
"""Mode of the distribution.
|
||||
|
||||
1 if p > 1-p. 0 otherwise.
|
||||
|
||||
Args:
|
||||
name: Name for the op.
|
||||
|
||||
Returns:
|
||||
mode: binary `Tensor` of type self.dtype.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self.p, self.q], name):
|
||||
return math_ops.cast(self.p > self.q, self.dtype)
|
||||
|
||||
def variance(self, name="variance"):
|
||||
"""Variance of the distribution.
|
||||
|
||||
Args:
|
||||
name: Name for the op.
|
||||
|
||||
Returns:
|
||||
variance: `Tensor` of the same type and shape as `p`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([self.p, self.q], name):
|
||||
return self.q * self.p
|
||||
|
||||
def std(self, name="std"):
|
||||
"""Standard deviation of the distribution.
|
||||
|
||||
Args:
|
||||
name: Name for the op.
|
||||
|
||||
Returns:
|
||||
std: `Tensor` of the same type and shape as `p`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.name_scope(name):
|
||||
return math_ops.sqrt(self.variance())
|
Loading…
Reference in New Issue
Block a user