Breaking change: Rename tf.contrib.distributions.Independent parameter from

`reduce_batch_ndims` to `reinterpreted_batch_ndims`. Also change default;
`reinterpreted_batch_ndims` default has semantics of `tf.layers.flatten`, i.e.,
all batch dimensions except the first (batch axis 0) are interpretted as being
part of the event.

PiperOrigin-RevId: 173729585
This commit is contained in:
Joshua V. Dillon 2017-10-27 15:59:46 -07:00 committed by TensorFlower Gardener
parent 5426a3c93d
commit 80374a7b47
2 changed files with 130 additions and 48 deletions

View File

@ -23,7 +23,10 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@ -42,13 +45,16 @@ stats = try_import("scipy.stats")
class ProductDistributionTest(test.TestCase):
def setUp(self):
self._rng = np.random.RandomState(42)
def testSampleAndLogProbUnivariate(self):
loc = np.float32([-1., 1])
scale = np.float32([0.1, 0.5])
with self.test_session() as sess:
ind = independent_lib.Independent(
distribution=normal_lib.Normal(loc=loc, scale=scale),
reduce_batch_ndims=1)
reinterpreted_batch_ndims=1)
x = ind.sample([4, 5])
log_prob_x = ind.log_prob(x)
@ -71,7 +77,7 @@ class ProductDistributionTest(test.TestCase):
distribution=mvn_diag_lib.MultivariateNormalDiag(
loc=loc,
scale_identity_multiplier=scale),
reduce_batch_ndims=1)
reinterpreted_batch_ndims=1)
x = ind.sample([4, 5])
log_prob_x = ind.log_prob(x)
@ -96,7 +102,7 @@ class ProductDistributionTest(test.TestCase):
distribution=mvn_diag_lib.MultivariateNormalDiag(
loc=loc,
scale_identity_multiplier=scale),
reduce_batch_ndims=1)
reinterpreted_batch_ndims=1)
x = ind.sample(int(n_samp), seed=42)
sample_mean = math_ops.reduce_mean(x, axis=0)
@ -120,6 +126,59 @@ class ProductDistributionTest(test.TestCase):
self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.)
self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.)
def _testMnistLike(self, static_shape):
sample_shape = [4, 5]
batch_shape = [10]
image_shape = [28, 28, 1]
logits = 3 * self._rng.random_sample(
batch_shape + image_shape).astype(np.float32) - 1
def expected_log_prob(x, logits):
return (x * logits - np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1)
with self.test_session() as sess:
logits_ph = array_ops.placeholder(
dtypes.float32, shape=logits.shape if static_shape else None)
ind = independent_lib.Independent(
distribution=bernoulli_lib.Bernoulli(logits=logits_ph))
x = ind.sample(sample_shape)
log_prob_x = ind.log_prob(x)
[
x_,
actual_log_prob_x,
ind_batch_shape,
ind_event_shape,
x_shape,
log_prob_x_shape,
] = sess.run([
x,
log_prob_x,
ind.batch_shape_tensor(),
ind.event_shape_tensor(),
array_ops.shape(x),
array_ops.shape(log_prob_x),
], feed_dict={logits_ph: logits})
if static_shape:
ind_batch_shape = ind.batch_shape
ind_event_shape = ind.event_shape
x_shape = x.shape
log_prob_x_shape = log_prob_x.shape
self.assertAllEqual(batch_shape, ind_batch_shape)
self.assertAllEqual(image_shape, ind_event_shape)
self.assertAllEqual(sample_shape + batch_shape + image_shape, x_shape)
self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape)
self.assertAllClose(expected_log_prob(x_, logits),
actual_log_prob_x,
rtol=1e-6, atol=0.)
def testMnistLikeStaticShape(self):
self._testMnistLike(static_shape=True)
def testMnistLikeDynamicShape(self):
self._testMnistLike(static_shape=False)
if __name__ == "__main__":
test.main()

View File

