Add batchwise KL to tf.contrib.distributions.

* Includes registration decorator RegisterKL.
* Includes distributions.kl() for calculating.
* Includes implementation for KL(Normal || Normal).
* distributions.kl() accepts the argument exceptions.  If true (default),
  then the output of the KL is tested for NaNs.  If any are there,
  a runtime exception is thrown.  If false, no runtime checks are performed.
Change: 125008908
This commit is contained in:
Eugene Brevdo 2016-06-15 15:57:22 -08:00 committed by TensorFlower Gardener
parent a496df8788
commit 60b0f7014f
6 changed files with 278 additions and 0 deletions

View File

@ -65,6 +65,17 @@ cuda_py_tests(
],
)
cuda_py_tests(
name = "kullback_leibler_test",
size = "small",
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "student_t_test",
size = "small",

View File

@ -49,6 +49,12 @@ representing the posterior or posterior predictive.
@@normal_conjugates_known_sigma_posterior
@@normal_congugates_known_sigma_predictive
## Kullback Leibler Divergence
@@kl
@@RegisterKL
"""
from __future__ import absolute_import
from __future__ import division
@ -62,6 +68,7 @@ from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.distribution import *
from tensorflow.contrib.distributions.python.ops.exponential import *
from tensorflow.contrib.distributions.python.ops.gamma import *
from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
from tensorflow.contrib.distributions.python.ops.mvn import *
from tensorflow.contrib.distributions.python.ops.normal import *
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *

View File

@ -0,0 +1,83 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for distributions KL mechanism."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class KLTest(tf.test.TestCase):
def testRegistration(self):
class MyDist(tf.contrib.distributions.Normal):
pass
# Register KL to a lambda that spits out the name parameter
@tf.contrib.distributions.RegisterKL(MyDist, MyDist)
def _kl(unused_a, unused_b, name=None): # pylint: disable=unused-variable
return name
a = MyDist(mu=0.0, sigma=1.0)
# Run kl() with allow_nan=True because strings can't go through is_nan.
self.assertEqual(
"OK", tf.contrib.distributions.kl(a, a, allow_nan=True, name="OK"))
def testDomainErrorExceptions(self):
class MyDistException(tf.contrib.distributions.Normal):
pass
# Register KL to a lambda that spits out the name parameter
@tf.contrib.distributions.RegisterKL(MyDistException, MyDistException)
# pylint: disable=unused-variable
def _kl(unused_a, unused_b, name=None): # pylint: disable=unused-argument
return tf.identity([float("nan")])
# pylint: disable=unused-variable
with self.test_session():
a = MyDistException(mu=0.0, sigma=1.0)
kl = tf.contrib.distributions.kl(a, a)
with self.assertRaisesOpError(
"KL calculation between .* and .* returned NaN values"):
kl.eval()
kl_ok = tf.contrib.distributions.kl(a, a, allow_nan=True)
self.assertAllEqual([float("nan")], kl_ok.eval())
def testRegistrationFailures(self):
with self.assertRaisesRegexp(TypeError, "is not a subclass of"):
tf.contrib.distributions.RegisterKL(
tf.contrib.distributions.Normal, object)(lambda x: x)
with self.assertRaisesRegexp(TypeError, "is not a subclass of"):
tf.contrib.distributions.RegisterKL(
object, tf.contrib.distributions.Normal)(lambda x: x)
class MyDist(tf.contrib.distributions.Normal):
pass
with self.assertRaisesRegexp(TypeError, "must be callable"):
tf.contrib.distributions.RegisterKL(MyDist, MyDist)("blah")
# First registration is OK
tf.contrib.distributions.RegisterKL(MyDist, MyDist)(lambda a, b: None)
# Second registration fails
with self.assertRaisesRegexp(ValueError, "has already been registered"):
tf.contrib.distributions.RegisterKL(MyDist, MyDist)(lambda a, b: None)
if __name__ == "__main__":
tf.test.main()

View File

@ -253,6 +253,28 @@ class NormalTest(tf.test.TestCase):
feed_dict={mu: 5.0, sigma: [1.0, 2.0]}),
[2])
def testNormalNormalKL(self):
with self.test_session() as sess:
batch_size = 6
mu_a = np.array([3.0] * batch_size)
sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5])
mu_b = np.array([-3.0] * batch_size)
sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
n_a = tf.contrib.distributions.Normal(mu=mu_a, sigma=sigma_a)
n_b = tf.contrib.distributions.Normal(mu=mu_b, sigma=sigma_b)
kl = tf.contrib.distributions.kl(n_a, n_b)
kl_val = sess.run(kl)
kl_expected = (
(mu_a - mu_b)**2 / (2 * sigma_b**2)
+ 0.5 * ((sigma_a**2/sigma_b**2) -
1 - 2 * np.log(sigma_a / sigma_b)))
self.assertEqual(kl.get_shape(), (batch_size,))
self.assertAllClose(kl_val, kl_expected)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,130 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registration and usage mechanisms for KL-divergences."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
_DIVERGENCES = {}
def kl(dist_a, dist_b, allow_nan=False, name=None):
"""Get the KL-divergence KL(dist_a || dist_b).
Args:
dist_a: instance of distributions.BaseDistribution.
dist_b: instance of distributions.BaseDistribution.
allow_nan: If False (default), a runtime error is raised
if the KL returns NaN values for any batch entry of the given
distributions. If True, the KL may return a NaN for the given entry.
name: (optional) Name scope to use for created operations.
Returns:
A Tensor with the batchwise KL-divergence between dist_a and dist_b.
Raises:
TypeError: If dist_a or dist_b is not an instance of BaseDistribution.
NotImplementedError: If no KL method is defined for distribution types
of dist_a and dist_b.
"""
if not isinstance(dist_a, distribution.BaseDistribution):
raise TypeError(
"dist_a is not an instance of BaseDistribution, received type: %s"
% type(dist_a))
if not isinstance(dist_b, distribution.BaseDistribution):
raise TypeError(
"dist_b is not an instance of BaseDistribution, received type: %s"
% type(dist_b))
kl_fn = _DIVERGENCES.get((type(dist_a), type(dist_b)), None)
if kl_fn is None:
raise NotImplementedError(
"No KL(dist_a || dist_b) registered for dist_a type %s and dist_b "
"type %s" % ((type(dist_a).__name__, type(dist_b).__name__)))
with ops.name_scope("KullbackLeibler"):
kl_t = kl_fn(dist_a, dist_b, name=name)
if allow_nan:
return kl_t
# Check KL for NaNs
kl_t = array_ops.identity(kl_t, name="kl")
with ops.control_dependencies([
logging_ops.Assert(
math_ops.logical_not(
math_ops.reduce_any(math_ops.is_nan(kl_t))),
["KL calculation between %s and %s returned NaN values "
"(and was called with allow_nan=False). Values:"
% (dist_a.name, dist_b.name), kl_t])]):
return array_ops.identity(kl_t, name="checked_kl")
class RegisterKL(object):
"""Decorator to register a KL divergence implementation function.
Usage:
@distributions.RegisterKL(distributions.Normal, distributions.Normal)
def _kl_normal_mvn(norm_a, norm_b):
# Return KL(norm_a || norm_b)
"""
def __init__(self, dist_cls_a, dist_cls_b):
"""Initialize the KL registrar.
Args:
dist_cls_a: the class of the first argument of the KL divergence.
dist_cls_b: the class of the second argument of the KL divergence.
Raises:
TypeError: if dist_cls_a or dist_cls_b are not subclasses of
BaseDistribution.
"""
if not issubclass(dist_cls_a, distribution.BaseDistribution):
raise TypeError("%s is not a subclass of BaseDistribution" % dist_cls_a)
if not issubclass(dist_cls_b, distribution.BaseDistribution):
raise TypeError("%s is not a subclass of BaseDistribution" % dist_cls_b)
self._key = (dist_cls_a, dist_cls_b)
def __call__(self, kl_fn):
"""Perform the KL registration.
Args:
kl_fn: The function to use for the KL divergence.
Returns:
kl_fn
Raises:
TypeError: if kl_fn is not a callable.
ValueError: if a KL divergence function has already been registered for
the given argument classes.
"""
if not callable(kl_fn):
raise TypeError("kl_fn must be callable, received: %s" % kl_fn)
if self._key in _DIVERGENCES:
raise ValueError("KL(%s || %s) has already been registered to: %s"
% (self._key[0].__name__, self._key[1].__name__,
_DIVERGENCES[self._key]))
_DIVERGENCES[self._key] = kl_fn
return kl_fn

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import math
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -317,3 +318,27 @@ class Normal(distribution.ContinuousDistribution):
def _zeros(self):
return array_ops.zeros_like(self._mu + self._sigma)
@kullback_leibler.RegisterKL(Normal, Normal)
def _kl_normal_normal(n_a, n_b, name=None):
"""Calculate the batched KL divergence KL(n_a || n_b) with n_a and n_b Normal.
Args:
n_a: instance of a Normal distribution object.
n_b: instance of a Normal distribution object.
name: (optional) Name to use for created operations.
default is "kl_normal_normal".
Returns:
Batchwise KL(n_a || n_b)
"""
with ops.op_scope([n_a.mu, n_b.mu], name, "kl_normal_normal"):
one = constant_op.constant(1, dtype=n_a.dtype)
two = constant_op.constant(2, dtype=n_a.dtype)
half = constant_op.constant(0.5, dtype=n_a.dtype)
s_a_squared = math_ops.square(n_a.sigma)
s_b_squared = math_ops.square(n_b.sigma)
ratio = s_a_squared / s_b_squared
return (math_ops.square(n_a.mu - n_b.mu) / (two * s_b_squared)
+ half * (ratio - one - math_ops.log(ratio)))