Do not create an extra GradientTape in custom_gradient/recompute_grad. The extra tape leads the tf.function gradient code to believe that the user intends to compute higher order derivatives. This requires generating a forward function with all possible side outputs which is expensive. This also doesn't work well with control flow and causes the added test to fail.
Instead, we create a VariableWatcher object that keeps track of variables that have been accessed. PiperOrigin-RevId: 307157027 Change-Id: Ifd628b421dc725ad2366af2f6f63cf52dd1511e9
This commit is contained in:
parent
4164702c55
commit
65b4e47ece
@ -331,6 +331,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:embedding_ops",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:memory_checker",
|
||||
"//tensorflow/python:nn_grad",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
@ -662,6 +663,7 @@ tf_py_test(
|
||||
deps = [
|
||||
":backprop",
|
||||
":context",
|
||||
":tape",
|
||||
":test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework.memory_checker import MemoryChecker
|
||||
from tensorflow.python.layers.pooling import max_pooling3d
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -1532,6 +1533,39 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIn('gradient_tape/my_scope/', op.name)
|
||||
self.assertEqual(num_sin_ops_found, 2)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testRecomputeGradWithNestedFunctionAndWhileLoop(self):
|
||||
|
||||
@custom_gradient.recompute_grad
|
||||
@def_function.function
|
||||
def outer(x):
|
||||
|
||||
@def_function.function
|
||||
def middle(y):
|
||||
|
||||
@def_function.function
|
||||
def inner(z):
|
||||
return z + 1
|
||||
|
||||
i = constant_op.constant(0.0)
|
||||
c = lambda y, i: i < 10.
|
||||
b = lambda y, i: (inner(y), i + 1.0)
|
||||
y, i = control_flow_ops.while_loop(c, b, [y, i])
|
||||
|
||||
return y
|
||||
|
||||
return middle(x)
|
||||
|
||||
with MemoryChecker() as memory_checker:
|
||||
for _ in range(5):
|
||||
x = variables.Variable(1.0, name='x')
|
||||
with backprop.GradientTape():
|
||||
y = outer(x)
|
||||
self.assertAllEqual(y, 11.0)
|
||||
|
||||
memory_checker.report()
|
||||
memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
|
||||
class JacobianTest(test.TestCase):
|
||||
|
||||
|
@ -331,6 +331,22 @@ PyObject* TFE_Py_ForwardAccumulatorPopState();
|
||||
// appended to `tensors`.
|
||||
PyObject* TFE_Py_PackJVPs(PyObject* tensors);
|
||||
|
||||
// Variable Watcher methods.
|
||||
|
||||
// Creates a new variable watcher and adds it to the set of active variable
|
||||
// watchers.
|
||||
PyObject* TFE_Py_VariableWatcherNew();
|
||||
|
||||
// Removes the passed variable watcher from the set of active variable watchers.
|
||||
void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher);
|
||||
|
||||
// Notifies all variable watchers that a variable has been accessed.
|
||||
void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable);
|
||||
|
||||
// Returns all variables watched by the given variable_watcher in the order
|
||||
// those variables were created.
|
||||
PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher);
|
||||
|
||||
// Returns an EagerTensor of dimension [len(`tensors`)] containing
|
||||
// the `slice_dim`'th dimension of each tensor in `tensors`. In other words,
|
||||
// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
|
||||
|
@ -1375,38 +1375,24 @@ PyObject* PyTapeTensor::ZerosLike() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
class GradientTape
|
||||
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
|
||||
PyTapeTensor> {
|
||||
// Keeps track of all variables that have been accessed during execution.
|
||||
class VariableWatcher {
|
||||
public:
|
||||
explicit GradientTape(bool persistent, bool watch_accessed_variables)
|
||||
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
|
||||
PyTapeTensor>(persistent),
|
||||
watch_accessed_variables_(watch_accessed_variables) {}
|
||||
VariableWatcher() {}
|
||||
|
||||
virtual ~GradientTape() {
|
||||
~VariableWatcher() {
|
||||
for (const IdAndVariable& v : watched_variables_) {
|
||||
Py_DECREF(v.variable);
|
||||
}
|
||||
}
|
||||
|
||||
void VariableAccessed(PyObject* v) {
|
||||
if (watch_accessed_variables_) {
|
||||
WatchVariable(v);
|
||||
}
|
||||
}
|
||||
|
||||
void WatchVariable(PyObject* v) {
|
||||
tensorflow::int64 WatchVariable(PyObject* v) {
|
||||
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
|
||||
if (handle == nullptr) {
|
||||
return;
|
||||
return -1;
|
||||
}
|
||||
tensorflow::int64 id = FastTensorId(handle.get());
|
||||
|
||||
if (!PyErr_Occurred()) {
|
||||
this->Watch(id);
|
||||
}
|
||||
|
||||
tensorflow::mutex_lock l(watched_variables_mu_);
|
||||
auto insert_result = watched_variables_.emplace(id, v);
|
||||
|
||||
@ -1415,6 +1401,8 @@ class GradientTape
|
||||
// variable.
|
||||
Py_INCREF(v);
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
PyObject* GetVariablesAsPyTuple() {
|
||||
@ -1445,12 +1433,45 @@ class GradientTape
|
||||
}
|
||||
};
|
||||
|
||||
bool watch_accessed_variables_;
|
||||
tensorflow::mutex watched_variables_mu_;
|
||||
std::set<IdAndVariable, CompareById> watched_variables_
|
||||
TF_GUARDED_BY(watched_variables_mu_);
|
||||
};
|
||||
|
||||
class GradientTape
|
||||
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
|
||||
PyTapeTensor> {
|
||||
public:
|
||||
explicit GradientTape(bool persistent, bool watch_accessed_variables)
|
||||
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
|
||||
PyTapeTensor>(persistent),
|
||||
watch_accessed_variables_(watch_accessed_variables) {}
|
||||
|
||||
virtual ~GradientTape() {}
|
||||
|
||||
void VariableAccessed(PyObject* v) {
|
||||
if (watch_accessed_variables_) {
|
||||
WatchVariable(v);
|
||||
}
|
||||
}
|
||||
|
||||
void WatchVariable(PyObject* v) {
|
||||
tensorflow::int64 id = variable_watcher_.WatchVariable(v);
|
||||
|
||||
if (!PyErr_Occurred()) {
|
||||
this->Watch(id);
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* GetVariablesAsPyTuple() {
|
||||
return variable_watcher_.GetVariablesAsPyTuple();
|
||||
}
|
||||
|
||||
private:
|
||||
bool watch_accessed_variables_;
|
||||
VariableWatcher variable_watcher_;
|
||||
};
|
||||
|
||||
typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
|
||||
PyTapeTensor>
|
||||
ForwardAccumulator;
|
||||
@ -1535,6 +1556,41 @@ static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
|
||||
"TFE_Py_ForwardAccumulator objects", /* tp_doc */
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
/* Type-specific fields go here. */
|
||||
VariableWatcher* variable_watcher;
|
||||
} TFE_Py_VariableWatcher;
|
||||
|
||||
static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
|
||||
delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
|
||||
->variable_watcher;
|
||||
Py_TYPE(variable_watcher)->tp_free(variable_watcher);
|
||||
}
|
||||
|
||||
static PyTypeObject TFE_Py_VariableWatcher_Type = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
|
||||
sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
&TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
nullptr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||
"TFE_Py_VariableWatcher objects", /* tp_doc */
|
||||
};
|
||||
|
||||
// Note: in the current design no mutex is needed here because of the python
|
||||
// GIL, which is always held when any TFE_Py_* methods are called. We should
|
||||
// revisit this if/when decide to not hold the GIL while manipulating the tape
|
||||
@ -1548,6 +1604,18 @@ tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
|
||||
return tape_set.get();
|
||||
}
|
||||
|
||||
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
|
||||
GetVariableWatcherSet() {
|
||||
thread_local std::unique_ptr<
|
||||
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
|
||||
variable_watcher_set = nullptr;
|
||||
if (variable_watcher_set == nullptr) {
|
||||
variable_watcher_set.reset(
|
||||
new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
|
||||
}
|
||||
return variable_watcher_set.get();
|
||||
}
|
||||
|
||||
// A linked hash set, where iteration is in insertion order.
|
||||
//
|
||||
// Nested accumulators rely on op recording happening in insertion order, so an
|
||||
@ -1670,6 +1738,16 @@ class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
|
||||
}
|
||||
};
|
||||
|
||||
class SafeVariableWatcherSet
|
||||
: public SafeSetCopy<
|
||||
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
|
||||
public:
|
||||
SafeVariableWatcherSet()
|
||||
: SafeSetCopy<
|
||||
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
|
||||
*GetVariableWatcherSet()) {}
|
||||
};
|
||||
|
||||
bool* ThreadTapeIsStopped() {
|
||||
thread_local bool thread_tape_is_stopped{false};
|
||||
return &thread_tape_is_stopped;
|
||||
@ -2037,6 +2115,36 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
|
||||
return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_VariableWatcherNew() {
|
||||
TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
|
||||
TFE_Py_VariableWatcher* variable_watcher =
|
||||
PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
|
||||
variable_watcher->variable_watcher = new VariableWatcher();
|
||||
Py_INCREF(variable_watcher);
|
||||
GetVariableWatcherSet()->insert(variable_watcher);
|
||||
return reinterpret_cast<PyObject*>(variable_watcher);
|
||||
}
|
||||
|
||||
void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
|
||||
auto* stack = GetVariableWatcherSet();
|
||||
stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
|
||||
// We kept a reference to the variable watcher in the set to ensure it
|
||||
// wouldn't get deleted under us; cleaning it up here.
|
||||
Py_DECREF(variable_watcher);
|
||||
}
|
||||
|
||||
void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
|
||||
for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
|
||||
variable_watcher->variable_watcher->WatchVariable(variable);
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
|
||||
return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
|
||||
->variable_watcher->GetVariablesAsPyTuple();
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
|
||||
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
||||
@ -3086,6 +3194,7 @@ void MaybeNotifyVariableAccessed(PyObject* input) {
|
||||
PyObject_GetAttrString(input, "_trainable"));
|
||||
if (trainable.get() == Py_False) return;
|
||||
TFE_Py_TapeVariableAccessed(input);
|
||||
TFE_Py_VariableWatcherVariableAccessed(input);
|
||||
}
|
||||
|
||||
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
|
||||
|
@ -58,6 +58,36 @@ def watch(tape, tensor):
|
||||
pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
|
||||
|
||||
|
||||
class VariableWatcher(object):
|
||||
"""A scope that tracks all trainable variable accesses within it.
|
||||
|
||||
This explicitly ignores variables that are not marked as trainable.
|
||||
|
||||
Sample usage:
|
||||
|
||||
var = tf.Variable(0.0)
|
||||
with VariableWatcher() as variable_watcher:
|
||||
var.assign_add(1.0)
|
||||
|
||||
assert variable_watcher.watched_variables == [var]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._variable_watcher = None
|
||||
|
||||
def __enter__(self):
|
||||
self._variable_watcher = pywrap_tfe.TFE_Py_VariableWatcherNew()
|
||||
return self
|
||||
|
||||
def __exit__(self, typ, value, traceback):
|
||||
pywrap_tfe.TFE_Py_VariableWatcherRemove(self._variable_watcher)
|
||||
|
||||
def watched_variables(self):
|
||||
"""Returns a tuple of variables accessed under this scope."""
|
||||
return pywrap_tfe.TFE_Py_VariableWatcherWatchedVariables(
|
||||
self._variable_watcher)
|
||||
|
||||
|
||||
def watch_variable(tape, variable):
|
||||
"""Marks this variable to be watched by the given tape."""
|
||||
strategy, context = (
|
||||
@ -68,6 +98,7 @@ def watch_variable(tape, variable):
|
||||
variables = strategy.experimental_local_results(variable)
|
||||
for var in variables:
|
||||
pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
|
||||
pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
|
||||
|
||||
|
||||
def variable_accessed(variable):
|
||||
@ -84,6 +115,7 @@ def variable_accessed(variable):
|
||||
variables = strategy.experimental_local_results(variable)
|
||||
for var in variables:
|
||||
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
|
||||
pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
|
||||
|
||||
|
||||
def variables_accessed(variables):
|
||||
@ -107,6 +139,7 @@ def variables_accessed(variables):
|
||||
|
||||
for var in accessed:
|
||||
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
|
||||
pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
|
||||
|
||||
|
||||
def pop_tape(tape):
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -31,6 +32,7 @@ from tensorflow.python.ops import math_ops
|
||||
# Importing nn_grad for the registration functions.
|
||||
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
@ -166,5 +168,48 @@ class TapeTest(test.TestCase):
|
||||
self.assertAllEqual(g, 1.0)
|
||||
|
||||
|
||||
class VariableWatcherTest(test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
var1 = variables.Variable(0.0)
|
||||
var2 = variables.Variable(1.0)
|
||||
with tape.VariableWatcher() as variable_watcher:
|
||||
var1.assign_add(1.0)
|
||||
var2.assign_add(2.0)
|
||||
|
||||
self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
|
||||
|
||||
def testNonTrainableVariables(self):
|
||||
var1 = variables.Variable(0.0)
|
||||
var2 = variables.Variable(1.0, trainable=False)
|
||||
with tape.VariableWatcher() as variable_watcher:
|
||||
var1.assign_add(1.0)
|
||||
var2.assign_add(2.0)
|
||||
|
||||
self.assertAllEqual(variable_watcher.watched_variables(), (var1,))
|
||||
|
||||
def testMultipleScopes(self):
|
||||
var1 = variables.Variable(0.0)
|
||||
var2 = variables.Variable(1.0)
|
||||
with tape.VariableWatcher() as variable_watcher1:
|
||||
var1.assign_add(1.0)
|
||||
with tape.VariableWatcher() as variable_watcher2:
|
||||
var2.assign_add(2.0)
|
||||
|
||||
# variable_watcher1 should see both vars and variable_watcher2 only sees
|
||||
# var2
|
||||
self.assertAllEqual(variable_watcher1.watched_variables(), (var1, var2))
|
||||
self.assertAllEqual(variable_watcher2.watched_variables(), (var2,))
|
||||
|
||||
def testCreateVariables(self):
|
||||
with tape.VariableWatcher() as variable_watcher:
|
||||
var1 = variables.Variable(0.0)
|
||||
var2 = variables.Variable(1.0)
|
||||
var1.assign_add(1.0)
|
||||
var2.assign_add(2.0)
|
||||
|
||||
self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -315,7 +315,7 @@ def _graph_mode_decorator(f, args, kwargs):
|
||||
v.ref() for v in current_var_scope.global_variables() +
|
||||
current_var_scope.local_variables()
|
||||
])
|
||||
with backprop.GradientTape() as tape:
|
||||
with tape_lib.VariableWatcher() as variable_watcher:
|
||||
result, grad_fn = f(*args)
|
||||
after_vars = set([
|
||||
v.ref() for v in current_var_scope.global_variables() +
|
||||
@ -332,8 +332,9 @@ def _graph_mode_decorator(f, args, kwargs):
|
||||
# The variables that grad_fn needs to return gradients for are the set of
|
||||
# variables used that are *not* part of the inputs.
|
||||
inputs = args
|
||||
variables_in_tape = frozenset([v.ref() for v in tape.watched_variables()
|
||||
]) - frozenset(v.ref() for v in inputs)
|
||||
variables_in_tape = frozenset([
|
||||
v.ref() for v in variable_watcher.watched_variables()
|
||||
]) - frozenset(v.ref() for v in inputs)
|
||||
variables_in_subgraph = frozenset([
|
||||
v.ref()
|
||||
for v in get_dependent_variables(input_ops=inputs, output_ops=result)
|
||||
@ -405,14 +406,14 @@ def _graph_mode_decorator(f, args, kwargs):
|
||||
|
||||
def _eager_mode_decorator(f, args, kwargs):
|
||||
"""Implement custom gradient decorator for eager mode."""
|
||||
with backprop.GradientTape() as tape:
|
||||
with tape_lib.VariableWatcher() as variable_watcher:
|
||||
result, grad_fn = f(*args, **kwargs)
|
||||
all_inputs = list(args) + list(kwargs.values())
|
||||
# The variables that grad_fn needs to return gradients for are the set of
|
||||
# variables used that are *not* part of the inputs.
|
||||
variables = [
|
||||
v.deref() # pylint: disable=g-complex-comprehension
|
||||
for v in set(v.ref() for v in tape.watched_variables())
|
||||
for v in set(v.ref() for v in variable_watcher.watched_variables())
|
||||
if all(v.deref() is not i for i in all_inputs)
|
||||
]
|
||||
grad_argspec = tf_inspect.getfullargspec(grad_fn)
|
||||
|
@ -665,6 +665,23 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
|
||||
});
|
||||
|
||||
// TFE_Py_VariableWatcher logic.
|
||||
m.def("TFE_Py_VariableWatcherNew",
|
||||
[]() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
|
||||
m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
|
||||
TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
|
||||
});
|
||||
m.def("TFE_Py_VariableWatcherVariableAccessed",
|
||||
[](const py::handle& variable) {
|
||||
TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
|
||||
});
|
||||
m.def("TFE_Py_VariableWatcherWatchedVariables",
|
||||
[](const py::handle& variable_watcher) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
|
||||
});
|
||||
|
||||
// TFE_Py_ForwardAccumulator logic.
|
||||
m.def("TFE_Py_ForwardAccumulatorNew", []() {
|
||||
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew());
|
||||
});
|
||||
|
@ -173,6 +173,10 @@ TFE_Py_TapeGradient
|
||||
TFE_Py_FastPathExecute_C
|
||||
TFE_Py_RecordGradient
|
||||
TFE_Py_TapeWatchedVariables
|
||||
TFE_Py_VariableWatcherNew
|
||||
TFE_Py_VariableWatcherRemove
|
||||
TFE_Py_VariableWatcherVariableAccessed
|
||||
TFE_Py_VariableWatcherWatchedVariables
|
||||
TFE_Py_ForwardAccumulatorNew
|
||||
TFE_Py_ForwardAccumulatorSetAdd
|
||||
TFE_Py_ForwardAccumulatorSetRemove
|
||||
|
@ -111,6 +111,7 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/python/distribute:distribute_test_lib_pip",
|
||||
"//tensorflow/python:loss_scale",
|
||||
"//tensorflow/python:loss_scale_optimizer",
|
||||
"//tensorflow/python:memory_checker",
|
||||
"//tensorflow/python:meta_graph_testdata",
|
||||
"//tensorflow/python:util_example_parser_configuration",
|
||||
"//tensorflow/python/data/benchmarks:benchmark_base",
|
||||
|
Loading…
Reference in New Issue
Block a user