Do not create an extra GradientTape in custom_gradient/recompute_grad. The extra tape leads the tf.function gradient code to believe that the user intends to compute higher order derivatives. This requires generating a forward function with all possible side outputs which is expensive. This also doesn't work well with control flow and causes the added test to fail.
Instead, we create a VariableWatcher object that keeps track of variables that have been accessed. PiperOrigin-RevId: 307157027 Change-Id: Ifd628b421dc725ad2366af2f6f63cf52dd1511e9
This commit is contained in:
parent
4164702c55
commit
65b4e47ece
@ -331,6 +331,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:embedding_ops",
|
"//tensorflow/python:embedding_ops",
|
||||||
"//tensorflow/python:layers",
|
"//tensorflow/python:layers",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:memory_checker",
|
||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
@ -662,6 +663,7 @@ tf_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":backprop",
|
":backprop",
|
||||||
":context",
|
":context",
|
||||||
|
":tape",
|
||||||
":test",
|
":test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.framework.memory_checker import MemoryChecker
|
||||||
from tensorflow.python.layers.pooling import max_pooling3d
|
from tensorflow.python.layers.pooling import max_pooling3d
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -1532,6 +1533,39 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertIn('gradient_tape/my_scope/', op.name)
|
self.assertIn('gradient_tape/my_scope/', op.name)
|
||||||
self.assertEqual(num_sin_ops_found, 2)
|
self.assertEqual(num_sin_ops_found, 2)
|
||||||
|
|
||||||
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
|
def testRecomputeGradWithNestedFunctionAndWhileLoop(self):
|
||||||
|
|
||||||
|
@custom_gradient.recompute_grad
|
||||||
|
@def_function.function
|
||||||
|
def outer(x):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def middle(y):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def inner(z):
|
||||||
|
return z + 1
|
||||||
|
|
||||||
|
i = constant_op.constant(0.0)
|
||||||
|
c = lambda y, i: i < 10.
|
||||||
|
b = lambda y, i: (inner(y), i + 1.0)
|
||||||
|
y, i = control_flow_ops.while_loop(c, b, [y, i])
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
return middle(x)
|
||||||
|
|
||||||
|
with MemoryChecker() as memory_checker:
|
||||||
|
for _ in range(5):
|
||||||
|
x = variables.Variable(1.0, name='x')
|
||||||
|
with backprop.GradientTape():
|
||||||
|
y = outer(x)
|
||||||
|
self.assertAllEqual(y, 11.0)
|
||||||
|
|
||||||
|
memory_checker.report()
|
||||||
|
memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||||
|
|
||||||
|
|
||||||
class JacobianTest(test.TestCase):
|
class JacobianTest(test.TestCase):
|
||||||
|
|
||||||
|
|||||||
@ -331,6 +331,22 @@ PyObject* TFE_Py_ForwardAccumulatorPopState();
|
|||||||
// appended to `tensors`.
|
// appended to `tensors`.
|
||||||
PyObject* TFE_Py_PackJVPs(PyObject* tensors);
|
PyObject* TFE_Py_PackJVPs(PyObject* tensors);
|
||||||
|
|
||||||
|
// Variable Watcher methods.
|
||||||
|
|
||||||
|
// Creates a new variable watcher and adds it to the set of active variable
|
||||||
|
// watchers.
|
||||||
|
PyObject* TFE_Py_VariableWatcherNew();
|
||||||
|
|
||||||
|
// Removes the passed variable watcher from the set of active variable watchers.
|
||||||
|
void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher);
|
||||||
|
|
||||||
|
// Notifies all variable watchers that a variable has been accessed.
|
||||||
|
void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable);
|
||||||
|
|
||||||
|
// Returns all variables watched by the given variable_watcher in the order
|
||||||
|
// those variables were created.
|
||||||
|
PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher);
|
||||||
|
|
||||||
// Returns an EagerTensor of dimension [len(`tensors`)] containing
|
// Returns an EagerTensor of dimension [len(`tensors`)] containing
|
||||||
// the `slice_dim`'th dimension of each tensor in `tensors`. In other words,
|
// the `slice_dim`'th dimension of each tensor in `tensors`. In other words,
|
||||||
// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
|
// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
|
||||||
|
|||||||
@ -1375,38 +1375,24 @@ PyObject* PyTapeTensor::ZerosLike() const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
class GradientTape
|
// Keeps track of all variables that have been accessed during execution.
|
||||||
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
|
class VariableWatcher {
|
||||||
PyTapeTensor> {
|
|
||||||
public:
|
public:
|
||||||
explicit GradientTape(bool persistent, bool watch_accessed_variables)
|
VariableWatcher() {}
|
||||||
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
|
|
||||||
PyTapeTensor>(persistent),
|
|
||||||
watch_accessed_variables_(watch_accessed_variables) {}
|
|
||||||
|
|
||||||
virtual ~GradientTape() {
|
~VariableWatcher() {
|
||||||
for (const IdAndVariable& v : watched_variables_) {
|
for (const IdAndVariable& v : watched_variables_) {
|
||||||
Py_DECREF(v.variable);
|
Py_DECREF(v.variable);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void VariableAccessed(PyObject* v) {
|
tensorflow::int64 WatchVariable(PyObject* v) {
|
||||||
if (watch_accessed_variables_) {
|
|
||||||
WatchVariable(v);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void WatchVariable(PyObject* v) {
|
|
||||||
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
|
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
|
||||||
if (handle == nullptr) {
|
if (handle == nullptr) {
|
||||||
return;
|
return -1;
|
||||||
}
|
}
|
||||||
tensorflow::int64 id = FastTensorId(handle.get());
|
tensorflow::int64 id = FastTensorId(handle.get());
|
||||||
|
|
||||||
if (!PyErr_Occurred()) {
|
|
||||||
this->Watch(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::mutex_lock l(watched_variables_mu_);
|
tensorflow::mutex_lock l(watched_variables_mu_);
|
||||||
auto insert_result = watched_variables_.emplace(id, v);
|
auto insert_result = watched_variables_.emplace(id, v);
|
||||||
|
|
||||||
@ -1415,6 +1401,8 @@ class GradientTape
|
|||||||
// variable.
|
// variable.
|
||||||
Py_INCREF(v);
|
Py_INCREF(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* GetVariablesAsPyTuple() {
|
PyObject* GetVariablesAsPyTuple() {
|
||||||
@ -1445,12 +1433,45 @@ class GradientTape
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool watch_accessed_variables_;
|
|
||||||
tensorflow::mutex watched_variables_mu_;
|
tensorflow::mutex watched_variables_mu_;
|
||||||
std::set<IdAndVariable, CompareById> watched_variables_
|
std::set<IdAndVariable, CompareById> watched_variables_
|
||||||
TF_GUARDED_BY(watched_variables_mu_);
|
TF_GUARDED_BY(watched_variables_mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class GradientTape
|
||||||
|
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
|
||||||
|
PyTapeTensor> {
|
||||||
|
public:
|
||||||
|
explicit GradientTape(bool persistent, bool watch_accessed_variables)
|
||||||
|
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
|
||||||
|
PyTapeTensor>(persistent),
|
||||||
|
watch_accessed_variables_(watch_accessed_variables) {}
|
||||||
|
|
||||||
|
virtual ~GradientTape() {}
|
||||||
|
|
||||||
|
void VariableAccessed(PyObject* v) {
|
||||||
|
if (watch_accessed_variables_) {
|
||||||
|
WatchVariable(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void WatchVariable(PyObject* v) {
|
||||||
|
tensorflow::int64 id = variable_watcher_.WatchVariable(v);
|
||||||
|
|
||||||
|
if (!PyErr_Occurred()) {
|
||||||
|
this->Watch(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* GetVariablesAsPyTuple() {
|
||||||
|
return variable_watcher_.GetVariablesAsPyTuple();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool watch_accessed_variables_;
|
||||||
|
VariableWatcher variable_watcher_;
|
||||||
|
};
|
||||||
|
|
||||||
typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
|
typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
|
||||||
PyTapeTensor>
|
PyTapeTensor>
|
||||||
ForwardAccumulator;
|
ForwardAccumulator;
|
||||||
@ -1535,6 +1556,41 @@ static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
|
|||||||
"TFE_Py_ForwardAccumulator objects", /* tp_doc */
|
"TFE_Py_ForwardAccumulator objects", /* tp_doc */
|
||||||
};
|
};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
PyObject_HEAD
|
||||||
|
/* Type-specific fields go here. */
|
||||||
|
VariableWatcher* variable_watcher;
|
||||||
|
} TFE_Py_VariableWatcher;
|
||||||
|
|
||||||
|
static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
|
||||||
|
delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
|
||||||
|
->variable_watcher;
|
||||||
|
Py_TYPE(variable_watcher)->tp_free(variable_watcher);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyTypeObject TFE_Py_VariableWatcher_Type = {
|
||||||
|
PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
|
||||||
|
sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
|
||||||
|
0, /* tp_itemsize */
|
||||||
|
&TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
|
||||||
|
0, /* tp_print */
|
||||||
|
nullptr, /* tp_getattr */
|
||||||
|
nullptr, /* tp_setattr */
|
||||||
|
nullptr, /* tp_reserved */
|
||||||
|
nullptr, /* tp_repr */
|
||||||
|
nullptr, /* tp_as_number */
|
||||||
|
nullptr, /* tp_as_sequence */
|
||||||
|
nullptr, /* tp_as_mapping */
|
||||||
|
nullptr, /* tp_hash */
|
||||||
|
nullptr, /* tp_call */
|
||||||
|
nullptr, /* tp_str */
|
||||||
|
nullptr, /* tp_getattro */
|
||||||
|
nullptr, /* tp_setattro */
|
||||||
|
nullptr, /* tp_as_buffer */
|
||||||
|
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||||
|
"TFE_Py_VariableWatcher objects", /* tp_doc */
|
||||||
|
};
|
||||||
|
|
||||||
// Note: in the current design no mutex is needed here because of the python
|
// Note: in the current design no mutex is needed here because of the python
|
||||||
// GIL, which is always held when any TFE_Py_* methods are called. We should
|
// GIL, which is always held when any TFE_Py_* methods are called. We should
|
||||||
// revisit this if/when decide to not hold the GIL while manipulating the tape
|
// revisit this if/when decide to not hold the GIL while manipulating the tape
|
||||||
@ -1548,6 +1604,18 @@ tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
|
|||||||
return tape_set.get();
|
return tape_set.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
|
||||||
|
GetVariableWatcherSet() {
|
||||||
|
thread_local std::unique_ptr<
|
||||||
|
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
|
||||||
|
variable_watcher_set = nullptr;
|
||||||
|
if (variable_watcher_set == nullptr) {
|
||||||
|
variable_watcher_set.reset(
|
||||||
|
new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
|
||||||
|
}
|
||||||
|
return variable_watcher_set.get();
|
||||||
|
}
|
||||||
|
|
||||||
// A linked hash set, where iteration is in insertion order.
|
// A linked hash set, where iteration is in insertion order.
|
||||||
//
|
//
|
||||||
// Nested accumulators rely on op recording happening in insertion order, so an
|
// Nested accumulators rely on op recording happening in insertion order, so an
|
||||||
@ -1670,6 +1738,16 @@ class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SafeVariableWatcherSet
|
||||||
|
: public SafeSetCopy<
|
||||||
|
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
|
||||||
|
public:
|
||||||
|
SafeVariableWatcherSet()
|
||||||
|
: SafeSetCopy<
|
||||||
|
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
|
||||||
|
*GetVariableWatcherSet()) {}
|
||||||
|
};
|
||||||
|
|
||||||
bool* ThreadTapeIsStopped() {
|
bool* ThreadTapeIsStopped() {
|
||||||
thread_local bool thread_tape_is_stopped{false};
|
thread_local bool thread_tape_is_stopped{false};
|
||||||
return &thread_tape_is_stopped;
|
return &thread_tape_is_stopped;
|
||||||
@ -2037,6 +2115,36 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
|
|||||||
return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
|
return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* TFE_Py_VariableWatcherNew() {
|
||||||
|
TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
|
||||||
|
if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
|
||||||
|
TFE_Py_VariableWatcher* variable_watcher =
|
||||||
|
PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
|
||||||
|
variable_watcher->variable_watcher = new VariableWatcher();
|
||||||
|
Py_INCREF(variable_watcher);
|
||||||
|
GetVariableWatcherSet()->insert(variable_watcher);
|
||||||
|
return reinterpret_cast<PyObject*>(variable_watcher);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
|
||||||
|
auto* stack = GetVariableWatcherSet();
|
||||||
|
stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
|
||||||
|
// We kept a reference to the variable watcher in the set to ensure it
|
||||||
|
// wouldn't get deleted under us; cleaning it up here.
|
||||||
|
Py_DECREF(variable_watcher);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
|
||||||
|
for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
|
||||||
|
variable_watcher->variable_watcher->WatchVariable(variable);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
|
||||||
|
return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
|
||||||
|
->variable_watcher->GetVariablesAsPyTuple();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
|
std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
|
||||||
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
||||||
@ -3086,6 +3194,7 @@ void MaybeNotifyVariableAccessed(PyObject* input) {
|
|||||||
PyObject_GetAttrString(input, "_trainable"));
|
PyObject_GetAttrString(input, "_trainable"));
|
||||||
if (trainable.get() == Py_False) return;
|
if (trainable.get() == Py_False) return;
|
||||||
TFE_Py_TapeVariableAccessed(input);
|
TFE_Py_TapeVariableAccessed(input);
|
||||||
|
TFE_Py_VariableWatcherVariableAccessed(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
|
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
|
||||||
|
|||||||
@ -58,6 +58,36 @@ def watch(tape, tensor):
|
|||||||
pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
|
pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
class VariableWatcher(object):
|
||||||
|
"""A scope that tracks all trainable variable accesses within it.
|
||||||
|
|
||||||
|
This explicitly ignores variables that are not marked as trainable.
|
||||||
|
|
||||||
|
Sample usage:
|
||||||
|
|
||||||
|
var = tf.Variable(0.0)
|
||||||
|
with VariableWatcher() as variable_watcher:
|
||||||
|
var.assign_add(1.0)
|
||||||
|
|
||||||
|
assert variable_watcher.watched_variables == [var]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._variable_watcher = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._variable_watcher = pywrap_tfe.TFE_Py_VariableWatcherNew()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, typ, value, traceback):
|
||||||
|
pywrap_tfe.TFE_Py_VariableWatcherRemove(self._variable_watcher)
|
||||||
|
|
||||||
|
def watched_variables(self):
|
||||||
|
"""Returns a tuple of variables accessed under this scope."""
|
||||||
|
return pywrap_tfe.TFE_Py_VariableWatcherWatchedVariables(
|
||||||
|
self._variable_watcher)
|
||||||
|
|
||||||
|
|
||||||
def watch_variable(tape, variable):
|
def watch_variable(tape, variable):
|
||||||
"""Marks this variable to be watched by the given tape."""
|
"""Marks this variable to be watched by the given tape."""
|
||||||
strategy, context = (
|
strategy, context = (
|
||||||
@ -68,6 +98,7 @@ def watch_variable(tape, variable):
|
|||||||
variables = strategy.experimental_local_results(variable)
|
variables = strategy.experimental_local_results(variable)
|
||||||
for var in variables:
|
for var in variables:
|
||||||
pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
|
pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
|
||||||
|
pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
|
||||||
|
|
||||||
|
|
||||||
def variable_accessed(variable):
|
def variable_accessed(variable):
|
||||||
@ -84,6 +115,7 @@ def variable_accessed(variable):
|
|||||||
variables = strategy.experimental_local_results(variable)
|
variables = strategy.experimental_local_results(variable)
|
||||||
for var in variables:
|
for var in variables:
|
||||||
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
|
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
|
||||||
|
pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
|
||||||
|
|
||||||
|
|
||||||
def variables_accessed(variables):
|
def variables_accessed(variables):
|
||||||
@ -107,6 +139,7 @@ def variables_accessed(variables):
|
|||||||
|
|
||||||
for var in accessed:
|
for var in accessed:
|
||||||
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
|
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
|
||||||
|
pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
|
||||||
|
|
||||||
|
|
||||||
def pop_tape(tape):
|
def pop_tape(tape):
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -31,6 +32,7 @@ from tensorflow.python.ops import math_ops
|
|||||||
# Importing nn_grad for the registration functions.
|
# Importing nn_grad for the registration functions.
|
||||||
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
|
|
||||||
@custom_gradient.custom_gradient
|
@custom_gradient.custom_gradient
|
||||||
@ -166,5 +168,48 @@ class TapeTest(test.TestCase):
|
|||||||
self.assertAllEqual(g, 1.0)
|
self.assertAllEqual(g, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableWatcherTest(test.TestCase):
|
||||||
|
|
||||||
|
def testBasic(self):
|
||||||
|
var1 = variables.Variable(0.0)
|
||||||
|
var2 = variables.Variable(1.0)
|
||||||
|
with tape.VariableWatcher() as variable_watcher:
|
||||||
|
var1.assign_add(1.0)
|
||||||
|
var2.assign_add(2.0)
|
||||||
|
|
||||||
|
self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
|
||||||
|
|
||||||
|
def testNonTrainableVariables(self):
|
||||||
|
var1 = variables.Variable(0.0)
|
||||||
|
var2 = variables.Variable(1.0, trainable=False)
|
||||||
|
with tape.VariableWatcher() as variable_watcher:
|
||||||
|
var1.assign_add(1.0)
|
||||||
|
var2.assign_add(2.0)
|
||||||
|
|
||||||
|
self.assertAllEqual(variable_watcher.watched_variables(), (var1,))
|
||||||
|
|
||||||
|
def testMultipleScopes(self):
|
||||||
|
var1 = variables.Variable(0.0)
|
||||||
|
var2 = variables.Variable(1.0)
|
||||||
|
with tape.VariableWatcher() as variable_watcher1:
|
||||||
|
var1.assign_add(1.0)
|
||||||
|
with tape.VariableWatcher() as variable_watcher2:
|
||||||
|
var2.assign_add(2.0)
|
||||||
|
|
||||||
|
# variable_watcher1 should see both vars and variable_watcher2 only sees
|
||||||
|
# var2
|
||||||
|
self.assertAllEqual(variable_watcher1.watched_variables(), (var1, var2))
|
||||||
|
self.assertAllEqual(variable_watcher2.watched_variables(), (var2,))
|
||||||
|
|
||||||
|
def testCreateVariables(self):
|
||||||
|
with tape.VariableWatcher() as variable_watcher:
|
||||||
|
var1 = variables.Variable(0.0)
|
||||||
|
var2 = variables.Variable(1.0)
|
||||||
|
var1.assign_add(1.0)
|
||||||
|
var2.assign_add(2.0)
|
||||||
|
|
||||||
|
self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|||||||
@ -315,7 +315,7 @@ def _graph_mode_decorator(f, args, kwargs):
|
|||||||
v.ref() for v in current_var_scope.global_variables() +
|
v.ref() for v in current_var_scope.global_variables() +
|
||||||
current_var_scope.local_variables()
|
current_var_scope.local_variables()
|
||||||
])
|
])
|
||||||
with backprop.GradientTape() as tape:
|
with tape_lib.VariableWatcher() as variable_watcher:
|
||||||
result, grad_fn = f(*args)
|
result, grad_fn = f(*args)
|
||||||
after_vars = set([
|
after_vars = set([
|
||||||
v.ref() for v in current_var_scope.global_variables() +
|
v.ref() for v in current_var_scope.global_variables() +
|
||||||
@ -332,8 +332,9 @@ def _graph_mode_decorator(f, args, kwargs):
|
|||||||
# The variables that grad_fn needs to return gradients for are the set of
|
# The variables that grad_fn needs to return gradients for are the set of
|
||||||
# variables used that are *not* part of the inputs.
|
# variables used that are *not* part of the inputs.
|
||||||
inputs = args
|
inputs = args
|
||||||
variables_in_tape = frozenset([v.ref() for v in tape.watched_variables()
|
variables_in_tape = frozenset([
|
||||||
]) - frozenset(v.ref() for v in inputs)
|
v.ref() for v in variable_watcher.watched_variables()
|
||||||
|
]) - frozenset(v.ref() for v in inputs)
|
||||||
variables_in_subgraph = frozenset([
|
variables_in_subgraph = frozenset([
|
||||||
v.ref()
|
v.ref()
|
||||||
for v in get_dependent_variables(input_ops=inputs, output_ops=result)
|
for v in get_dependent_variables(input_ops=inputs, output_ops=result)
|
||||||
@ -405,14 +406,14 @@ def _graph_mode_decorator(f, args, kwargs):
|
|||||||
|
|
||||||
def _eager_mode_decorator(f, args, kwargs):
|
def _eager_mode_decorator(f, args, kwargs):
|
||||||
"""Implement custom gradient decorator for eager mode."""
|
"""Implement custom gradient decorator for eager mode."""
|
||||||
with backprop.GradientTape() as tape:
|
with tape_lib.VariableWatcher() as variable_watcher:
|
||||||
result, grad_fn = f(*args, **kwargs)
|
result, grad_fn = f(*args, **kwargs)
|
||||||
all_inputs = list(args) + list(kwargs.values())
|
all_inputs = list(args) + list(kwargs.values())
|
||||||
# The variables that grad_fn needs to return gradients for are the set of
|
# The variables that grad_fn needs to return gradients for are the set of
|
||||||
# variables used that are *not* part of the inputs.
|
# variables used that are *not* part of the inputs.
|
||||||
variables = [
|
variables = [
|
||||||
v.deref() # pylint: disable=g-complex-comprehension
|
v.deref() # pylint: disable=g-complex-comprehension
|
||||||
for v in set(v.ref() for v in tape.watched_variables())
|
for v in set(v.ref() for v in variable_watcher.watched_variables())
|
||||||
if all(v.deref() is not i for i in all_inputs)
|
if all(v.deref() is not i for i in all_inputs)
|
||||||
]
|
]
|
||||||
grad_argspec = tf_inspect.getfullargspec(grad_fn)
|
grad_argspec = tf_inspect.getfullargspec(grad_fn)
|
||||||
|
|||||||
@ -665,6 +665,23 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
|
return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// TFE_Py_VariableWatcher logic.
|
||||||
|
m.def("TFE_Py_VariableWatcherNew",
|
||||||
|
[]() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
|
||||||
|
m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
|
||||||
|
TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
|
||||||
|
});
|
||||||
|
m.def("TFE_Py_VariableWatcherVariableAccessed",
|
||||||
|
[](const py::handle& variable) {
|
||||||
|
TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
|
||||||
|
});
|
||||||
|
m.def("TFE_Py_VariableWatcherWatchedVariables",
|
||||||
|
[](const py::handle& variable_watcher) {
|
||||||
|
return tensorflow::PyoOrThrow(
|
||||||
|
TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
|
||||||
|
});
|
||||||
|
|
||||||
|
// TFE_Py_ForwardAccumulator logic.
|
||||||
m.def("TFE_Py_ForwardAccumulatorNew", []() {
|
m.def("TFE_Py_ForwardAccumulatorNew", []() {
|
||||||
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew());
|
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew());
|
||||||
});
|
});
|
||||||
|
|||||||
@ -173,6 +173,10 @@ TFE_Py_TapeGradient
|
|||||||
TFE_Py_FastPathExecute_C
|
TFE_Py_FastPathExecute_C
|
||||||
TFE_Py_RecordGradient
|
TFE_Py_RecordGradient
|
||||||
TFE_Py_TapeWatchedVariables
|
TFE_Py_TapeWatchedVariables
|
||||||
|
TFE_Py_VariableWatcherNew
|
||||||
|
TFE_Py_VariableWatcherRemove
|
||||||
|
TFE_Py_VariableWatcherVariableAccessed
|
||||||
|
TFE_Py_VariableWatcherWatchedVariables
|
||||||
TFE_Py_ForwardAccumulatorNew
|
TFE_Py_ForwardAccumulatorNew
|
||||||
TFE_Py_ForwardAccumulatorSetAdd
|
TFE_Py_ForwardAccumulatorSetAdd
|
||||||
TFE_Py_ForwardAccumulatorSetRemove
|
TFE_Py_ForwardAccumulatorSetRemove
|
||||||
|
|||||||
@ -111,6 +111,7 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/python/distribute:distribute_test_lib_pip",
|
"//tensorflow/python/distribute:distribute_test_lib_pip",
|
||||||
"//tensorflow/python:loss_scale",
|
"//tensorflow/python:loss_scale",
|
||||||
"//tensorflow/python:loss_scale_optimizer",
|
"//tensorflow/python:loss_scale_optimizer",
|
||||||
|
"//tensorflow/python:memory_checker",
|
||||||
"//tensorflow/python:meta_graph_testdata",
|
"//tensorflow/python:meta_graph_testdata",
|
||||||
"//tensorflow/python:util_example_parser_configuration",
|
"//tensorflow/python:util_example_parser_configuration",
|
||||||
"//tensorflow/python/data/benchmarks:benchmark_base",
|
"//tensorflow/python/data/benchmarks:benchmark_base",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user