STT-tensorflow/tensorflow/contrib/distributions/python/ops/beta.py
2016-10-13 13:19:41 -07:00

296 lines
10 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Beta distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
_beta_prob_note = """
Note that the argument `x` must be a non-negative floating point tensor
whose shape can be broadcast with `self.a` and `self.b`. For fixed leading
dimensions, the last dimension represents counts for the corresponding Beta
distribution in `self.a` and `self.b`. `x` is only legal if `0 < x < 1`.
"""
class Beta(distribution.Distribution):
"""Beta distribution.
This distribution is parameterized by `a` and `b` which are shape
parameters.
#### Mathematical details
The Beta is a distribution over the interval (0, 1).
The distribution has hyperparameters `a` and `b` and
probability mass function (pdf):
```pdf(x) = 1 / Beta(a, b) * x^(a - 1) * (1 - x)^(b - 1)```
where `Beta(a, b) = Gamma(a) * Gamma(b) / Gamma(a + b)`
is the beta function.
This class provides methods to create indexed batches of Beta
distributions. One entry of the broadcasted
shape represents of `a` and `b` represents one single Beta distribution.
When calling distribution functions (e.g. `dist.pdf(x)`), `a`, `b`
and `x` are broadcast to the same shape (if possible).
Every entry in a/b/x corresponds to a single Beta distribution.
#### Examples
Creates 3 distributions.
The distribution functions can be evaluated on x.
```python
a = [1, 2, 3]
b = [1, 2, 3]
dist = Beta(a, b)
```
```python
# x same shape as a.
x = [.2, .3, .7]
dist.pdf(x) # Shape [3]
# a/b will be broadcast to [[1, 2, 3], [1, 2, 3]] to match x.
x = [[.1, .4, .5], [.2, .3, .5]]
dist.pdf(x) # Shape [2, 3]
# a/b will be broadcast to shape [5, 7, 3] to match x.
x = [[...]] # Shape [5, 7, 3]
dist.pdf(x) # Shape [5, 7, 3]
```
Creates a 2-batch of 3-class distributions.
```python
a = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3]
b = 5 # Shape []
dist = Beta(a, b)
# x will be broadcast to [[.2, .3, .9], [.2, .3, .9]] to match a/b.
x = [.2, .3, .9]
dist.pdf(x) # Shape [2]
```
"""
def __init__(self, a, b, validate_args=False, allow_nan_stats=True,
name="Beta"):
"""Initialize a batch of Beta distributions.
Args:
a: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
different Beta distributions. This also defines the
dtype of the distribution.
b: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
different Beta distributions.
validate_args: `Boolean`, default `False`. Whether to assert valid
values for parameters `a`, `b`, and `x` in `prob` and `log_prob`.
If `False` and inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class.
Examples:
```python
# Define 1-batch.
dist = Beta(1.1, 2.0)
# Define a 2-batch.
dist = Beta([1.0, 2.0], [4.0, 5.0])
```
"""
with ops.name_scope(name, values=[a, b]) as ns:
with ops.control_dependencies([
check_ops.assert_positive(a),
check_ops.assert_positive(b),
] if validate_args else []):
self._a = array_ops.identity(a, name="a")
self._b = array_ops.identity(b, name="b")
contrib_tensor_util.assert_same_float_dtype((self._a, self._b))
# Used for mean/mode/variance/entropy/sampling computations
self._a_b_sum = self._a + self._b
super(Beta, self).__init__(
dtype=self._a_b_sum.dtype,
parameters={"a": self._a, "b": self._b, "a_b_sum": self._a_b_sum},
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
is_continuous=True,
is_reparameterized=False,
name=ns)
@staticmethod
def _param_shapes(sample_shape):
return dict(
zip(("a", "b"), ([ops.convert_to_tensor(
sample_shape, dtype=dtypes.int32)] * 2)))
@property
def a(self):
"""Shape parameter."""
return self._a
@property
def b(self):
"""Shape parameter."""
return self._b
@property
def a_b_sum(self):
"""Sum of parameters."""
return self._a_b_sum
def _batch_shape(self):
return array_ops.shape(self.a_b_sum)
def _get_batch_shape(self):
return self.a_b_sum.get_shape()
def _event_shape(self):
return constant_op.constant([], dtype=dtypes.int32)
def _get_event_shape(self):
return tensor_shape.scalar()
def _sample_n(self, n, seed=None):
a = array_ops.ones_like(self.a_b_sum, dtype=self.dtype) * self.a
b = array_ops.ones_like(self.a_b_sum, dtype=self.dtype) * self.b
gamma1_sample = random_ops.random_gamma(
[n,], a, dtype=self.dtype, seed=seed)
gamma2_sample = random_ops.random_gamma(
[n,], b, dtype=self.dtype,
seed=distribution_util.gen_new_seed(seed, "beta"))
beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
return beta_sample
def _log_prob(self, x):
x = self._assert_valid_sample(x)
log_unnormalized_prob = ((self.a - 1.) * math_ops.log(x) +
(self.b - 1.) * math_ops.log(1. - x))
log_normalization = (math_ops.lgamma(self.a) +
math_ops.lgamma(self.b) -
math_ops.lgamma(self.a_b_sum))
return log_unnormalized_prob - log_normalization
@distribution_util.AppendDocstring(_beta_prob_note)
def _prob(self, x):
return math_ops.exp(self._log_prob(x))
@distribution_util.AppendDocstring(_beta_prob_note)
def _log_cdf(self, x):
return math_ops.log(self._cdf(x))
def _cdf(self, x):
return math_ops.betainc(self.a, self.b, x)
def _entropy(self):
return (math_ops.lgamma(self.a) -
(self.a - 1.) * math_ops.digamma(self.a) +
math_ops.lgamma(self.b) -
(self.b - 1.) * math_ops.digamma(self.b) -
math_ops.lgamma(self.a_b_sum) +
(self.a_b_sum - 2.) * math_ops.digamma(self.a_b_sum))
def _mean(self):
return self.a / self.a_b_sum
def _variance(self):
return (self.a * self.b) / (self.a_b_sum**2. * (self.a_b_sum + 1.))
def _std(self):
return math_ops.sqrt(self.variance())
@distribution_util.AppendDocstring(
"""Note that the mode for the Beta distribution is only defined
when `a > 1`, `b > 1`. This returns the mode when `a > 1` and `b > 1`,
and `NaN` otherwise. If `self.allow_nan_stats` is `False`, an exception
will be raised rather than returning `NaN`.""")
def _mode(self):
mode = (self.a - 1.)/ (self.a_b_sum - 2.)
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
return math_ops.select(
math_ops.logical_and(
math_ops.greater(self.a, 1.),
math_ops.greater(self.b, 1.)),
mode,
array_ops.fill(self.batch_shape(), nan, name="nan"))
else:
return control_flow_ops.with_dependencies([
check_ops.assert_less(
array_ops.ones((), dtype=self.dtype), self.a,
message="Mode not defined for components of a <= 1."),
check_ops.assert_less(
array_ops.ones((), dtype=self.dtype), self.b,
message="Mode not defined for components of b <= 1."),
], mode)
def _assert_valid_sample(self, x):
"""Check x for proper shape, values, then return tensor version."""
if not self.validate_args: return x
return control_flow_ops.with_dependencies([
check_ops.assert_positive(
x,
message="Negative events lie outside Beta distribution support."),
check_ops.assert_less(
x, array_ops.ones((), self.dtype),
message="Event>=1 lies outside Beta distribution support."),
], x)
class BetaWithSoftplusAB(Beta):
"""Beta with softplus transform on `a` and `b`."""
def __init__(self,
a,
b,
validate_args=False,
allow_nan_stats=True,
name="BetaWithSoftplusAB"):
with ops.name_scope(name, values=[a, b]) as ns:
super(BetaWithSoftplusAB, self).__init__(
a=nn.softplus(a),
b=nn.softplus(b),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=ns)