From 05cf54667e1e72ebab29dc287baa0ab57a5d5d6a Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Wed, 28 Aug 2019 12:37:11 -0700 Subject: [PATCH] Forwardprop: allow watching variables The gradient function transpose logic just needed to support non-float trainable types. PiperOrigin-RevId: 265969263 --- tensorflow/python/eager/BUILD | 1 + tensorflow/python/eager/forwardprop.py | 39 +++++++++++---------- tensorflow/python/eager/forwardprop_test.py | 36 +++++++++++++++++++ 3 files changed, 57 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 9a55ace76ac..4c93ba13fbc 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -273,6 +273,7 @@ cuda_py_test( ":forwardprop", ":forwardprop_util", ":test", + "//tensorflow/python/distribute:mirrored_strategy", ], shard_count = 5, xla_enable_strict_auto_jit = True, diff --git a/tensorflow/python/eager/forwardprop.py b/tensorflow/python/eager/forwardprop.py index bd153277485..74fc9db8b08 100644 --- a/tensorflow/python/eager/forwardprop.py +++ b/tensorflow/python/eager/forwardprop.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools - from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function @@ -28,6 +26,7 @@ from tensorflow.python.eager import execute from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_util from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -55,35 +54,35 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents): Returns: A flat list of tangents corresponding to `outputs`. """ - float_inputs = [] - float_indices = [] + trainable_inputs = [] + trainable_indices = [] nontrivial_tangents = [] for input_index, tensor in enumerate(inputs): - if tensor.dtype.is_floating: - float_inputs.append(tensor) - float_indices.append(input_index) + if gradients_util.IsTrainable(tensor): + trainable_inputs.append(tensor) + trainable_indices.append(input_index) nontrivial_tangents.append(tangents[input_index]) with backprop.GradientTape() as transpose_tape: with backprop.GradientTape() as backfunc_tape: - backfunc_tape.watch(float_inputs) + backfunc_tape.watch(trainable_inputs) execute.record_gradient(op_name, inputs, attr_tuple, outputs, "forward_op_replay") forwardprop_aids = [] - float_outputs = [] + trainable_outputs = [] nontrivial_output_indices = [] for output_index, output in enumerate(outputs): - if output.dtype.is_floating: + if gradients_util.IsTrainable(output): forwardprop_aids.append( array_ops.ones_like(output, name="unused_forwardprop_aid")) - float_outputs.append(output) + trainable_outputs.append(output) nontrivial_output_indices.append(output_index) transpose_tape.watch(forwardprop_aids) grads = backfunc_tape.gradient( - float_outputs, - float_inputs, + trainable_outputs, + trainable_inputs, forwardprop_aids, unconnected_gradients=UnconnectedGradients.ZERO) nontrivial_output_tangents = transpose_tape.gradient( @@ -183,10 +182,9 @@ class ForwardGradientAccumulator(object): logging.log_first_n( logging.WARN, "The dtype of the watched tensor must be " "floating (e.g. tf.float32), got %r", 5, t.dtype) - if hasattr(t, "handle"): - # TODO(allenl): Handle watching variables. - raise NotImplementedError("Currently only Tensors may be watched.") g = ops.convert_to_tensor(g, dtype=t.dtype) + if hasattr(t, "handle"): + t = t.handle pywrap_tensorflow.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g) def jvp(self, target): @@ -206,6 +204,9 @@ class ForwardGradientAccumulator(object): """ if self._accumulator is None: raise ValueError("Called jvp() without first tracing anything.") - return nest.map_structure( - functools.partial(pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP, - self._accumulator), target) + def _fetch_jvp(tensor): + if hasattr(tensor, "handle"): + tensor = tensor.handle + return pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP( + self._accumulator, tensor) + return nest.map_structure(_fetch_jvp, target) diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index 35678650481..3fd3cdfcdbc 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.python import pywrap_tensorflow +from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function from tensorflow.python.eager import forwardprop @@ -37,6 +38,7 @@ from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients from tensorflow.python.platform import test from tensorflow.python.util import nest @@ -470,6 +472,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): self.assertAllClose(backback_hvp, forwardback_hvp_eager) self.assertAllClose(backback_hvp, forwardback_hvp_function) + @test_util.assert_no_new_pyobjects_executing_eagerly def testShouldRecordAndStopRecord(self): with forwardprop.ForwardGradientAccumulator() as acc: c = constant_op.constant(1.) @@ -495,6 +498,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): self.assertIsNone(acc.jvp(d)) self.assertIsNone(tape.gradient(d, c)) + @test_util.assert_no_new_pyobjects_executing_eagerly def testRecordingSelectively(self): with forwardprop.ForwardGradientAccumulator() as acc: c = constant_op.constant(1.) @@ -522,6 +526,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): self.assertIsNone(tape.gradient(d, c)) self.assertAllClose(3., tape.gradient(e, c)) + @test_util.assert_no_new_pyobjects_executing_eagerly def testRecordingWithJVPIndices(self): with forwardprop.ForwardGradientAccumulator() as acc: c = constant_op.constant(1.) @@ -538,6 +543,37 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): None, (((0, 1),),)) self.assertAllClose(3., acc.jvp(d)) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testVariableWatched(self): + v = variables.Variable([1., 2., 3.]) + with forwardprop.ForwardGradientAccumulator() as acc: + acc.watch(v, constant_op.constant([.1, -.2, .3])) + self.assertAllClose([.1, -.2, .3], acc.jvp(v)) + x = v * 2. + self.assertAllClose([.2, -.4, .6], acc.jvp(x)) + x2 = v + .1 + self.assertAllClose([.1, -.2, .3], acc.jvp(x2)) + + # NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this + # test... could be something wrong with the test decorator, or some sort of + # nondeterminstic caching. + def testMirroredVariableWatched(self): + + def _replicated(input_tangent): + with forwardprop.ForwardGradientAccumulator() as acc: + acc.watch(v, input_tangent) + self.assertAllClose([.1, -.2, .3], acc.jvp(v)) + x = v * 2. + self.assertAllClose([.2, -.4, .6], acc.jvp(x)) + x2 = v + .1 + self.assertAllClose([.1, -.2, .3], acc.jvp(x2)) + + strategy = mirrored_strategy.MirroredStrategy() + with strategy.scope(): + v = variables.Variable([1., 2., 3.]) + strategy.experimental_run_v2( + _replicated, args=(constant_op.constant([.1, -.2, .3]),)) + if __name__ == "__main__": # TODO(allenl): Also test with 1.x-style graph mode.