From a0d14f00b9fd5b55f80ed5a658b37f4a5c80022c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Apr 2016 09:47:02 -0800 Subject: [PATCH] Adding DirichletMultinomial class to contrib/distributions/ Class represents multi-indexed batches of Dirichlet Multinomial distributions. Initialized with parameters alpha, which broadcast to arbitrary shapes to match arguments in e.g. dist.pdf(x). Change: 120138028 --- tensorflow/contrib/distributions/BUILD | 12 +- tensorflow/contrib/distributions/__init__.py | 2 +- .../dirichlet_multinomial_test.py | 192 +++++++++++++ .../python/ops/dirichlet_multinomial.py | 261 ++++++++++++++++++ 4 files changed, 465 insertions(+), 2 deletions(-) create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py create mode 100644 tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 1d7d01fdae2..5feac79ecb0 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -1,5 +1,5 @@ # Description: -# Contains ops to train linear models on top of TensorFlow. +# Contains ops for statistical distributions (with pdf, cdf, sample, etc...). # APIs here are meant to evolve over time. licenses(["notice"]) # Apache 2.0 @@ -16,6 +16,16 @@ py_library( srcs_version = "PY2AND3", ) +cuda_py_tests( + name = "dirichlet_multinomial_test", + srcs = ["python/kernel_tests/dirichlet_multinomial_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_tests( name = "gaussian_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 46aae254a7a..2f9b8fcafb1 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -23,6 +23,6 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import, line-too-long from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors +from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * from tensorflow.contrib.distributions.python.ops.gaussian import * # from tensorflow.contrib.distributions.python.ops.dirichlet import * # pylint: disable=line-too-long -# from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * # pylint: disable=line-too-long diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py new file mode 100644 index 00000000000..c83beddab4d --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py @@ -0,0 +1,192 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +class DirichletMultinomialTest(tf.test.TestCase): + + def test_num_classes(self): + with self.test_session(): + for num_classes in range(3): + alpha = np.random.rand(3, num_classes) + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + self.assertEqual([], dist.num_classes.get_shape()) + self.assertEqual(num_classes, dist.num_classes.eval()) + + def test_alpha_property(self): + alpha = np.array([[1., 2, 3]]) + with self.test_session(): + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + self.assertEqual([1, 3], dist.alpha.get_shape()) + self.assertAllClose(alpha, dist.alpha.eval()) + + def test_empty_alpha_and_empty_counts_returns_empty(self): + with self.test_session(): + alpha = [[]] + counts = [[]] + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + self.assertAllEqual([], dist.pmf(counts).eval()) + self.assertAllEqual([0], dist.pmf(counts).get_shape()) + self.assertAllEqual([], dist.log_pmf(counts).eval()) + self.assertAllEqual([0], dist.log_pmf(counts).get_shape()) + self.assertAllEqual([[]], dist.mean.eval()) + self.assertAllEqual([1, 0], dist.mean.get_shape()) + self.assertAllEqual(0, dist.num_classes.eval()) + self.assertAllEqual([], dist.num_classes.get_shape()) + + def test_pmf_both_zero_batches(self): + # The probabilities of one vote falling into class k is the mean for class + # k. + with self.test_session(): + # Both zero-batches. No broadcast + alpha = [1., 2] + counts = [1, 0.] + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + pmf = dist.pmf(counts) + self.assertAllClose(1 / 3., pmf.eval()) + self.assertEqual((), pmf.get_shape()) + + def test_pmf_alpha_stretched_in_broadcast_when_same_rank(self): + # The probabilities of one vote falling into class k is the mean for class + # k. + with self.test_session(): + alpha = [[1., 2]] + counts = [[1, 0.], [0, 1.]] + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + pmf = dist.pmf(counts) + self.assertAllClose([1 / 3., 2 / 3.], pmf.eval()) + self.assertEqual((2), pmf.get_shape()) + + def test_pmf_alpha_stretched_in_broadcast_when_lower_rank(self): + # The probabilities of one vote falling into class k is the mean for class + # k. + with self.test_session(): + alpha = [1., 2] + counts = [[1, 0.], [0, 1.]] + pmf = tf.contrib.distributions.DirichletMultinomial(alpha).pmf(counts) + self.assertAllClose([1 / 3., 2 / 3.], pmf.eval()) + self.assertEqual((2), pmf.get_shape()) + + def test_pmf_counts_stretched_in_broadcast_when_same_rank(self): + # The probabilities of one vote falling into class k is the mean for class + # k. + with self.test_session(): + alpha = [[1., 2], [2., 3]] + counts = [[1, 0.]] + pmf = tf.contrib.distributions.DirichletMultinomial(alpha).pmf(counts) + self.assertAllClose([1 / 3., 2 / 5.], pmf.eval()) + self.assertEqual((2), pmf.get_shape()) + + def test_pmf_counts_stretched_in_broadcast_when_lower_rank(self): + # The probabilities of one vote falling into class k is the mean for class + # k. + with self.test_session(): + alpha = [[1., 2], [2., 3]] + counts = [1, 0.] + pmf = tf.contrib.distributions.DirichletMultinomial(alpha).pmf(counts) + self.assertAllClose([1 / 3., 2 / 5.], pmf.eval()) + self.assertEqual((2), pmf.get_shape()) + + def test_pmf_for_one_vote_is_the_mean_with_one_record_input(self): + # The probabilities of one vote falling into class k is the mean for class + # k. + alpha = [1., 2, 3] + with self.test_session(): + for class_num in range(3): + counts = np.zeros((3), dtype=np.float32) + counts[class_num] = 1.0 + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + mean = dist.mean.eval() + pmf = dist.pmf(counts).eval() + + self.assertAllClose(mean[class_num], pmf) + self.assertTupleEqual((3,), mean.shape) + self.assertTupleEqual((), pmf.shape) + + def test_zero_counts_results_in_pmf_equal_to_one(self): + # There is only one way for zero items to be selected, and this happens with + # probability 1. + alpha = [5, 0.5] + counts = [0., 0.] + with self.test_session(): + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + pmf = dist.pmf(counts) + self.assertAllClose(1.0, pmf.eval()) + self.assertEqual((), pmf.get_shape()) + + def test_large_tau_gives_precise_probabilities(self): + # If tau is large, we are doing coin flips with probability mu. + mu = np.array([0.1, 0.1, 0.8], dtype=np.float32) + tau = np.array([100.], dtype=np.float32) + alpha = tau * mu + + # One (three sided) coin flip. Prob[coin 3] = 0.8. + # Note that since it was one flip, value of tau didn't matter. + counts = [0., 0, 1] + with self.test_session(): + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + pmf = dist.pmf(counts) + self.assertAllClose(0.8, pmf.eval(), atol=1e-4) + self.assertEqual((), pmf.get_shape()) + + # Two (three sided) coin flips. Prob[coin 3] = 0.8. + counts = [0., 0, 2] + with self.test_session(): + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + pmf = dist.pmf(counts) + self.assertAllClose(0.8**2, pmf.eval(), atol=1e-2) + self.assertEqual((), pmf.get_shape()) + + # Three (three sided) coin flips. + counts = [1., 0, 2] + with self.test_session(): + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + pmf = dist.pmf(counts) + self.assertAllClose(3 * 0.1 * 0.8 * 0.8, pmf.eval(), atol=1e-2) + self.assertEqual((), pmf.get_shape()) + + def test_small_tau_prefers_correlated_results(self): + # If tau is small, then correlation between draws is large, so draws that + # are both of the same class are more likely. + mu = np.array([0.5, 0.5], dtype=np.float32) + tau = np.array([0.1], dtype=np.float32) + alpha = tau * mu + + # If there is only one draw, it is still a coin flip, even with small tau. + counts = [1, 0.] + with self.test_session(): + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + pmf = dist.pmf(counts) + self.assertAllClose(0.5, pmf.eval()) + self.assertEqual((), pmf.get_shape()) + + # If there are two draws, it is much more likely that they are the same. + counts_same = [2, 0.] + counts_different = [1, 1.] + with self.test_session(): + dist = tf.contrib.distributions.DirichletMultinomial(alpha) + pmf_same = dist.pmf(counts_same) + pmf_different = dist.pmf(counts_different) + self.assertLess(5 * pmf_different.eval(), pmf_same.eval()) + self.assertEqual((), pmf_same.get_shape()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py new file mode 100644 index 00000000000..358af118255 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py @@ -0,0 +1,261 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Dirichlet Multinomial distribution class. + +@@DirichletMultinomial +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import special_math_ops + + +def _check_alpha(alpha): + """Check alpha for proper shape, values, then return tensor version.""" + alpha = ops.convert_to_tensor(alpha, name='alpha_before_deps') + return control_flow_ops.with_dependencies( + [check_ops.assert_rank_at_least(alpha, 1), + check_ops.assert_positive(alpha)], alpha) + + +def _log_combinations(counts, name='log_combinations'): + """Log number of ways counts could have come in.""" + # Firt a bit about the number of ways counts could have come in: + # E.g. if counts = [1, 2], then this is 3 choose 2. + # In general, this is (sum counts)! / sum(counts!) + # The sum should be along the last dimension of counts. This is the + # "distribution" dimension. + with ops.op_scope([counts], name): + last_dim = array_ops.rank(counts) - 1 + # To compute factorials, use the fact that Gamma(n + 1) = n! + # Compute two terms, each a sum over counts. Compute each for each + # batch member. + # Log Gamma((sum counts) + 1) = Log((sum counts)!) + sum_of_counts = math_ops.reduce_sum(counts, reduction_indices=last_dim) + total_permutations = math_ops.lgamma(sum_of_counts + 1) + # sum(Log Gamma(counts + 1)) = Log sum(counts!) + counts_factorial = math_ops.lgamma(counts + 1) + redundant_permutations = math_ops.reduce_sum(counts_factorial, + reduction_indices=last_dim) + return total_permutations - redundant_permutations + + +class DirichletMultinomial(object): + """DirichletMultinomial mixture distribution. + + The Dirichlet Multinomial is a distribution over k-class count data, meaning + for each k-tuple of non-negative integer `counts = [c_1,...,c_k]`, we have a + probability of these draws being made from the distribution. The distribution + has hyperparameters `alpha = (alpha_1,...,alpha_k)`, and probability mass + function (pmf): + + ```pmf(counts) = C! / (c_1!...c_k!) * Beta(alpha + c) / Beta(alpha)``` + + where above `C = sum_j c_j`, `N!` is `N` factorial, and + `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the multivariate beta + function. + + This is a mixture distribution in that `N` samples can be produced by: + 1. Choose class probabilities `p = (p_1,...,p_k) ~ Dir(alpha)` + 2. Draw integers `m = (m_1,...,m_k) ~ Multinomial(p, N)` + + This class provides methods to create indexed batches of Dirichlet + Multinomial distributions. If the provided `alpha` is rank 2 or higher, for + every fixed set of leading dimensions, the last dimension represents one + single Dirichlet Multinomial distribution. When calling distribution + functions (e.g. `dist.pdf(counts)`), `alpha` and `counts` are broadcast to the + same shape (if possible). In all cases, the last dimension of alpha/counts + represents single Dirichlet Multinomial distributions. + + Examples: + + ```python + alpha = [1, 2, 3] + dist = DirichletMultinomial(alpha) + ``` + + Creates a 3-class distribution, with the 3rd class is most likely to be drawn. + The distribution functions can be evaluated on counts. + + ```python + # counts same shape as alpha. + counts = [0, 2, 0] + dist.pdf(counts) # Shape [] + + # alpha will be broadcast to [[1, 2, 3], [1, 2, 3]] to match counts. + counts = [[11, 22, 33], [44, 55, 66]] + dist.pdf(counts) # Shape [2] + + # alpha will be broadcast to shape [5, 7, 3] to match counts. + counts = [[...]] # Shape [5, 7, 3] + dist.pdf(counts) # Shape [5, 7] + ``` + + Creates a 2-batch of 3-class distributions. + + ```python + alpha = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3] + dist = DirichletMultinomial(alpha) + + # counts will be broadcast to [[11, 22, 33], [11, 22, 33]] to match alpha. + counts = [11, 22, 33] + dist.pdf(counts) # Shape [2] + ``` + """ + + # TODO(b/27419586) Change docstring for dtype of alpha once int allowed. + def __init__(self, alpha): + """Initialize a batch of DirichletMultinomial distributions. + + Args: + alpha: Shape `[N1,..., Nn, k]` positive `float` or `double` tensor with + `n >= 0`. Defines this as a batch of `N1 x ... x Nn` different `k` + class Dirichlet multinomial distributions. + + Examples: + + ```python + # Define 1-batch of 2-class Dirichlet multinomial distribution, + # also known as a beta-binomial. + dist = DirichletMultinomial([1.1, 2.0]) + + # Define a 2-batch of 3-class distributions. + dist = DirichletMultinomial([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + ``` + """ + # Broadcasting works because: + # * The broadcasting convention is to prepend dimensions of size [1], and + # we use the last dimension for the distribution, wherease + # the batch dimensions are the leading dimensions, which forces the + # distribution dimension to be defined explicitly (i.e. it cannot be + # created automatically by prepending). This forces enough explicitivity. + # * All calls involving `counts` eventually require a broadcast between + # `counts` and alpha. + self._alpha = _check_alpha(alpha) + + self._num_classes = self._get_num_classes() + self._dist_indices = self._get_dist_indices() + + @property + def alpha(self): + """Parameters defining this distribution.""" + return self._alpha + + @property + def dtype(self): + return self._alpha.dtype + + @property + def mean(self): + """Class means for every batch member.""" + with ops.name_scope('mean'): + alpha_sum = math_ops.reduce_sum(self._alpha, + reduction_indices=self._dist_indices, + keep_dims=True) + mean = math_ops.truediv(self._alpha, alpha_sum) + mean.set_shape(self._alpha.get_shape()) + return mean + + def _get_dist_indices(self): + """Dimensions corresponding to individual distributions.""" + # Reshape the scalar to a rank 1 tensor. + return array_ops.reshape(array_ops.rank(self._alpha) - 1, [-1]) + + def _get_num_classes(self): + return ops.convert_to_tensor( + array_ops.reverse( + array_ops.shape(self._alpha), [True])[0], + name='num_classes') + + @property + def num_classes(self): + """Tensor providing number of classes in each batch member.""" + return self._num_classes + + def cdf(self, x): + raise NotImplementedError( + 'DirichletMultinomial does not have a well-defined cdf.') + + def log_cdf(self, x): + raise NotImplementedError( + 'DirichletMultinomial does not have a well-defined cdf.') + + def log_pmf(self, counts, name=None): + """`Log(P[counts])`, computed for every batch member. + + For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability + that after sampling `sum_j c_j` draws from this Dirichlet Multinomial + distribution, the number of draws falling in class `j` is `c_j`. Note that + different sequences of draws can result in the same counts, thus the + probability includes a combinatorial coefficient. + + Args: + counts: Non-negative `float`, `double`, or `int` tensor whose shape can + be broadcast with `self.alpha`. For fixed leading dimensions, the last + dimension represents counts for the corresponding Dirichlet Multinomial + distribution in `self.alpha`. + name: Name to give this Op, defaults to "log_pmf". + + Returns: + Log probabilities for each record, shape `[N1,...,Nn]`. + """ + alpha = self._alpha + with ops.op_scope([alpha, counts], name, 'log_pmf'): + counts = self._check_counts(counts) + ordered_pmf = (special_math_ops.lbeta(alpha + counts) - + special_math_ops.lbeta(alpha)) + log_pmf = ordered_pmf + _log_combinations(counts) + # If alpha = counts = [[]], ordered_pmf carries the right shape, which is + # []. However, since reduce_sum([[]]) = [0], log_combinations = [0], + # which is not correct. Luckily, [] + [0] = [], so the sum is fine, but + # shape must be inferred from ordered_pmf. + # Note also that tf.constant([]).get_shape() = TensorShape([Dimension(0)]) + log_pmf.set_shape(ordered_pmf.get_shape()) + return log_pmf + + def pmf(self, counts, name=None): + """`P[counts]`, computed for every batch member. + + For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability + that after sampling `sum_j c_j` draws from this Dirichlet Multinomial + distribution, the number of draws falling in class `j` is `c_j`. Note that + different sequences of draws can result in the same counts, thus the + probability includes a combinatorial coefficient. + + Args: + counts: Non-negative `float`, `double`, or `int` tensor whose shape can + be broadcast with `self.alpha`. For fixed leading dimensions, the last + dimension represents counts for the corresponding Dirichlet Multinomial + distribution in `self.alpha`. + name: Name to give this Op, defaults to "pmf". + + Returns: + Probabilities for each record, shape `[N1,...,Nn]`. + """ + with ops.name_scope('pmf' if name is None else name): + return math_ops.exp(self.log_pmf(counts)) + + def _check_counts(self, counts): + """Check counts for proper shape, values, then return tensor version.""" + counts = ops.convert_to_tensor(counts, name='counts_before_deps') + counts = math_ops.cast(counts, self.dtype) + return control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(counts)], counts)