diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index 1ab819d7976..cbc66b6dc13 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Hamiltonian Monte Carlo. -""" +"""Tests for Hamiltonian Monte Carlo.""" from __future__ import absolute_import from __future__ import division @@ -27,6 +26,7 @@ from tensorflow.contrib.bayesflow.python.ops import hmc from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -46,6 +46,9 @@ class HMCTest(test.TestCase): random_seed.set_random_seed(10003) np.random.seed(10003) + def assertAllFinite(self, x): + self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) + def _log_gamma_log_prob(self, x, event_dims=()): """Computes log-pdf of a log-gamma random variable. @@ -375,5 +378,67 @@ class HMCTest(test.TestCase): self.assertAllEqual(initial_x_val, updated_x_val) self.assertEqual(acceptance_probs_val, 0.) + def testNanFromGradsDontPropagate(self): + """Test that update with NaN gradients does not cause NaN in results.""" + def _nan_log_prob_with_nan_gradient(x): + return np.nan * math_ops.reduce_sum(x) + + with self.test_session() as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel( + 2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) + + self.assertAllFinite( + gradients_impl.gradients(updated_x, initial_x)[0].eval()) + self.assertTrue( + gradients_impl.gradients(new_grad, initial_x)[0] is None) + + # Gradients of the acceptance probs and new log prob are not finite. + _ = new_log_prob # Prevent unused arg error. + # self.assertAllFinite( + # gradients_impl.gradients(acceptance_probs, initial_x)[0].eval()) + # self.assertAllFinite( + # gradients_impl.gradients(new_log_prob, initial_x)[0].eval()) + + def testChainWorksIn64Bit(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float64(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float64), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float64, states_.dtype) + self.assertEqual(np.float64, acceptance_probs_.dtype) + + def testChainWorksIn16Bit(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float16(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float16), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float16, states_.dtype) + self.assertEqual(np.float16, acceptance_probs_.dtype) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index da788be3db0..5685a942e98 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -27,6 +27,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -174,9 +175,11 @@ def chain(n_iterations, step_size, n_leapfrog_steps, initial_x, potential_and_grad = _make_potential_and_grad(target_log_prob_fn) potential, grad = potential_and_grad(initial_x) - return functional_ops.scan(body, array_ops.zeros(n_iterations), - (initial_x, array_ops.zeros(non_event_shape), - -potential, -grad))[:2] + return functional_ops.scan( + body, array_ops.zeros(n_iterations, dtype=initial_x.dtype), + (initial_x, + array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + -potential, -grad))[:2] def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, @@ -298,8 +301,9 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, return updated_x, acceptance_probs, w x, acceptance_probs, w = functional_ops.scan( - _body, beta_series, (initial_x, array_ops.zeros(non_event_shape), - array_ops.zeros(non_event_shape))) + _body, beta_series, + (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + array_ops.zeros(non_event_shape, dtype=initial_x.dtype))) return w[-1], x[-1], acceptance_probs[-1] @@ -446,9 +450,10 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), """ with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]): potential_and_grad = _make_potential_and_grad(target_log_prob_fn) + x = ops.convert_to_tensor(x, name='x') x_shape = array_ops.shape(x) - m = random_ops.random_normal(x_shape) + m = random_ops.random_normal(x_shape, dtype=x.dtype) kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims) @@ -475,23 +480,26 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), array_ops.fill(array_ops.shape(energy_change), energy_change.dtype.as_numpy_dtype(np.inf)), energy_change) - acceptance_probs = math_ops.exp(math_ops.minimum(0., -energy_change)) - accepted = math_ops.cast( - random_ops.random_uniform(array_ops.shape(acceptance_probs)) < - acceptance_probs, log_potential_0.dtype) - new_log_prob = (-log_potential_0 * (1. - accepted) - - log_potential_1 * accepted) + acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.)) + accepted = ( + random_ops.random_uniform( + array_ops.shape(acceptance_probs), dtype=x.dtype) + < acceptance_probs) + new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0) # TODO(b/65738010): This should work, but it doesn't for now. # reduced_shape = math_ops.reduced_shape(x_shape, event_dims) reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims, keep_dims=True)) accepted = array_ops.reshape(accepted, reduced_shape) - accepted = math_ops.cast(accepted, x.dtype) - new_x = x * (1. - accepted) + new_x * accepted - accepted = math_ops.cast(accepted, accepted.dtype) - new_grad = -grad_0 * (1. - accepted) - grad_1 * accepted + accepted = math_ops.logical_or( + accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool)) + new_x = array_ops.where(accepted, new_x, x) + new_grad = -array_ops.where(accepted, grad_1, grad_0) + # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect + # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This + # should be fixed. return new_x, acceptance_probs, new_log_prob, new_grad