Forwardprop: allow watching variables

The gradient function transpose logic just needed to support non-float trainable types.

PiperOrigin-RevId: 265969263
This commit is contained in:
Allen Lavoie 2019-08-28 12:37:11 -07:00 committed by TensorFlower Gardener
parent 16e4f42086
commit 05cf54667e
3 changed files with 57 additions and 19 deletions

View File

@ -273,6 +273,7 @@ cuda_py_test(
":forwardprop",
":forwardprop_util",
":test",
"//tensorflow/python/distribute:mirrored_strategy",
],
shard_count = 5,
xla_enable_strict_auto_jit = True,

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
@ -28,6 +26,7 @@ from tensorflow.python.eager import execute
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_util
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@ -55,35 +54,35 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
Returns:
A flat list of tangents corresponding to `outputs`.
"""
float_inputs = []
float_indices = []
trainable_inputs = []
trainable_indices = []
nontrivial_tangents = []
for input_index, tensor in enumerate(inputs):
if tensor.dtype.is_floating:
float_inputs.append(tensor)
float_indices.append(input_index)
if gradients_util.IsTrainable(tensor):
trainable_inputs.append(tensor)
trainable_indices.append(input_index)
nontrivial_tangents.append(tangents[input_index])
with backprop.GradientTape() as transpose_tape:
with backprop.GradientTape() as backfunc_tape:
backfunc_tape.watch(float_inputs)
backfunc_tape.watch(trainable_inputs)
execute.record_gradient(op_name, inputs, attr_tuple, outputs,
"forward_op_replay")
forwardprop_aids = []
float_outputs = []
trainable_outputs = []
nontrivial_output_indices = []
for output_index, output in enumerate(outputs):
if output.dtype.is_floating:
if gradients_util.IsTrainable(output):
forwardprop_aids.append(
array_ops.ones_like(output, name="unused_forwardprop_aid"))
float_outputs.append(output)
trainable_outputs.append(output)
nontrivial_output_indices.append(output_index)
transpose_tape.watch(forwardprop_aids)
grads = backfunc_tape.gradient(
float_outputs,
float_inputs,
trainable_outputs,
trainable_inputs,
forwardprop_aids,
unconnected_gradients=UnconnectedGradients.ZERO)
nontrivial_output_tangents = transpose_tape.gradient(
@ -183,10 +182,9 @@ class ForwardGradientAccumulator(object):
logging.log_first_n(
logging.WARN, "The dtype of the watched tensor must be "
"floating (e.g. tf.float32), got %r", 5, t.dtype)
if hasattr(t, "handle"):
# TODO(allenl): Handle watching variables.
raise NotImplementedError("Currently only Tensors may be watched.")
g = ops.convert_to_tensor(g, dtype=t.dtype)
if hasattr(t, "handle"):
t = t.handle
pywrap_tensorflow.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
def jvp(self, target):
@ -206,6 +204,9 @@ class ForwardGradientAccumulator(object):
"""
if self._accumulator is None:
raise ValueError("Called jvp() without first tracing anything.")
return nest.map_structure(
functools.partial(pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP,
self._accumulator), target)
def _fetch_jvp(tensor):
if hasattr(tensor, "handle"):
tensor = tensor.handle
return pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP(
self._accumulator, tensor)
return nest.map_structure(_fetch_jvp, target)

View File

@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import forwardprop
@ -37,6 +38,7 @@ from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import test
from tensorflow.python.util import nest
@ -470,6 +472,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(backback_hvp, forwardback_hvp_eager)
self.assertAllClose(backback_hvp, forwardback_hvp_function)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testShouldRecordAndStopRecord(self):
with forwardprop.ForwardGradientAccumulator() as acc:
c = constant_op.constant(1.)
@ -495,6 +498,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
self.assertIsNone(acc.jvp(d))
self.assertIsNone(tape.gradient(d, c))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testRecordingSelectively(self):
with forwardprop.ForwardGradientAccumulator() as acc:
c = constant_op.constant(1.)
@ -522,6 +526,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
self.assertIsNone(tape.gradient(d, c))
self.assertAllClose(3., tape.gradient(e, c))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testRecordingWithJVPIndices(self):
with forwardprop.ForwardGradientAccumulator() as acc:
c = constant_op.constant(1.)
@ -538,6 +543,37 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
None, (((0, 1),),))
self.assertAllClose(3., acc.jvp(d))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testVariableWatched(self):
v = variables.Variable([1., 2., 3.])
with forwardprop.ForwardGradientAccumulator() as acc:
acc.watch(v, constant_op.constant([.1, -.2, .3]))
self.assertAllClose([.1, -.2, .3], acc.jvp(v))
x = v * 2.
self.assertAllClose([.2, -.4, .6], acc.jvp(x))
x2 = v + .1
self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
# NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this
# test... could be something wrong with the test decorator, or some sort of
# nondeterminstic caching.
def testMirroredVariableWatched(self):
def _replicated(input_tangent):
with forwardprop.ForwardGradientAccumulator() as acc:
acc.watch(v, input_tangent)
self.assertAllClose([.1, -.2, .3], acc.jvp(v))
x = v * 2.
self.assertAllClose([.2, -.4, .6], acc.jvp(x))
x2 = v + .1
self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
strategy = mirrored_strategy.MirroredStrategy()
with strategy.scope():
v = variables.Variable([1., 2., 3.])
strategy.experimental_run_v2(
_replicated, args=(constant_op.constant([.1, -.2, .3]),))
if __name__ == "__main__":
# TODO(allenl): Also test with 1.x-style graph mode.