Sergei Lebedev bca5e7385f Inlined tensor_shape.{scalar,vector,matrix}
Explicit constructor call is no less clear and match what we export via
the public API.

The functions will be removed once all the internal users are migrated.

PiperOrigin-RevId: 259620054
2019-07-23 15:33:44 -07:00

348 lines
12 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Categorical distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
def _broadcast_cat_event_and_params(event, params, base_dtype):
"""Broadcasts the event or distribution parameters."""
if event.dtype.is_integer:
pass
elif event.dtype.is_floating:
# When `validate_args=True` we've already ensured int/float casting
# is closed.
event = math_ops.cast(event, dtype=dtypes.int32)
else:
raise TypeError("`value` should have integer `dtype` or "
"`self.dtype` ({})".format(base_dtype))
shape_known_statically = (
params.shape.ndims is not None and
params.shape[:-1].is_fully_defined() and
event.shape.is_fully_defined())
if not shape_known_statically or params.shape[:-1] != event.shape:
params *= array_ops.ones_like(event[..., array_ops.newaxis],
dtype=params.dtype)
params_shape = array_ops.shape(params)[:-1]
event *= array_ops.ones(params_shape, dtype=event.dtype)
if params.shape.ndims is not None:
event.set_shape(tensor_shape.TensorShape(params.shape[:-1]))
return event, params
@tf_export(v1=["distributions.Categorical"])
class Categorical(distribution.Distribution):
"""Categorical distribution.
The Categorical distribution is parameterized by either probabilities or
log-probabilities of a set of `K` classes. It is defined over the integers
`{0, 1, ..., K}`.
The Categorical distribution is closely related to the `OneHotCategorical` and
`Multinomial` distributions. The Categorical distribution can be intuited as
generating samples according to `argmax{ OneHotCategorical(probs) }` itself
being identical to `argmax{ Multinomial(probs, total_count=1) }`.
#### Mathematical Details
The probability mass function (pmf) is,
```none
pmf(k; pi) = prod_j pi_j**[k == j]
```
#### Pitfalls
The number of classes, `K`, must not exceed:
- the largest integer representable by `self.dtype`, i.e.,
`2**(mantissa_bits+1)` (IEEE 754),
- the maximum `Tensor` index, i.e., `2**31-1`.
In other words,
```python
K <= min(2**31-1, {
tf.float16: 2**11,
tf.float32: 2**24,
tf.float64: 2**53 }[param.dtype])
```
Note: This condition is validated only when `self.validate_args = True`.
#### Examples
Creates a 3-class distribution with the 2nd class being most likely.
```python
dist = Categorical(probs=[0.1, 0.5, 0.4])
n = 1e4
empirical_prob = tf.cast(
tf.histogram_fixed_width(
dist.sample(int(n)),
[0., 2],
nbins=3),
dtype=tf.float32) / n
# ==> array([ 0.1005, 0.5037, 0.3958], dtype=float32)
```
Creates a 3-class distribution with the 2nd class being most likely.
Parameterized by [logits](https://en.wikipedia.org/wiki/Logit) rather than
probabilities.
```python
dist = Categorical(logits=np.log([0.1, 0.5, 0.4])
n = 1e4
empirical_prob = tf.cast(
tf.histogram_fixed_width(
dist.sample(int(n)),
[0., 2],
nbins=3),
dtype=tf.float32) / n
# ==> array([0.1045, 0.5047, 0.3908], dtype=float32)
```
Creates a 3-class distribution with the 3rd class being most likely.
The distribution functions can be evaluated on counts.
```python
# counts is a scalar.
p = [0.1, 0.4, 0.5]
dist = Categorical(probs=p)
dist.prob(0) # Shape []
# p will be broadcast to [[0.1, 0.4, 0.5], [0.1, 0.4, 0.5]] to match counts.
counts = [1, 0]
dist.prob(counts) # Shape [2]
# p will be broadcast to shape [3, 5, 7, 3] to match counts.
counts = [[...]] # Shape [5, 7, 3]
dist.prob(counts) # Shape [5, 7, 3]
```
"""
@deprecation.deprecated(
"2019-01-01",
"The TensorFlow Distributions library has moved to "
"TensorFlow Probability "
"(https://github.com/tensorflow/probability). You "
"should update all references to use `tfp.distributions` "
"instead of `tf.distributions`.",
warn_once=True)
def __init__(
self,
logits=None,
probs=None,
dtype=dtypes.int32,
validate_args=False,
allow_nan_stats=True,
name="Categorical"):
"""Initialize Categorical distributions using class log-probabilities.
Args:
logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
of a set of Categorical distributions. The first `N - 1` dimensions
index into a batch of independent distributions and the last dimension
represents a vector of logits for each class. Only one of `logits` or
`probs` should be passed in.
probs: An N-D `Tensor`, `N >= 1`, representing the probabilities
of a set of Categorical distributions. The first `N - 1` dimensions
index into a batch of independent distributions and the last dimension
represents a vector of probabilities for each class. Only one of
`logits` or `probs` should be passed in.
dtype: The type of the event samples (default: int32).
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
parameters = dict(locals())
with ops.name_scope(name, values=[logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits=logits,
probs=probs,
validate_args=validate_args,
multidimensional=True,
name=name)
if validate_args:
self._logits = distribution_util.embed_check_categorical_event_shape(
self._logits)
logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
if logits_shape_static.ndims is not None:
self._batch_rank = ops.convert_to_tensor(
logits_shape_static.ndims - 1,
dtype=dtypes.int32,
name="batch_rank")
else:
with ops.name_scope(name="batch_rank"):
self._batch_rank = array_ops.rank(self._logits) - 1
logits_shape = array_ops.shape(self._logits, name="logits_shape")
if tensor_shape.dimension_value(logits_shape_static[-1]) is not None:
self._event_size = ops.convert_to_tensor(
logits_shape_static.dims[-1].value,
dtype=dtypes.int32,
name="event_size")
else:
with ops.name_scope(name="event_size"):
self._event_size = logits_shape[self._batch_rank]
if logits_shape_static[:-1].is_fully_defined():
self._batch_shape_val = constant_op.constant(
logits_shape_static[:-1].as_list(),
dtype=dtypes.int32,
name="batch_shape")
else:
with ops.name_scope(name="batch_shape"):
self._batch_shape_val = logits_shape[:-1]
super(Categorical, self).__init__(
dtype=dtype,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._logits,
self._probs],
name=name)
@property
def event_size(self):
"""Scalar `int32` tensor: the number of classes."""
return self._event_size
@property
def logits(self):
"""Vector of coordinatewise logits."""
return self._logits
@property
def probs(self):
"""Vector of coordinatewise probabilities."""
return self._probs
def _batch_shape_tensor(self):
return array_ops.identity(self._batch_shape_val)
def _batch_shape(self):
return self.logits.get_shape()[:-1]
def _event_shape_tensor(self):
return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self):
return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None):
if self.logits.get_shape().ndims == 2:
logits_2d = self.logits
else:
logits_2d = array_ops.reshape(self.logits, [-1, self.event_size])
sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32
draws = random_ops.multinomial(
logits_2d, n, seed=seed, output_dtype=sample_dtype)
draws = array_ops.reshape(
array_ops.transpose(draws),
array_ops.concat([[n], self.batch_shape_tensor()], 0))
return math_ops.cast(draws, self.dtype)
def _cdf(self, k):
k = ops.convert_to_tensor(k, name="k")
if self.validate_args:
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=dtypes.int32)
k, probs = _broadcast_cat_event_and_params(
k, self.probs, base_dtype=self.dtype.base_dtype)
# batch-flatten everything in order to use `sequence_mask()`.
batch_flattened_probs = array_ops.reshape(probs,
(-1, self._event_size))
batch_flattened_k = array_ops.reshape(k, [-1])
to_sum_over = array_ops.where(
array_ops.sequence_mask(batch_flattened_k, self._event_size),
batch_flattened_probs,
array_ops.zeros_like(batch_flattened_probs))
batch_flattened_cdf = math_ops.reduce_sum(to_sum_over, axis=-1)
# Reshape back to the shape of the argument.
return array_ops.reshape(batch_flattened_cdf, array_ops.shape(k))
def _log_prob(self, k):
k = ops.convert_to_tensor(k, name="k")
if self.validate_args:
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=dtypes.int32)
k, logits = _broadcast_cat_event_and_params(
k, self.logits, base_dtype=self.dtype.base_dtype)
return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k,
logits=logits)
def _entropy(self):
return -math_ops.reduce_sum(
nn_ops.log_softmax(self.logits) * self.probs, axis=-1)
def _mode(self):
ret = math_ops.argmax(self.logits, axis=self._batch_rank)
ret = math_ops.cast(ret, self.dtype)
ret.set_shape(self.batch_shape)
return ret
@kullback_leibler.RegisterKL(Categorical, Categorical)
def _kl_categorical_categorical(a, b, name=None):
"""Calculate the batched KL divergence KL(a || b) with a and b Categorical.
Args:
a: instance of a Categorical distribution object.
b: instance of a Categorical distribution object.
name: (optional) Name to use for created operations.
default is "kl_categorical_categorical".
Returns:
Batchwise KL(a || b)
"""
with ops.name_scope(name, "kl_categorical_categorical",
values=[a.logits, b.logits]):
# sum(probs log(probs / (1 - probs)))
delta_log_probs1 = (nn_ops.log_softmax(a.logits) -
nn_ops.log_softmax(b.logits))
return math_ops.reduce_sum(nn_ops.softmax(a.logits) * delta_log_probs1,
axis=-1)