Adding DirichletMultinomial class to contrib/distributions/

Class represents multi-indexed batches of Dirichlet Multinomial distributions.  Initialized with parameters alpha, which broadcast to arbitrary shapes to match arguments in e.g. dist.pdf(x).
Change: 120138028
This commit is contained in:
A. Unique TensorFlower 2016-04-18 09:47:02 -08:00 committed by TensorFlower Gardener
parent 3402f51ecd
commit a0d14f00b9
4 changed files with 465 additions and 2 deletions

View File

@ -1,5 +1,5 @@
# Description:
# Contains ops to train linear models on top of TensorFlow.
# Contains ops for statistical distributions (with pdf, cdf, sample, etc...).
# APIs here are meant to evolve over time.
licenses(["notice"]) # Apache 2.0
@ -16,6 +16,16 @@ py_library(
srcs_version = "PY2AND3",
)
cuda_py_tests(
name = "dirichlet_multinomial_test",
srcs = ["python/kernel_tests/dirichlet_multinomial_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "gaussian_test",
size = "small",

View File

@ -23,6 +23,6 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import, line-too-long
from tensorflow.contrib.distributions.python.ops import gaussian_conjugate_posteriors
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.gaussian import *
# from tensorflow.contrib.distributions.python.ops.dirichlet import * # pylint: disable=line-too-long
# from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * # pylint: disable=line-too-long

View File

@ -0,0 +1,192 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class DirichletMultinomialTest(tf.test.TestCase):
def test_num_classes(self):
with self.test_session():
for num_classes in range(3):
alpha = np.random.rand(3, num_classes)
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
self.assertEqual([], dist.num_classes.get_shape())
self.assertEqual(num_classes, dist.num_classes.eval())
def test_alpha_property(self):
alpha = np.array([[1., 2, 3]])
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
self.assertEqual([1, 3], dist.alpha.get_shape())
self.assertAllClose(alpha, dist.alpha.eval())
def test_empty_alpha_and_empty_counts_returns_empty(self):
with self.test_session():
alpha = [[]]
counts = [[]]
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
self.assertAllEqual([], dist.pmf(counts).eval())
self.assertAllEqual([0], dist.pmf(counts).get_shape())
self.assertAllEqual([], dist.log_pmf(counts).eval())
self.assertAllEqual([0], dist.log_pmf(counts).get_shape())
self.assertAllEqual([[]], dist.mean.eval())
self.assertAllEqual([1, 0], dist.mean.get_shape())
self.assertAllEqual(0, dist.num_classes.eval())
self.assertAllEqual([], dist.num_classes.get_shape())
def test_pmf_both_zero_batches(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
with self.test_session():
# Both zero-batches. No broadcast
alpha = [1., 2]
counts = [1, 0.]
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
pmf = dist.pmf(counts)
self.assertAllClose(1 / 3., pmf.eval())
self.assertEqual((), pmf.get_shape())
def test_pmf_alpha_stretched_in_broadcast_when_same_rank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
with self.test_session():
alpha = [[1., 2]]
counts = [[1, 0.], [0, 1.]]
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
pmf = dist.pmf(counts)
self.assertAllClose([1 / 3., 2 / 3.], pmf.eval())
self.assertEqual((2), pmf.get_shape())
def test_pmf_alpha_stretched_in_broadcast_when_lower_rank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
with self.test_session():
alpha = [1., 2]
counts = [[1, 0.], [0, 1.]]
pmf = tf.contrib.distributions.DirichletMultinomial(alpha).pmf(counts)
self.assertAllClose([1 / 3., 2 / 3.], pmf.eval())
self.assertEqual((2), pmf.get_shape())
def test_pmf_counts_stretched_in_broadcast_when_same_rank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
with self.test_session():
alpha = [[1., 2], [2., 3]]
counts = [[1, 0.]]
pmf = tf.contrib.distributions.DirichletMultinomial(alpha).pmf(counts)
self.assertAllClose([1 / 3., 2 / 5.], pmf.eval())
self.assertEqual((2), pmf.get_shape())
def test_pmf_counts_stretched_in_broadcast_when_lower_rank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
with self.test_session():
alpha = [[1., 2], [2., 3]]
counts = [1, 0.]
pmf = tf.contrib.distributions.DirichletMultinomial(alpha).pmf(counts)
self.assertAllClose([1 / 3., 2 / 5.], pmf.eval())
self.assertEqual((2), pmf.get_shape())
def test_pmf_for_one_vote_is_the_mean_with_one_record_input(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
alpha = [1., 2, 3]
with self.test_session():
for class_num in range(3):
counts = np.zeros((3), dtype=np.float32)
counts[class_num] = 1.0
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
mean = dist.mean.eval()
pmf = dist.pmf(counts).eval()
self.assertAllClose(mean[class_num], pmf)
self.assertTupleEqual((3,), mean.shape)
self.assertTupleEqual((), pmf.shape)
def test_zero_counts_results_in_pmf_equal_to_one(self):
# There is only one way for zero items to be selected, and this happens with
# probability 1.
alpha = [5, 0.5]
counts = [0., 0.]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
pmf = dist.pmf(counts)
self.assertAllClose(1.0, pmf.eval())
self.assertEqual((), pmf.get_shape())
def test_large_tau_gives_precise_probabilities(self):
# If tau is large, we are doing coin flips with probability mu.
mu = np.array([0.1, 0.1, 0.8], dtype=np.float32)
tau = np.array([100.], dtype=np.float32)
alpha = tau * mu
# One (three sided) coin flip. Prob[coin 3] = 0.8.
# Note that since it was one flip, value of tau didn't matter.
counts = [0., 0, 1]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
pmf = dist.pmf(counts)
self.assertAllClose(0.8, pmf.eval(), atol=1e-4)
self.assertEqual((), pmf.get_shape())
# Two (three sided) coin flips. Prob[coin 3] = 0.8.
counts = [0., 0, 2]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
pmf = dist.pmf(counts)
self.assertAllClose(0.8**2, pmf.eval(), atol=1e-2)
self.assertEqual((), pmf.get_shape())
# Three (three sided) coin flips.
counts = [1., 0, 2]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
pmf = dist.pmf(counts)
self.assertAllClose(3 * 0.1 * 0.8 * 0.8, pmf.eval(), atol=1e-2)
self.assertEqual((), pmf.get_shape())
def test_small_tau_prefers_correlated_results(self):
# If tau is small, then correlation between draws is large, so draws that
# are both of the same class are more likely.
mu = np.array([0.5, 0.5], dtype=np.float32)
tau = np.array([0.1], dtype=np.float32)
alpha = tau * mu
# If there is only one draw, it is still a coin flip, even with small tau.
counts = [1, 0.]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
pmf = dist.pmf(counts)
self.assertAllClose(0.5, pmf.eval())
self.assertEqual((), pmf.get_shape())
# If there are two draws, it is much more likely that they are the same.
counts_same = [2, 0.]
counts_different = [1, 1.]
with self.test_session():
dist = tf.contrib.distributions.DirichletMultinomial(alpha)
pmf_same = dist.pmf(counts_same)
pmf_different = dist.pmf(counts_different)
self.assertLess(5 * pmf_different.eval(), pmf_same.eval())
self.assertEqual((), pmf_same.get_shape())
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,261 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Dirichlet Multinomial distribution class.
@@DirichletMultinomial
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
def _check_alpha(alpha):
"""Check alpha for proper shape, values, then return tensor version."""
alpha = ops.convert_to_tensor(alpha, name='alpha_before_deps')
return control_flow_ops.with_dependencies(
[check_ops.assert_rank_at_least(alpha, 1),
check_ops.assert_positive(alpha)], alpha)
def _log_combinations(counts, name='log_combinations'):
"""Log number of ways counts could have come in."""
# Firt a bit about the number of ways counts could have come in:
# E.g. if counts = [1, 2], then this is 3 choose 2.
# In general, this is (sum counts)! / sum(counts!)
# The sum should be along the last dimension of counts. This is the
# "distribution" dimension.
with ops.op_scope([counts], name):
last_dim = array_ops.rank(counts) - 1
# To compute factorials, use the fact that Gamma(n + 1) = n!
# Compute two terms, each a sum over counts. Compute each for each
# batch member.
# Log Gamma((sum counts) + 1) = Log((sum counts)!)
sum_of_counts = math_ops.reduce_sum(counts, reduction_indices=last_dim)
total_permutations = math_ops.lgamma(sum_of_counts + 1)
# sum(Log Gamma(counts + 1)) = Log sum(counts!)
counts_factorial = math_ops.lgamma(counts + 1)
redundant_permutations = math_ops.reduce_sum(counts_factorial,
reduction_indices=last_dim)
return total_permutations - redundant_permutations
class DirichletMultinomial(object):
"""DirichletMultinomial mixture distribution.
The Dirichlet Multinomial is a distribution over k-class count data, meaning
for each k-tuple of non-negative integer `counts = [c_1,...,c_k]`, we have a
probability of these draws being made from the distribution. The distribution
has hyperparameters `alpha = (alpha_1,...,alpha_k)`, and probability mass
function (pmf):
```pmf(counts) = C! / (c_1!...c_k!) * Beta(alpha + c) / Beta(alpha)```
where above `C = sum_j c_j`, `N!` is `N` factorial, and
`Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the multivariate beta
function.
This is a mixture distribution in that `N` samples can be produced by:
1. Choose class probabilities `p = (p_1,...,p_k) ~ Dir(alpha)`
2. Draw integers `m = (m_1,...,m_k) ~ Multinomial(p, N)`
This class provides methods to create indexed batches of Dirichlet
Multinomial distributions. If the provided `alpha` is rank 2 or higher, for
every fixed set of leading dimensions, the last dimension represents one
single Dirichlet Multinomial distribution. When calling distribution
functions (e.g. `dist.pdf(counts)`), `alpha` and `counts` are broadcast to the
same shape (if possible). In all cases, the last dimension of alpha/counts
represents single Dirichlet Multinomial distributions.
Examples:
```python
alpha = [1, 2, 3]
dist = DirichletMultinomial(alpha)
```
Creates a 3-class distribution, with the 3rd class is most likely to be drawn.
The distribution functions can be evaluated on counts.
```python
# counts same shape as alpha.
counts = [0, 2, 0]
dist.pdf(counts) # Shape []
# alpha will be broadcast to [[1, 2, 3], [1, 2, 3]] to match counts.
counts = [[11, 22, 33], [44, 55, 66]]
dist.pdf(counts) # Shape [2]
# alpha will be broadcast to shape [5, 7, 3] to match counts.
counts = [[...]] # Shape [5, 7, 3]
dist.pdf(counts) # Shape [5, 7]
```
Creates a 2-batch of 3-class distributions.
```python
alpha = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3]
dist = DirichletMultinomial(alpha)
# counts will be broadcast to [[11, 22, 33], [11, 22, 33]] to match alpha.
counts = [11, 22, 33]
dist.pdf(counts) # Shape [2]
```
"""
# TODO(b/27419586) Change docstring for dtype of alpha once int allowed.
def __init__(self, alpha):
"""Initialize a batch of DirichletMultinomial distributions.
Args:
alpha: Shape `[N1,..., Nn, k]` positive `float` or `double` tensor with
`n >= 0`. Defines this as a batch of `N1 x ... x Nn` different `k`
class Dirichlet multinomial distributions.
Examples:
```python
# Define 1-batch of 2-class Dirichlet multinomial distribution,
# also known as a beta-binomial.
dist = DirichletMultinomial([1.1, 2.0])
# Define a 2-batch of 3-class distributions.
dist = DirichletMultinomial([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
```
"""
# Broadcasting works because:
# * The broadcasting convention is to prepend dimensions of size [1], and
# we use the last dimension for the distribution, wherease
# the batch dimensions are the leading dimensions, which forces the
# distribution dimension to be defined explicitly (i.e. it cannot be
# created automatically by prepending). This forces enough explicitivity.
# * All calls involving `counts` eventually require a broadcast between
# `counts` and alpha.
self._alpha = _check_alpha(alpha)
self._num_classes = self._get_num_classes()
self._dist_indices = self._get_dist_indices()
@property
def alpha(self):
"""Parameters defining this distribution."""
return self._alpha
@property
def dtype(self):
return self._alpha.dtype
@property
def mean(self):
"""Class means for every batch member."""
with ops.name_scope('mean'):
alpha_sum = math_ops.reduce_sum(self._alpha,
reduction_indices=self._dist_indices,
keep_dims=True)
mean = math_ops.truediv(self._alpha, alpha_sum)
mean.set_shape(self._alpha.get_shape())
return mean
def _get_dist_indices(self):
"""Dimensions corresponding to individual distributions."""
# Reshape the scalar to a rank 1 tensor.
return array_ops.reshape(array_ops.rank(self._alpha) - 1, [-1])
def _get_num_classes(self):
return ops.convert_to_tensor(
array_ops.reverse(
array_ops.shape(self._alpha), [True])[0],
name='num_classes')
@property
def num_classes(self):
"""Tensor providing number of classes in each batch member."""
return self._num_classes
def cdf(self, x):
raise NotImplementedError(
'DirichletMultinomial does not have a well-defined cdf.')
def log_cdf(self, x):
raise NotImplementedError(
'DirichletMultinomial does not have a well-defined cdf.')
def log_pmf(self, counts, name=None):
"""`Log(P[counts])`, computed for every batch member.
For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability
that after sampling `sum_j c_j` draws from this Dirichlet Multinomial
distribution, the number of draws falling in class `j` is `c_j`. Note that
different sequences of draws can result in the same counts, thus the
probability includes a combinatorial coefficient.
Args:
counts: Non-negative `float`, `double`, or `int` tensor whose shape can
be broadcast with `self.alpha`. For fixed leading dimensions, the last
dimension represents counts for the corresponding Dirichlet Multinomial
distribution in `self.alpha`.
name: Name to give this Op, defaults to "log_pmf".
Returns:
Log probabilities for each record, shape `[N1,...,Nn]`.
"""
alpha = self._alpha
with ops.op_scope([alpha, counts], name, 'log_pmf'):
counts = self._check_counts(counts)
ordered_pmf = (special_math_ops.lbeta(alpha + counts) -
special_math_ops.lbeta(alpha))
log_pmf = ordered_pmf + _log_combinations(counts)
# If alpha = counts = [[]], ordered_pmf carries the right shape, which is
# []. However, since reduce_sum([[]]) = [0], log_combinations = [0],
# which is not correct. Luckily, [] + [0] = [], so the sum is fine, but
# shape must be inferred from ordered_pmf.
# Note also that tf.constant([]).get_shape() = TensorShape([Dimension(0)])
log_pmf.set_shape(ordered_pmf.get_shape())
return log_pmf
def pmf(self, counts, name=None):
"""`P[counts]`, computed for every batch member.
For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability
that after sampling `sum_j c_j` draws from this Dirichlet Multinomial
distribution, the number of draws falling in class `j` is `c_j`. Note that
different sequences of draws can result in the same counts, thus the
probability includes a combinatorial coefficient.
Args:
counts: Non-negative `float`, `double`, or `int` tensor whose shape can
be broadcast with `self.alpha`. For fixed leading dimensions, the last
dimension represents counts for the corresponding Dirichlet Multinomial
distribution in `self.alpha`.
name: Name to give this Op, defaults to "pmf".
Returns:
Probabilities for each record, shape `[N1,...,Nn]`.
"""
with ops.name_scope('pmf' if name is None else name):
return math_ops.exp(self.log_pmf(counts))
def _check_counts(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
counts = ops.convert_to_tensor(counts, name='counts_before_deps')
counts = math_ops.cast(counts, self.dtype)
return control_flow_ops.with_dependencies(
[check_ops.assert_non_negative(counts)], counts)