Forwardprop: allow watching variables
The gradient function transpose logic just needed to support non-float trainable types. PiperOrigin-RevId: 265969263
This commit is contained in:
parent
16e4f42086
commit
05cf54667e
@ -273,6 +273,7 @@ cuda_py_test(
|
||||
":forwardprop",
|
||||
":forwardprop_util",
|
||||
":test",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
],
|
||||
shard_count = 5,
|
||||
xla_enable_strict_auto_jit = True,
|
||||
|
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -28,6 +26,7 @@ from tensorflow.python.eager import execute
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_util
|
||||
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
@ -55,35 +54,35 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
|
||||
Returns:
|
||||
A flat list of tangents corresponding to `outputs`.
|
||||
"""
|
||||
float_inputs = []
|
||||
float_indices = []
|
||||
trainable_inputs = []
|
||||
trainable_indices = []
|
||||
nontrivial_tangents = []
|
||||
for input_index, tensor in enumerate(inputs):
|
||||
if tensor.dtype.is_floating:
|
||||
float_inputs.append(tensor)
|
||||
float_indices.append(input_index)
|
||||
if gradients_util.IsTrainable(tensor):
|
||||
trainable_inputs.append(tensor)
|
||||
trainable_indices.append(input_index)
|
||||
nontrivial_tangents.append(tangents[input_index])
|
||||
|
||||
with backprop.GradientTape() as transpose_tape:
|
||||
with backprop.GradientTape() as backfunc_tape:
|
||||
backfunc_tape.watch(float_inputs)
|
||||
backfunc_tape.watch(trainable_inputs)
|
||||
execute.record_gradient(op_name, inputs, attr_tuple, outputs,
|
||||
"forward_op_replay")
|
||||
|
||||
forwardprop_aids = []
|
||||
float_outputs = []
|
||||
trainable_outputs = []
|
||||
nontrivial_output_indices = []
|
||||
for output_index, output in enumerate(outputs):
|
||||
if output.dtype.is_floating:
|
||||
if gradients_util.IsTrainable(output):
|
||||
forwardprop_aids.append(
|
||||
array_ops.ones_like(output, name="unused_forwardprop_aid"))
|
||||
float_outputs.append(output)
|
||||
trainable_outputs.append(output)
|
||||
nontrivial_output_indices.append(output_index)
|
||||
|
||||
transpose_tape.watch(forwardprop_aids)
|
||||
grads = backfunc_tape.gradient(
|
||||
float_outputs,
|
||||
float_inputs,
|
||||
trainable_outputs,
|
||||
trainable_inputs,
|
||||
forwardprop_aids,
|
||||
unconnected_gradients=UnconnectedGradients.ZERO)
|
||||
nontrivial_output_tangents = transpose_tape.gradient(
|
||||
@ -183,10 +182,9 @@ class ForwardGradientAccumulator(object):
|
||||
logging.log_first_n(
|
||||
logging.WARN, "The dtype of the watched tensor must be "
|
||||
"floating (e.g. tf.float32), got %r", 5, t.dtype)
|
||||
if hasattr(t, "handle"):
|
||||
# TODO(allenl): Handle watching variables.
|
||||
raise NotImplementedError("Currently only Tensors may be watched.")
|
||||
g = ops.convert_to_tensor(g, dtype=t.dtype)
|
||||
if hasattr(t, "handle"):
|
||||
t = t.handle
|
||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
|
||||
|
||||
def jvp(self, target):
|
||||
@ -206,6 +204,9 @@ class ForwardGradientAccumulator(object):
|
||||
"""
|
||||
if self._accumulator is None:
|
||||
raise ValueError("Called jvp() without first tracing anything.")
|
||||
return nest.map_structure(
|
||||
functools.partial(pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP,
|
||||
self._accumulator), target)
|
||||
def _fetch_jvp(tensor):
|
||||
if hasattr(tensor, "handle"):
|
||||
tensor = tensor.handle
|
||||
return pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP(
|
||||
self._accumulator, tensor)
|
||||
return nest.map_structure(_fetch_jvp, target)
|
||||
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import forwardprop
|
||||
@ -37,6 +38,7 @@ from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
@ -470,6 +472,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(backback_hvp, forwardback_hvp_eager)
|
||||
self.assertAllClose(backback_hvp, forwardback_hvp_function)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testShouldRecordAndStopRecord(self):
|
||||
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||
c = constant_op.constant(1.)
|
||||
@ -495,6 +498,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsNone(acc.jvp(d))
|
||||
self.assertIsNone(tape.gradient(d, c))
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testRecordingSelectively(self):
|
||||
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||
c = constant_op.constant(1.)
|
||||
@ -522,6 +526,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsNone(tape.gradient(d, c))
|
||||
self.assertAllClose(3., tape.gradient(e, c))
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testRecordingWithJVPIndices(self):
|
||||
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||
c = constant_op.constant(1.)
|
||||
@ -538,6 +543,37 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
None, (((0, 1),),))
|
||||
self.assertAllClose(3., acc.jvp(d))
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testVariableWatched(self):
|
||||
v = variables.Variable([1., 2., 3.])
|
||||
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||
acc.watch(v, constant_op.constant([.1, -.2, .3]))
|
||||
self.assertAllClose([.1, -.2, .3], acc.jvp(v))
|
||||
x = v * 2.
|
||||
self.assertAllClose([.2, -.4, .6], acc.jvp(x))
|
||||
x2 = v + .1
|
||||
self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
|
||||
|
||||
# NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this
|
||||
# test... could be something wrong with the test decorator, or some sort of
|
||||
# nondeterminstic caching.
|
||||
def testMirroredVariableWatched(self):
|
||||
|
||||
def _replicated(input_tangent):
|
||||
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||
acc.watch(v, input_tangent)
|
||||
self.assertAllClose([.1, -.2, .3], acc.jvp(v))
|
||||
x = v * 2.
|
||||
self.assertAllClose([.2, -.4, .6], acc.jvp(x))
|
||||
x2 = v + .1
|
||||
self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
|
||||
|
||||
strategy = mirrored_strategy.MirroredStrategy()
|
||||
with strategy.scope():
|
||||
v = variables.Variable([1., 2., 3.])
|
||||
strategy.experimental_run_v2(
|
||||
_replicated, args=(constant_op.constant([.1, -.2, .3]),))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(allenl): Also test with 1.x-style graph mode.
|
||||
|
Loading…
x
Reference in New Issue
Block a user