Do not create an extra GradientTape in custom_gradient/recompute_grad. The extra tape leads the tf.function gradient code to believe that the user intends to compute higher order derivatives. This requires generating a forward function with all possible side outputs which is expensive. This also doesn't work well with control flow and causes the added test to fail.

Instead, we create a VariableWatcher object that keeps track of variables that
have been accessed.

PiperOrigin-RevId: 307157027
Change-Id: Ifd628b421dc725ad2366af2f6f63cf52dd1511e9
This commit is contained in:
Rohan Jain 2020-04-17 19:55:17 -07:00 committed by TensorFlower Gardener
parent 4164702c55
commit 65b4e47ece
10 changed files with 288 additions and 26 deletions

View File

@ -331,6 +331,7 @@ cuda_py_test(
"//tensorflow/python:embedding_ops",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
"//tensorflow/python:memory_checker",
"//tensorflow/python:nn_grad",
"//tensorflow/python:nn_ops",
"//tensorflow/python:random_ops",
@ -662,6 +663,7 @@ tf_py_test(
deps = [
":backprop",
":context",
":tape",
":test",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",

View File

@ -35,6 +35,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.framework.memory_checker import MemoryChecker
from tensorflow.python.layers.pooling import max_pooling3d
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -1532,6 +1533,39 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
self.assertIn('gradient_tape/my_scope/', op.name)
self.assertEqual(num_sin_ops_found, 2)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testRecomputeGradWithNestedFunctionAndWhileLoop(self):
@custom_gradient.recompute_grad
@def_function.function
def outer(x):
@def_function.function
def middle(y):
@def_function.function
def inner(z):
return z + 1
i = constant_op.constant(0.0)
c = lambda y, i: i < 10.
b = lambda y, i: (inner(y), i + 1.0)
y, i = control_flow_ops.while_loop(c, b, [y, i])
return y
return middle(x)
with MemoryChecker() as memory_checker:
for _ in range(5):
x = variables.Variable(1.0, name='x')
with backprop.GradientTape():
y = outer(x)
self.assertAllEqual(y, 11.0)
memory_checker.report()
memory_checker.assert_no_leak_if_all_possibly_except_one()
class JacobianTest(test.TestCase):

View File

@ -331,6 +331,22 @@ PyObject* TFE_Py_ForwardAccumulatorPopState();
// appended to `tensors`.
PyObject* TFE_Py_PackJVPs(PyObject* tensors);
// Variable Watcher methods.
// Creates a new variable watcher and adds it to the set of active variable
// watchers.
PyObject* TFE_Py_VariableWatcherNew();
// Removes the passed variable watcher from the set of active variable watchers.
void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher);
// Notifies all variable watchers that a variable has been accessed.
void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable);
// Returns all variables watched by the given variable_watcher in the order
// those variables were created.
PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher);
// Returns an EagerTensor of dimension [len(`tensors`)] containing
// the `slice_dim`'th dimension of each tensor in `tensors`. In other words,
// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in

View File

