From e9ff15d98aee395d81391fce6b1d4ac186140d86 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 23 Aug 2019 13:18:22 -0700 Subject: [PATCH] Forwardprop: Add utilities for temporarily pushing forward accumulator state More work toward adding a function special case. An accumulator triggers function-building, then needs to work on symbolic tensors captured by the function before returning to its original task. PiperOrigin-RevId: 265120491 --- tensorflow/c/eager/tape.h | 85 +++++++++++++-------- tensorflow/python/eager/BUILD | 12 +++ tensorflow/python/eager/forwardprop_test.py | 26 +++++++ tensorflow/python/eager/forwardprop_util.py | 47 ++++++++++++ tensorflow/python/eager/pywrap_tfe.h | 11 +++ tensorflow/python/eager/pywrap_tfe_src.cc | 16 ++++ tensorflow/python/pywrap_tfe.i | 2 + 7 files changed, 167 insertions(+), 32 deletions(-) create mode 100644 tensorflow/python/eager/forwardprop_util.py diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index d87781dd346..f3d9bb4ab27 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -18,6 +18,7 @@ limitations under the License. // Language-agnostic gradient tape. Does not perform backpropagation, just // maintains the data structures required to do so. +#include #include #include "tensorflow/core/framework/tensor_shape.h" @@ -209,7 +210,9 @@ class ForwardAccumulator { // ForwardAccumulator. explicit ForwardAccumulator( const VSpace& vspace) - : vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {} + : vspace_(vspace) { + call_state_.emplace(nullptr, false); + } virtual ~ForwardAccumulator() { for (auto accumulated : accumulated_gradients_) { @@ -262,11 +265,11 @@ class ForwardAccumulator { const std::function& backward_function_getter, const std::function& backward_function_deleter); - // Returns true if `Accumulate` is active somewhere above on the stack. This - // is useful for ordering ForwardAccumulators, where more deeply nested - // accumulators should not see computations from less deeply nested - // accumulators. - bool BusyAccumulating() const { return this->accumulating_; } + // Returns true if `Accumulate` is active somewhere above on the stack and + // there isn't an intervening PushState. This is useful for ordering + // ForwardAccumulators, where more deeply nested accumulators should not see + // computations from less deeply nested accumulators. + bool BusyAccumulating() const { return call_state_.top().accumulating; } // Fetches the current Jacobian-vector product associated with `tensor_id`, or // a nullptr if none is available. @@ -282,6 +285,15 @@ class ForwardAccumulator { bool ShouldRecord(gtl::ArraySlice tensor_ids, gtl::ArraySlice dtypes); + // Temporarily push or pop transient state for this accumulator. + // + // Allows an accumulator which is currently processing an operation to + // temporarily reset its state. Without pushing and poping, accumulators + // ignore operations executed as a direct result of their own jvp + // computations. + void PushState() { call_state_.emplace(nullptr, false); } + void PopState() { call_state_.pop(); } + private: // Helper for Accumulate: uses a GradientTape to compute forward gradients // from a backward gradient function. Fills `out_grads` corresponding to @@ -289,7 +301,7 @@ class ForwardAccumulator { // // Executes the backward function in order to trace its gradient, which will // waste computation if executing eagerly (when graph building the unneeded - // computation is pruned). Temporarily sets `backward_tape_` so that + // computation is pruned). Temporarily sets `backward_tape` so that // Accumulate will forward op executions to the tape while the backward // function is running; this effectively adds the backward tape to the active // set (but does not require complicated callbacks to the language bindings). @@ -305,16 +317,25 @@ class ForwardAccumulator { // Not owned; provides operations on Tensors which are currently only // available in language bindings (e.g. Python). const VSpace& vspace_; - // Set temporarily while in the Accumulate method; if backward_tape_ is not - // nullptr then we forward op executions to it so Accumulate can compute a - // backward pass on its backward function. - // - // Not owned by the ForwardAccumulator. The method which sets `backward_tape_` - // keeps ownership. - GradientTape* backward_tape_; - // While the Accumulate method is running (accumulating_ is True), any op - // executions not forwarded to backward_tape_ should be ignored. - bool accumulating_; + + struct AccumulatorCallState { + AccumulatorCallState( + GradientTape* backward_tape, + bool accumulating) + : backward_tape(backward_tape), accumulating(accumulating) {} + // Set temporarily while in the Accumulate method; if backward_tape is not + // nullptr then we forward op executions to it so Accumulate can compute a + // backward pass on its backward function. + // + // Not owned by the ForwardAccumulator. The method which sets + // `backward_tape` keeps ownership. + GradientTape* backward_tape; + // While the Accumulate method is running (accumulating is True), any op + // executions not forwarded to backward_tape should be ignored. + bool accumulating; + }; + std::stack> + call_state_; }; // Template instantiations here @@ -847,12 +868,12 @@ template bool ForwardAccumulator::ShouldRecord( gtl::ArraySlice tensor_ids, gtl::ArraySlice dtypes) { - if (backward_tape_ != nullptr) { - // If we're forwarding Accumulate calls to backward_tape_'s RecordOperation, + if (call_state_.top().backward_tape != nullptr) { + // If we're forwarding Accumulate calls to backward_tape's RecordOperation, // we should also delegate ShouldRecord. - return backward_tape_->ShouldRecord(tensor_ids, dtypes); + return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes); } - if (accumulating_) { + if (call_state_.top().accumulating) { return false; } for (int i = 0; i < tensor_ids.size(); ++i) { @@ -884,9 +905,10 @@ ForwardAccumulator::ForwardpropFromTape( */ std::unique_ptr> tape( new GradientTape(false)); - backward_tape_ = tape.get(); + AccumulatorCallState& call_state = call_state_.top(); + call_state.backward_tape = tape.get(); auto pop_backward_tape = - gtl::MakeCleanup([this] { this->backward_tape_ = nullptr; }); + gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; }); std::vector forwardprop_aids; std::vector sources; std::unordered_set sources_set; @@ -961,10 +983,10 @@ Status ForwardAccumulator::Accumulate( const ForwardFunction* forward_function, const std::function& backward_function_getter, const std::function& backward_function_deleter) { - if (backward_tape_ != nullptr) { - // If backward_tape_ is not null, then this call to Accumulate is the result + if (call_state_.top().backward_tape != nullptr) { + // If backward_tape is not null, then this call to Accumulate is the result // of a still-active call to Accumulate which is running operations. We - // forward these operations to backward_tape_ so the outer Accumulate call + // forward these operations to backward_tape so the outer Accumulate call // can do its work. // // Rather than re-entering and delegating Accumulate like this, we could @@ -972,9 +994,9 @@ Status ForwardAccumulator::Accumulate( // (so it can deactivate itself and activate its GradientTape). Currently // that is managed by the language binding and would require relatively // messy callbacks. - backward_tape_->RecordOperation(op_type, output_tensors, input_tensor_id, - input_dtypes, backward_function_getter, - backward_function_deleter); + call_state_.top().backward_tape->RecordOperation( + op_type, output_tensors, input_tensor_id, input_dtypes, + backward_function_getter, backward_function_deleter); return Status::OK(); } if (!ShouldRecord(input_tensor_id, input_dtypes)) { @@ -1012,9 +1034,8 @@ Status ForwardAccumulator::Accumulate( // Avoid infinite recursion. Whichever forward function we run, it'll end up // executing ops, and we don't want to watch those with this accumulator. - accumulating_ = true; - auto reset_accumulating = - gtl::MakeCleanup([this] { this->accumulating_ = false; }); + call_state_.emplace(nullptr, true); + auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); }); std::vector forward_grads; if (forward_function == nullptr) { diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 4d502e9b23d..9a55ace76ac 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -67,6 +67,7 @@ py_library( ":execute", ":execution_callbacks", ":forwardprop", + ":forwardprop_util", ":function", ":graph_only_ops", ":monitoring", @@ -270,6 +271,7 @@ cuda_py_test( srcs = ["forwardprop_test.py"], additional_deps = [ ":forwardprop", + ":forwardprop_util", ":test", ], shard_count = 5, @@ -529,6 +531,16 @@ py_library( ], ) +py_library( + name = "forwardprop_util", + srcs = ["forwardprop_util.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:pywrap_tensorflow", + ], +) + cuda_py_test( name = "benchmarks_test", srcs = ["benchmarks_test.py"], diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index f0c1fe96bf7..90942c74cfd 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -27,6 +27,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function from tensorflow.python.eager import forwardprop +from tensorflow.python.eager import forwardprop_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -274,6 +275,31 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegexp(ValueError, "test_error_string"): f(c) + def testPushPopAccumulatorState(self): + # Note that this example is somewhat contrived. push_forwardprop_state is + # probably only useful in practice for building functions that compute jvps + # alongside their usual outputs. + with forwardprop.ForwardGradientAccumulator() as acc: + + @custom_gradient.custom_gradient + def f(x): + y = math_ops.sin(x.numpy()) + + def grad(dy): + with forwardprop_util.push_forwardprop_state(): + x_copy = constant_op.constant(x.numpy()) + acc.watch(x_copy, dy) + y_copy = math_ops.sin(x_copy) + return dy * acc.jvp(y_copy) + + return y, grad + + c = constant_op.constant(1.) + d = constant_op.constant(2.) + acc.watch(c, d) + output = f(c) + self.assertAllClose(d * math_ops.cos(c), acc.jvp(output)) + @parameterized.named_parameters( [("Order{}".format(order), order, expected) for order, expected in enumerate(_X11_35_DERIVATIVES)]) diff --git a/tensorflow/python/eager/forwardprop_util.py b/tensorflow/python/eager/forwardprop_util.py new file mode 100644 index 00000000000..81d6c61db0c --- /dev/null +++ b/tensorflow/python/eager/forwardprop_util.py @@ -0,0 +1,47 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for managing forward accumulators. + +A separate file from forwardprop.py so that functions can use these utilities. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.python import pywrap_tensorflow + + +@contextlib.contextmanager +def push_forwardprop_state(): + """Temporarily push or pop transient state for accumulators in the active set. + + Allows an accumulator which is currently processing an operation to + temporarily reset its state. This is useful when building forwardprop versions + of functions, where an accumulator will trigger function building and then + must process captured symbolic tensors while building it. Without pushing and + poping, accumulators ignore operations executed as a direct result of their + own jvp computations. + + Yields: + None (used for its side effect). + """ + try: + pywrap_tensorflow.TFE_Py_ForwardAccumulatorPushState() + yield + finally: + pywrap_tensorflow.TFE_Py_ForwardAccumulatorPopState() diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 96152ae362c..88d7664b311 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -256,6 +256,17 @@ void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor, // `accumulator`. Returns None if no JVP is available. PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, PyObject* tensor); +// Temporarily push or pop transient state for accumulators in the active set. +// +// Allows an accumulator which is currently processing an operation to +// temporarily reset its state. This is useful when building forwardprop +// versions of functions, where an accumulator will trigger function building +// and then must process captured symbolic tensors while building it. Without +// pushing and poping, accumulators ignore operations executed as a direct +// result of their own jvp computations. +PyObject* TFE_Py_ForwardAccumulatorPushState(); +PyObject* TFE_Py_ForwardAccumulatorPopState(); + // Collects state from all current forward accumulators related to `tensors`. // // This is useful for packing JVPs as function inputs before executing a diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index c42d2818ed4..77c3e93b3cf 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1640,6 +1640,22 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { Py_RETURN_FALSE; } +PyObject* TFE_Py_ForwardAccumulatorPushState() { + auto forward_accumulators = *GetAccumulatorSet(); + for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { + accumulator->accumulator->PushState(); + } + Py_RETURN_NONE; +} + +PyObject* TFE_Py_ForwardAccumulatorPopState() { + auto forward_accumulators = *GetAccumulatorSet(); + for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { + accumulator->accumulator->PopState(); + } + Py_RETURN_NONE; +} + PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) { if (!TapeCouldPossiblyRecord(tensors)) { return GetPythonObjectFromInt(0); diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 2c76f50aa5d..6cd2b1a039e 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -88,6 +88,8 @@ limitations under the License. %rename("%s") TFE_Py_ForwardAccumulatorSetRemove; %rename("%s") TFE_Py_ForwardAccumulatorWatch; %rename("%s") TFE_Py_ForwardAccumulatorJVP; +%rename("%s") TFE_Py_ForwardAccumulatorPushState; +%rename("%s") TFE_Py_ForwardAccumulatorPopState; %rename("%s") TFE_Py_PackForwardGradients; %rename("%s") TFE_NewContextOptions; %rename("%s") TFE_ContextOptionsSetConfig;