@ -45,24 +45,24 @@ class Independent(distribution_lib.Distribution):
`p(x_1, ..., x_B) = p_1(x_1) * ... * p_B(x_B)` where `p_b(X_b)` is the
probability of the `b`-th rv. More generally `B, E` can be arbitrary shapes.
Similarly, the `Independent` distribution specifies a distribution over
`[B, E]`-shaped events. It operates by reinterpreting the rightmost batch dims
as part of the event dimensions. The `reduce_batch_ndims` parameter controls
the number of batch dims which are absorbed as event dims;
`reduce_batch_ndims < len(batch_shape)`. For example, the `log_prob` function
entails a `reduce_sum` over the rightmost `reduce_batch_ndims` after calling
the base distribution's `log_prob`. In other words, since the batch
dimension(s) index independent distributions, the resultant multivariate will
have independent components.
Similarly, the `Independent` distribution specifies a distribution over `[B,
E]`-shaped events. It operates by reinterpreting the rightmost batch dims as
part of the event dimensions. The `reinterpreted_batch_ndims` parameter
controls the number of batch dims which are absorbed as event dims;
`reinterpreted_batch_ndims < len(batch_shape)`. For example, the `log_prob`
function entails a `reduce_sum` over the rightmost `reinterpreted_batch_ndims`
after calling the base distribution's `log_prob`. In other words, since the
batch dimension(s) index independent distributions, the resultant multivariate
will have independent components.
#### Mathematical Details
The probability function is,
```none
prob(x; reduce_batch_ndims) = tf.reduce_prod(
prob(x; reinterpreted_batch_ndims) = tf.reduce_prod(
dist.prob(x),
axis=-1-range(reduce_batch_ndims))
axis=-1-range(reinterpreted_batch_ndims))
```
#### Examples
@ -73,7 +73,7 @@ class Independent(distribution_lib.Distribution):
# Make independent distribution from a 2-batch Normal.
ind = ds.Independent(
distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]),
reduce_batch_ndims=1)
reinterpreted_batch_ndims=1)
# All batch dims have been "absorbed" into event dims.
ind.batch_shape # ==> []
@ -84,7 +84,7 @@ class Independent(distribution_lib.Distribution):
distribution=ds.MultivariateNormalDiag(
loc=[[-1., 1], [1, -1]],
scale_identity_multiplier=[1., 0.5]),
reduce_batch_ndims=1)
reinterpreted_batch_ndims=1)
# All batch dims have been "absorbed" into event dims.
ind.batch_shape # ==> []
@ -94,14 +94,17 @@ class Independent(distribution_lib.Distribution):
"""
def __init__(
self, distribution, reduce_batch_ndims=1, validate_args=False, name=None):
self, distribution, reinterpreted_batch_ndims=None,
validate_args=False, name=None):
"""Construct a `Independent` distribution.
Args:
distribution: The base distribution instance to transform. Typically an
instance of `Distribution`.
reduce_batch_ndims: Scalar, integer number of rightmost batch dims which
will be regard as event dims.
reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims
which will be regarded as event dims. When `None` all but the first
batch axis (batch axis 0) will be transferred to event dimensions
(analogous to `tf.layers.flatten`).
validate_args: Python `bool`. Whether to validate input with asserts.
If `validate_args` is `False`, and the inputs are invalid,
correct behavior is not guaranteed.
@ -109,19 +112,25 @@ class Independent(distribution_lib.Distribution):
Default value: `Independent + distribution.name`.
Raises:
ValueError: if `reduce_batch_ndims` exceeds `distribution.batch_ndims`
ValueError: if `reinterpreted_batch_ndims` exceeds
`distribution.batch_ndims`
"""
parameters = locals()
name = name or "Independent" + distribution.name
self._distribution = distribution
with ops.name_scope(name):
reduce_batch_ndims = ops.convert_to_tensor(
reduce_batch_ndims, dtype=dtypes.int32, name="reduce_batch_ndims")
self._reduce_batch_ndims = reduce_batch_ndims
self._static_reduce_batch_ndims = tensor_util.constant_value(
reduce_batch_ndims)
if self._static_reduce_batch_ndims is not None:
self._reduce_batch_ndims = self._static_reduce_batch_ndims
if reinterpreted_batch_ndims is None:
reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims(
distribution)
reinterpreted_batch_ndims = ops.convert_to_tensor(
reinterpreted_batch_ndims,
dtype=dtypes.int32,
name="reinterpreted_batch_ndims")
self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
self._static_reinterpreted_batch_ndims = tensor_util.constant_value(
reinterpreted_batch_ndims)
if self._static_reinterpreted_batch_ndims is not None:
self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims
super(Independent, self).__init__(
dtype=self._distribution.dtype,
reparameterization_type=self._distribution.reparameterization_type,
@ -129,19 +138,19 @@ class Independent(distribution_lib.Distribution):
allow_nan_stats=self._distribution.allow_nan_stats,
parameters=parameters,
graph_parents=(
[reduce_batch_ndims] +
[reinterpreted_batch_ndims] +
distribution._graph_parents), # pylint: disable=protected-access
name=name)
self._runtime_assertions = self._make_runtime_assertions(
distribution, reduce_batch_ndims, validate_args)
distribution, reinterpreted_batch_ndims, validate_args)
@property
def distribution(self):
return self._distribution
@property
def reduce_batch_ndims(self):
return self._reduce_batch_ndims
def reinterpreted_batch_ndims(self):
return self._reinterpreted_batch_ndims
def _batch_shape_tensor(self):
with ops.control_dependencies(self._runtime_assertions):
@ -149,13 +158,14 @@ class Independent(distribution_lib.Distribution):
batch_ndims = (batch_shape.shape[0].value
if batch_shape.shape.with_rank_at_least(1)[0].value
else array_ops.shape(batch_shape)[0])
return batch_shape[:batch_ndims - self.reduce_batch_ndims]
return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims]
def _batch_shape(self):
batch_shape = self.distribution.batch_shape
if self._static_reduce_batch_ndims is None or batch_shape.ndims is None:
if (self._static_reinterpreted_batch_ndims is None
or batch_shape.ndims is None):
return tensor_shape.TensorShape(None)
d = batch_shape.ndims - self._static_reduce_batch_ndims
d = batch_shape.ndims - self._static_reinterpreted_batch_ndims
return batch_shape[:d]
def _event_shape_tensor(self):
@ -165,15 +175,16 @@ class Independent(distribution_lib.Distribution):
if batch_shape.shape.with_rank_at_least(1)[0].value
else array_ops.shape(batch_shape)[0])
return array_ops.concat([
batch_shape[batch_ndims - self.reduce_batch_ndims:],
batch_shape[batch_ndims - self.reinterpreted_batch_ndims:],
self.distribution.event_shape_tensor(),
], axis=0)
def _event_shape(self):
batch_shape = self.distribution.batch_shape
if self._static_reduce_batch_ndims is None or batch_shape.ndims is None:
if (self._static_reinterpreted_batch_ndims is None
or batch_shape.ndims is None):
return tensor_shape.TensorShape(None)
d = batch_shape.ndims - self._static_reduce_batch_ndims
d = batch_shape.ndims - self._static_reinterpreted_batch_ndims
return batch_shape[d:].concatenate(self.distribution.event_shape)
def _sample_n(self, n, seed):
@ -205,15 +216,16 @@ class Independent(distribution_lib.Distribution):
return self.distribution.mode()
def _make_runtime_assertions(
self, distribution, reduce_batch_ndims, validate_args):
self, distribution, reinterpreted_batch_ndims, validate_args):
assertions = []
static_reduce_batch_ndims = tensor_util.constant_value(reduce_batch_ndims)
static_reinterpreted_batch_ndims = tensor_util.constant_value(
reinterpreted_batch_ndims)
batch_ndims = distribution.batch_shape.ndims
if batch_ndims is not None and static_reduce_batch_ndims is not None:
if static_reduce_batch_ndims > batch_ndims:
raise ValueError("reduce_batch_ndims({}) cannot exceed "
if batch_ndims is not None and static_reinterpreted_batch_ndims is not None:
if static_reinterpreted_batch_ndims > batch_ndims:
raise ValueError("reinterpreted_batch_ndims({}) cannot exceed "
"distribution.batch_ndims({})".format(
static_reduce_batch_ndims, batch_ndims))
static_reinterpreted_batch_ndims, batch_ndims))
elif validate_args:
batch_shape = distribution.batch_shape_tensor()
batch_ndims = (
@ -221,13 +233,24 @@ class Independent(distribution_lib.Distribution):
if batch_shape.shape.with_rank_at_least(1)[0].value is not None
else array_ops.shape(batch_shape)[0])
assertions.append(check_ops.assert_less_equal(
reduce_batch_ndims, batch_ndims,
message="reduce_batch_ndims cannot exceed distribution.batch_ndims"))
reinterpreted_batch_ndims, batch_ndims,
message=("reinterpreted_batch_ndims cannot exceed "
"distribution.batch_ndims")))
return assertions
def _reduce_sum(self, stat):
if self._static_reduce_batch_ndims is None:
range_ = array_ops.range(self._reduce_batch_ndims)
if self._static_reinterpreted_batch_ndims is None:
range_ = math_ops.range(self._reinterpreted_batch_ndims)
else:
range_ = np.arange(self._static_reduce_batch_ndims)
range_ = np.arange(self._static_reinterpreted_batch_ndims)
return math_ops.reduce_sum(stat, axis=-1-range_)
def _get_default_reinterpreted_batch_ndims(self, distribution):
"""Computes the default value for reinterpreted_batch_ndim __init__ arg."""
ndims = distribution.batch_shape.ndims
if ndims is None:
which_maximum = math_ops.maximum
ndims = array_ops.shape(distribution.batch_shape_tensor())[0]
else:
which_maximum = np.maximum
return which_maximum(0, ndims - 1)