Breaking change: Rename tf.contrib.distributions.Independent
parameter from
`reduce_batch_ndims` to `reinterpreted_batch_ndims`. Also change default; `reinterpreted_batch_ndims` default has semantics of `tf.layers.flatten`, i.e., all batch dimensions except the first (batch axis 0) are interpretted as being part of the event. PiperOrigin-RevId: 173729585
This commit is contained in:
parent
5426a3c93d
commit
80374a7b47
@ -23,7 +23,10 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
|
||||
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib
|
||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
@ -42,13 +45,16 @@ stats = try_import("scipy.stats")
|
||||
|
||||
class ProductDistributionTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._rng = np.random.RandomState(42)
|
||||
|
||||
def testSampleAndLogProbUnivariate(self):
|
||||
loc = np.float32([-1., 1])
|
||||
scale = np.float32([0.1, 0.5])
|
||||
with self.test_session() as sess:
|
||||
ind = independent_lib.Independent(
|
||||
distribution=normal_lib.Normal(loc=loc, scale=scale),
|
||||
reduce_batch_ndims=1)
|
||||
reinterpreted_batch_ndims=1)
|
||||
|
||||
x = ind.sample([4, 5])
|
||||
log_prob_x = ind.log_prob(x)
|
||||
@ -71,7 +77,7 @@ class ProductDistributionTest(test.TestCase):
|
||||
distribution=mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=loc,
|
||||
scale_identity_multiplier=scale),
|
||||
reduce_batch_ndims=1)
|
||||
reinterpreted_batch_ndims=1)
|
||||
|
||||
x = ind.sample([4, 5])
|
||||
log_prob_x = ind.log_prob(x)
|
||||
@ -96,7 +102,7 @@ class ProductDistributionTest(test.TestCase):
|
||||
distribution=mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=loc,
|
||||
scale_identity_multiplier=scale),
|
||||
reduce_batch_ndims=1)
|
||||
reinterpreted_batch_ndims=1)
|
||||
|
||||
x = ind.sample(int(n_samp), seed=42)
|
||||
sample_mean = math_ops.reduce_mean(x, axis=0)
|
||||
@ -120,6 +126,59 @@ class ProductDistributionTest(test.TestCase):
|
||||
self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.)
|
||||
self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.)
|
||||
|
||||
def _testMnistLike(self, static_shape):
|
||||
sample_shape = [4, 5]
|
||||
batch_shape = [10]
|
||||
image_shape = [28, 28, 1]
|
||||
logits = 3 * self._rng.random_sample(
|
||||
batch_shape + image_shape).astype(np.float32) - 1
|
||||
|
||||
def expected_log_prob(x, logits):
|
||||
return (x * logits - np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1)
|
||||
|
||||
with self.test_session() as sess:
|
||||
logits_ph = array_ops.placeholder(
|
||||
dtypes.float32, shape=logits.shape if static_shape else None)
|
||||
ind = independent_lib.Independent(
|
||||
distribution=bernoulli_lib.Bernoulli(logits=logits_ph))
|
||||
x = ind.sample(sample_shape)
|
||||
log_prob_x = ind.log_prob(x)
|
||||
[
|
||||
x_,
|
||||
actual_log_prob_x,
|
||||
ind_batch_shape,
|
||||
ind_event_shape,
|
||||
x_shape,
|
||||
log_prob_x_shape,
|
||||
] = sess.run([
|
||||
x,
|
||||
log_prob_x,
|
||||
ind.batch_shape_tensor(),
|
||||
ind.event_shape_tensor(),
|
||||
array_ops.shape(x),
|
||||
array_ops.shape(log_prob_x),
|
||||
], feed_dict={logits_ph: logits})
|
||||
|
||||
if static_shape:
|
||||
ind_batch_shape = ind.batch_shape
|
||||
ind_event_shape = ind.event_shape
|
||||
x_shape = x.shape
|
||||
log_prob_x_shape = log_prob_x.shape
|
||||
|
||||
self.assertAllEqual(batch_shape, ind_batch_shape)
|
||||
self.assertAllEqual(image_shape, ind_event_shape)
|
||||
self.assertAllEqual(sample_shape + batch_shape + image_shape, x_shape)
|
||||
self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape)
|
||||
self.assertAllClose(expected_log_prob(x_, logits),
|
||||
actual_log_prob_x,
|
||||
rtol=1e-6, atol=0.)
|
||||
|
||||
def testMnistLikeStaticShape(self):
|
||||
self._testMnistLike(static_shape=True)
|
||||
|
||||
def testMnistLikeDynamicShape(self):
|
||||
self._testMnistLike(static_shape=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -45,24 +45,24 @@ class Independent(distribution_lib.Distribution):
|
||||
`p(x_1, ..., x_B) = p_1(x_1) * ... * p_B(x_B)` where `p_b(X_b)` is the
|
||||
probability of the `b`-th rv. More generally `B, E` can be arbitrary shapes.
|
||||
|
||||
Similarly, the `Independent` distribution specifies a distribution over
|
||||
`[B, E]`-shaped events. It operates by reinterpreting the rightmost batch dims
|
||||
as part of the event dimensions. The `reduce_batch_ndims` parameter controls
|
||||
the number of batch dims which are absorbed as event dims;
|
||||
`reduce_batch_ndims < len(batch_shape)`. For example, the `log_prob` function
|
||||
entails a `reduce_sum` over the rightmost `reduce_batch_ndims` after calling
|
||||
the base distribution's `log_prob`. In other words, since the batch
|
||||
dimension(s) index independent distributions, the resultant multivariate will
|
||||
have independent components.
|
||||
Similarly, the `Independent` distribution specifies a distribution over `[B,
|
||||
E]`-shaped events. It operates by reinterpreting the rightmost batch dims as
|
||||
part of the event dimensions. The `reinterpreted_batch_ndims` parameter
|
||||
controls the number of batch dims which are absorbed as event dims;
|
||||
`reinterpreted_batch_ndims < len(batch_shape)`. For example, the `log_prob`
|
||||
function entails a `reduce_sum` over the rightmost `reinterpreted_batch_ndims`
|
||||
after calling the base distribution's `log_prob`. In other words, since the
|
||||
batch dimension(s) index independent distributions, the resultant multivariate
|
||||
will have independent components.
|
||||
|
||||
#### Mathematical Details
|
||||
|
||||
The probability function is,
|
||||
|
||||
```none
|
||||
prob(x; reduce_batch_ndims) = tf.reduce_prod(
|
||||
prob(x; reinterpreted_batch_ndims) = tf.reduce_prod(
|
||||
dist.prob(x),
|
||||
axis=-1-range(reduce_batch_ndims))
|
||||
axis=-1-range(reinterpreted_batch_ndims))
|
||||
```
|
||||
|
||||
#### Examples
|
||||
@ -73,7 +73,7 @@ class Independent(distribution_lib.Distribution):
|
||||
# Make independent distribution from a 2-batch Normal.
|
||||
ind = ds.Independent(
|
||||
distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]),
|
||||
reduce_batch_ndims=1)
|
||||
reinterpreted_batch_ndims=1)
|
||||
|
||||
# All batch dims have been "absorbed" into event dims.
|
||||
ind.batch_shape # ==> []
|
||||
@ -84,7 +84,7 @@ class Independent(distribution_lib.Distribution):
|
||||
distribution=ds.MultivariateNormalDiag(
|
||||
loc=[[-1., 1], [1, -1]],
|
||||
scale_identity_multiplier=[1., 0.5]),
|
||||
reduce_batch_ndims=1)
|
||||
reinterpreted_batch_ndims=1)
|
||||
|
||||
# All batch dims have been "absorbed" into event dims.
|
||||
ind.batch_shape # ==> []
|
||||
@ -94,14 +94,17 @@ class Independent(distribution_lib.Distribution):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, distribution, reduce_batch_ndims=1, validate_args=False, name=None):
|
||||
self, distribution, reinterpreted_batch_ndims=None,
|
||||
validate_args=False, name=None):
|
||||
"""Construct a `Independent` distribution.
|
||||
|
||||
Args:
|
||||
distribution: The base distribution instance to transform. Typically an
|
||||
instance of `Distribution`.
|
||||
reduce_batch_ndims: Scalar, integer number of rightmost batch dims which
|
||||
will be regard as event dims.
|
||||
reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims
|
||||
which will be regarded as event dims. When `None` all but the first
|
||||
batch axis (batch axis 0) will be transferred to event dimensions
|
||||
(analogous to `tf.layers.flatten`).
|
||||
validate_args: Python `bool`. Whether to validate input with asserts.
|
||||
If `validate_args` is `False`, and the inputs are invalid,
|
||||
correct behavior is not guaranteed.
|
||||
@ -109,19 +112,25 @@ class Independent(distribution_lib.Distribution):
|
||||
Default value: `Independent + distribution.name`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `reduce_batch_ndims` exceeds `distribution.batch_ndims`
|
||||
ValueError: if `reinterpreted_batch_ndims` exceeds
|
||||
`distribution.batch_ndims`
|
||||
"""
|
||||
parameters = locals()
|
||||
name = name or "Independent" + distribution.name
|
||||
self._distribution = distribution
|
||||
with ops.name_scope(name):
|
||||
reduce_batch_ndims = ops.convert_to_tensor(
|
||||
reduce_batch_ndims, dtype=dtypes.int32, name="reduce_batch_ndims")
|
||||
self._reduce_batch_ndims = reduce_batch_ndims
|
||||
self._static_reduce_batch_ndims = tensor_util.constant_value(
|
||||
reduce_batch_ndims)
|
||||
if self._static_reduce_batch_ndims is not None:
|
||||
self._reduce_batch_ndims = self._static_reduce_batch_ndims
|
||||
if reinterpreted_batch_ndims is None:
|
||||
reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims(
|
||||
distribution)
|
||||
reinterpreted_batch_ndims = ops.convert_to_tensor(
|
||||
reinterpreted_batch_ndims,
|
||||
dtype=dtypes.int32,
|
||||
name="reinterpreted_batch_ndims")
|
||||
self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
|
||||
self._static_reinterpreted_batch_ndims = tensor_util.constant_value(
|
||||
reinterpreted_batch_ndims)
|
||||
if self._static_reinterpreted_batch_ndims is not None:
|
||||
self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims
|
||||
super(Independent, self).__init__(
|
||||
dtype=self._distribution.dtype,
|
||||
reparameterization_type=self._distribution.reparameterization_type,
|
||||
@ -129,19 +138,19 @@ class Independent(distribution_lib.Distribution):
|
||||
allow_nan_stats=self._distribution.allow_nan_stats,
|
||||
parameters=parameters,
|
||||
graph_parents=(
|
||||
[reduce_batch_ndims] +
|
||||
[reinterpreted_batch_ndims] +
|
||||
distribution._graph_parents), # pylint: disable=protected-access
|
||||
name=name)
|
||||
self._runtime_assertions = self._make_runtime_assertions(
|
||||
distribution, reduce_batch_ndims, validate_args)
|
||||
distribution, reinterpreted_batch_ndims, validate_args)
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
return self._distribution
|
||||
|
||||
@property
|
||||
def reduce_batch_ndims(self):
|
||||
return self._reduce_batch_ndims
|
||||
def reinterpreted_batch_ndims(self):
|
||||
return self._reinterpreted_batch_ndims
|
||||
|
||||
def _batch_shape_tensor(self):
|
||||
with ops.control_dependencies(self._runtime_assertions):
|
||||
@ -149,13 +158,14 @@ class Independent(distribution_lib.Distribution):
|
||||
batch_ndims = (batch_shape.shape[0].value
|
||||
if batch_shape.shape.with_rank_at_least(1)[0].value
|
||||
else array_ops.shape(batch_shape)[0])
|
||||
return batch_shape[:batch_ndims - self.reduce_batch_ndims]
|
||||
return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims]
|
||||
|
||||
def _batch_shape(self):
|
||||
batch_shape = self.distribution.batch_shape
|
||||
if self._static_reduce_batch_ndims is None or batch_shape.ndims is None:
|
||||
if (self._static_reinterpreted_batch_ndims is None
|
||||
or batch_shape.ndims is None):
|
||||
return tensor_shape.TensorShape(None)
|
||||
d = batch_shape.ndims - self._static_reduce_batch_ndims
|
||||
d = batch_shape.ndims - self._static_reinterpreted_batch_ndims
|
||||
return batch_shape[:d]
|
||||
|
||||
def _event_shape_tensor(self):
|
||||
@ -165,15 +175,16 @@ class Independent(distribution_lib.Distribution):
|
||||
if batch_shape.shape.with_rank_at_least(1)[0].value
|
||||
else array_ops.shape(batch_shape)[0])
|
||||
return array_ops.concat([
|
||||
batch_shape[batch_ndims - self.reduce_batch_ndims:],
|
||||
batch_shape[batch_ndims - self.reinterpreted_batch_ndims:],
|
||||
self.distribution.event_shape_tensor(),
|
||||
], axis=0)
|
||||
|
||||
def _event_shape(self):
|
||||
batch_shape = self.distribution.batch_shape
|
||||
if self._static_reduce_batch_ndims is None or batch_shape.ndims is None:
|
||||
if (self._static_reinterpreted_batch_ndims is None
|
||||
or batch_shape.ndims is None):
|
||||
return tensor_shape.TensorShape(None)
|
||||
d = batch_shape.ndims - self._static_reduce_batch_ndims
|
||||
d = batch_shape.ndims - self._static_reinterpreted_batch_ndims
|
||||
return batch_shape[d:].concatenate(self.distribution.event_shape)
|
||||
|
||||
def _sample_n(self, n, seed):
|
||||
@ -205,15 +216,16 @@ class Independent(distribution_lib.Distribution):
|
||||
return self.distribution.mode()
|
||||
|
||||
def _make_runtime_assertions(
|
||||
self, distribution, reduce_batch_ndims, validate_args):
|
||||
self, distribution, reinterpreted_batch_ndims, validate_args):
|
||||
assertions = []
|
||||
static_reduce_batch_ndims = tensor_util.constant_value(reduce_batch_ndims)
|
||||
static_reinterpreted_batch_ndims = tensor_util.constant_value(
|
||||
reinterpreted_batch_ndims)
|
||||
batch_ndims = distribution.batch_shape.ndims
|
||||
if batch_ndims is not None and static_reduce_batch_ndims is not None:
|
||||
if static_reduce_batch_ndims > batch_ndims:
|
||||
raise ValueError("reduce_batch_ndims({}) cannot exceed "
|
||||
if batch_ndims is not None and static_reinterpreted_batch_ndims is not None:
|
||||
if static_reinterpreted_batch_ndims > batch_ndims:
|
||||
raise ValueError("reinterpreted_batch_ndims({}) cannot exceed "
|
||||
"distribution.batch_ndims({})".format(
|
||||
static_reduce_batch_ndims, batch_ndims))
|
||||
static_reinterpreted_batch_ndims, batch_ndims))
|
||||
elif validate_args:
|
||||
batch_shape = distribution.batch_shape_tensor()
|
||||
batch_ndims = (
|
||||
@ -221,13 +233,24 @@ class Independent(distribution_lib.Distribution):
|
||||
if batch_shape.shape.with_rank_at_least(1)[0].value is not None
|
||||
else array_ops.shape(batch_shape)[0])
|
||||
assertions.append(check_ops.assert_less_equal(
|
||||
reduce_batch_ndims, batch_ndims,
|
||||
message="reduce_batch_ndims cannot exceed distribution.batch_ndims"))
|
||||
reinterpreted_batch_ndims, batch_ndims,
|
||||
message=("reinterpreted_batch_ndims cannot exceed "
|
||||
"distribution.batch_ndims")))
|
||||
return assertions
|
||||
|
||||
def _reduce_sum(self, stat):
|
||||
if self._static_reduce_batch_ndims is None:
|
||||
range_ = array_ops.range(self._reduce_batch_ndims)
|
||||
if self._static_reinterpreted_batch_ndims is None:
|
||||
range_ = math_ops.range(self._reinterpreted_batch_ndims)
|
||||
else:
|
||||
range_ = np.arange(self._static_reduce_batch_ndims)
|
||||
range_ = np.arange(self._static_reinterpreted_batch_ndims)
|
||||
return math_ops.reduce_sum(stat, axis=-1-range_)
|
||||
|
||||
def _get_default_reinterpreted_batch_ndims(self, distribution):
|
||||
"""Computes the default value for reinterpreted_batch_ndim __init__ arg."""
|
||||
ndims = distribution.batch_shape.ndims
|
||||
if ndims is None:
|
||||
which_maximum = math_ops.maximum
|
||||
ndims = array_ops.shape(distribution.batch_shape_tensor())[0]
|
||||
else:
|
||||
which_maximum = np.maximum
|
||||
return which_maximum(0, ndims - 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user