Breaking change: Rename tf.contrib.distributions.Independent
parameter from
`reduce_batch_ndims` to `reinterpreted_batch_ndims`. Also change default; `reinterpreted_batch_ndims` default has semantics of `tf.layers.flatten`, i.e., all batch dimensions except the first (batch axis 0) are interpretted as being part of the event. PiperOrigin-RevId: 173729585
This commit is contained in:
parent
5426a3c93d
commit
80374a7b47
@ -23,7 +23,10 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
|
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
|
||||||
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
|
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib
|
||||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
@ -42,13 +45,16 @@ stats = try_import("scipy.stats")
|
|||||||
|
|
||||||
class ProductDistributionTest(test.TestCase):
|
class ProductDistributionTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._rng = np.random.RandomState(42)
|
||||||
|
|
||||||
def testSampleAndLogProbUnivariate(self):
|
def testSampleAndLogProbUnivariate(self):
|
||||||
loc = np.float32([-1., 1])
|
loc = np.float32([-1., 1])
|
||||||
scale = np.float32([0.1, 0.5])
|
scale = np.float32([0.1, 0.5])
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
ind = independent_lib.Independent(
|
ind = independent_lib.Independent(
|
||||||
distribution=normal_lib.Normal(loc=loc, scale=scale),
|
distribution=normal_lib.Normal(loc=loc, scale=scale),
|
||||||
reduce_batch_ndims=1)
|
reinterpreted_batch_ndims=1)
|
||||||
|
|
||||||
x = ind.sample([4, 5])
|
x = ind.sample([4, 5])
|
||||||
log_prob_x = ind.log_prob(x)
|
log_prob_x = ind.log_prob(x)
|
||||||
@ -71,7 +77,7 @@ class ProductDistributionTest(test.TestCase):
|
|||||||
distribution=mvn_diag_lib.MultivariateNormalDiag(
|
distribution=mvn_diag_lib.MultivariateNormalDiag(
|
||||||
loc=loc,
|
loc=loc,
|
||||||
scale_identity_multiplier=scale),
|
scale_identity_multiplier=scale),
|
||||||
reduce_batch_ndims=1)
|
reinterpreted_batch_ndims=1)
|
||||||
|
|
||||||
x = ind.sample([4, 5])
|
x = ind.sample([4, 5])
|
||||||
log_prob_x = ind.log_prob(x)
|
log_prob_x = ind.log_prob(x)
|
||||||
@ -96,7 +102,7 @@ class ProductDistributionTest(test.TestCase):
|
|||||||
distribution=mvn_diag_lib.MultivariateNormalDiag(
|
distribution=mvn_diag_lib.MultivariateNormalDiag(
|
||||||
loc=loc,
|
loc=loc,
|
||||||
scale_identity_multiplier=scale),
|
scale_identity_multiplier=scale),
|
||||||
reduce_batch_ndims=1)
|
reinterpreted_batch_ndims=1)
|
||||||
|
|
||||||
x = ind.sample(int(n_samp), seed=42)
|
x = ind.sample(int(n_samp), seed=42)
|
||||||
sample_mean = math_ops.reduce_mean(x, axis=0)
|
sample_mean = math_ops.reduce_mean(x, axis=0)
|
||||||
@ -120,6 +126,59 @@ class ProductDistributionTest(test.TestCase):
|
|||||||
self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.)
|
self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.)
|
||||||
self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.)
|
self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.)
|
||||||
|
|
||||||
|
def _testMnistLike(self, static_shape):
|
||||||
|
sample_shape = [4, 5]
|
||||||
|
batch_shape = [10]
|
||||||
|
image_shape = [28, 28, 1]
|
||||||
|
logits = 3 * self._rng.random_sample(
|
||||||
|
batch_shape + image_shape).astype(np.float32) - 1
|
||||||
|
|
||||||
|
def expected_log_prob(x, logits):
|
||||||
|
return (x * logits - np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
logits_ph = array_ops.placeholder(
|
||||||
|
dtypes.float32, shape=logits.shape if static_shape else None)
|
||||||
|
ind = independent_lib.Independent(
|
||||||
|
distribution=bernoulli_lib.Bernoulli(logits=logits_ph))
|
||||||
|
x = ind.sample(sample_shape)
|
||||||
|
log_prob_x = ind.log_prob(x)
|
||||||
|
[
|
||||||
|
x_,
|
||||||
|
actual_log_prob_x,
|
||||||
|
ind_batch_shape,
|
||||||
|
ind_event_shape,
|
||||||
|
x_shape,
|
||||||
|
log_prob_x_shape,
|
||||||
|
] = sess.run([
|
||||||
|
x,
|
||||||
|
log_prob_x,
|
||||||
|
ind.batch_shape_tensor(),
|
||||||
|
ind.event_shape_tensor(),
|
||||||
|
array_ops.shape(x),
|
||||||
|
array_ops.shape(log_prob_x),
|
||||||
|
], feed_dict={logits_ph: logits})
|
||||||
|
|
||||||
|
if static_shape:
|
||||||
|
ind_batch_shape = ind.batch_shape
|
||||||
|
ind_event_shape = ind.event_shape
|
||||||
|
x_shape = x.shape
|
||||||
|
log_prob_x_shape = log_prob_x.shape
|
||||||
|
|
||||||
|
self.assertAllEqual(batch_shape, ind_batch_shape)
|
||||||
|
self.assertAllEqual(image_shape, ind_event_shape)
|
||||||
|
self.assertAllEqual(sample_shape + batch_shape + image_shape, x_shape)
|
||||||
|
self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape)
|
||||||
|
self.assertAllClose(expected_log_prob(x_, logits),
|
||||||
|
actual_log_prob_x,
|
||||||
|
rtol=1e-6, atol=0.)
|
||||||
|
|
||||||
|
def testMnistLikeStaticShape(self):
|
||||||
|
self._testMnistLike(static_shape=True)
|
||||||
|
|
||||||
|
def testMnistLikeDynamicShape(self):
|
||||||
|
self._testMnistLike(static_shape=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -45,24 +45,24 @@ class Independent(distribution_lib.Distribution):
|
|||||||
`p(x_1, ..., x_B) = p_1(x_1) * ... * p_B(x_B)` where `p_b(X_b)` is the
|
`p(x_1, ..., x_B) = p_1(x_1) * ... * p_B(x_B)` where `p_b(X_b)` is the
|
||||||
probability of the `b`-th rv. More generally `B, E` can be arbitrary shapes.
|
probability of the `b`-th rv. More generally `B, E` can be arbitrary shapes.
|
||||||
|
|
||||||
Similarly, the `Independent` distribution specifies a distribution over
|
Similarly, the `Independent` distribution specifies a distribution over `[B,
|
||||||
`[B, E]`-shaped events. It operates by reinterpreting the rightmost batch dims
|
E]`-shaped events. It operates by reinterpreting the rightmost batch dims as
|
||||||
as part of the event dimensions. The `reduce_batch_ndims` parameter controls
|
part of the event dimensions. The `reinterpreted_batch_ndims` parameter
|
||||||
the number of batch dims which are absorbed as event dims;
|
controls the number of batch dims which are absorbed as event dims;
|
||||||
`reduce_batch_ndims < len(batch_shape)`. For example, the `log_prob` function
|
`reinterpreted_batch_ndims < len(batch_shape)`. For example, the `log_prob`
|
||||||
entails a `reduce_sum` over the rightmost `reduce_batch_ndims` after calling
|
function entails a `reduce_sum` over the rightmost `reinterpreted_batch_ndims`
|
||||||
the base distribution's `log_prob`. In other words, since the batch
|
after calling the base distribution's `log_prob`. In other words, since the
|
||||||
dimension(s) index independent distributions, the resultant multivariate will
|
batch dimension(s) index independent distributions, the resultant multivariate
|
||||||
have independent components.
|
will have independent components.
|
||||||
|
|
||||||
#### Mathematical Details
|
#### Mathematical Details
|
||||||
|
|
||||||
The probability function is,
|
The probability function is,
|
||||||
|
|
||||||
```none
|
```none
|
||||||
prob(x; reduce_batch_ndims) = tf.reduce_prod(
|
prob(x; reinterpreted_batch_ndims) = tf.reduce_prod(
|
||||||
dist.prob(x),
|
dist.prob(x),
|
||||||
axis=-1-range(reduce_batch_ndims))
|
axis=-1-range(reinterpreted_batch_ndims))
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Examples
|
#### Examples
|
||||||
@ -73,7 +73,7 @@ class Independent(distribution_lib.Distribution):
|
|||||||
# Make independent distribution from a 2-batch Normal.
|
# Make independent distribution from a 2-batch Normal.
|
||||||
ind = ds.Independent(
|
ind = ds.Independent(
|
||||||
distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]),
|
distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]),
|
||||||
reduce_batch_ndims=1)
|
reinterpreted_batch_ndims=1)
|
||||||
|
|
||||||
# All batch dims have been "absorbed" into event dims.
|
# All batch dims have been "absorbed" into event dims.
|
||||||
ind.batch_shape # ==> []
|
ind.batch_shape # ==> []
|
||||||
@ -84,7 +84,7 @@ class Independent(distribution_lib.Distribution):
|
|||||||
distribution=ds.MultivariateNormalDiag(
|
distribution=ds.MultivariateNormalDiag(
|
||||||
loc=[[-1., 1], [1, -1]],
|
loc=[[-1., 1], [1, -1]],
|
||||||
scale_identity_multiplier=[1., 0.5]),
|
scale_identity_multiplier=[1., 0.5]),
|
||||||
reduce_batch_ndims=1)
|
reinterpreted_batch_ndims=1)
|
||||||
|
|
||||||
# All batch dims have been "absorbed" into event dims.
|
# All batch dims have been "absorbed" into event dims.
|
||||||
ind.batch_shape # ==> []
|
ind.batch_shape # ==> []
|
||||||
@ -94,14 +94,17 @@ class Independent(distribution_lib.Distribution):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, distribution, reduce_batch_ndims=1, validate_args=False, name=None):
|
self, distribution, reinterpreted_batch_ndims=None,
|
||||||
|
validate_args=False, name=None):
|
||||||
"""Construct a `Independent` distribution.
|
"""Construct a `Independent` distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
distribution: The base distribution instance to transform. Typically an
|
distribution: The base distribution instance to transform. Typically an
|
||||||
instance of `Distribution`.
|
instance of `Distribution`.
|
||||||
reduce_batch_ndims: Scalar, integer number of rightmost batch dims which
|
reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims
|
||||||
will be regard as event dims.
|
which will be regarded as event dims. When `None` all but the first
|
||||||
|
batch axis (batch axis 0) will be transferred to event dimensions
|
||||||
|
(analogous to `tf.layers.flatten`).
|
||||||
validate_args: Python `bool`. Whether to validate input with asserts.
|
validate_args: Python `bool`. Whether to validate input with asserts.
|
||||||
If `validate_args` is `False`, and the inputs are invalid,
|
If `validate_args` is `False`, and the inputs are invalid,
|
||||||
correct behavior is not guaranteed.
|
correct behavior is not guaranteed.
|
||||||
@ -109,19 +112,25 @@ class Independent(distribution_lib.Distribution):
|
|||||||
Default value: `Independent + distribution.name`.
|
Default value: `Independent + distribution.name`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `reduce_batch_ndims` exceeds `distribution.batch_ndims`
|
ValueError: if `reinterpreted_batch_ndims` exceeds
|
||||||
|
`distribution.batch_ndims`
|
||||||
"""
|
"""
|
||||||
parameters = locals()
|
parameters = locals()
|
||||||
name = name or "Independent" + distribution.name
|
name = name or "Independent" + distribution.name
|
||||||
self._distribution = distribution
|
self._distribution = distribution
|
||||||
with ops.name_scope(name):
|
with ops.name_scope(name):
|
||||||
reduce_batch_ndims = ops.convert_to_tensor(
|
if reinterpreted_batch_ndims is None:
|
||||||
reduce_batch_ndims, dtype=dtypes.int32, name="reduce_batch_ndims")
|
reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims(
|
||||||
self._reduce_batch_ndims = reduce_batch_ndims
|
distribution)
|
||||||
self._static_reduce_batch_ndims = tensor_util.constant_value(
|
reinterpreted_batch_ndims = ops.convert_to_tensor(
|
||||||
reduce_batch_ndims)
|
reinterpreted_batch_ndims,
|
||||||
if self._static_reduce_batch_ndims is not None:
|
dtype=dtypes.int32,
|
||||||
self._reduce_batch_ndims = self._static_reduce_batch_ndims
|
name="reinterpreted_batch_ndims")
|
||||||
|
self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
|
||||||
|
self._static_reinterpreted_batch_ndims = tensor_util.constant_value(
|
||||||
|
reinterpreted_batch_ndims)
|
||||||
|
if self._static_reinterpreted_batch_ndims is not None:
|
||||||
|
self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims
|
||||||
super(Independent, self).__init__(
|
super(Independent, self).__init__(
|
||||||
dtype=self._distribution.dtype,
|
dtype=self._distribution.dtype,
|
||||||
reparameterization_type=self._distribution.reparameterization_type,
|
reparameterization_type=self._distribution.reparameterization_type,
|
||||||
@ -129,19 +138,19 @@ class Independent(distribution_lib.Distribution):
|
|||||||
allow_nan_stats=self._distribution.allow_nan_stats,
|
allow_nan_stats=self._distribution.allow_nan_stats,
|
||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
graph_parents=(
|
graph_parents=(
|
||||||
[reduce_batch_ndims] +
|
[reinterpreted_batch_ndims] +
|
||||||
distribution._graph_parents), # pylint: disable=protected-access
|
distribution._graph_parents), # pylint: disable=protected-access
|
||||||
name=name)
|
name=name)
|
||||||
self._runtime_assertions = self._make_runtime_assertions(
|
self._runtime_assertions = self._make_runtime_assertions(
|
||||||
distribution, reduce_batch_ndims, validate_args)
|
distribution, reinterpreted_batch_ndims, validate_args)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def distribution(self):
|
def distribution(self):
|
||||||
return self._distribution
|
return self._distribution
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reduce_batch_ndims(self):
|
def reinterpreted_batch_ndims(self):
|
||||||
return self._reduce_batch_ndims
|
return self._reinterpreted_batch_ndims
|
||||||
|
|
||||||
def _batch_shape_tensor(self):
|
def _batch_shape_tensor(self):
|
||||||
with ops.control_dependencies(self._runtime_assertions):
|
with ops.control_dependencies(self._runtime_assertions):
|
||||||
@ -149,13 +158,14 @@ class Independent(distribution_lib.Distribution):
|
|||||||
batch_ndims = (batch_shape.shape[0].value
|
batch_ndims = (batch_shape.shape[0].value
|
||||||
if batch_shape.shape.with_rank_at_least(1)[0].value
|
if batch_shape.shape.with_rank_at_least(1)[0].value
|
||||||
else array_ops.shape(batch_shape)[0])
|
else array_ops.shape(batch_shape)[0])
|
||||||
return batch_shape[:batch_ndims - self.reduce_batch_ndims]
|
return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims]
|
||||||
|
|
||||||
def _batch_shape(self):
|
def _batch_shape(self):
|
||||||
batch_shape = self.distribution.batch_shape
|
batch_shape = self.distribution.batch_shape
|
||||||
if self._static_reduce_batch_ndims is None or batch_shape.ndims is None:
|
if (self._static_reinterpreted_batch_ndims is None
|
||||||
|
or batch_shape.ndims is None):
|
||||||
return tensor_shape.TensorShape(None)
|
return tensor_shape.TensorShape(None)
|
||||||
d = batch_shape.ndims - self._static_reduce_batch_ndims
|
d = batch_shape.ndims - self._static_reinterpreted_batch_ndims
|
||||||
return batch_shape[:d]
|
return batch_shape[:d]
|
||||||
|
|
||||||
def _event_shape_tensor(self):
|
def _event_shape_tensor(self):
|
||||||
@ -165,15 +175,16 @@ class Independent(distribution_lib.Distribution):
|
|||||||
if batch_shape.shape.with_rank_at_least(1)[0].value
|
if batch_shape.shape.with_rank_at_least(1)[0].value
|
||||||
else array_ops.shape(batch_shape)[0])
|
else array_ops.shape(batch_shape)[0])
|
||||||
return array_ops.concat([
|
return array_ops.concat([
|
||||||
batch_shape[batch_ndims - self.reduce_batch_ndims:],
|
batch_shape[batch_ndims - self.reinterpreted_batch_ndims:],
|
||||||
self.distribution.event_shape_tensor(),
|
self.distribution.event_shape_tensor(),
|
||||||
], axis=0)
|
], axis=0)
|
||||||
|
|
||||||
def _event_shape(self):
|
def _event_shape(self):
|
||||||
batch_shape = self.distribution.batch_shape
|
batch_shape = self.distribution.batch_shape
|
||||||
if self._static_reduce_batch_ndims is None or batch_shape.ndims is None:
|
if (self._static_reinterpreted_batch_ndims is None
|
||||||
|
or batch_shape.ndims is None):
|
||||||
return tensor_shape.TensorShape(None)
|
return tensor_shape.TensorShape(None)
|
||||||
d = batch_shape.ndims - self._static_reduce_batch_ndims
|
d = batch_shape.ndims - self._static_reinterpreted_batch_ndims
|
||||||
return batch_shape[d:].concatenate(self.distribution.event_shape)
|
return batch_shape[d:].concatenate(self.distribution.event_shape)
|
||||||
|
|
||||||
def _sample_n(self, n, seed):
|
def _sample_n(self, n, seed):
|
||||||
@ -205,15 +216,16 @@ class Independent(distribution_lib.Distribution):
|
|||||||
return self.distribution.mode()
|
return self.distribution.mode()
|
||||||
|
|
||||||
def _make_runtime_assertions(
|
def _make_runtime_assertions(
|
||||||
self, distribution, reduce_batch_ndims, validate_args):
|
self, distribution, reinterpreted_batch_ndims, validate_args):
|
||||||
assertions = []
|
assertions = []
|
||||||
static_reduce_batch_ndims = tensor_util.constant_value(reduce_batch_ndims)
|
static_reinterpreted_batch_ndims = tensor_util.constant_value(
|
||||||
|
reinterpreted_batch_ndims)
|
||||||
batch_ndims = distribution.batch_shape.ndims
|
batch_ndims = distribution.batch_shape.ndims
|
||||||
if batch_ndims is not None and static_reduce_batch_ndims is not None:
|
if batch_ndims is not None and static_reinterpreted_batch_ndims is not None:
|
||||||
if static_reduce_batch_ndims > batch_ndims:
|
if static_reinterpreted_batch_ndims > batch_ndims:
|
||||||
raise ValueError("reduce_batch_ndims({}) cannot exceed "
|
raise ValueError("reinterpreted_batch_ndims({}) cannot exceed "
|
||||||
"distribution.batch_ndims({})".format(
|
"distribution.batch_ndims({})".format(
|
||||||
static_reduce_batch_ndims, batch_ndims))
|
static_reinterpreted_batch_ndims, batch_ndims))
|
||||||
elif validate_args:
|
elif validate_args:
|
||||||
batch_shape = distribution.batch_shape_tensor()
|
batch_shape = distribution.batch_shape_tensor()
|
||||||
batch_ndims = (
|
batch_ndims = (
|
||||||
@ -221,13 +233,24 @@ class Independent(distribution_lib.Distribution):
|
|||||||
if batch_shape.shape.with_rank_at_least(1)[0].value is not None
|
if batch_shape.shape.with_rank_at_least(1)[0].value is not None
|
||||||
else array_ops.shape(batch_shape)[0])
|
else array_ops.shape(batch_shape)[0])
|
||||||
assertions.append(check_ops.assert_less_equal(
|
assertions.append(check_ops.assert_less_equal(
|
||||||
reduce_batch_ndims, batch_ndims,
|
reinterpreted_batch_ndims, batch_ndims,
|
||||||
message="reduce_batch_ndims cannot exceed distribution.batch_ndims"))
|
message=("reinterpreted_batch_ndims cannot exceed "
|
||||||
|
"distribution.batch_ndims")))
|
||||||
return assertions
|
return assertions
|
||||||
|
|
||||||
def _reduce_sum(self, stat):
|
def _reduce_sum(self, stat):
|
||||||
if self._static_reduce_batch_ndims is None:
|
if self._static_reinterpreted_batch_ndims is None:
|
||||||
range_ = array_ops.range(self._reduce_batch_ndims)
|
range_ = math_ops.range(self._reinterpreted_batch_ndims)
|
||||||
else:
|
else:
|
||||||
range_ = np.arange(self._static_reduce_batch_ndims)
|
range_ = np.arange(self._static_reinterpreted_batch_ndims)
|
||||||
return math_ops.reduce_sum(stat, axis=-1-range_)
|
return math_ops.reduce_sum(stat, axis=-1-range_)
|
||||||
|
|
||||||
|
def _get_default_reinterpreted_batch_ndims(self, distribution):
|
||||||
|
"""Computes the default value for reinterpreted_batch_ndim __init__ arg."""
|
||||||
|
ndims = distribution.batch_shape.ndims
|
||||||
|
if ndims is None:
|
||||||
|
which_maximum = math_ops.maximum
|
||||||
|
ndims = array_ops.shape(distribution.batch_shape_tensor())[0]
|
||||||
|
else:
|
||||||
|
which_maximum = np.maximum
|
||||||
|
return which_maximum(0, ndims - 1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user