From 80374a7b47dddb591f711b6240ea0896fbe90d29 Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Fri, 27 Oct 2017 15:59:46 -0700 Subject: [PATCH] Breaking change: Rename `tf.contrib.distributions.Independent` parameter from `reduce_batch_ndims` to `reinterpreted_batch_ndims`. Also change default; `reinterpreted_batch_ndims` default has semantics of `tf.layers.flatten`, i.e., all batch dimensions except the first (batch axis 0) are interpretted as being part of the event. PiperOrigin-RevId: 173729585 --- .../python/kernel_tests/independent_test.py | 65 +++++++++- .../distributions/python/ops/independent.py | 113 +++++++++++------- 2 files changed, 130 insertions(+), 48 deletions(-) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py index dcc66e89720..8e23a3ab8fd 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py @@ -23,7 +23,10 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import independent as independent_lib from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -42,13 +45,16 @@ stats = try_import("scipy.stats") class ProductDistributionTest(test.TestCase): + def setUp(self): + self._rng = np.random.RandomState(42) + def testSampleAndLogProbUnivariate(self): loc = np.float32([-1., 1]) scale = np.float32([0.1, 0.5]) with self.test_session() as sess: ind = independent_lib.Independent( distribution=normal_lib.Normal(loc=loc, scale=scale), - reduce_batch_ndims=1) + reinterpreted_batch_ndims=1) x = ind.sample([4, 5]) log_prob_x = ind.log_prob(x) @@ -71,7 +77,7 @@ class ProductDistributionTest(test.TestCase): distribution=mvn_diag_lib.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale), - reduce_batch_ndims=1) + reinterpreted_batch_ndims=1) x = ind.sample([4, 5]) log_prob_x = ind.log_prob(x) @@ -96,7 +102,7 @@ class ProductDistributionTest(test.TestCase): distribution=mvn_diag_lib.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale), - reduce_batch_ndims=1) + reinterpreted_batch_ndims=1) x = ind.sample(int(n_samp), seed=42) sample_mean = math_ops.reduce_mean(x, axis=0) @@ -120,6 +126,59 @@ class ProductDistributionTest(test.TestCase): self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.) self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.) + def _testMnistLike(self, static_shape): + sample_shape = [4, 5] + batch_shape = [10] + image_shape = [28, 28, 1] + logits = 3 * self._rng.random_sample( + batch_shape + image_shape).astype(np.float32) - 1 + + def expected_log_prob(x, logits): + return (x * logits - np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1) + + with self.test_session() as sess: + logits_ph = array_ops.placeholder( + dtypes.float32, shape=logits.shape if static_shape else None) + ind = independent_lib.Independent( + distribution=bernoulli_lib.Bernoulli(logits=logits_ph)) + x = ind.sample(sample_shape) + log_prob_x = ind.log_prob(x) + [ + x_, + actual_log_prob_x, + ind_batch_shape, + ind_event_shape, + x_shape, + log_prob_x_shape, + ] = sess.run([ + x, + log_prob_x, + ind.batch_shape_tensor(), + ind.event_shape_tensor(), + array_ops.shape(x), + array_ops.shape(log_prob_x), + ], feed_dict={logits_ph: logits}) + + if static_shape: + ind_batch_shape = ind.batch_shape + ind_event_shape = ind.event_shape + x_shape = x.shape + log_prob_x_shape = log_prob_x.shape + + self.assertAllEqual(batch_shape, ind_batch_shape) + self.assertAllEqual(image_shape, ind_event_shape) + self.assertAllEqual(sample_shape + batch_shape + image_shape, x_shape) + self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape) + self.assertAllClose(expected_log_prob(x_, logits), + actual_log_prob_x, + rtol=1e-6, atol=0.) + + def testMnistLikeStaticShape(self): + self._testMnistLike(static_shape=True) + + def testMnistLikeDynamicShape(self): + self._testMnistLike(static_shape=False) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index 393c0082424..6a74ca9a0ae 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -45,24 +45,24 @@ class Independent(distribution_lib.Distribution): `p(x_1, ..., x_B) = p_1(x_1) * ... * p_B(x_B)` where `p_b(X_b)` is the probability of the `b`-th rv. More generally `B, E` can be arbitrary shapes. - Similarly, the `Independent` distribution specifies a distribution over - `[B, E]`-shaped events. It operates by reinterpreting the rightmost batch dims - as part of the event dimensions. The `reduce_batch_ndims` parameter controls - the number of batch dims which are absorbed as event dims; - `reduce_batch_ndims < len(batch_shape)`. For example, the `log_prob` function - entails a `reduce_sum` over the rightmost `reduce_batch_ndims` after calling - the base distribution's `log_prob`. In other words, since the batch - dimension(s) index independent distributions, the resultant multivariate will - have independent components. + Similarly, the `Independent` distribution specifies a distribution over `[B, + E]`-shaped events. It operates by reinterpreting the rightmost batch dims as + part of the event dimensions. The `reinterpreted_batch_ndims` parameter + controls the number of batch dims which are absorbed as event dims; + `reinterpreted_batch_ndims < len(batch_shape)`. For example, the `log_prob` + function entails a `reduce_sum` over the rightmost `reinterpreted_batch_ndims` + after calling the base distribution's `log_prob`. In other words, since the + batch dimension(s) index independent distributions, the resultant multivariate + will have independent components. #### Mathematical Details The probability function is, ```none - prob(x; reduce_batch_ndims) = tf.reduce_prod( + prob(x; reinterpreted_batch_ndims) = tf.reduce_prod( dist.prob(x), - axis=-1-range(reduce_batch_ndims)) + axis=-1-range(reinterpreted_batch_ndims)) ``` #### Examples @@ -73,7 +73,7 @@ class Independent(distribution_lib.Distribution): # Make independent distribution from a 2-batch Normal. ind = ds.Independent( distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]), - reduce_batch_ndims=1) + reinterpreted_batch_ndims=1) # All batch dims have been "absorbed" into event dims. ind.batch_shape # ==> [] @@ -84,7 +84,7 @@ class Independent(distribution_lib.Distribution): distribution=ds.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]), - reduce_batch_ndims=1) + reinterpreted_batch_ndims=1) # All batch dims have been "absorbed" into event dims. ind.batch_shape # ==> [] @@ -94,14 +94,17 @@ class Independent(distribution_lib.Distribution): """ def __init__( - self, distribution, reduce_batch_ndims=1, validate_args=False, name=None): + self, distribution, reinterpreted_batch_ndims=None, + validate_args=False, name=None): """Construct a `Independent` distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. - reduce_batch_ndims: Scalar, integer number of rightmost batch dims which - will be regard as event dims. + reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims + which will be regarded as event dims. When `None` all but the first + batch axis (batch axis 0) will be transferred to event dimensions + (analogous to `tf.layers.flatten`). validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. @@ -109,19 +112,25 @@ class Independent(distribution_lib.Distribution): Default value: `Independent + distribution.name`. Raises: - ValueError: if `reduce_batch_ndims` exceeds `distribution.batch_ndims` + ValueError: if `reinterpreted_batch_ndims` exceeds + `distribution.batch_ndims` """ parameters = locals() name = name or "Independent" + distribution.name self._distribution = distribution with ops.name_scope(name): - reduce_batch_ndims = ops.convert_to_tensor( - reduce_batch_ndims, dtype=dtypes.int32, name="reduce_batch_ndims") - self._reduce_batch_ndims = reduce_batch_ndims - self._static_reduce_batch_ndims = tensor_util.constant_value( - reduce_batch_ndims) - if self._static_reduce_batch_ndims is not None: - self._reduce_batch_ndims = self._static_reduce_batch_ndims + if reinterpreted_batch_ndims is None: + reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims( + distribution) + reinterpreted_batch_ndims = ops.convert_to_tensor( + reinterpreted_batch_ndims, + dtype=dtypes.int32, + name="reinterpreted_batch_ndims") + self._reinterpreted_batch_ndims = reinterpreted_batch_ndims + self._static_reinterpreted_batch_ndims = tensor_util.constant_value( + reinterpreted_batch_ndims) + if self._static_reinterpreted_batch_ndims is not None: + self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims super(Independent, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution.reparameterization_type, @@ -129,19 +138,19 @@ class Independent(distribution_lib.Distribution): allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, graph_parents=( - [reduce_batch_ndims] + + [reinterpreted_batch_ndims] + distribution._graph_parents), # pylint: disable=protected-access name=name) self._runtime_assertions = self._make_runtime_assertions( - distribution, reduce_batch_ndims, validate_args) + distribution, reinterpreted_batch_ndims, validate_args) @property def distribution(self): return self._distribution @property - def reduce_batch_ndims(self): - return self._reduce_batch_ndims + def reinterpreted_batch_ndims(self): + return self._reinterpreted_batch_ndims def _batch_shape_tensor(self): with ops.control_dependencies(self._runtime_assertions): @@ -149,13 +158,14 @@ class Independent(distribution_lib.Distribution): batch_ndims = (batch_shape.shape[0].value if batch_shape.shape.with_rank_at_least(1)[0].value else array_ops.shape(batch_shape)[0]) - return batch_shape[:batch_ndims - self.reduce_batch_ndims] + return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims] def _batch_shape(self): batch_shape = self.distribution.batch_shape - if self._static_reduce_batch_ndims is None or batch_shape.ndims is None: + if (self._static_reinterpreted_batch_ndims is None + or batch_shape.ndims is None): return tensor_shape.TensorShape(None) - d = batch_shape.ndims - self._static_reduce_batch_ndims + d = batch_shape.ndims - self._static_reinterpreted_batch_ndims return batch_shape[:d] def _event_shape_tensor(self): @@ -165,15 +175,16 @@ class Independent(distribution_lib.Distribution): if batch_shape.shape.with_rank_at_least(1)[0].value else array_ops.shape(batch_shape)[0]) return array_ops.concat([ - batch_shape[batch_ndims - self.reduce_batch_ndims:], + batch_shape[batch_ndims - self.reinterpreted_batch_ndims:], self.distribution.event_shape_tensor(), ], axis=0) def _event_shape(self): batch_shape = self.distribution.batch_shape - if self._static_reduce_batch_ndims is None or batch_shape.ndims is None: + if (self._static_reinterpreted_batch_ndims is None + or batch_shape.ndims is None): return tensor_shape.TensorShape(None) - d = batch_shape.ndims - self._static_reduce_batch_ndims + d = batch_shape.ndims - self._static_reinterpreted_batch_ndims return batch_shape[d:].concatenate(self.distribution.event_shape) def _sample_n(self, n, seed): @@ -205,15 +216,16 @@ class Independent(distribution_lib.Distribution): return self.distribution.mode() def _make_runtime_assertions( - self, distribution, reduce_batch_ndims, validate_args): + self, distribution, reinterpreted_batch_ndims, validate_args): assertions = [] - static_reduce_batch_ndims = tensor_util.constant_value(reduce_batch_ndims) + static_reinterpreted_batch_ndims = tensor_util.constant_value( + reinterpreted_batch_ndims) batch_ndims = distribution.batch_shape.ndims - if batch_ndims is not None and static_reduce_batch_ndims is not None: - if static_reduce_batch_ndims > batch_ndims: - raise ValueError("reduce_batch_ndims({}) cannot exceed " + if batch_ndims is not None and static_reinterpreted_batch_ndims is not None: + if static_reinterpreted_batch_ndims > batch_ndims: + raise ValueError("reinterpreted_batch_ndims({}) cannot exceed " "distribution.batch_ndims({})".format( - static_reduce_batch_ndims, batch_ndims)) + static_reinterpreted_batch_ndims, batch_ndims)) elif validate_args: batch_shape = distribution.batch_shape_tensor() batch_ndims = ( @@ -221,13 +233,24 @@ class Independent(distribution_lib.Distribution): if batch_shape.shape.with_rank_at_least(1)[0].value is not None else array_ops.shape(batch_shape)[0]) assertions.append(check_ops.assert_less_equal( - reduce_batch_ndims, batch_ndims, - message="reduce_batch_ndims cannot exceed distribution.batch_ndims")) + reinterpreted_batch_ndims, batch_ndims, + message=("reinterpreted_batch_ndims cannot exceed " + "distribution.batch_ndims"))) return assertions def _reduce_sum(self, stat): - if self._static_reduce_batch_ndims is None: - range_ = array_ops.range(self._reduce_batch_ndims) + if self._static_reinterpreted_batch_ndims is None: + range_ = math_ops.range(self._reinterpreted_batch_ndims) else: - range_ = np.arange(self._static_reduce_batch_ndims) + range_ = np.arange(self._static_reinterpreted_batch_ndims) return math_ops.reduce_sum(stat, axis=-1-range_) + + def _get_default_reinterpreted_batch_ndims(self, distribution): + """Computes the default value for reinterpreted_batch_ndim __init__ arg.""" + ndims = distribution.batch_shape.ndims + if ndims is None: + which_maximum = math_ops.maximum + ndims = array_ops.shape(distribution.batch_shape_tensor())[0] + else: + which_maximum = np.maximum + return which_maximum(0, ndims - 1)