Forwardprop: Add utilities for temporarily pushing forward accumulator state

More work toward adding a function special case. An accumulator triggers function-building, then needs to work on symbolic tensors captured by the function before returning to its original task.

PiperOrigin-RevId: 265120491
This commit is contained in:
Allen Lavoie 2019-08-23 13:18:22 -07:00 committed by TensorFlower Gardener
parent aba46a95e0
commit e9ff15d98a
7 changed files with 167 additions and 32 deletions

View File

@ -18,6 +18,7 @@ limitations under the License.
// Language-agnostic gradient tape. Does not perform backpropagation, just
// maintains the data structures required to do so.
#include <stack>
#include <vector>
#include "tensorflow/core/framework/tensor_shape.h"
@ -209,7 +210,9 @@ class ForwardAccumulator {
// ForwardAccumulator.
explicit ForwardAccumulator(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
: vspace_(vspace) {
call_state_.emplace(nullptr, false);
}
virtual ~ForwardAccumulator() {
for (auto accumulated : accumulated_gradients_) {
@ -262,11 +265,11 @@ class ForwardAccumulator {
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter);
// Returns true if `Accumulate` is active somewhere above on the stack. This
// is useful for ordering ForwardAccumulators, where more deeply nested
// accumulators should not see computations from less deeply nested
// accumulators.
bool BusyAccumulating() const { return this->accumulating_; }
// Returns true if `Accumulate` is active somewhere above on the stack and
// there isn't an intervening PushState. This is useful for ordering
// ForwardAccumulators, where more deeply nested accumulators should not see
// computations from less deeply nested accumulators.
bool BusyAccumulating() const { return call_state_.top().accumulating; }
// Fetches the current Jacobian-vector product associated with `tensor_id`, or
// a nullptr if none is available.
@ -282,6 +285,15 @@ class ForwardAccumulator {
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes);
// Temporarily push or pop transient state for this accumulator.
//
// Allows an accumulator which is currently processing an operation to
// temporarily reset its state. Without pushing and poping, accumulators
// ignore operations executed as a direct result of their own jvp
// computations.
void PushState() { call_state_.emplace(nullptr, false); }
void PopState() { call_state_.pop(); }
private:
// Helper for Accumulate: uses a GradientTape to compute forward gradients
// from a backward gradient function. Fills `out_grads` corresponding to
@ -289,7 +301,7 @@ class ForwardAccumulator {
//
// Executes the backward function in order to trace its gradient, which will
// waste computation if executing eagerly (when graph building the unneeded
// computation is pruned). Temporarily sets `backward_tape_` so that
// computation is pruned). Temporarily sets `backward_tape` so that
// Accumulate will forward op executions to the tape while the backward
// function is running; this effectively adds the backward tape to the active
// set (but does not require complicated callbacks to the language bindings).
@ -305,16 +317,25 @@ class ForwardAccumulator {
// Not owned; provides operations on Tensors which are currently only
// available in language bindings (e.g. Python).
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
// Set temporarily while in the Accumulate method; if backward_tape_ is not
// nullptr then we forward op executions to it so Accumulate can compute a
// backward pass on its backward function.
//
// Not owned by the ForwardAccumulator. The method which sets `backward_tape_`
// keeps ownership.
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape_;
// While the Accumulate method is running (accumulating_ is True), any op
// executions not forwarded to backward_tape_ should be ignored.
bool accumulating_;
struct AccumulatorCallState {
AccumulatorCallState(
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
bool accumulating)
: backward_tape(backward_tape), accumulating(accumulating) {}
// Set temporarily while in the Accumulate method; if backward_tape is not
// nullptr then we forward op executions to it so Accumulate can compute a
// backward pass on its backward function.
//
// Not owned by the ForwardAccumulator. The method which sets
// `backward_tape` keeps ownership.
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
// While the Accumulate method is running (accumulating is True), any op
// executions not forwarded to backward_tape should be ignored.
bool accumulating;
};
std::stack<AccumulatorCallState, std::vector<AccumulatorCallState>>
call_state_;
};
// Template instantiations here
@ -847,12 +868,12 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
if (backward_tape_ != nullptr) {
// If we're forwarding Accumulate calls to backward_tape_'s RecordOperation,
if (call_state_.top().backward_tape != nullptr) {
// If we're forwarding Accumulate calls to backward_tape's RecordOperation,
// we should also delegate ShouldRecord.
return backward_tape_->ShouldRecord(tensor_ids, dtypes);
return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
}
if (accumulating_) {
if (call_state_.top().accumulating) {
return false;
}
for (int i = 0; i < tensor_ids.size(); ++i) {
@ -884,9 +905,10 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
*/
std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
backward_tape_ = tape.get();
AccumulatorCallState& call_state = call_state_.top();
call_state.backward_tape = tape.get();
auto pop_backward_tape =
gtl::MakeCleanup([this] { this->backward_tape_ = nullptr; });
gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
std::vector<Gradient*> forwardprop_aids;
std::vector<int64> sources;
std::unordered_set<int64> sources_set;
@ -961,10 +983,10 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
const ForwardFunction<Gradient>* forward_function,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
if (backward_tape_ != nullptr) {
// If backward_tape_ is not null, then this call to Accumulate is the result
if (call_state_.top().backward_tape != nullptr) {
// If backward_tape is not null, then this call to Accumulate is the result
// of a still-active call to Accumulate which is running operations. We
// forward these operations to backward_tape_ so the outer Accumulate call
// forward these operations to backward_tape so the outer Accumulate call
// can do its work.
//
// Rather than re-entering and delegating Accumulate like this, we could
@ -972,9 +994,9 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
// (so it can deactivate itself and activate its GradientTape). Currently
// that is managed by the language binding and would require relatively
// messy callbacks.
backward_tape_->RecordOperation(op_type, output_tensors, input_tensor_id,
input_dtypes, backward_function_getter,
backward_function_deleter);
call_state_.top().backward_tape->RecordOperation(
op_type, output_tensors, input_tensor_id, input_dtypes,
backward_function_getter, backward_function_deleter);
return Status::OK();
}
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
@ -1012,9 +1034,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
// Avoid infinite recursion. Whichever forward function we run, it'll end up
// executing ops, and we don't want to watch those with this accumulator.
accumulating_ = true;
auto reset_accumulating =
gtl::MakeCleanup([this] { this->accumulating_ = false; });
call_state_.emplace(nullptr, true);
auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
std::vector<Gradient*> forward_grads;
if (forward_function == nullptr) {

View File

@ -67,6 +67,7 @@ py_library(
":execute",
":execution_callbacks",
":forwardprop",
":forwardprop_util",
":function",
":graph_only_ops",
":monitoring",
@ -270,6 +271,7 @@ cuda_py_test(
srcs = ["forwardprop_test.py"],
additional_deps = [
":forwardprop",
":forwardprop_util",
":test",
],
shard_count = 5,
@ -529,6 +531,16 @@ py_library(
],
)
py_library(
name = "forwardprop_util",
srcs = ["forwardprop_util.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
],
)
cuda_py_test(
name = "benchmarks_test",
srcs = ["benchmarks_test.py"],

View File

@ -27,6 +27,7 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import forwardprop
from tensorflow.python.eager import forwardprop_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -274,6 +275,31 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, "test_error_string"):
f(c)
def testPushPopAccumulatorState(self):
# Note that this example is somewhat contrived. push_forwardprop_state is
# probably only useful in practice for building functions that compute jvps
# alongside their usual outputs.
with forwardprop.ForwardGradientAccumulator() as acc:
@custom_gradient.custom_gradient
def f(x):
y = math_ops.sin(x.numpy())
def grad(dy):
with forwardprop_util.push_forwardprop_state():
x_copy = constant_op.constant(x.numpy())
acc.watch(x_copy, dy)
y_copy = math_ops.sin(x_copy)
return dy * acc.jvp(y_copy)
return y, grad
c = constant_op.constant(1.)
d = constant_op.constant(2.)
acc.watch(c, d)
output = f(c)
self.assertAllClose(d * math_ops.cos(c), acc.jvp(output))
@parameterized.named_parameters(
[("Order{}".format(order), order, expected)
for order, expected in enumerate(_X11_35_DERIVATIVES)])

View File

@ -0,0 +1,47 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for managing forward accumulators.
A separate file from forwardprop.py so that functions can use these utilities.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python import pywrap_tensorflow
@contextlib.contextmanager
def push_forwardprop_state():
"""Temporarily push or pop transient state for accumulators in the active set.
Allows an accumulator which is currently processing an operation to
temporarily reset its state. This is useful when building forwardprop versions
of functions, where an accumulator will trigger function building and then
must process captured symbolic tensors while building it. Without pushing and
poping, accumulators ignore operations executed as a direct result of their
own jvp computations.
Yields:
None (used for its side effect).
"""
try:
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPushState()
yield
finally:
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPopState()

View File

@ -256,6 +256,17 @@ void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
// `accumulator`. Returns None if no JVP is available.
PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, PyObject* tensor);
// Temporarily push or pop transient state for accumulators in the active set.
//
// Allows an accumulator which is currently processing an operation to
// temporarily reset its state. This is useful when building forwardprop
// versions of functions, where an accumulator will trigger function building
// and then must process captured symbolic tensors while building it. Without
// pushing and poping, accumulators ignore operations executed as a direct
// result of their own jvp computations.
PyObject* TFE_Py_ForwardAccumulatorPushState();
PyObject* TFE_Py_ForwardAccumulatorPopState();
// Collects state from all current forward accumulators related to `tensors`.
//
// This is useful for packing JVPs as function inputs before executing a

View File

@ -1640,6 +1640,22 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
Py_RETURN_FALSE;
}
PyObject* TFE_Py_ForwardAccumulatorPushState() {
auto forward_accumulators = *GetAccumulatorSet();
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
accumulator->accumulator->PushState();
}
Py_RETURN_NONE;
}
PyObject* TFE_Py_ForwardAccumulatorPopState() {
auto forward_accumulators = *GetAccumulatorSet();
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
accumulator->accumulator->PopState();
}
Py_RETURN_NONE;
}
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
if (!TapeCouldPossiblyRecord(tensors)) {
return GetPythonObjectFromInt(0);

View File

@ -88,6 +88,8 @@ limitations under the License.
%rename("%s") TFE_Py_ForwardAccumulatorSetRemove;
%rename("%s") TFE_Py_ForwardAccumulatorWatch;
%rename("%s") TFE_Py_ForwardAccumulatorJVP;
%rename("%s") TFE_Py_ForwardAccumulatorPushState;
%rename("%s") TFE_Py_ForwardAccumulatorPopState;
%rename("%s") TFE_Py_PackForwardGradients;
%rename("%s") TFE_NewContextOptions;
%rename("%s") TFE_ContextOptionsSetConfig;