BUGFIX: Ensure that rejected states don't propagate NaN. Make float64 work.
PiperOrigin-RevId: 178307112
This commit is contained in:
parent
6bb91a0712
commit
7b0458a789
@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Hamiltonian Monte Carlo.
|
||||
"""
|
||||
"""Tests for Hamiltonian Monte Carlo."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -27,6 +26,7 @@ from tensorflow.contrib.bayesflow.python.ops import hmc
|
||||
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -46,6 +46,9 @@ class HMCTest(test.TestCase):
|
||||
random_seed.set_random_seed(10003)
|
||||
np.random.seed(10003)
|
||||
|
||||
def assertAllFinite(self, x):
|
||||
self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x))
|
||||
|
||||
def _log_gamma_log_prob(self, x, event_dims=()):
|
||||
"""Computes log-pdf of a log-gamma random variable.
|
||||
|
||||
@ -375,5 +378,67 @@ class HMCTest(test.TestCase):
|
||||
self.assertAllEqual(initial_x_val, updated_x_val)
|
||||
self.assertEqual(acceptance_probs_val, 0.)
|
||||
|
||||
def testNanFromGradsDontPropagate(self):
|
||||
"""Test that update with NaN gradients does not cause NaN in results."""
|
||||
def _nan_log_prob_with_nan_gradient(x):
|
||||
return np.nan * math_ops.reduce_sum(x)
|
||||
|
||||
with self.test_session() as sess:
|
||||
initial_x = math_ops.linspace(0.01, 5, 10)
|
||||
updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel(
|
||||
2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0])
|
||||
initial_x_val, updated_x_val, acceptance_probs_val = sess.run(
|
||||
[initial_x, updated_x, acceptance_probs])
|
||||
|
||||
logging.vlog(1, 'initial_x = {}'.format(initial_x_val))
|
||||
logging.vlog(1, 'updated_x = {}'.format(updated_x_val))
|
||||
logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val))
|
||||
|
||||
self.assertAllEqual(initial_x_val, updated_x_val)
|
||||
self.assertEqual(acceptance_probs_val, 0.)
|
||||
|
||||
self.assertAllFinite(
|
||||
gradients_impl.gradients(updated_x, initial_x)[0].eval())
|
||||
self.assertTrue(
|
||||
gradients_impl.gradients(new_grad, initial_x)[0] is None)
|
||||
|
||||
# Gradients of the acceptance probs and new log prob are not finite.
|
||||
_ = new_log_prob # Prevent unused arg error.
|
||||
# self.assertAllFinite(
|
||||
# gradients_impl.gradients(acceptance_probs, initial_x)[0].eval())
|
||||
# self.assertAllFinite(
|
||||
# gradients_impl.gradients(new_log_prob, initial_x)[0].eval())
|
||||
|
||||
def testChainWorksIn64Bit(self):
|
||||
def log_prob(x):
|
||||
return - math_ops.reduce_sum(x * x, axis=-1)
|
||||
states, acceptance_probs = hmc.chain(
|
||||
n_iterations=10,
|
||||
step_size=np.float64(0.01),
|
||||
n_leapfrog_steps=10,
|
||||
initial_x=np.zeros(5).astype(np.float64),
|
||||
target_log_prob_fn=log_prob,
|
||||
event_dims=[-1])
|
||||
with self.test_session() as sess:
|
||||
states_, acceptance_probs_ = sess.run([states, acceptance_probs])
|
||||
self.assertEqual(np.float64, states_.dtype)
|
||||
self.assertEqual(np.float64, acceptance_probs_.dtype)
|
||||
|
||||
def testChainWorksIn16Bit(self):
|
||||
def log_prob(x):
|
||||
return - math_ops.reduce_sum(x * x, axis=-1)
|
||||
states, acceptance_probs = hmc.chain(
|
||||
n_iterations=10,
|
||||
step_size=np.float16(0.01),
|
||||
n_leapfrog_steps=10,
|
||||
initial_x=np.zeros(5).astype(np.float16),
|
||||
target_log_prob_fn=log_prob,
|
||||
event_dims=[-1])
|
||||
with self.test_session() as sess:
|
||||
states_, acceptance_probs_ = sess.run([states, acceptance_probs])
|
||||
self.assertEqual(np.float16, states_.dtype)
|
||||
self.assertEqual(np.float16, acceptance_probs_.dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -27,6 +27,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -174,9 +175,11 @@ def chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
|
||||
|
||||
potential_and_grad = _make_potential_and_grad(target_log_prob_fn)
|
||||
potential, grad = potential_and_grad(initial_x)
|
||||
return functional_ops.scan(body, array_ops.zeros(n_iterations),
|
||||
(initial_x, array_ops.zeros(non_event_shape),
|
||||
-potential, -grad))[:2]
|
||||
return functional_ops.scan(
|
||||
body, array_ops.zeros(n_iterations, dtype=initial_x.dtype),
|
||||
(initial_x,
|
||||
array_ops.zeros(non_event_shape, dtype=initial_x.dtype),
|
||||
-potential, -grad))[:2]
|
||||
|
||||
|
||||
def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
|
||||
@ -298,8 +301,9 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
|
||||
return updated_x, acceptance_probs, w
|
||||
|
||||
x, acceptance_probs, w = functional_ops.scan(
|
||||
_body, beta_series, (initial_x, array_ops.zeros(non_event_shape),
|
||||
array_ops.zeros(non_event_shape)))
|
||||
_body, beta_series,
|
||||
(initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype),
|
||||
array_ops.zeros(non_event_shape, dtype=initial_x.dtype)))
|
||||
return w[-1], x[-1], acceptance_probs[-1]
|
||||
|
||||
|
||||
@ -446,9 +450,10 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(),
|
||||
"""
|
||||
with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]):
|
||||
potential_and_grad = _make_potential_and_grad(target_log_prob_fn)
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
|
||||
x_shape = array_ops.shape(x)
|
||||
m = random_ops.random_normal(x_shape)
|
||||
m = random_ops.random_normal(x_shape, dtype=x.dtype)
|
||||
|
||||
kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims)
|
||||
|
||||
@ -475,23 +480,26 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(),
|
||||
array_ops.fill(array_ops.shape(energy_change),
|
||||
energy_change.dtype.as_numpy_dtype(np.inf)),
|
||||
energy_change)
|
||||
acceptance_probs = math_ops.exp(math_ops.minimum(0., -energy_change))
|
||||
accepted = math_ops.cast(
|
||||
random_ops.random_uniform(array_ops.shape(acceptance_probs)) <
|
||||
acceptance_probs, log_potential_0.dtype)
|
||||
new_log_prob = (-log_potential_0 * (1. - accepted) -
|
||||
log_potential_1 * accepted)
|
||||
acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.))
|
||||
accepted = (
|
||||
random_ops.random_uniform(
|
||||
array_ops.shape(acceptance_probs), dtype=x.dtype)
|
||||
< acceptance_probs)
|
||||
new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0)
|
||||
|
||||
# TODO(b/65738010): This should work, but it doesn't for now.
|
||||
# reduced_shape = math_ops.reduced_shape(x_shape, event_dims)
|
||||
reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims,
|
||||
keep_dims=True))
|
||||
accepted = array_ops.reshape(accepted, reduced_shape)
|
||||
accepted = math_ops.cast(accepted, x.dtype)
|
||||
new_x = x * (1. - accepted) + new_x * accepted
|
||||
accepted = math_ops.cast(accepted, accepted.dtype)
|
||||
new_grad = -grad_0 * (1. - accepted) - grad_1 * accepted
|
||||
accepted = math_ops.logical_or(
|
||||
accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool))
|
||||
new_x = array_ops.where(accepted, new_x, x)
|
||||
new_grad = -array_ops.where(accepted, grad_1, grad_0)
|
||||
|
||||
# TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect
|
||||
# to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This
|
||||
# should be fixed.
|
||||
return new_x, acceptance_probs, new_log_prob, new_grad
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user