Added Bernoulli distribution.

Change: 125604684
This commit is contained in:
Eugene Brevdo 2016-06-22 12:52:24 -08:00 committed by TensorFlower Gardener
parent e231a8b382
commit cb44b3a001
4 changed files with 406 additions and 0 deletions

View File

@ -107,6 +107,16 @@ cuda_py_tests(
],
)
cuda_py_tests(
name = "bernoulli_test",
size = "small",
srcs = ["python/kernel_tests/bernoulli_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "mvn_test",
size = "small",

View File

@ -27,6 +27,7 @@ initialized with parameters that define the distributions.
### Univariate (scalar) distributions
@@Bernoulli
@@Categorical
@@Chi2
@@Exponential
@ -62,6 +63,7 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.distributions.python.ops.bernoulli import *
from tensorflow.contrib.distributions.python.ops.categorical import *
from tensorflow.contrib.distributions.python.ops.chi2 import *
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *

View File

@ -0,0 +1,163 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the Bernoulli distribution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def make_bernoulli(batch_shape, dtype=tf.int32):
p = np.random.uniform(size=list(batch_shape))
p = tf.constant(p, dtype=tf.float32)
return tf.contrib.distributions.Bernoulli(p, dtype=dtype)
def entropy(p):
q = 1. - p
return -q * np.log(q) - p * np.log(p)
class BernoulliTest(tf.test.TestCase):
def testP(self):
p = [0.2, 0.4]
dist = tf.contrib.distributions.Bernoulli(p)
with self.test_session():
self.assertAllClose(p, dist.p.eval())
def testInvalidP(self):
invalid_ps = [1.01, -0.01, 2., -3.]
for p in invalid_ps:
with self.test_session():
with self.assertRaisesOpError("x <= y"):
dist = tf.contrib.distributions.Bernoulli(p)
dist.p.eval()
valid_ps = [0.0, 0.5, 1.0]
for p in valid_ps:
with self.test_session():
dist = tf.contrib.distributions.Bernoulli(p)
self.assertEqual(p, dist.p.eval()) # Should not fail
def testShapes(self):
with self.test_session():
for batch_shape in ([], [1], [2, 3, 4]):
dist = make_bernoulli(batch_shape)
self.assertAllEqual(batch_shape, dist.get_batch_shape().as_list())
self.assertAllEqual(batch_shape, dist.batch_shape().eval())
self.assertAllEqual([], dist.get_event_shape().as_list())
self.assertAllEqual([], dist.event_shape().eval())
def testDtype(self):
dist = make_bernoulli([])
self.assertEqual(dist.dtype, tf.int32)
self.assertEqual(dist.dtype, dist.sample(5).dtype)
self.assertEqual(dist.dtype, dist.mode().dtype)
self.assertEqual(dist.p.dtype, dist.mean().dtype)
self.assertEqual(dist.p.dtype, dist.variance().dtype)
self.assertEqual(dist.p.dtype, dist.std().dtype)
self.assertEqual(dist.p.dtype, dist.entropy().dtype)
self.assertEqual(dist.p.dtype, dist.pmf(0).dtype)
self.assertEqual(dist.p.dtype, dist.log_pmf(0).dtype)
dist64 = make_bernoulli([], tf.int64)
self.assertEqual(dist64.dtype, tf.int64)
self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
self.assertEqual(dist64.dtype, dist64.mode().dtype)
def testPmf(self):
p = [[0.2, 0.4], [0.3, 0.6]]
dist = tf.contrib.distributions.Bernoulli(p)
with self.test_session():
# pylint: disable=bad-continuation
xs = [
0,
[1],
[1, 0],
[[1, 0]],
[[1, 0], [1, 1]],
]
expected_pmfs = [
[[0.8, 0.6], [0.7, 0.4]],
[[0.2, 0.4], [0.3, 0.6]],
[[0.2, 0.6], [0.3, 0.4]],
[[0.2, 0.6], [0.3, 0.4]],
[[0.2, 0.6], [0.3, 0.6]],
]
# pylint: enable=bad-continuation
for x, expected_pmf in zip(xs, expected_pmfs):
self.assertAllClose(dist.pmf(x).eval(), expected_pmf)
self.assertAllClose(dist.log_pmf(x).eval(), np.log(expected_pmf))
def testBoundaryConditions(self):
with self.test_session():
dist = tf.contrib.distributions.Bernoulli(1.0)
self.assertEqual(-np.inf, dist.log_pmf(0).eval())
self.assertAllClose([0.0], [dist.log_pmf(1).eval()])
def testEntropyNoBatch(self):
p = 0.2
dist = tf.contrib.distributions.Bernoulli(p)
with self.test_session():
self.assertAllClose(dist.entropy().eval(), entropy(p))
def testEntropyWithBatch(self):
p = [[0.0, 0.7], [1.0, 0.6]]
dist = tf.contrib.distributions.Bernoulli(p, strict=False)
with self.test_session():
self.assertAllClose(dist.entropy().eval(), [[0.0, entropy(0.7)],
[0.0, entropy(0.6)]])
def testSample(self):
with self.test_session():
p = [0.2, 0.6]
dist = tf.contrib.distributions.Bernoulli(p)
n = 1000
samples = dist.sample(n, seed=123)
samples.set_shape([n, 2])
self.assertEqual(samples.dtype, tf.int32)
sample_values = samples.eval()
self.assertFalse(np.any(sample_values < 0))
self.assertFalse(np.any(sample_values > 1))
self.assertAllClose(p, np.mean(sample_values == 1, axis=0), atol=1e-2)
self.assertEqual(set([0, 1]), set(sample_values.flatten()))
def testMean(self):
with self.test_session():
p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
dist = tf.contrib.distributions.Bernoulli(p)
self.assertAllEqual(dist.mean().eval(), p)
def testVarianceAndStd(self):
var = lambda p: p * (1. - p)
with self.test_session():
p = [[0.2, 0.7], [0.5, 0.4]]
dist = tf.contrib.distributions.Bernoulli(p)
self.assertAllClose(dist.variance().eval(),
np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
dtype=np.float32))
self.assertAllClose(dist.std().eval(),
np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
[np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
dtype=np.float32))
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,231 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Bernoulli distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
class Bernoulli(distribution.DiscreteDistribution):
"""Bernoulli distribution.
The Bernoulli distribution is parameterized by p, the probability of a
positive event.
Note, the following methods of the base class aren't implemented:
* cdf
* log_cdf
"""
def __init__(self, p, dtype=dtypes.int32, strict=True, name="Bernoulli"):
"""Construct Bernoulli distributions.
Args:
p: An N-D `Tensor` representing the probability of a positive
event. Each entry in the `Tensor` parameterizes an independent
Bernoulli distribution.
dtype: dtype for samples. Note that other values will take the dtype of p.
strict: Whether to assert that `0 <= p <= 1`. If not strict, `log_pmf` may
return nans.
name: A name for this distribution.
"""
self._name = name
self._dtype = dtype
self._strict = strict
check_op = check_ops.assert_less_equal
with ops.op_scope([p], name):
with ops.control_dependencies(
[check_op(p, 1.), check_op(0., p)] if strict else []):
p = array_ops.identity(p, name="p")
self._p = p
self._q = array_ops.identity(1. - p, name="q")
self._batch_shape = array_ops.shape(self._p)
self._event_shape = array_ops.constant([], dtype=dtypes.int32)
@property
def name(self):
return self._name
@property
def dtype(self):
return self._dtype
@property
def is_reparameterized(self):
return False
def batch_shape(self, name="batch_shape"):
with ops.name_scope(self.name):
with ops.op_scope([self._batch_shape], name):
return array_ops.identity(self._batch_shape)
def get_batch_shape(self):
return self.p.get_shape()
def event_shape(self, name="event_shape"):
with ops.name_scope(self.name):
with ops.op_scope([self._batch_shape], name):
return array_ops.constant([], dtype=self._batch_shape.dtype)
def get_event_shape(self):
return tensor_shape.scalar()
@property
def p(self):
return self._p
@property
def q(self):
"""1-p."""
return self._q
def pmf(self, event, name="pmf"):
"""Probability mass function.
Args:
event: `int32` or `int64` binary Tensor; must be broadcastable with `p`.
name: A name for this operation.
Returns:
The probabilities of the events.
"""
with ops.name_scope(self.name):
with ops.op_scope([self.p, self.q, event], name):
event = ops.convert_to_tensor(event, name="event")
event = math_ops.cast(event, self.p.dtype)
return event * self.p + (1. - event) * self.q
def log_pmf(self, event, name="log_pmf"):
"""Log of the probability mass function.
Args:
event: `int32` or `int64` binary Tensor.
name: A name for this operation (optional).
Returns:
The log-probabilities of the events.
"""
return super(Bernoulli, self).log_pmf(event, name)
def sample(self, n, seed=None, name="sample"):
"""Generate `n` samples.
Args:
n: scalar. Number of samples to draw from each distribution.
seed: Python integer seed for RNG.
name: name to give to the op.
Returns:
samples: a `Tensor` of shape `(n,) + self.batch_shape` with values of type
`self.dtype`.
"""
with ops.name_scope(self.name):
with ops.op_scope([self.p, n], name):
n = ops.convert_to_tensor(n, name="n")
p_2d = array_ops.reshape(self.p, array_ops.pack([-1, 1]))
q_2d = 1. - p_2d
probs = array_ops.concat(1, [q_2d, p_2d])
samples = random_ops.multinomial(math_ops.log(probs), n, seed=seed)
ret = array_ops.reshape(
array_ops.transpose(samples),
array_ops.concat(0,
[array_ops.expand_dims(n, 0), self.batch_shape()]))
ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
.concatenate(self.get_batch_shape()))
return math_ops.cast(ret, self.dtype)
def entropy(self, name="entropy"):
"""Entropy of the distribution.
Args:
name: Name for the op.
Returns:
entropy: `Tensor` of the same type and shape as `p`.
"""
with ops.name_scope(self.name):
with ops.op_scope([self.q, self.p], name):
e = array_ops.constant(
np.finfo(self.p.dtype.as_numpy_dtype).tiny,
dtype=self.p.dtype)
return (-self.q * math_ops.log(self.q + e) - self.p *
math_ops.log(self.p + e))
def mean(self, name="mean"):
"""Mean of the distribution.
Args:
name: Name for the op.
Returns:
mean: `Tensor` of the same type and shape as `p`.
"""
with ops.name_scope(self.name):
with ops.op_scope([self.p], name):
return array_ops.identity(self.p)
def mode(self, name="mode"):
"""Mode of the distribution.
1 if p > 1-p. 0 otherwise.
Args:
name: Name for the op.
Returns:
mode: binary `Tensor` of type self.dtype.
"""
with ops.name_scope(self.name):
with ops.op_scope([self.p, self.q], name):
return math_ops.cast(self.p > self.q, self.dtype)
def variance(self, name="variance"):
"""Variance of the distribution.
Args:
name: Name for the op.
Returns:
variance: `Tensor` of the same type and shape as `p`.
"""
with ops.name_scope(self.name):
with ops.op_scope([self.p, self.q], name):
return self.q * self.p
def std(self, name="std"):
"""Standard deviation of the distribution.
Args:
name: Name for the op.
Returns:
std: `Tensor` of the same type and shape as `p`.
"""
with ops.name_scope(self.name):
with ops.name_scope(name):
return math_ops.sqrt(self.variance())