Forwardprop: Add some tests that compute Hessians, including using pfor
Still not completely sure how we'd want to wrap this up. An API requiring variables/parameters up-front would certainly be the easiest. Just wanted a proof-of-concept before adding a symbol for forwardprop. PiperOrigin-RevId: 272060371
This commit is contained in:
parent
11c3c50cf9
commit
0eaefe9fdc
@ -42,11 +42,13 @@ from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.parallel_for import control_flow_ops
|
||||
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
@ -107,7 +109,52 @@ def _grad(f, argnums=0):
|
||||
|
||||
def _hvp(f, primals, tangents):
|
||||
"""Compute a forward-over-back Hessian-vector product."""
|
||||
return _jvp(_grad(f), primals, tangents)[1]
|
||||
with forwardprop.ForwardAccumulator(primals, tangents) as acc:
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(primals)
|
||||
f_out = f(*primals)
|
||||
f_out.shape.assert_is_compatible_with([])
|
||||
return acc.jvp(tape.gradient(f_out, primals))
|
||||
|
||||
|
||||
def _vectorize_parameters(f, params, use_pfor, dtype):
|
||||
"""Loop over `params`, providing a one-hot mask to `f` for each."""
|
||||
parameter_sizes = [array_ops.size(param) for param in params]
|
||||
total_size = math_ops.add_n(parameter_sizes)
|
||||
|
||||
def _wrapper(index):
|
||||
full_onehot = array_ops.one_hot(index, total_size)
|
||||
split_onehot = array_ops.split(full_onehot, parameter_sizes)
|
||||
tangents = [array_ops.reshape(v, array_ops.shape(param))
|
||||
for param, v in zip(params, split_onehot)]
|
||||
return f(tangents)
|
||||
|
||||
if use_pfor:
|
||||
return control_flow_ops.vectorized_map(_wrapper, math_ops.range(total_size))
|
||||
else:
|
||||
return map_fn.map_fn(_wrapper, math_ops.range(total_size), dtype)
|
||||
|
||||
|
||||
def _forward_over_back_hessian(f, params, use_pfor, dtype=None):
|
||||
"""Computes the full Hessian matrix for the scalar-valued f(*params).
|
||||
|
||||
Args:
|
||||
f: A function taking `params` and returning a scalar.
|
||||
params: A possibly nested structure of tensors.
|
||||
use_pfor: If true, uses `tf.vectorized_map` calls instead of looping.
|
||||
dtype: Required if `use_pfor=False`. A possibly nested structure of dtypes
|
||||
(e.g. `tf.float32`) matching the structure of `f`'s returns.
|
||||
|
||||
Returns:
|
||||
A possibly nested structure of matrix slices corresponding to `params`. Each
|
||||
slice has shape [P, p_s] where `p_s` is the number of parameters (`tf.size`)
|
||||
in the corresponding element of `params` and `P` is the total number of
|
||||
parameters (`sum_s(p_s)`). The full matrix can be obtained by concatenating
|
||||
along the second axis.
|
||||
"""
|
||||
return _vectorize_parameters(
|
||||
functools.partial(_hvp, f, params),
|
||||
params, use_pfor=use_pfor, dtype=dtype)
|
||||
|
||||
|
||||
def _test_gradients(testcase,
|
||||
@ -374,7 +421,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
loss = _loss()
|
||||
vector = tape.gradient(loss, model.trainable_variables)
|
||||
variable_input_fn = lambda unused_variables: _loss()
|
||||
forward_over_back_hvp = _hvp(
|
||||
forward_over_back_hvp, = _hvp(
|
||||
variable_input_fn, [model.trainable_variables], [vector])
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
tape.watch(model.trainable_variables)
|
||||
@ -557,9 +604,9 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
primals = constant_op.constant([1., 2., 3.])
|
||||
tangents = constant_op.constant([3., 4., 5.])
|
||||
forwardback_hvp_eager = _hvp(fun, (primals,), (tangents,))
|
||||
forwardback_hvp_function = def_function.function(_hvp)(fun, (primals,),
|
||||
(tangents,))
|
||||
forwardback_hvp_eager, = _hvp(fun, (primals,), (tangents,))
|
||||
forwardback_hvp_function, = def_function.function(_hvp)(fun, (primals,),
|
||||
(tangents,))
|
||||
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
g.watch(primals)
|
||||
@ -812,6 +859,57 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsNone(acc.jvp(result))
|
||||
|
||||
|
||||
class HessianTests(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testHessian1D(self):
|
||||
# Note: stolen from ops/gradients_test.py
|
||||
m = 4
|
||||
rng = np.random.RandomState([1, 2, 3])
|
||||
mat_value = rng.randn(m, m).astype("float32")
|
||||
x_value = rng.randn(m).astype("float32")
|
||||
hess_value = mat_value + mat_value.T
|
||||
mat = variables.Variable(mat_value)
|
||||
|
||||
def _f(x):
|
||||
return math_ops.reduce_sum(x[:, None] * mat * x[None, :])
|
||||
|
||||
hessian_eager, = _forward_over_back_hessian(
|
||||
_f, [constant_op.constant(x_value)],
|
||||
use_pfor=False, dtype=[dtypes.float32])
|
||||
self.assertAllClose(hess_value, hessian_eager)
|
||||
hessian_function, = def_function.function(_forward_over_back_hessian)(
|
||||
_f, [constant_op.constant(x_value)],
|
||||
use_pfor=False, dtype=[dtypes.float32])
|
||||
self.assertAllClose(hess_value, hessian_function)
|
||||
hessian_pfor, = def_function.function(_forward_over_back_hessian)(
|
||||
_f, [constant_op.constant(x_value)],
|
||||
use_pfor=True, dtype=[dtypes.float32])
|
||||
self.assertAllClose(hess_value, hessian_pfor)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
[("PFor", True),
|
||||
("MapFn", False)])
|
||||
def testHessianOfVariables(self, use_pfor):
|
||||
model = core.Dense(1)
|
||||
model.build([2])
|
||||
|
||||
def _loss(*unused_args):
|
||||
input_value = constant_op.constant([[-0.5, 1.], [0.5, -1.]])
|
||||
target = constant_op.constant([[-1.], [2.]])
|
||||
return math_ops.reduce_sum((model(input_value) - target) ** 2.)
|
||||
|
||||
kernel_hess, bias_hess = _forward_over_back_hessian(
|
||||
_loss, [model.kernel, model.bias], use_pfor=use_pfor,
|
||||
dtype=[dtypes.float32, dtypes.float32])
|
||||
# 3 total parameters, the whole hessian is the 3x3 concatenation
|
||||
self.assertEqual([3, 2, 1], kernel_hess.shape)
|
||||
self.assertEqual([3, 1], bias_hess.shape)
|
||||
full_hessian = array_ops.concat(
|
||||
[array_ops.reshape(kernel_hess, [3, 2]), bias_hess], axis=1)
|
||||
# The full Hessian should be symmetric.
|
||||
self.assertAllClose(full_hessian, array_ops.transpose(full_hessian))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(allenl): Also test with 1.x-style graph mode.
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
x
Reference in New Issue
Block a user