Explicit constructor call is no less clear and match what we export via the public API. The functions will be removed once all the internal users are migrated. PiperOrigin-RevId: 259620054
348 lines
12 KiB
Python
348 lines
12 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""The Categorical distribution class."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import nn_ops
|
|
from tensorflow.python.ops import random_ops
|
|
from tensorflow.python.ops.distributions import distribution
|
|
from tensorflow.python.ops.distributions import kullback_leibler
|
|
from tensorflow.python.ops.distributions import util as distribution_util
|
|
from tensorflow.python.util import deprecation
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
def _broadcast_cat_event_and_params(event, params, base_dtype):
|
|
"""Broadcasts the event or distribution parameters."""
|
|
if event.dtype.is_integer:
|
|
pass
|
|
elif event.dtype.is_floating:
|
|
# When `validate_args=True` we've already ensured int/float casting
|
|
# is closed.
|
|
event = math_ops.cast(event, dtype=dtypes.int32)
|
|
else:
|
|
raise TypeError("`value` should have integer `dtype` or "
|
|
"`self.dtype` ({})".format(base_dtype))
|
|
shape_known_statically = (
|
|
params.shape.ndims is not None and
|
|
params.shape[:-1].is_fully_defined() and
|
|
event.shape.is_fully_defined())
|
|
if not shape_known_statically or params.shape[:-1] != event.shape:
|
|
params *= array_ops.ones_like(event[..., array_ops.newaxis],
|
|
dtype=params.dtype)
|
|
params_shape = array_ops.shape(params)[:-1]
|
|
event *= array_ops.ones(params_shape, dtype=event.dtype)
|
|
if params.shape.ndims is not None:
|
|
event.set_shape(tensor_shape.TensorShape(params.shape[:-1]))
|
|
|
|
return event, params
|
|
|
|
|
|
@tf_export(v1=["distributions.Categorical"])
|
|
class Categorical(distribution.Distribution):
|
|
"""Categorical distribution.
|
|
|
|
The Categorical distribution is parameterized by either probabilities or
|
|
log-probabilities of a set of `K` classes. It is defined over the integers
|
|
`{0, 1, ..., K}`.
|
|
|
|
The Categorical distribution is closely related to the `OneHotCategorical` and
|
|
`Multinomial` distributions. The Categorical distribution can be intuited as
|
|
generating samples according to `argmax{ OneHotCategorical(probs) }` itself
|
|
being identical to `argmax{ Multinomial(probs, total_count=1) }`.
|
|
|
|
#### Mathematical Details
|
|
|
|
The probability mass function (pmf) is,
|
|
|
|
```none
|
|
pmf(k; pi) = prod_j pi_j**[k == j]
|
|
```
|
|
|
|
#### Pitfalls
|
|
|
|
The number of classes, `K`, must not exceed:
|
|
- the largest integer representable by `self.dtype`, i.e.,
|
|
`2**(mantissa_bits+1)` (IEEE 754),
|
|
- the maximum `Tensor` index, i.e., `2**31-1`.
|
|
|
|
In other words,
|
|
|
|
```python
|
|
K <= min(2**31-1, {
|
|
tf.float16: 2**11,
|
|
tf.float32: 2**24,
|
|
tf.float64: 2**53 }[param.dtype])
|
|
```
|
|
|
|
Note: This condition is validated only when `self.validate_args = True`.
|
|
|
|
#### Examples
|
|
|
|
Creates a 3-class distribution with the 2nd class being most likely.
|
|
|
|
```python
|
|
dist = Categorical(probs=[0.1, 0.5, 0.4])
|
|
n = 1e4
|
|
empirical_prob = tf.cast(
|
|
tf.histogram_fixed_width(
|
|
dist.sample(int(n)),
|
|
[0., 2],
|
|
nbins=3),
|
|
dtype=tf.float32) / n
|
|
# ==> array([ 0.1005, 0.5037, 0.3958], dtype=float32)
|
|
```
|
|
|
|
Creates a 3-class distribution with the 2nd class being most likely.
|
|
Parameterized by [logits](https://en.wikipedia.org/wiki/Logit) rather than
|
|
probabilities.
|
|
|
|
```python
|
|
dist = Categorical(logits=np.log([0.1, 0.5, 0.4])
|
|
n = 1e4
|
|
empirical_prob = tf.cast(
|
|
tf.histogram_fixed_width(
|
|
dist.sample(int(n)),
|
|
[0., 2],
|
|
nbins=3),
|
|
dtype=tf.float32) / n
|
|
# ==> array([0.1045, 0.5047, 0.3908], dtype=float32)
|
|
```
|
|
|
|
Creates a 3-class distribution with the 3rd class being most likely.
|
|
The distribution functions can be evaluated on counts.
|
|
|
|
```python
|
|
# counts is a scalar.
|
|
p = [0.1, 0.4, 0.5]
|
|
dist = Categorical(probs=p)
|
|
dist.prob(0) # Shape []
|
|
|
|
# p will be broadcast to [[0.1, 0.4, 0.5], [0.1, 0.4, 0.5]] to match counts.
|
|
counts = [1, 0]
|
|
dist.prob(counts) # Shape [2]
|
|
|
|
# p will be broadcast to shape [3, 5, 7, 3] to match counts.
|
|
counts = [[...]] # Shape [5, 7, 3]
|
|
dist.prob(counts) # Shape [5, 7, 3]
|
|
```
|
|
|
|
"""
|
|
|
|
@deprecation.deprecated(
|
|
"2019-01-01",
|
|
"The TensorFlow Distributions library has moved to "
|
|
"TensorFlow Probability "
|
|
"(https://github.com/tensorflow/probability). You "
|
|
"should update all references to use `tfp.distributions` "
|
|
"instead of `tf.distributions`.",
|
|
warn_once=True)
|
|
def __init__(
|
|
self,
|
|
logits=None,
|
|
probs=None,
|
|
dtype=dtypes.int32,
|
|
validate_args=False,
|
|
allow_nan_stats=True,
|
|
name="Categorical"):
|
|
"""Initialize Categorical distributions using class log-probabilities.
|
|
|
|
Args:
|
|
logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
|
|
of a set of Categorical distributions. The first `N - 1` dimensions
|
|
index into a batch of independent distributions and the last dimension
|
|
represents a vector of logits for each class. Only one of `logits` or
|
|
`probs` should be passed in.
|
|
probs: An N-D `Tensor`, `N >= 1`, representing the probabilities
|
|
of a set of Categorical distributions. The first `N - 1` dimensions
|
|
index into a batch of independent distributions and the last dimension
|
|
represents a vector of probabilities for each class. Only one of
|
|
`logits` or `probs` should be passed in.
|
|
dtype: The type of the event samples (default: int32).
|
|
validate_args: Python `bool`, default `False`. When `True` distribution
|
|
parameters are checked for validity despite possibly degrading runtime
|
|
performance. When `False` invalid inputs may silently render incorrect
|
|
outputs.
|
|
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
|
|
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
|
|
result is undefined. When `False`, an exception is raised if one or
|
|
more of the statistic's batch members are undefined.
|
|
name: Python `str` name prefixed to Ops created by this class.
|
|
"""
|
|
parameters = dict(locals())
|
|
with ops.name_scope(name, values=[logits, probs]) as name:
|
|
self._logits, self._probs = distribution_util.get_logits_and_probs(
|
|
logits=logits,
|
|
probs=probs,
|
|
validate_args=validate_args,
|
|
multidimensional=True,
|
|
name=name)
|
|
|
|
if validate_args:
|
|
self._logits = distribution_util.embed_check_categorical_event_shape(
|
|
self._logits)
|
|
|
|
logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
|
|
if logits_shape_static.ndims is not None:
|
|
self._batch_rank = ops.convert_to_tensor(
|
|
logits_shape_static.ndims - 1,
|
|
dtype=dtypes.int32,
|
|
name="batch_rank")
|
|
else:
|
|
with ops.name_scope(name="batch_rank"):
|
|
self._batch_rank = array_ops.rank(self._logits) - 1
|
|
|
|
logits_shape = array_ops.shape(self._logits, name="logits_shape")
|
|
if tensor_shape.dimension_value(logits_shape_static[-1]) is not None:
|
|
self._event_size = ops.convert_to_tensor(
|
|
logits_shape_static.dims[-1].value,
|
|
dtype=dtypes.int32,
|
|
name="event_size")
|
|
else:
|
|
with ops.name_scope(name="event_size"):
|
|
self._event_size = logits_shape[self._batch_rank]
|
|
|
|
if logits_shape_static[:-1].is_fully_defined():
|
|
self._batch_shape_val = constant_op.constant(
|
|
logits_shape_static[:-1].as_list(),
|
|
dtype=dtypes.int32,
|
|
name="batch_shape")
|
|
else:
|
|
with ops.name_scope(name="batch_shape"):
|
|
self._batch_shape_val = logits_shape[:-1]
|
|
super(Categorical, self).__init__(
|
|
dtype=dtype,
|
|
reparameterization_type=distribution.NOT_REPARAMETERIZED,
|
|
validate_args=validate_args,
|
|
allow_nan_stats=allow_nan_stats,
|
|
parameters=parameters,
|
|
graph_parents=[self._logits,
|
|
self._probs],
|
|
name=name)
|
|
|
|
@property
|
|
def event_size(self):
|
|
"""Scalar `int32` tensor: the number of classes."""
|
|
return self._event_size
|
|
|
|
@property
|
|
def logits(self):
|
|
"""Vector of coordinatewise logits."""
|
|
return self._logits
|
|
|
|
@property
|
|
def probs(self):
|
|
"""Vector of coordinatewise probabilities."""
|
|
return self._probs
|
|
|
|
def _batch_shape_tensor(self):
|
|
return array_ops.identity(self._batch_shape_val)
|
|
|
|
def _batch_shape(self):
|
|
return self.logits.get_shape()[:-1]
|
|
|
|
def _event_shape_tensor(self):
|
|
return constant_op.constant([], dtype=dtypes.int32)
|
|
|
|
def _event_shape(self):
|
|
return tensor_shape.TensorShape([])
|
|
|
|
def _sample_n(self, n, seed=None):
|
|
if self.logits.get_shape().ndims == 2:
|
|
logits_2d = self.logits
|
|
else:
|
|
logits_2d = array_ops.reshape(self.logits, [-1, self.event_size])
|
|
sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32
|
|
draws = random_ops.multinomial(
|
|
logits_2d, n, seed=seed, output_dtype=sample_dtype)
|
|
draws = array_ops.reshape(
|
|
array_ops.transpose(draws),
|
|
array_ops.concat([[n], self.batch_shape_tensor()], 0))
|
|
return math_ops.cast(draws, self.dtype)
|
|
|
|
def _cdf(self, k):
|
|
k = ops.convert_to_tensor(k, name="k")
|
|
if self.validate_args:
|
|
k = distribution_util.embed_check_integer_casting_closed(
|
|
k, target_dtype=dtypes.int32)
|
|
|
|
k, probs = _broadcast_cat_event_and_params(
|
|
k, self.probs, base_dtype=self.dtype.base_dtype)
|
|
|
|
# batch-flatten everything in order to use `sequence_mask()`.
|
|
batch_flattened_probs = array_ops.reshape(probs,
|
|
(-1, self._event_size))
|
|
batch_flattened_k = array_ops.reshape(k, [-1])
|
|
|
|
to_sum_over = array_ops.where(
|
|
array_ops.sequence_mask(batch_flattened_k, self._event_size),
|
|
batch_flattened_probs,
|
|
array_ops.zeros_like(batch_flattened_probs))
|
|
batch_flattened_cdf = math_ops.reduce_sum(to_sum_over, axis=-1)
|
|
# Reshape back to the shape of the argument.
|
|
return array_ops.reshape(batch_flattened_cdf, array_ops.shape(k))
|
|
|
|
def _log_prob(self, k):
|
|
k = ops.convert_to_tensor(k, name="k")
|
|
if self.validate_args:
|
|
k = distribution_util.embed_check_integer_casting_closed(
|
|
k, target_dtype=dtypes.int32)
|
|
k, logits = _broadcast_cat_event_and_params(
|
|
k, self.logits, base_dtype=self.dtype.base_dtype)
|
|
|
|
return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k,
|
|
logits=logits)
|
|
|
|
def _entropy(self):
|
|
return -math_ops.reduce_sum(
|
|
nn_ops.log_softmax(self.logits) * self.probs, axis=-1)
|
|
|
|
def _mode(self):
|
|
ret = math_ops.argmax(self.logits, axis=self._batch_rank)
|
|
ret = math_ops.cast(ret, self.dtype)
|
|
ret.set_shape(self.batch_shape)
|
|
return ret
|
|
|
|
|
|
@kullback_leibler.RegisterKL(Categorical, Categorical)
|
|
def _kl_categorical_categorical(a, b, name=None):
|
|
"""Calculate the batched KL divergence KL(a || b) with a and b Categorical.
|
|
|
|
Args:
|
|
a: instance of a Categorical distribution object.
|
|
b: instance of a Categorical distribution object.
|
|
name: (optional) Name to use for created operations.
|
|
default is "kl_categorical_categorical".
|
|
|
|
Returns:
|
|
Batchwise KL(a || b)
|
|
"""
|
|
with ops.name_scope(name, "kl_categorical_categorical",
|
|
values=[a.logits, b.logits]):
|
|
# sum(probs log(probs / (1 - probs)))
|
|
delta_log_probs1 = (nn_ops.log_softmax(a.logits) -
|
|
nn_ops.log_softmax(b.logits))
|
|
return math_ops.reduce_sum(nn_ops.softmax(a.logits) * delta_log_probs1,
|
|
axis=-1)
|