Forwardprop: Add utilities for temporarily pushing forward accumulator state
More work toward adding a function special case. An accumulator triggers function-building, then needs to work on symbolic tensors captured by the function before returning to its original task. PiperOrigin-RevId: 265120491
This commit is contained in:
parent
aba46a95e0
commit
e9ff15d98a
@ -18,6 +18,7 @@ limitations under the License.
|
||||
// Language-agnostic gradient tape. Does not perform backpropagation, just
|
||||
// maintains the data structures required to do so.
|
||||
|
||||
#include <stack>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -209,7 +210,9 @@ class ForwardAccumulator {
|
||||
// ForwardAccumulator.
|
||||
explicit ForwardAccumulator(
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
|
||||
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
|
||||
: vspace_(vspace) {
|
||||
call_state_.emplace(nullptr, false);
|
||||
}
|
||||
|
||||
virtual ~ForwardAccumulator() {
|
||||
for (auto accumulated : accumulated_gradients_) {
|
||||
@ -262,11 +265,11 @@ class ForwardAccumulator {
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter);
|
||||
|
||||
// Returns true if `Accumulate` is active somewhere above on the stack. This
|
||||
// is useful for ordering ForwardAccumulators, where more deeply nested
|
||||
// accumulators should not see computations from less deeply nested
|
||||
// accumulators.
|
||||
bool BusyAccumulating() const { return this->accumulating_; }
|
||||
// Returns true if `Accumulate` is active somewhere above on the stack and
|
||||
// there isn't an intervening PushState. This is useful for ordering
|
||||
// ForwardAccumulators, where more deeply nested accumulators should not see
|
||||
// computations from less deeply nested accumulators.
|
||||
bool BusyAccumulating() const { return call_state_.top().accumulating; }
|
||||
|
||||
// Fetches the current Jacobian-vector product associated with `tensor_id`, or
|
||||
// a nullptr if none is available.
|
||||
@ -282,6 +285,15 @@ class ForwardAccumulator {
|
||||
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes);
|
||||
|
||||
// Temporarily push or pop transient state for this accumulator.
|
||||
//
|
||||
// Allows an accumulator which is currently processing an operation to
|
||||
// temporarily reset its state. Without pushing and poping, accumulators
|
||||
// ignore operations executed as a direct result of their own jvp
|
||||
// computations.
|
||||
void PushState() { call_state_.emplace(nullptr, false); }
|
||||
void PopState() { call_state_.pop(); }
|
||||
|
||||
private:
|
||||
// Helper for Accumulate: uses a GradientTape to compute forward gradients
|
||||
// from a backward gradient function. Fills `out_grads` corresponding to
|
||||
@ -289,7 +301,7 @@ class ForwardAccumulator {
|
||||
//
|
||||
// Executes the backward function in order to trace its gradient, which will
|
||||
// waste computation if executing eagerly (when graph building the unneeded
|
||||
// computation is pruned). Temporarily sets `backward_tape_` so that
|
||||
// computation is pruned). Temporarily sets `backward_tape` so that
|
||||
// Accumulate will forward op executions to the tape while the backward
|
||||
// function is running; this effectively adds the backward tape to the active
|
||||
// set (but does not require complicated callbacks to the language bindings).
|
||||
@ -305,16 +317,25 @@ class ForwardAccumulator {
|
||||
// Not owned; provides operations on Tensors which are currently only
|
||||
// available in language bindings (e.g. Python).
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
|
||||
// Set temporarily while in the Accumulate method; if backward_tape_ is not
|
||||
// nullptr then we forward op executions to it so Accumulate can compute a
|
||||
// backward pass on its backward function.
|
||||
//
|
||||
// Not owned by the ForwardAccumulator. The method which sets `backward_tape_`
|
||||
// keeps ownership.
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape_;
|
||||
// While the Accumulate method is running (accumulating_ is True), any op
|
||||
// executions not forwarded to backward_tape_ should be ignored.
|
||||
bool accumulating_;
|
||||
|
||||
struct AccumulatorCallState {
|
||||
AccumulatorCallState(
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
|
||||
bool accumulating)
|
||||
: backward_tape(backward_tape), accumulating(accumulating) {}
|
||||
// Set temporarily while in the Accumulate method; if backward_tape is not
|
||||
// nullptr then we forward op executions to it so Accumulate can compute a
|
||||
// backward pass on its backward function.
|
||||
//
|
||||
// Not owned by the ForwardAccumulator. The method which sets
|
||||
// `backward_tape` keeps ownership.
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
|
||||
// While the Accumulate method is running (accumulating is True), any op
|
||||
// executions not forwarded to backward_tape should be ignored.
|
||||
bool accumulating;
|
||||
};
|
||||
std::stack<AccumulatorCallState, std::vector<AccumulatorCallState>>
|
||||
call_state_;
|
||||
};
|
||||
|
||||
// Template instantiations here
|
||||
@ -847,12 +868,12 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
|
||||
gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes) {
|
||||
if (backward_tape_ != nullptr) {
|
||||
// If we're forwarding Accumulate calls to backward_tape_'s RecordOperation,
|
||||
if (call_state_.top().backward_tape != nullptr) {
|
||||
// If we're forwarding Accumulate calls to backward_tape's RecordOperation,
|
||||
// we should also delegate ShouldRecord.
|
||||
return backward_tape_->ShouldRecord(tensor_ids, dtypes);
|
||||
return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
|
||||
}
|
||||
if (accumulating_) {
|
||||
if (call_state_.top().accumulating) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < tensor_ids.size(); ++i) {
|
||||
@ -884,9 +905,10 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
*/
|
||||
std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
|
||||
new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
|
||||
backward_tape_ = tape.get();
|
||||
AccumulatorCallState& call_state = call_state_.top();
|
||||
call_state.backward_tape = tape.get();
|
||||
auto pop_backward_tape =
|
||||
gtl::MakeCleanup([this] { this->backward_tape_ = nullptr; });
|
||||
gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
|
||||
std::vector<Gradient*> forwardprop_aids;
|
||||
std::vector<int64> sources;
|
||||
std::unordered_set<int64> sources_set;
|
||||
@ -961,10 +983,10 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
const ForwardFunction<Gradient>* forward_function,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
|
||||
if (backward_tape_ != nullptr) {
|
||||
// If backward_tape_ is not null, then this call to Accumulate is the result
|
||||
if (call_state_.top().backward_tape != nullptr) {
|
||||
// If backward_tape is not null, then this call to Accumulate is the result
|
||||
// of a still-active call to Accumulate which is running operations. We
|
||||
// forward these operations to backward_tape_ so the outer Accumulate call
|
||||
// forward these operations to backward_tape so the outer Accumulate call
|
||||
// can do its work.
|
||||
//
|
||||
// Rather than re-entering and delegating Accumulate like this, we could
|
||||
@ -972,9 +994,9 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
// (so it can deactivate itself and activate its GradientTape). Currently
|
||||
// that is managed by the language binding and would require relatively
|
||||
// messy callbacks.
|
||||
backward_tape_->RecordOperation(op_type, output_tensors, input_tensor_id,
|
||||
input_dtypes, backward_function_getter,
|
||||
backward_function_deleter);
|
||||
call_state_.top().backward_tape->RecordOperation(
|
||||
op_type, output_tensors, input_tensor_id, input_dtypes,
|
||||
backward_function_getter, backward_function_deleter);
|
||||
return Status::OK();
|
||||
}
|
||||
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
|
||||
@ -1012,9 +1034,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
|
||||
// Avoid infinite recursion. Whichever forward function we run, it'll end up
|
||||
// executing ops, and we don't want to watch those with this accumulator.
|
||||
accumulating_ = true;
|
||||
auto reset_accumulating =
|
||||
gtl::MakeCleanup([this] { this->accumulating_ = false; });
|
||||
call_state_.emplace(nullptr, true);
|
||||
auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
|
||||
|
||||
std::vector<Gradient*> forward_grads;
|
||||
if (forward_function == nullptr) {
|
||||
|
@ -67,6 +67,7 @@ py_library(
|
||||
":execute",
|
||||
":execution_callbacks",
|
||||
":forwardprop",
|
||||
":forwardprop_util",
|
||||
":function",
|
||||
":graph_only_ops",
|
||||
":monitoring",
|
||||
@ -270,6 +271,7 @@ cuda_py_test(
|
||||
srcs = ["forwardprop_test.py"],
|
||||
additional_deps = [
|
||||
":forwardprop",
|
||||
":forwardprop_util",
|
||||
":test",
|
||||
],
|
||||
shard_count = 5,
|
||||
@ -529,6 +531,16 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "forwardprop_util",
|
||||
srcs = ["forwardprop_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "benchmarks_test",
|
||||
srcs = ["benchmarks_test.py"],
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import forwardprop
|
||||
from tensorflow.python.eager import forwardprop_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -274,6 +275,31 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "test_error_string"):
|
||||
f(c)
|
||||
|
||||
def testPushPopAccumulatorState(self):
|
||||
# Note that this example is somewhat contrived. push_forwardprop_state is
|
||||
# probably only useful in practice for building functions that compute jvps
|
||||
# alongside their usual outputs.
|
||||
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def f(x):
|
||||
y = math_ops.sin(x.numpy())
|
||||
|
||||
def grad(dy):
|
||||
with forwardprop_util.push_forwardprop_state():
|
||||
x_copy = constant_op.constant(x.numpy())
|
||||
acc.watch(x_copy, dy)
|
||||
y_copy = math_ops.sin(x_copy)
|
||||
return dy * acc.jvp(y_copy)
|
||||
|
||||
return y, grad
|
||||
|
||||
c = constant_op.constant(1.)
|
||||
d = constant_op.constant(2.)
|
||||
acc.watch(c, d)
|
||||
output = f(c)
|
||||
self.assertAllClose(d * math_ops.cos(c), acc.jvp(output))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
[("Order{}".format(order), order, expected)
|
||||
for order, expected in enumerate(_X11_35_DERIVATIVES)])
|
||||
|
47
tensorflow/python/eager/forwardprop_util.py
Normal file
47
tensorflow/python/eager/forwardprop_util.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for managing forward accumulators.
|
||||
|
||||
A separate file from forwardprop.py so that functions can use these utilities.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def push_forwardprop_state():
|
||||
"""Temporarily push or pop transient state for accumulators in the active set.
|
||||
|
||||
Allows an accumulator which is currently processing an operation to
|
||||
temporarily reset its state. This is useful when building forwardprop versions
|
||||
of functions, where an accumulator will trigger function building and then
|
||||
must process captured symbolic tensors while building it. Without pushing and
|
||||
poping, accumulators ignore operations executed as a direct result of their
|
||||
own jvp computations.
|
||||
|
||||
Yields:
|
||||
None (used for its side effect).
|
||||
"""
|
||||
try:
|
||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPushState()
|
||||
yield
|
||||
finally:
|
||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorPopState()
|
@ -256,6 +256,17 @@ void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
|
||||
// `accumulator`. Returns None if no JVP is available.
|
||||
PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, PyObject* tensor);
|
||||
|
||||
// Temporarily push or pop transient state for accumulators in the active set.
|
||||
//
|
||||
// Allows an accumulator which is currently processing an operation to
|
||||
// temporarily reset its state. This is useful when building forwardprop
|
||||
// versions of functions, where an accumulator will trigger function building
|
||||
// and then must process captured symbolic tensors while building it. Without
|
||||
// pushing and poping, accumulators ignore operations executed as a direct
|
||||
// result of their own jvp computations.
|
||||
PyObject* TFE_Py_ForwardAccumulatorPushState();
|
||||
PyObject* TFE_Py_ForwardAccumulatorPopState();
|
||||
|
||||
// Collects state from all current forward accumulators related to `tensors`.
|
||||
//
|
||||
// This is useful for packing JVPs as function inputs before executing a
|
||||
|
@ -1640,6 +1640,22 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_ForwardAccumulatorPushState() {
|
||||
auto forward_accumulators = *GetAccumulatorSet();
|
||||
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
|
||||
accumulator->accumulator->PushState();
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_ForwardAccumulatorPopState() {
|
||||
auto forward_accumulators = *GetAccumulatorSet();
|
||||
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
|
||||
accumulator->accumulator->PopState();
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
|
||||
if (!TapeCouldPossiblyRecord(tensors)) {
|
||||
return GetPythonObjectFromInt(0);
|
||||
|
@ -88,6 +88,8 @@ limitations under the License.
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorSetRemove;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorWatch;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorJVP;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorPushState;
|
||||
%rename("%s") TFE_Py_ForwardAccumulatorPopState;
|
||||
%rename("%s") TFE_Py_PackForwardGradients;
|
||||
%rename("%s") TFE_NewContextOptions;
|
||||
%rename("%s") TFE_ContextOptionsSetConfig;
|
||||
|
Loading…
x
Reference in New Issue
Block a user