diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 8e8fc886017..704b5dd2c4b 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -65,6 +65,17 @@ cuda_py_tests( ], ) +cuda_py_tests( + name = "kullback_leibler_test", + size = "small", + srcs = ["python/kernel_tests/kullback_leibler_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_tests( name = "student_t_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 496e3bbb2d6..7e19c82c1e7 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -49,6 +49,12 @@ representing the posterior or posterior predictive. @@normal_conjugates_known_sigma_posterior @@normal_congugates_known_sigma_predictive + +## Kullback Leibler Divergence + +@@kl +@@RegisterKL + """ from __future__ import absolute_import from __future__ import division @@ -62,6 +68,7 @@ from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * from tensorflow.contrib.distributions.python.ops.distribution import * from tensorflow.contrib.distributions.python.ops.exponential import * from tensorflow.contrib.distributions.python.ops.gamma import * +from tensorflow.contrib.distributions.python.ops.kullback_leibler import * from tensorflow.contrib.distributions.python.ops.mvn import * from tensorflow.contrib.distributions.python.ops.normal import * from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py new file mode 100644 index 00000000000..ea1395eb9d2 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/kullback_leibler_test.py @@ -0,0 +1,83 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for distributions KL mechanism.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +class KLTest(tf.test.TestCase): + + def testRegistration(self): + class MyDist(tf.contrib.distributions.Normal): + pass + + # Register KL to a lambda that spits out the name parameter + @tf.contrib.distributions.RegisterKL(MyDist, MyDist) + def _kl(unused_a, unused_b, name=None): # pylint: disable=unused-variable + return name + + a = MyDist(mu=0.0, sigma=1.0) + # Run kl() with allow_nan=True because strings can't go through is_nan. + self.assertEqual( + "OK", tf.contrib.distributions.kl(a, a, allow_nan=True, name="OK")) + + def testDomainErrorExceptions(self): + class MyDistException(tf.contrib.distributions.Normal): + pass + + # Register KL to a lambda that spits out the name parameter + @tf.contrib.distributions.RegisterKL(MyDistException, MyDistException) + # pylint: disable=unused-variable + def _kl(unused_a, unused_b, name=None): # pylint: disable=unused-argument + return tf.identity([float("nan")]) + # pylint: disable=unused-variable + + with self.test_session(): + a = MyDistException(mu=0.0, sigma=1.0) + kl = tf.contrib.distributions.kl(a, a) + with self.assertRaisesOpError( + "KL calculation between .* and .* returned NaN values"): + kl.eval() + kl_ok = tf.contrib.distributions.kl(a, a, allow_nan=True) + self.assertAllEqual([float("nan")], kl_ok.eval()) + + def testRegistrationFailures(self): + with self.assertRaisesRegexp(TypeError, "is not a subclass of"): + tf.contrib.distributions.RegisterKL( + tf.contrib.distributions.Normal, object)(lambda x: x) + with self.assertRaisesRegexp(TypeError, "is not a subclass of"): + tf.contrib.distributions.RegisterKL( + object, tf.contrib.distributions.Normal)(lambda x: x) + + class MyDist(tf.contrib.distributions.Normal): + pass + + with self.assertRaisesRegexp(TypeError, "must be callable"): + tf.contrib.distributions.RegisterKL(MyDist, MyDist)("blah") + + # First registration is OK + tf.contrib.distributions.RegisterKL(MyDist, MyDist)(lambda a, b: None) + + # Second registration fails + with self.assertRaisesRegexp(ValueError, "has already been registered"): + tf.contrib.distributions.RegisterKL(MyDist, MyDist)(lambda a, b: None) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py index fc9766e22e4..46ea200ae23 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/normal_test.py @@ -253,6 +253,28 @@ class NormalTest(tf.test.TestCase): feed_dict={mu: 5.0, sigma: [1.0, 2.0]}), [2]) + def testNormalNormalKL(self): + with self.test_session() as sess: + batch_size = 6 + mu_a = np.array([3.0] * batch_size) + sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) + mu_b = np.array([-3.0] * batch_size) + sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) + + n_a = tf.contrib.distributions.Normal(mu=mu_a, sigma=sigma_a) + n_b = tf.contrib.distributions.Normal(mu=mu_b, sigma=sigma_b) + + kl = tf.contrib.distributions.kl(n_a, n_b) + kl_val = sess.run(kl) + + kl_expected = ( + (mu_a - mu_b)**2 / (2 * sigma_b**2) + + 0.5 * ((sigma_a**2/sigma_b**2) - + 1 - 2 * np.log(sigma_a / sigma_b))) + + self.assertEqual(kl.get_shape(), (batch_size,)) + self.assertAllClose(kl_val, kl_expected) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py new file mode 100644 index 00000000000..ac3801bbe27 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py @@ -0,0 +1,130 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Registration and usage mechanisms for KL-divergences.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops + + +_DIVERGENCES = {} + + +def kl(dist_a, dist_b, allow_nan=False, name=None): + """Get the KL-divergence KL(dist_a || dist_b). + + Args: + dist_a: instance of distributions.BaseDistribution. + dist_b: instance of distributions.BaseDistribution. + allow_nan: If False (default), a runtime error is raised + if the KL returns NaN values for any batch entry of the given + distributions. If True, the KL may return a NaN for the given entry. + name: (optional) Name scope to use for created operations. + + Returns: + A Tensor with the batchwise KL-divergence between dist_a and dist_b. + + Raises: + TypeError: If dist_a or dist_b is not an instance of BaseDistribution. + NotImplementedError: If no KL method is defined for distribution types + of dist_a and dist_b. + """ + if not isinstance(dist_a, distribution.BaseDistribution): + raise TypeError( + "dist_a is not an instance of BaseDistribution, received type: %s" + % type(dist_a)) + if not isinstance(dist_b, distribution.BaseDistribution): + raise TypeError( + "dist_b is not an instance of BaseDistribution, received type: %s" + % type(dist_b)) + kl_fn = _DIVERGENCES.get((type(dist_a), type(dist_b)), None) + if kl_fn is None: + raise NotImplementedError( + "No KL(dist_a || dist_b) registered for dist_a type %s and dist_b " + "type %s" % ((type(dist_a).__name__, type(dist_b).__name__))) + with ops.name_scope("KullbackLeibler"): + kl_t = kl_fn(dist_a, dist_b, name=name) + if allow_nan: + return kl_t + + # Check KL for NaNs + kl_t = array_ops.identity(kl_t, name="kl") + + with ops.control_dependencies([ + logging_ops.Assert( + math_ops.logical_not( + math_ops.reduce_any(math_ops.is_nan(kl_t))), + ["KL calculation between %s and %s returned NaN values " + "(and was called with allow_nan=False). Values:" + % (dist_a.name, dist_b.name), kl_t])]): + return array_ops.identity(kl_t, name="checked_kl") + + +class RegisterKL(object): + """Decorator to register a KL divergence implementation function. + + Usage: + + @distributions.RegisterKL(distributions.Normal, distributions.Normal) + def _kl_normal_mvn(norm_a, norm_b): + # Return KL(norm_a || norm_b) + """ + + def __init__(self, dist_cls_a, dist_cls_b): + """Initialize the KL registrar. + + Args: + dist_cls_a: the class of the first argument of the KL divergence. + dist_cls_b: the class of the second argument of the KL divergence. + + Raises: + TypeError: if dist_cls_a or dist_cls_b are not subclasses of + BaseDistribution. + """ + + if not issubclass(dist_cls_a, distribution.BaseDistribution): + raise TypeError("%s is not a subclass of BaseDistribution" % dist_cls_a) + if not issubclass(dist_cls_b, distribution.BaseDistribution): + raise TypeError("%s is not a subclass of BaseDistribution" % dist_cls_b) + self._key = (dist_cls_a, dist_cls_b) + + def __call__(self, kl_fn): + """Perform the KL registration. + + Args: + kl_fn: The function to use for the KL divergence. + + Returns: + kl_fn + + Raises: + TypeError: if kl_fn is not a callable. + ValueError: if a KL divergence function has already been registered for + the given argument classes. + """ + if not callable(kl_fn): + raise TypeError("kl_fn must be callable, received: %s" % kl_fn) + if self._key in _DIVERGENCES: + raise ValueError("KL(%s || %s) has already been registered to: %s" + % (self._key[0].__name__, self._key[1].__name__, + _DIVERGENCES[self._key])) + _DIVERGENCES[self._key] = kl_fn + return kl_fn diff --git a/tensorflow/contrib/distributions/python/ops/normal.py b/tensorflow/contrib/distributions/python/ops/normal.py index bfa8e474c4d..c233bb8a7db 100644 --- a/tensorflow/contrib/distributions/python/ops/normal.py +++ b/tensorflow/contrib/distributions/python/ops/normal.py @@ -21,6 +21,7 @@ from __future__ import print_function import math from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long +from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -317,3 +318,27 @@ class Normal(distribution.ContinuousDistribution): def _zeros(self): return array_ops.zeros_like(self._mu + self._sigma) + + +@kullback_leibler.RegisterKL(Normal, Normal) +def _kl_normal_normal(n_a, n_b, name=None): + """Calculate the batched KL divergence KL(n_a || n_b) with n_a and n_b Normal. + + Args: + n_a: instance of a Normal distribution object. + n_b: instance of a Normal distribution object. + name: (optional) Name to use for created operations. + default is "kl_normal_normal". + + Returns: + Batchwise KL(n_a || n_b) + """ + with ops.op_scope([n_a.mu, n_b.mu], name, "kl_normal_normal"): + one = constant_op.constant(1, dtype=n_a.dtype) + two = constant_op.constant(2, dtype=n_a.dtype) + half = constant_op.constant(0.5, dtype=n_a.dtype) + s_a_squared = math_ops.square(n_a.sigma) + s_b_squared = math_ops.square(n_b.sigma) + ratio = s_a_squared / s_b_squared + return (math_ops.square(n_a.mu - n_b.mu) / (two * s_b_squared) + + half * (ratio - one - math_ops.log(ratio)))