Gamma distribution and the derived distributions (Beta, Dirichlet, Student's t, inverse Gamma) are fully reparameterized.

For every distribution, the changes are:
* Set reparameterization_type to FULLY_REPARAMETERIZED.
* Add a note about reparameterization and an example to the docstring.
* Add a test that ensures that the gradients exist.

Additional changes:
* Fix docstring and test in TFP that assume that Gamma is not reparameterized. We simply replace Gamma with Bernoulli :)
* Fix paths to modules in docstrings.

PiperOrigin-RevId: 201691205
This commit is contained in:
A. Unique TensorFlower 2018-06-22 08:50:45 -07:00 committed by TensorFlower Gardener
parent fcb519a4a3
commit b894f6844a
10 changed files with 135 additions and 55 deletions

View File

@ -29,7 +29,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import gamma as gamma_lib
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
@ -256,50 +255,6 @@ class ExpectationTest(test.TestCase):
gradq_approx_kl_normal_normal_,
rtol=0.01, atol=0.)
def test_docstring_example_gamma(self):
with self.test_session() as sess:
num_draws = int(1e5)
concentration_p = constant_op.constant(1.)
concentration_q = constant_op.constant(2.)
p = gamma_lib.Gamma(concentration=concentration_p, rate=1.)
q = gamma_lib.Gamma(concentration=concentration_q, rate=3.)
approx_kl_gamma_gamma = monte_carlo_lib.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
use_reparametrization=(p.reparameterization_type
== distribution_lib.FULLY_REPARAMETERIZED))
exact_kl_gamma_gamma = kullback_leibler.kl_divergence(p, q)
[exact_kl_gamma_gamma_, approx_kl_gamma_gamma_] = sess.run([
exact_kl_gamma_gamma, approx_kl_gamma_gamma])
self.assertEqual(
False,
p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED)
self.assertAllClose(exact_kl_gamma_gamma_, approx_kl_gamma_gamma_,
rtol=0.01, atol=0.)
# Compare gradients. (Not present in `docstring`.)
gradp = lambda fp: gradients_impl.gradients(fp, concentration_p)[0]
gradq = lambda fq: gradients_impl.gradients(fq, concentration_q)[0]
[
gradp_exact_kl_gamma_gamma_,
gradq_exact_kl_gamma_gamma_,
gradp_approx_kl_gamma_gamma_,
gradq_approx_kl_gamma_gamma_,
] = sess.run([
gradp(exact_kl_gamma_gamma),
gradq(exact_kl_gamma_gamma),
gradp(approx_kl_gamma_gamma),
gradq(approx_kl_gamma_gamma),
])
# Notice that variance (i.e., `rtol`) is higher when using score-trick.
self.assertAllClose(gradp_exact_kl_gamma_gamma_,
gradp_approx_kl_gamma_gamma_,
rtol=0.05, atol=0.)
self.assertAllClose(gradq_exact_kl_gamma_gamma_,
gradq_approx_kl_gamma_gamma_,
rtol=0.03, atol=0.)
if __name__ == '__main__':
test.main()

View File

@ -21,6 +21,7 @@ import importlib
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
@ -282,6 +283,18 @@ class BetaTest(test.TestCase):
self.assertAllClose(
np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
def testBetaFullyReparameterized(self):
a = constant_op.constant(1.0)
b = constant_op.constant(2.0)
with backprop.GradientTape() as tape:
tape.watch(a)
tape.watch(b)
beta = beta_lib.Beta(a, b)
samples = beta.sample(100)
grad_a, grad_b = tape.gradient(samples, [a, b])
self.assertIsNotNone(grad_a)
self.assertIsNotNone(grad_b)
# Test that sampling with the same seed twice gives the same results.
def testBetaSampleMultipleTimes(self):
with self.test_session():

View File

@ -20,6 +20,7 @@ import importlib
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
@ -264,6 +265,15 @@ class DirichletTest(test.TestCase):
a=1., b=2.).cdf)[0],
0.01)
def testDirichletFullyReparameterized(self):
alpha = constant_op.constant([1.0, 2.0, 3.0])
with backprop.GradientTape() as tape:
tape.watch(alpha)
dirichlet = dirichlet_lib.Dirichlet(alpha)
samples = dirichlet.sample(100)
grad_alpha = tape.gradient(samples, alpha)
self.assertIsNotNone(grad_alpha)
def testDirichletDirichletKL(self):
conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5],
[1.5, 2.5, 3.5, 4.5, 5.5, 6.5]])

View File

@ -21,6 +21,7 @@ import importlib
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
@ -265,6 +266,18 @@ class GammaTest(test.TestCase):
stats.gamma.var(alpha_v, scale=1 / beta_v),
atol=.15)
def testGammaFullyReparameterized(self):
alpha = constant_op.constant(4.0)
beta = constant_op.constant(3.0)
with backprop.GradientTape() as tape:
tape.watch(alpha)
tape.watch(beta)
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
samples = gamma.sample(100)
grad_alpha, grad_beta = tape.gradient(samples, [alpha, beta])
self.assertIsNotNone(grad_alpha)
self.assertIsNotNone(grad_beta)
def testGammaSampleMultiDimensional(self):
with self.test_session():
alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100

View File

@ -23,6 +23,7 @@ import math
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
@ -454,6 +455,21 @@ class StudentTTest(test.TestCase):
return
self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
def testFullyReparameterized(self):
df = constant_op.constant(2.0)
mu = constant_op.constant(1.0)
sigma = constant_op.constant(3.0)
with backprop.GradientTape() as tape:
tape.watch(df)
tape.watch(mu)
tape.watch(sigma)
student = student_t.StudentT(df=df, loc=mu, scale=sigma)
samples = student.sample(100)
grad_df, grad_mu, grad_sigma = tape.gradient(samples, [df, mu, sigma])
self.assertIsNotNone(grad_df)
self.assertIsNotNone(grad_mu)
self.assertIsNotNone(grad_sigma)
def testPdfOfSampleMultiDims(self):
student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
self.assertAllEqual([], student.event_shape)

