Add batchwise KL to tf.contrib.distributions.
* Includes registration decorator RegisterKL. * Includes distributions.kl() for calculating. * Includes implementation for KL(Normal || Normal). * distributions.kl() accepts the argument exceptions. If true (default), then the output of the KL is tested for NaNs. If any are there, a runtime exception is thrown. If false, no runtime checks are performed. Change: 125008908
This commit is contained in:
parent
a496df8788
commit
60b0f7014f
@ -65,6 +65,17 @@ cuda_py_tests(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "kullback_leibler_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "student_t_test",
|
||||
size = "small",
|
||||
|
@ -49,6 +49,12 @@ representing the posterior or posterior predictive.
|
||||
|
||||
@@normal_conjugates_known_sigma_posterior
|
||||
@@normal_congugates_known_sigma_predictive
|
||||
|
||||
## Kullback Leibler Divergence
|
||||
|
||||
@@kl
|
||||
@@RegisterKL
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -62,6 +68,7 @@ from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
|
||||
from tensorflow.contrib.distributions.python.ops.distribution import *
|
||||
from tensorflow.contrib.distributions.python.ops.exponential import *
|
||||
from tensorflow.contrib.distributions.python.ops.gamma import *
|
||||
from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
|
||||
from tensorflow.contrib.distributions.python.ops.mvn import *
|
||||
from tensorflow.contrib.distributions.python.ops.normal import *
|
||||
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
|
||||
|
@ -0,0 +1,83 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for distributions KL mechanism."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class KLTest(tf.test.TestCase):
|
||||
|
||||
def testRegistration(self):
|
||||
class MyDist(tf.contrib.distributions.Normal):
|
||||
pass
|
||||
|
||||
# Register KL to a lambda that spits out the name parameter
|
||||
@tf.contrib.distributions.RegisterKL(MyDist, MyDist)
|
||||
def _kl(unused_a, unused_b, name=None): # pylint: disable=unused-variable
|
||||
return name
|
||||
|
||||
a = MyDist(mu=0.0, sigma=1.0)
|
||||
# Run kl() with allow_nan=True because strings can't go through is_nan.
|
||||
self.assertEqual(
|
||||
"OK", tf.contrib.distributions.kl(a, a, allow_nan=True, name="OK"))
|
||||
|
||||
def testDomainErrorExceptions(self):
|
||||
class MyDistException(tf.contrib.distributions.Normal):
|
||||
pass
|
||||
|
||||
# Register KL to a lambda that spits out the name parameter
|
||||
@tf.contrib.distributions.RegisterKL(MyDistException, MyDistException)
|
||||
# pylint: disable=unused-variable
|
||||
def _kl(unused_a, unused_b, name=None): # pylint: disable=unused-argument
|
||||
return tf.identity([float("nan")])
|
||||
# pylint: disable=unused-variable
|
||||
|
||||
with self.test_session():
|
||||
a = MyDistException(mu=0.0, sigma=1.0)
|
||||
kl = tf.contrib.distributions.kl(a, a)
|
||||
with self.assertRaisesOpError(
|
||||
"KL calculation between .* and .* returned NaN values"):
|
||||
kl.eval()
|
||||
kl_ok = tf.contrib.distributions.kl(a, a, allow_nan=True)
|
||||
self.assertAllEqual([float("nan")], kl_ok.eval())
|
||||
|
||||
def testRegistrationFailures(self):
|
||||
with self.assertRaisesRegexp(TypeError, "is not a subclass of"):
|
||||
tf.contrib.distributions.RegisterKL(
|
||||
tf.contrib.distributions.Normal, object)(lambda x: x)
|
||||
with self.assertRaisesRegexp(TypeError, "is not a subclass of"):
|
||||
tf.contrib.distributions.RegisterKL(
|
||||
object, tf.contrib.distributions.Normal)(lambda x: x)
|
||||
|
||||
class MyDist(tf.contrib.distributions.Normal):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "must be callable"):
|
||||
tf.contrib.distributions.RegisterKL(MyDist, MyDist)("blah")
|
||||
|
||||
# First registration is OK
|
||||
tf.contrib.distributions.RegisterKL(MyDist, MyDist)(lambda a, b: None)
|
||||
|
||||
# Second registration fails
|
||||
with self.assertRaisesRegexp(ValueError, "has already been registered"):
|
||||
tf.contrib.distributions.RegisterKL(MyDist, MyDist)(lambda a, b: None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -253,6 +253,28 @@ class NormalTest(tf.test.TestCase):
|
||||
feed_dict={mu: 5.0, sigma: [1.0, 2.0]}),
|
||||
[2])
|
||||
|
||||
def testNormalNormalKL(self):
|
||||
with self.test_session() as sess:
|
||||
batch_size = 6
|
||||
mu_a = np.array([3.0] * batch_size)
|
||||
sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5])
|
||||
mu_b = np.array([-3.0] * batch_size)
|
||||
sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
|
||||
|
||||
n_a = tf.contrib.distributions.Normal(mu=mu_a, sigma=sigma_a)
|
||||
n_b = tf.contrib.distributions.Normal(mu=mu_b, sigma=sigma_b)
|
||||
|
||||
kl = tf.contrib.distributions.kl(n_a, n_b)
|
||||
kl_val = sess.run(kl)
|
||||
|
||||
kl_expected = (
|
||||
(mu_a - mu_b)**2 / (2 * sigma_b**2)
|
||||
+ 0.5 * ((sigma_a**2/sigma_b**2) -
|
||||
1 - 2 * np.log(sigma_a / sigma_b)))
|
||||
|
||||
self.assertEqual(kl.get_shape(), (batch_size,))
|
||||
self.assertAllClose(kl_val, kl_expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
130
tensorflow/contrib/distributions/python/ops/kullback_leibler.py
Normal file
130
tensorflow/contrib/distributions/python/ops/kullback_leibler.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Registration and usage mechanisms for KL-divergences."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
_DIVERGENCES = {}
|
||||
|
||||
|
||||
def kl(dist_a, dist_b, allow_nan=False, name=None):
|
||||
"""Get the KL-divergence KL(dist_a || dist_b).
|
||||
|
||||
Args:
|
||||
dist_a: instance of distributions.BaseDistribution.
|
||||
dist_b: instance of distributions.BaseDistribution.
|
||||
allow_nan: If False (default), a runtime error is raised
|
||||
if the KL returns NaN values for any batch entry of the given
|
||||
distributions. If True, the KL may return a NaN for the given entry.
|
||||
name: (optional) Name scope to use for created operations.
|
||||
|
||||
Returns:
|
||||
A Tensor with the batchwise KL-divergence between dist_a and dist_b.
|
||||
|
||||
Raises:
|
||||
TypeError: If dist_a or dist_b is not an instance of BaseDistribution.
|
||||
NotImplementedError: If no KL method is defined for distribution types
|
||||
of dist_a and dist_b.
|
||||
"""
|
||||
if not isinstance(dist_a, distribution.BaseDistribution):
|
||||
raise TypeError(
|
||||
"dist_a is not an instance of BaseDistribution, received type: %s"
|
||||
% type(dist_a))
|
||||
if not isinstance(dist_b, distribution.BaseDistribution):
|
||||
raise TypeError(
|
||||
"dist_b is not an instance of BaseDistribution, received type: %s"
|
||||
% type(dist_b))
|
||||
kl_fn = _DIVERGENCES.get((type(dist_a), type(dist_b)), None)
|
||||
if kl_fn is None:
|
||||
raise NotImplementedError(
|
||||
"No KL(dist_a || dist_b) registered for dist_a type %s and dist_b "
|
||||
"type %s" % ((type(dist_a).__name__, type(dist_b).__name__)))
|
||||
with ops.name_scope("KullbackLeibler"):
|
||||
kl_t = kl_fn(dist_a, dist_b, name=name)
|
||||
if allow_nan:
|
||||
return kl_t
|
||||
|
||||
# Check KL for NaNs
|
||||
kl_t = array_ops.identity(kl_t, name="kl")
|
||||
|
||||
with ops.control_dependencies([
|
||||
logging_ops.Assert(
|
||||
math_ops.logical_not(
|
||||
math_ops.reduce_any(math_ops.is_nan(kl_t))),
|
||||
["KL calculation between %s and %s returned NaN values "
|
||||
"(and was called with allow_nan=False). Values:"
|
||||
% (dist_a.name, dist_b.name), kl_t])]):
|
||||
return array_ops.identity(kl_t, name="checked_kl")
|
||||
|
||||
|
||||
class RegisterKL(object):
|
||||
"""Decorator to register a KL divergence implementation function.
|
||||
|
||||
Usage:
|
||||
|
||||
@distributions.RegisterKL(distributions.Normal, distributions.Normal)
|
||||
def _kl_normal_mvn(norm_a, norm_b):
|
||||
# Return KL(norm_a || norm_b)
|
||||
"""
|
||||
|
||||
def __init__(self, dist_cls_a, dist_cls_b):
|
||||
"""Initialize the KL registrar.
|
||||
|
||||
Args:
|
||||
dist_cls_a: the class of the first argument of the KL divergence.
|
||||
dist_cls_b: the class of the second argument of the KL divergence.
|
||||
|
||||
Raises:
|
||||
TypeError: if dist_cls_a or dist_cls_b are not subclasses of
|
||||
BaseDistribution.
|
||||
"""
|
||||
|
||||
if not issubclass(dist_cls_a, distribution.BaseDistribution):
|
||||
raise TypeError("%s is not a subclass of BaseDistribution" % dist_cls_a)
|
||||
if not issubclass(dist_cls_b, distribution.BaseDistribution):
|
||||
raise TypeError("%s is not a subclass of BaseDistribution" % dist_cls_b)
|
||||
self._key = (dist_cls_a, dist_cls_b)
|
||||
|
||||
def __call__(self, kl_fn):
|
||||
"""Perform the KL registration.
|
||||
|
||||
Args:
|
||||
kl_fn: The function to use for the KL divergence.
|
||||
|
||||
Returns:
|
||||
kl_fn
|
||||
|
||||
Raises:
|
||||
TypeError: if kl_fn is not a callable.
|
||||
ValueError: if a KL divergence function has already been registered for
|
||||
the given argument classes.
|
||||
"""
|
||||
if not callable(kl_fn):
|
||||
raise TypeError("kl_fn must be callable, received: %s" % kl_fn)
|
||||
if self._key in _DIVERGENCES:
|
||||
raise ValueError("KL(%s || %s) has already been registered to: %s"
|
||||
% (self._key[0].__name__, self._key[1].__name__,
|
||||
_DIVERGENCES[self._key]))
|
||||
_DIVERGENCES[self._key] = kl_fn
|
||||
return kl_fn
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import math
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -317,3 +318,27 @@ class Normal(distribution.ContinuousDistribution):
|
||||
|
||||
def _zeros(self):
|
||||
return array_ops.zeros_like(self._mu + self._sigma)
|
||||
|
||||
|
||||
@kullback_leibler.RegisterKL(Normal, Normal)
|
||||
def _kl_normal_normal(n_a, n_b, name=None):
|
||||
"""Calculate the batched KL divergence KL(n_a || n_b) with n_a and n_b Normal.
|
||||
|
||||
Args:
|
||||
n_a: instance of a Normal distribution object.
|
||||
n_b: instance of a Normal distribution object.
|
||||
name: (optional) Name to use for created operations.
|
||||
default is "kl_normal_normal".
|
||||
|
||||
Returns:
|
||||
Batchwise KL(n_a || n_b)
|
||||
"""
|
||||
with ops.op_scope([n_a.mu, n_b.mu], name, "kl_normal_normal"):
|
||||
one = constant_op.constant(1, dtype=n_a.dtype)
|
||||
two = constant_op.constant(2, dtype=n_a.dtype)
|
||||
half = constant_op.constant(0.5, dtype=n_a.dtype)
|
||||
s_a_squared = math_ops.square(n_a.sigma)
|
||||
s_b_squared = math_ops.square(n_b.sigma)
|
||||
ratio = s_a_squared / s_b_squared
|
||||
return (math_ops.square(n_a.mu - n_b.mu) / (two * s_b_squared)
|
||||
+ half * (ratio - one - math_ops.log(ratio)))
|
||||
|
Loading…
Reference in New Issue
Block a user