@ -1375,38 +1375,24 @@ PyObject* PyTapeTensor::ZerosLike() const {
return result;
}
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
PyTapeTensor> {
// Keeps track of all variables that have been accessed during execution.
class VariableWatcher {
public:
explicit GradientTape(bool persistent, bool watch_accessed_variables)
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
PyTapeTensor>(persistent),
watch_accessed_variables_(watch_accessed_variables) {}
VariableWatcher() {}
virtual ~GradientTape() {
~VariableWatcher() {
for (const IdAndVariable& v : watched_variables_) {
Py_DECREF(v.variable);
}
}
void VariableAccessed(PyObject* v) {
if (watch_accessed_variables_) {
WatchVariable(v);
}
}
void WatchVariable(PyObject* v) {
tensorflow::int64 WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
return;
return -1;
}
tensorflow::int64 id = FastTensorId(handle.get());
if (!PyErr_Occurred()) {
this->Watch(id);
}
tensorflow::mutex_lock l(watched_variables_mu_);
auto insert_result = watched_variables_.emplace(id, v);
@ -1415,6 +1401,8 @@ class GradientTape
// variable.
Py_INCREF(v);
}
return id;
}
PyObject* GetVariablesAsPyTuple() {
@ -1445,12 +1433,45 @@ class GradientTape
}
};
bool watch_accessed_variables_;
tensorflow::mutex watched_variables_mu_;
std::set<IdAndVariable, CompareById> watched_variables_
TF_GUARDED_BY(watched_variables_mu_);
};
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
PyTapeTensor> {
public:
explicit GradientTape(bool persistent, bool watch_accessed_variables)
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
PyTapeTensor>(persistent),
watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {}
void VariableAccessed(PyObject* v) {
if (watch_accessed_variables_) {
WatchVariable(v);
}
}
void WatchVariable(PyObject* v) {
tensorflow::int64 id = variable_watcher_.WatchVariable(v);
if (!PyErr_Occurred()) {
this->Watch(id);
}
}
PyObject* GetVariablesAsPyTuple() {
return variable_watcher_.GetVariablesAsPyTuple();
}
private:
bool watch_accessed_variables_;
VariableWatcher variable_watcher_;
};
typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
PyTapeTensor>
ForwardAccumulator;
@ -1535,6 +1556,41 @@ static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
"TFE_Py_ForwardAccumulator objects", /* tp_doc */
};
typedef struct {
PyObject_HEAD
/* Type-specific fields go here. */
VariableWatcher* variable_watcher;
} TFE_Py_VariableWatcher;
static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
->variable_watcher;
Py_TYPE(variable_watcher)->tp_free(variable_watcher);
}
static PyTypeObject TFE_Py_VariableWatcher_Type = {
PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
0, /* tp_itemsize */
&TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
0, /* tp_print */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"TFE_Py_VariableWatcher objects", /* tp_doc */
};
// Note: in the current design no mutex is needed here because of the python
// GIL, which is always held when any TFE_Py_* methods are called. We should
// revisit this if/when decide to not hold the GIL while manipulating the tape
@ -1548,6 +1604,18 @@ tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
return tape_set.get();
}
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
GetVariableWatcherSet() {
thread_local std::unique_ptr<
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
variable_watcher_set = nullptr;
if (variable_watcher_set == nullptr) {
variable_watcher_set.reset(
new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
}
return variable_watcher_set.get();
}
// A linked hash set, where iteration is in insertion order.
//
// Nested accumulators rely on op recording happening in insertion order, so an
@ -1670,6 +1738,16 @@ class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
}
};
class SafeVariableWatcherSet
: public SafeSetCopy<
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
public:
SafeVariableWatcherSet()
: SafeSetCopy<
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
*GetVariableWatcherSet()) {}
};
bool* ThreadTapeIsStopped() {
thread_local bool thread_tape_is_stopped{false};
return &thread_tape_is_stopped;
@ -2037,6 +2115,36 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
}
PyObject* TFE_Py_VariableWatcherNew() {
TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
TFE_Py_VariableWatcher* variable_watcher =
PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
variable_watcher->variable_watcher = new VariableWatcher();
Py_INCREF(variable_watcher);
GetVariableWatcherSet()->insert(variable_watcher);
return reinterpret_cast<PyObject*>(variable_watcher);
}
void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
auto* stack = GetVariableWatcherSet();
stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
// We kept a reference to the variable watcher in the set to ensure it
// wouldn't get deleted under us; cleaning it up here.
Py_DECREF(variable_watcher);
}
void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
variable_watcher->variable_watcher->WatchVariable(variable);
}
}
PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
->variable_watcher->GetVariablesAsPyTuple();
}
namespace {
std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
@ -3086,6 +3194,7 @@ void MaybeNotifyVariableAccessed(PyObject* input) {
PyObject_GetAttrString(input, "_trainable"));
if (trainable.get() == Py_False) return;
TFE_Py_TapeVariableAccessed(input);
TFE_Py_VariableWatcherVariableAccessed(input);
}
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,

View File