View File

@ -89,13 +89,19 @@ class Beta(distribution.Distribution):
Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
density.
Samples of this distribution are reparameterized (pathwise differentiable).
The derivatives are computed using the approach described in the paper
[Michael Figurnov, Shakir Mohamed, Andriy Mnih.
Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
#### Examples
```python
# Create a batch of three Beta distributions.
alpha = [1, 2, 3]
beta = [1, 2, 3]
dist = Beta(alpha, beta)
dist = tf.distributions.Beta(alpha, beta)
dist.sample([4, 5]) # Shape [4, 5, 3]
@ -111,7 +117,7 @@ class Beta(distribution.Distribution):
# Create batch_shape=[2, 3] via parameter broadcast:
alpha = [[1.], [2]] # Shape [2, 1]
beta = [3., 4, 5] # Shape [3]
dist = Beta(alpha, beta)
dist = tf.distributions.Beta(alpha, beta)
# alpha broadcast as: [[1., 1, 1,],
# [2, 2, 2]]
@ -127,6 +133,18 @@ class Beta(distribution.Distribution):
dist.prob(x) # Shape [2, 3]
```
Compute the gradients of samples w.r.t. the parameters:
```python
alpha = tf.constant(1.0)
beta = tf.constant(2.0)
dist = tf.distributions.Beta(alpha, beta)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
grads = tf.gradients(loss, [alpha, beta])
```
"""
def __init__(self,
@ -170,7 +188,7 @@ class Beta(distribution.Distribution):
dtype=self._total_concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
parameters=parameters,
graph_parents=[self._concentration1,
self._concentration0,

View File

@ -95,13 +95,19 @@ class Dirichlet(distribution.Distribution):
Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
density.
Samples of this distribution are reparameterized (pathwise differentiable).
The derivatives are computed using the approach described in the paper
[Michael Figurnov, Shakir Mohamed, Andriy Mnih.
Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
#### Examples
```python
# Create a single trivariate Dirichlet, with the 3rd class being three times
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
alpha = [1., 2, 3]
dist = Dirichlet(alpha)
dist = tf.distributions.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 3]
@ -123,7 +129,7 @@ class Dirichlet(distribution.Distribution):
# Create batch_shape=[2], event_shape=[3]:
alpha = [[1., 2, 3],
[4, 5, 6]] # shape: [2, 3]
dist = Dirichlet(alpha)
dist = tf.distributions.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 2, 3]
@ -134,6 +140,17 @@ class Dirichlet(distribution.Distribution):
dist.prob(x) # shape: [2]
```
Compute the gradients of samples w.r.t. the parameters:
```python
alpha = tf.constant([1.0, 2.0, 3.0])
dist = tf.distributions.Dirichlet(alpha)
samples = dist.sample(5) # Shape [5, 3]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
grads = tf.gradients(loss, alpha)
```
"""
def __init__(self,
@ -170,7 +187,7 @@ class Dirichlet(distribution.Distribution):
dtype=self._concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
parameters=parameters,
graph_parents=[self._concentration,
self._total_concentration],

View File

@ -91,11 +91,29 @@ class Gamma(distribution.Distribution):
This should only be noticeable when the `concentration` is very small, or the
`rate` is very large. See note in `tf.random_gamma` docstring.
Samples of this distribution are reparameterized (pathwise differentiable).
The derivatives are computed using the approach described in the paper
[Michael Figurnov, Shakir Mohamed, Andriy Mnih.
Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
#### Examples
```python
dist = Gamma(concentration=3.0, rate=2.0)
dist2 = Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
dist = tf.distributions.Gamma(concentration=3.0, rate=2.0)
dist2 = tf.distributions.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
Compute the gradients of samples w.r.t. the parameters:
```python
concentration = tf.constant(3.0)
rate = tf.constant(2.0)
dist = tf.distributions.Gamma(concentration, rate)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
grads = tf.gradients(loss, [concentration, rate])
```
"""
@ -144,7 +162,7 @@ class Gamma(distribution.Distribution):
dtype=self._concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
parameters=parameters,
graph_parents=[self._concentration,
self._rate],

View File

@ -80,6 +80,12 @@ class StudentT(distribution.Distribution):
variance. However it is not actually the std. deviation; the Student's
t-distribution std. dev. is `scale sqrt(df / (df - 2))` when `df > 2`.
Samples of this distribution are reparameterized (pathwise differentiable).
The derivatives are computed using the approach described in the paper
[Michael Figurnov, Shakir Mohamed, Andriy Mnih.
Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
#### Examples
Examples of initialization of one or a batch of distributions.
@ -118,6 +124,19 @@ class StudentT(distribution.Distribution):
dist.prob(3.0)
```
Compute the gradients of samples w.r.t. the parameters:
```python
df = tf.constant(2.0)
loc = tf.constant(2.0)
scale = tf.constant(11.0)
dist = tf.distributions.StudentT(df=df, loc=loc, scale=scale)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
grads = tf.gradients(loss, [df, loc, scale])
```
"""
# pylint: enable=line-too-long
@ -168,7 +187,7 @@ class StudentT(distribution.Distribution):
(self._df, self._loc, self._scale))
super(StudentT, self).__init__(
dtype=self._scale.dtype,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,

View File

@ -49,6 +49,7 @@ from tensorflow.python.ops import logging_ops # pylint: disable=unused-import
from tensorflow.python.ops import manip_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_grad # pylint: disable=unused-import
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops