BUGFIX: Ensure that rejected states don't propagate NaN. Make float64 work.

PiperOrigin-RevId: 178307112
This commit is contained in:
Ian Langmore 2017-12-07 15:39:40 -08:00 committed by TensorFlower Gardener
parent 6bb91a0712
commit 7b0458a789
2 changed files with 91 additions and 18 deletions

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Hamiltonian Monte Carlo.
"""
"""Tests for Hamiltonian Monte Carlo."""
from __future__ import absolute_import
from __future__ import division
@ -27,6 +26,7 @@ from tensorflow.contrib.bayesflow.python.ops import hmc
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@ -46,6 +46,9 @@ class HMCTest(test.TestCase):
random_seed.set_random_seed(10003)
np.random.seed(10003)
def assertAllFinite(self, x):
self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x))
def _log_gamma_log_prob(self, x, event_dims=()):
"""Computes log-pdf of a log-gamma random variable.
@ -375,5 +378,67 @@ class HMCTest(test.TestCase):
self.assertAllEqual(initial_x_val, updated_x_val)
self.assertEqual(acceptance_probs_val, 0.)
def testNanFromGradsDontPropagate(self):
"""Test that update with NaN gradients does not cause NaN in results."""
def _nan_log_prob_with_nan_gradient(x):
return np.nan * math_ops.reduce_sum(x)
with self.test_session() as sess:
initial_x = math_ops.linspace(0.01, 5, 10)
updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel(
2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0])
initial_x_val, updated_x_val, acceptance_probs_val = sess.run(
[initial_x, updated_x, acceptance_probs])
logging.vlog(1, 'initial_x = {}'.format(initial_x_val))
logging.vlog(1, 'updated_x = {}'.format(updated_x_val))
logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val))
self.assertAllEqual(initial_x_val, updated_x_val)
self.assertEqual(acceptance_probs_val, 0.)
self.assertAllFinite(
gradients_impl.gradients(updated_x, initial_x)[0].eval())
self.assertTrue(
gradients_impl.gradients(new_grad, initial_x)[0] is None)
# Gradients of the acceptance probs and new log prob are not finite.
_ = new_log_prob # Prevent unused arg error.
# self.assertAllFinite(
# gradients_impl.gradients(acceptance_probs, initial_x)[0].eval())
# self.assertAllFinite(
# gradients_impl.gradients(new_log_prob, initial_x)[0].eval())
def testChainWorksIn64Bit(self):
def log_prob(x):
return - math_ops.reduce_sum(x * x, axis=-1)
states, acceptance_probs = hmc.chain(
n_iterations=10,
step_size=np.float64(0.01),
n_leapfrog_steps=10,
initial_x=np.zeros(5).astype(np.float64),
target_log_prob_fn=log_prob,
event_dims=[-1])
with self.test_session() as sess:
states_, acceptance_probs_ = sess.run([states, acceptance_probs])
self.assertEqual(np.float64, states_.dtype)
self.assertEqual(np.float64, acceptance_probs_.dtype)
def testChainWorksIn16Bit(self):
def log_prob(x):
return - math_ops.reduce_sum(x * x, axis=-1)
states, acceptance_probs = hmc.chain(
n_iterations=10,
step_size=np.float16(0.01),
n_leapfrog_steps=10,
initial_x=np.zeros(5).astype(np.float16),
target_log_prob_fn=log_prob,
event_dims=[-1])
with self.test_session() as sess:
states_, acceptance_probs_ = sess.run([states, acceptance_probs])
self.assertEqual(np.float16, states_.dtype)
self.assertEqual(np.float16, acceptance_probs_.dtype)
if __name__ == '__main__':
test.main()

View File

@ -27,6 +27,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -174,9 +175,11 @@ def chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
potential_and_grad = _make_potential_and_grad(target_log_prob_fn)
potential, grad = potential_and_grad(initial_x)
return functional_ops.scan(body, array_ops.zeros(n_iterations),
(initial_x, array_ops.zeros(non_event_shape),
-potential, -grad))[:2]
return functional_ops.scan(
body, array_ops.zeros(n_iterations, dtype=initial_x.dtype),
(initial_x,
array_ops.zeros(non_event_shape, dtype=initial_x.dtype),
-potential, -grad))[:2]
def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
@ -298,8 +301,9 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
return updated_x, acceptance_probs, w
x, acceptance_probs, w = functional_ops.scan(
_body, beta_series, (initial_x, array_ops.zeros(non_event_shape),
array_ops.zeros(non_event_shape)))
_body, beta_series,
(initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype),
array_ops.zeros(non_event_shape, dtype=initial_x.dtype)))
return w[-1], x[-1], acceptance_probs[-1]
@ -446,9 +450,10 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(),
"""
with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]):
potential_and_grad = _make_potential_and_grad(target_log_prob_fn)
x = ops.convert_to_tensor(x, name='x')
x_shape = array_ops.shape(x)
m = random_ops.random_normal(x_shape)
m = random_ops.random_normal(x_shape, dtype=x.dtype)
kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims)
@ -475,23 +480,26 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(),
array_ops.fill(array_ops.shape(energy_change),
energy_change.dtype.as_numpy_dtype(np.inf)),
energy_change)
acceptance_probs = math_ops.exp(math_ops.minimum(0., -energy_change))
accepted = math_ops.cast(
random_ops.random_uniform(array_ops.shape(acceptance_probs)) <
acceptance_probs, log_potential_0.dtype)
new_log_prob = (-log_potential_0 * (1. - accepted) -
log_potential_1 * accepted)
acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.))
accepted = (
random_ops.random_uniform(
array_ops.shape(acceptance_probs), dtype=x.dtype)
< acceptance_probs)
new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0)
# TODO(b/65738010): This should work, but it doesn't for now.
# reduced_shape = math_ops.reduced_shape(x_shape, event_dims)
reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims,
keep_dims=True))
accepted = array_ops.reshape(accepted, reduced_shape)
accepted = math_ops.cast(accepted, x.dtype)
new_x = x * (1. - accepted) + new_x * accepted
accepted = math_ops.cast(accepted, accepted.dtype)
new_grad = -grad_0 * (1. - accepted) - grad_1 * accepted
accepted = math_ops.logical_or(
accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool))
new_x = array_ops.where(accepted, new_x, x)
new_grad = -array_ops.where(accepted, grad_1, grad_0)
# TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect
# to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This
# should be fixed.
return new_x, acceptance_probs, new_log_prob, new_grad