@ -58,6 +58,36 @@ def watch(tape, tensor):
pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
class VariableWatcher(object):
"""A scope that tracks all trainable variable accesses within it.
This explicitly ignores variables that are not marked as trainable.
Sample usage:
var = tf.Variable(0.0)
with VariableWatcher() as variable_watcher:
var.assign_add(1.0)
assert variable_watcher.watched_variables == [var]
"""
def __init__(self):
self._variable_watcher = None
def __enter__(self):
self._variable_watcher = pywrap_tfe.TFE_Py_VariableWatcherNew()
return self
def __exit__(self, typ, value, traceback):
pywrap_tfe.TFE_Py_VariableWatcherRemove(self._variable_watcher)
def watched_variables(self):
"""Returns a tuple of variables accessed under this scope."""
return pywrap_tfe.TFE_Py_VariableWatcherWatchedVariables(
self._variable_watcher)
def watch_variable(tape, variable):
"""Marks this variable to be watched by the given tape."""
strategy, context = (
@ -68,6 +98,7 @@ def watch_variable(tape, variable):
variables = strategy.experimental_local_results(variable)
for var in variables:
pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
def variable_accessed(variable):
@ -84,6 +115,7 @@ def variable_accessed(variable):
variables = strategy.experimental_local_results(variable)
for var in variables:
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
def variables_accessed(variables):
@ -107,6 +139,7 @@ def variables_accessed(variables):
for var in accessed:
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
def pop_tape(tape):

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -31,6 +32,7 @@ from tensorflow.python.ops import math_ops
# Importing nn_grad for the registration functions.
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
@custom_gradient.custom_gradient
@ -166,5 +168,48 @@ class TapeTest(test.TestCase):
self.assertAllEqual(g, 1.0)
class VariableWatcherTest(test.TestCase):
def testBasic(self):
var1 = variables.Variable(0.0)
var2 = variables.Variable(1.0)
with tape.VariableWatcher() as variable_watcher:
var1.assign_add(1.0)
var2.assign_add(2.0)
self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
def testNonTrainableVariables(self):
var1 = variables.Variable(0.0)
var2 = variables.Variable(1.0, trainable=False)
with tape.VariableWatcher() as variable_watcher:
var1.assign_add(1.0)
var2.assign_add(2.0)
self.assertAllEqual(variable_watcher.watched_variables(), (var1,))
def testMultipleScopes(self):
var1 = variables.Variable(0.0)
var2 = variables.Variable(1.0)
with tape.VariableWatcher() as variable_watcher1:
var1.assign_add(1.0)
with tape.VariableWatcher() as variable_watcher2:
var2.assign_add(2.0)
# variable_watcher1 should see both vars and variable_watcher2 only sees
# var2
self.assertAllEqual(variable_watcher1.watched_variables(), (var1, var2))
self.assertAllEqual(variable_watcher2.watched_variables(), (var2,))
def testCreateVariables(self):
with tape.VariableWatcher() as variable_watcher:
var1 = variables.Variable(0.0)
var2 = variables.Variable(1.0)
var1.assign_add(1.0)
var2.assign_add(2.0)
self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
if __name__ == '__main__':
test.main()

View File

@ -315,7 +315,7 @@ def _graph_mode_decorator(f, args, kwargs):
v.ref() for v in current_var_scope.global_variables() +
current_var_scope.local_variables()
])
with backprop.GradientTape() as tape:
with tape_lib.VariableWatcher() as variable_watcher:
result, grad_fn = f(*args)
after_vars = set([
v.ref() for v in current_var_scope.global_variables() +
@ -332,8 +332,9 @@ def _graph_mode_decorator(f, args, kwargs):
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
inputs = args
variables_in_tape = frozenset([v.ref() for v in tape.watched_variables()
]) - frozenset(v.ref() for v in inputs)
variables_in_tape = frozenset([
v.ref() for v in variable_watcher.watched_variables()
]) - frozenset(v.ref() for v in inputs)
variables_in_subgraph = frozenset([
v.ref()
for v in get_dependent_variables(input_ops=inputs, output_ops=result)
@ -405,14 +406,14 @@ def _graph_mode_decorator(f, args, kwargs):
def _eager_mode_decorator(f, args, kwargs):
"""Implement custom gradient decorator for eager mode."""
with backprop.GradientTape() as tape:
with tape_lib.VariableWatcher() as variable_watcher:
result, grad_fn = f(*args, **kwargs)
all_inputs = list(args) + list(kwargs.values())
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
variables = [
v.deref() # pylint: disable=g-complex-comprehension
for v in set(v.ref() for v in tape.watched_variables())
for v in set(v.ref() for v in variable_watcher.watched_variables())
if all(v.deref() is not i for i in all_inputs)
]
grad_argspec = tf_inspect.getfullargspec(grad_fn)

View File

@ -665,6 +665,23 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
});
// TFE_Py_VariableWatcher logic.
m.def("TFE_Py_VariableWatcherNew",
[]() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
});
m.def("TFE_Py_VariableWatcherVariableAccessed",
[](const py::handle& variable) {
TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
});
m.def("TFE_Py_VariableWatcherWatchedVariables",
[](const py::handle& variable_watcher) {
return tensorflow::PyoOrThrow(
TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
});
// TFE_Py_ForwardAccumulator logic.
m.def("TFE_Py_ForwardAccumulatorNew", []() {
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew());
});

View File

@ -173,6 +173,10 @@ TFE_Py_TapeGradient
TFE_Py_FastPathExecute_C
TFE_Py_RecordGradient
TFE_Py_TapeWatchedVariables
TFE_Py_VariableWatcherNew
TFE_Py_VariableWatcherRemove
TFE_Py_VariableWatcherVariableAccessed
TFE_Py_VariableWatcherWatchedVariables
TFE_Py_ForwardAccumulatorNew
TFE_Py_ForwardAccumulatorSetAdd
TFE_Py_ForwardAccumulatorSetRemove

View File

@ -111,6 +111,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/distribute:distribute_test_lib_pip",
"//tensorflow/python:loss_scale",
"//tensorflow/python:loss_scale_optimizer",
"//tensorflow/python:memory_checker",
"//tensorflow/python:meta_graph_testdata",
"//tensorflow/python:util_example_parser_configuration",
"//tensorflow/python/data/benchmarks:benchmark_base",