From cb44b3a001766f7ab633108014c38dadc4adab25 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Wed, 22 Jun 2016 12:52:24 -0800 Subject: [PATCH] Added Bernoulli distribution. Change: 125604684 --- tensorflow/contrib/distributions/BUILD | 10 + tensorflow/contrib/distributions/__init__.py | 2 + .../python/kernel_tests/bernoulli_test.py | 163 ++++++++++++ .../distributions/python/ops/bernoulli.py | 231 ++++++++++++++++++ 4 files changed, 406 insertions(+) create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py create mode 100644 tensorflow/contrib/distributions/python/ops/bernoulli.py diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 704b5dd2c4b..5cdf4b92e43 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -107,6 +107,16 @@ cuda_py_tests( ], ) +cuda_py_tests( + name = "bernoulli_test", + size = "small", + srcs = ["python/kernel_tests/bernoulli_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:platform_test", + ], +) + cuda_py_tests( name = "mvn_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 7e19c82c1e7..edc8c78e099 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -27,6 +27,7 @@ initialized with parameters that define the distributions. ### Univariate (scalar) distributions +@@Bernoulli @@Categorical @@Chi2 @@Exponential @@ -62,6 +63,7 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long +from tensorflow.contrib.distributions.python.ops.bernoulli import * from tensorflow.contrib.distributions.python.ops.categorical import * from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py new file mode 100644 index 00000000000..def9224e11e --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py @@ -0,0 +1,163 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the Bernoulli distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +def make_bernoulli(batch_shape, dtype=tf.int32): + p = np.random.uniform(size=list(batch_shape)) + p = tf.constant(p, dtype=tf.float32) + return tf.contrib.distributions.Bernoulli(p, dtype=dtype) + + +def entropy(p): + q = 1. - p + return -q * np.log(q) - p * np.log(p) + + +class BernoulliTest(tf.test.TestCase): + + def testP(self): + p = [0.2, 0.4] + dist = tf.contrib.distributions.Bernoulli(p) + with self.test_session(): + self.assertAllClose(p, dist.p.eval()) + + def testInvalidP(self): + invalid_ps = [1.01, -0.01, 2., -3.] + for p in invalid_ps: + with self.test_session(): + with self.assertRaisesOpError("x <= y"): + dist = tf.contrib.distributions.Bernoulli(p) + dist.p.eval() + + valid_ps = [0.0, 0.5, 1.0] + for p in valid_ps: + with self.test_session(): + dist = tf.contrib.distributions.Bernoulli(p) + self.assertEqual(p, dist.p.eval()) # Should not fail + + def testShapes(self): + with self.test_session(): + for batch_shape in ([], [1], [2, 3, 4]): + dist = make_bernoulli(batch_shape) + self.assertAllEqual(batch_shape, dist.get_batch_shape().as_list()) + self.assertAllEqual(batch_shape, dist.batch_shape().eval()) + self.assertAllEqual([], dist.get_event_shape().as_list()) + self.assertAllEqual([], dist.event_shape().eval()) + + def testDtype(self): + dist = make_bernoulli([]) + self.assertEqual(dist.dtype, tf.int32) + self.assertEqual(dist.dtype, dist.sample(5).dtype) + self.assertEqual(dist.dtype, dist.mode().dtype) + self.assertEqual(dist.p.dtype, dist.mean().dtype) + self.assertEqual(dist.p.dtype, dist.variance().dtype) + self.assertEqual(dist.p.dtype, dist.std().dtype) + self.assertEqual(dist.p.dtype, dist.entropy().dtype) + self.assertEqual(dist.p.dtype, dist.pmf(0).dtype) + self.assertEqual(dist.p.dtype, dist.log_pmf(0).dtype) + + dist64 = make_bernoulli([], tf.int64) + self.assertEqual(dist64.dtype, tf.int64) + self.assertEqual(dist64.dtype, dist64.sample(5).dtype) + self.assertEqual(dist64.dtype, dist64.mode().dtype) + + def testPmf(self): + p = [[0.2, 0.4], [0.3, 0.6]] + dist = tf.contrib.distributions.Bernoulli(p) + with self.test_session(): + # pylint: disable=bad-continuation + xs = [ + 0, + [1], + [1, 0], + [[1, 0]], + [[1, 0], [1, 1]], + ] + expected_pmfs = [ + [[0.8, 0.6], [0.7, 0.4]], + [[0.2, 0.4], [0.3, 0.6]], + [[0.2, 0.6], [0.3, 0.4]], + [[0.2, 0.6], [0.3, 0.4]], + [[0.2, 0.6], [0.3, 0.6]], + ] + # pylint: enable=bad-continuation + + for x, expected_pmf in zip(xs, expected_pmfs): + self.assertAllClose(dist.pmf(x).eval(), expected_pmf) + self.assertAllClose(dist.log_pmf(x).eval(), np.log(expected_pmf)) + + def testBoundaryConditions(self): + with self.test_session(): + dist = tf.contrib.distributions.Bernoulli(1.0) + self.assertEqual(-np.inf, dist.log_pmf(0).eval()) + self.assertAllClose([0.0], [dist.log_pmf(1).eval()]) + + def testEntropyNoBatch(self): + p = 0.2 + dist = tf.contrib.distributions.Bernoulli(p) + with self.test_session(): + self.assertAllClose(dist.entropy().eval(), entropy(p)) + + def testEntropyWithBatch(self): + p = [[0.0, 0.7], [1.0, 0.6]] + dist = tf.contrib.distributions.Bernoulli(p, strict=False) + with self.test_session(): + self.assertAllClose(dist.entropy().eval(), [[0.0, entropy(0.7)], + [0.0, entropy(0.6)]]) + + def testSample(self): + with self.test_session(): + p = [0.2, 0.6] + dist = tf.contrib.distributions.Bernoulli(p) + n = 1000 + samples = dist.sample(n, seed=123) + samples.set_shape([n, 2]) + self.assertEqual(samples.dtype, tf.int32) + sample_values = samples.eval() + self.assertFalse(np.any(sample_values < 0)) + self.assertFalse(np.any(sample_values > 1)) + self.assertAllClose(p, np.mean(sample_values == 1, axis=0), atol=1e-2) + self.assertEqual(set([0, 1]), set(sample_values.flatten())) + + def testMean(self): + with self.test_session(): + p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) + dist = tf.contrib.distributions.Bernoulli(p) + self.assertAllEqual(dist.mean().eval(), p) + + def testVarianceAndStd(self): + var = lambda p: p * (1. - p) + with self.test_session(): + p = [[0.2, 0.7], [0.5, 0.4]] + dist = tf.contrib.distributions.Bernoulli(p) + self.assertAllClose(dist.variance().eval(), + np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]], + dtype=np.float32)) + self.assertAllClose(dist.std().eval(), + np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))], + [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], + dtype=np.float32)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bernoulli.py b/tensorflow/contrib/distributions/python/ops/bernoulli.py new file mode 100644 index 00000000000..79a0852d964 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bernoulli.py @@ -0,0 +1,231 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Bernoulli distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + + +class Bernoulli(distribution.DiscreteDistribution): + """Bernoulli distribution. + + The Bernoulli distribution is parameterized by p, the probability of a + positive event. + + Note, the following methods of the base class aren't implemented: + * cdf + * log_cdf + """ + + def __init__(self, p, dtype=dtypes.int32, strict=True, name="Bernoulli"): + """Construct Bernoulli distributions. + + Args: + p: An N-D `Tensor` representing the probability of a positive + event. Each entry in the `Tensor` parameterizes an independent + Bernoulli distribution. + dtype: dtype for samples. Note that other values will take the dtype of p. + strict: Whether to assert that `0 <= p <= 1`. If not strict, `log_pmf` may + return nans. + name: A name for this distribution. + """ + self._name = name + self._dtype = dtype + self._strict = strict + check_op = check_ops.assert_less_equal + with ops.op_scope([p], name): + with ops.control_dependencies( + [check_op(p, 1.), check_op(0., p)] if strict else []): + p = array_ops.identity(p, name="p") + self._p = p + self._q = array_ops.identity(1. - p, name="q") + self._batch_shape = array_ops.shape(self._p) + self._event_shape = array_ops.constant([], dtype=dtypes.int32) + + @property + def name(self): + return self._name + + @property + def dtype(self): + return self._dtype + + @property + def is_reparameterized(self): + return False + + def batch_shape(self, name="batch_shape"): + with ops.name_scope(self.name): + with ops.op_scope([self._batch_shape], name): + return array_ops.identity(self._batch_shape) + + def get_batch_shape(self): + return self.p.get_shape() + + def event_shape(self, name="event_shape"): + with ops.name_scope(self.name): + with ops.op_scope([self._batch_shape], name): + return array_ops.constant([], dtype=self._batch_shape.dtype) + + def get_event_shape(self): + return tensor_shape.scalar() + + @property + def p(self): + return self._p + + @property + def q(self): + """1-p.""" + return self._q + + def pmf(self, event, name="pmf"): + """Probability mass function. + + Args: + event: `int32` or `int64` binary Tensor; must be broadcastable with `p`. + name: A name for this operation. + + Returns: + The probabilities of the events. + """ + with ops.name_scope(self.name): + with ops.op_scope([self.p, self.q, event], name): + event = ops.convert_to_tensor(event, name="event") + event = math_ops.cast(event, self.p.dtype) + return event * self.p + (1. - event) * self.q + + def log_pmf(self, event, name="log_pmf"): + """Log of the probability mass function. + + Args: + event: `int32` or `int64` binary Tensor. + name: A name for this operation (optional). + + Returns: + The log-probabilities of the events. + """ + return super(Bernoulli, self).log_pmf(event, name) + + def sample(self, n, seed=None, name="sample"): + """Generate `n` samples. + + Args: + n: scalar. Number of samples to draw from each distribution. + seed: Python integer seed for RNG. + name: name to give to the op. + + Returns: + samples: a `Tensor` of shape `(n,) + self.batch_shape` with values of type + `self.dtype`. + """ + with ops.name_scope(self.name): + with ops.op_scope([self.p, n], name): + n = ops.convert_to_tensor(n, name="n") + p_2d = array_ops.reshape(self.p, array_ops.pack([-1, 1])) + q_2d = 1. - p_2d + probs = array_ops.concat(1, [q_2d, p_2d]) + samples = random_ops.multinomial(math_ops.log(probs), n, seed=seed) + ret = array_ops.reshape( + array_ops.transpose(samples), + array_ops.concat(0, + [array_ops.expand_dims(n, 0), self.batch_shape()])) + ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n)) + .concatenate(self.get_batch_shape())) + return math_ops.cast(ret, self.dtype) + + def entropy(self, name="entropy"): + """Entropy of the distribution. + + Args: + name: Name for the op. + + Returns: + entropy: `Tensor` of the same type and shape as `p`. + """ + with ops.name_scope(self.name): + with ops.op_scope([self.q, self.p], name): + e = array_ops.constant( + np.finfo(self.p.dtype.as_numpy_dtype).tiny, + dtype=self.p.dtype) + return (-self.q * math_ops.log(self.q + e) - self.p * + math_ops.log(self.p + e)) + + def mean(self, name="mean"): + """Mean of the distribution. + + Args: + name: Name for the op. + + Returns: + mean: `Tensor` of the same type and shape as `p`. + """ + with ops.name_scope(self.name): + with ops.op_scope([self.p], name): + return array_ops.identity(self.p) + + def mode(self, name="mode"): + """Mode of the distribution. + + 1 if p > 1-p. 0 otherwise. + + Args: + name: Name for the op. + + Returns: + mode: binary `Tensor` of type self.dtype. + """ + with ops.name_scope(self.name): + with ops.op_scope([self.p, self.q], name): + return math_ops.cast(self.p > self.q, self.dtype) + + def variance(self, name="variance"): + """Variance of the distribution. + + Args: + name: Name for the op. + + Returns: + variance: `Tensor` of the same type and shape as `p`. + """ + with ops.name_scope(self.name): + with ops.op_scope([self.p, self.q], name): + return self.q * self.p + + def std(self, name="std"): + """Standard deviation of the distribution. + + Args: + name: Name for the op. + + Returns: + std: `Tensor` of the same type and shape as `p`. + """ + with ops.name_scope(self.name): + with ops.name_scope(name): + return math_ops.sqrt(self.variance())