diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index d6905f0a850..b91592413a4 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -42,11 +42,13 @@ from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gradient_checker_v2 +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.parallel_for import control_flow_ops from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients from tensorflow.python.platform import test from tensorflow.python.util import nest @@ -107,7 +109,52 @@ def _grad(f, argnums=0): def _hvp(f, primals, tangents): """Compute a forward-over-back Hessian-vector product.""" - return _jvp(_grad(f), primals, tangents)[1] + with forwardprop.ForwardAccumulator(primals, tangents) as acc: + with backprop.GradientTape() as tape: + tape.watch(primals) + f_out = f(*primals) + f_out.shape.assert_is_compatible_with([]) + return acc.jvp(tape.gradient(f_out, primals)) + + +def _vectorize_parameters(f, params, use_pfor, dtype): + """Loop over `params`, providing a one-hot mask to `f` for each.""" + parameter_sizes = [array_ops.size(param) for param in params] + total_size = math_ops.add_n(parameter_sizes) + + def _wrapper(index): + full_onehot = array_ops.one_hot(index, total_size) + split_onehot = array_ops.split(full_onehot, parameter_sizes) + tangents = [array_ops.reshape(v, array_ops.shape(param)) + for param, v in zip(params, split_onehot)] + return f(tangents) + + if use_pfor: + return control_flow_ops.vectorized_map(_wrapper, math_ops.range(total_size)) + else: + return map_fn.map_fn(_wrapper, math_ops.range(total_size), dtype) + + +def _forward_over_back_hessian(f, params, use_pfor, dtype=None): + """Computes the full Hessian matrix for the scalar-valued f(*params). + + Args: + f: A function taking `params` and returning a scalar. + params: A possibly nested structure of tensors. + use_pfor: If true, uses `tf.vectorized_map` calls instead of looping. + dtype: Required if `use_pfor=False`. A possibly nested structure of dtypes + (e.g. `tf.float32`) matching the structure of `f`'s returns. + + Returns: + A possibly nested structure of matrix slices corresponding to `params`. Each + slice has shape [P, p_s] where `p_s` is the number of parameters (`tf.size`) + in the corresponding element of `params` and `P` is the total number of + parameters (`sum_s(p_s)`). The full matrix can be obtained by concatenating + along the second axis. + """ + return _vectorize_parameters( + functools.partial(_hvp, f, params), + params, use_pfor=use_pfor, dtype=dtype) def _test_gradients(testcase, @@ -374,7 +421,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): loss = _loss() vector = tape.gradient(loss, model.trainable_variables) variable_input_fn = lambda unused_variables: _loss() - forward_over_back_hvp = _hvp( + forward_over_back_hvp, = _hvp( variable_input_fn, [model.trainable_variables], [vector]) with backprop.GradientTape(persistent=True) as tape: tape.watch(model.trainable_variables) @@ -557,9 +604,9 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): primals = constant_op.constant([1., 2., 3.]) tangents = constant_op.constant([3., 4., 5.]) - forwardback_hvp_eager = _hvp(fun, (primals,), (tangents,)) - forwardback_hvp_function = def_function.function(_hvp)(fun, (primals,), - (tangents,)) + forwardback_hvp_eager, = _hvp(fun, (primals,), (tangents,)) + forwardback_hvp_function, = def_function.function(_hvp)(fun, (primals,), + (tangents,)) with backprop.GradientTape(persistent=True) as g: g.watch(primals) @@ -812,6 +859,57 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): self.assertIsNone(acc.jvp(result)) +class HessianTests(test.TestCase, parameterized.TestCase): + + def testHessian1D(self): + # Note: stolen from ops/gradients_test.py + m = 4 + rng = np.random.RandomState([1, 2, 3]) + mat_value = rng.randn(m, m).astype("float32") + x_value = rng.randn(m).astype("float32") + hess_value = mat_value + mat_value.T + mat = variables.Variable(mat_value) + + def _f(x): + return math_ops.reduce_sum(x[:, None] * mat * x[None, :]) + + hessian_eager, = _forward_over_back_hessian( + _f, [constant_op.constant(x_value)], + use_pfor=False, dtype=[dtypes.float32]) + self.assertAllClose(hess_value, hessian_eager) + hessian_function, = def_function.function(_forward_over_back_hessian)( + _f, [constant_op.constant(x_value)], + use_pfor=False, dtype=[dtypes.float32]) + self.assertAllClose(hess_value, hessian_function) + hessian_pfor, = def_function.function(_forward_over_back_hessian)( + _f, [constant_op.constant(x_value)], + use_pfor=True, dtype=[dtypes.float32]) + self.assertAllClose(hess_value, hessian_pfor) + + @parameterized.named_parameters( + [("PFor", True), + ("MapFn", False)]) + def testHessianOfVariables(self, use_pfor): + model = core.Dense(1) + model.build([2]) + + def _loss(*unused_args): + input_value = constant_op.constant([[-0.5, 1.], [0.5, -1.]]) + target = constant_op.constant([[-1.], [2.]]) + return math_ops.reduce_sum((model(input_value) - target) ** 2.) + + kernel_hess, bias_hess = _forward_over_back_hessian( + _loss, [model.kernel, model.bias], use_pfor=use_pfor, + dtype=[dtypes.float32, dtypes.float32]) + # 3 total parameters, the whole hessian is the 3x3 concatenation + self.assertEqual([3, 2, 1], kernel_hess.shape) + self.assertEqual([3, 1], bias_hess.shape) + full_hessian = array_ops.concat( + [array_ops.reshape(kernel_hess, [3, 2]), bias_hess], axis=1) + # The full Hessian should be symmetric. + self.assertAllClose(full_hessian, array_ops.transpose(full_hessian)) + + if __name__ == "__main__": # TODO(allenl): Also test with 1.x-style graph mode. ops.enable_eager_execution()