Use a tf.function to more efficiently compute op jvps
Allows the unused backward computation to be pruned out. Does not change custom_gradient or function forward-mode computations. Some fiddling with the memory checking on the unit tests, since tf.function creates persistent symbolic Tensors the first time it's called. This means we need to do warmup runs and ignore Tensors allocated there. Forward gradients still need some followups after this: - Functions should have a special-cased forward function so that they're efficient when executing eagerly. - Watching variables on an accumulator should be possible After those the remaining case is custom gradients, which are probably fine to leave as they are for now (they work, they're just a bit less efficient than they could be if the user provided a jvp or told us the code was safe to wrap in a tf.function). From //tensorflow/python/eager:benchmarks_test: benchmark_forwardprop_in_defun_matmul_256_by_2096_CPU 487 examples/second no change benchmark_forwardprop_matmul_256_by_2096_CPU 406 examples/second 1.6x speedup benchmark_forwardprop_of_defun_matmul_256_by_2096_CPU 176 examples/second no change benchmark_forwardprop_in_defun_matmul_100_by_784_CPU 2872 examples/second no change benchmark_forwardprop_matmul_100_by_784_CPU 1766 examples/second 1.4x speedup benchmark_forwardprop_of_defun_matmul_100_by_784_CPU 909 examples/second no change PiperOrigin-RevId: 257832992
This commit is contained in:
parent
d329b403ed
commit
acab6a2051
@ -212,6 +212,11 @@ class ForwardAccumulator {
|
||||
// Tensor associated with `tensor_id` is deleted.
|
||||
void DeleteGradient(int64 tensor_id);
|
||||
|
||||
// Describes a callback for special-cased and more efficient jvp computation.
|
||||
typedef std::function<Status(const std::vector<Gradient*>&,
|
||||
std::vector<Gradient*>*)>
|
||||
ForwardFunction;
|
||||
|
||||
// Runs forward autodiff. Should be called whenever a new operation is
|
||||
// available and the accumulator is active.
|
||||
//
|
||||
@ -222,6 +227,12 @@ class ForwardAccumulator {
|
||||
// between calls to ShouldRecord and Accumulator), and its outputs
|
||||
// (`output_tensors`).
|
||||
//
|
||||
// If provided, a non-null `forward_function` will be used instead of the
|
||||
// backward function (`backward_function_getter` /
|
||||
// `backward_function_deleter`) to compute jvps for this operation. If
|
||||
// `forward_function` is null, a GradientTape is used on the backward function
|
||||
// to compute the jvp, which will waste computation when executing eagerly.
|
||||
//
|
||||
// Unlike GradientTape::RecordOperation, Accumulate runs gradient computation
|
||||
// immediately. It stores the results, which feed into Accumulate for future
|
||||
// operations and may be fetched by calling FetchJVP. ForwardAccumulator
|
||||
@ -237,6 +248,7 @@ class ForwardAccumulator {
|
||||
const std::vector<TapeTensor>& output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id,
|
||||
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
|
||||
const ForwardFunction* forward_function,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter);
|
||||
|
||||
@ -930,6 +942,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
const std::vector<TapeTensor>& output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id,
|
||||
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
|
||||
const ForwardAccumulator<Gradient, BackwardFunction,
|
||||
TapeTensor>::ForwardFunction* forward_function,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
|
||||
if (backward_tape_ != nullptr) {
|
||||
@ -981,23 +995,36 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
}
|
||||
}
|
||||
|
||||
// Avoid infinite recursion. Whichever forward function we run, it'll end up
|
||||
// executing ops, and we don't want to watch those with this accumulator.
|
||||
accumulating_ = true;
|
||||
auto reset_accumulating =
|
||||
gtl::MakeCleanup([this] { this->accumulating_ = false; });
|
||||
|
||||
std::vector<Gradient*> forward_grads;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ForwardpropFromTape(output_tensors, backward_function_getter,
|
||||
backward_function_deleter, in_grads, &forward_grads));
|
||||
|
||||
if (forward_function == nullptr) {
|
||||
// We have no special-cased forward gradient. Fall back to running the
|
||||
// backward function under a gradient tape.
|
||||
TF_RETURN_IF_ERROR(ForwardpropFromTape(
|
||||
output_tensors, backward_function_getter, backward_function_deleter,
|
||||
in_grads, &forward_grads));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads));
|
||||
}
|
||||
for (int i = 0; i < forward_grads.size(); ++i) {
|
||||
if (forward_grads[i] != nullptr) {
|
||||
int64 tensor_id = output_tensors[i].GetID();
|
||||
auto existing = accumulated_gradients_.find(tensor_id);
|
||||
if (existing != accumulated_gradients_.end()) {
|
||||
vspace_.DeleteGradient(existing->second);
|
||||
// This is a somewhat odd case to be in, since it means we have two
|
||||
// operations which supposedly both created the same Tensor. It comes up
|
||||
// in recompute_grad, where the gradients have the same value. However,
|
||||
// only the original gradient is connected to everything else, so we
|
||||
// should still use that.
|
||||
vspace_.DeleteGradient(forward_grads[i]);
|
||||
} else {
|
||||
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
|
||||
}
|
||||
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -250,6 +250,7 @@ cuda_py_test(
|
||||
":forwardprop",
|
||||
":test",
|
||||
],
|
||||
shard_count = 5,
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
|
@ -21,13 +21,83 @@ from __future__ import print_function
|
||||
import functools
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import execute
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# TODO(allenl): Special-case op gradients and tf.functions to avoid unnecessary
|
||||
# evaluation of gradient functions.
|
||||
# TODO(allenl): experimental_relax_shapes for gradients which rely on static
|
||||
# shape information may be underspecialized. We may want hand-written forward
|
||||
# implementations.
|
||||
@def_function.function(experimental_relax_shapes=True)
|
||||
def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
|
||||
"""Computes a Jacobian-vector product for an op.
|
||||
|
||||
Note that this function would be wasteful if executed eagerly. It runs the
|
||||
backward gradient function and throws away the result just to record its
|
||||
operations on a GradientTape. These unused ops are pruned away when this
|
||||
function is traced.
|
||||
|
||||
Args:
|
||||
op_name: A string, the type of operation being executed.
|
||||
attr_tuple: Attributes of the operation.
|
||||
inputs: A flat list of input Tensors to the operation.
|
||||
outputs: A flat list of output Tensors from the operation.
|
||||
tangents: A flat list of Tensors, same shape as `inputs`.
|
||||
|
||||
Returns:
|
||||
A flat list of tangents corresponding to `outputs`.
|
||||
"""
|
||||
float_inputs = []
|
||||
float_indices = []
|
||||
nontrivial_tangents = []
|
||||
for input_index, tensor in enumerate(inputs):
|
||||
if tensor.dtype.is_floating:
|
||||
float_inputs.append(tensor)
|
||||
float_indices.append(input_index)
|
||||
nontrivial_tangents.append(tangents[input_index])
|
||||
|
||||
with backprop.GradientTape() as transpose_tape:
|
||||
with backprop.GradientTape() as backfunc_tape:
|
||||
backfunc_tape.watch(float_inputs)
|
||||
execute.record_gradient(op_name, inputs, attr_tuple, outputs,
|
||||
"forward_op_replay")
|
||||
|
||||
forwardprop_aids = []
|
||||
float_outputs = []
|
||||
nontrivial_output_indices = []
|
||||
for output_index, output in enumerate(outputs):
|
||||
if output.dtype.is_floating:
|
||||
forwardprop_aids.append(
|
||||
array_ops.ones_like(output, name="unused_forwardprop_aid"))
|
||||
float_outputs.append(output)
|
||||
nontrivial_output_indices.append(output_index)
|
||||
|
||||
transpose_tape.watch(forwardprop_aids)
|
||||
grads = backfunc_tape.gradient(
|
||||
float_outputs,
|
||||
float_inputs,
|
||||
forwardprop_aids,
|
||||
unconnected_gradients=UnconnectedGradients.ZERO)
|
||||
nontrivial_output_tangents = transpose_tape.gradient(
|
||||
grads, forwardprop_aids, output_gradients=nontrivial_tangents)
|
||||
output_tangents = [None] * len(outputs)
|
||||
for index, tangent in zip(nontrivial_output_indices,
|
||||
nontrivial_output_tangents):
|
||||
output_tangents[index] = tangent
|
||||
return output_tangents
|
||||
|
||||
|
||||
pywrap_tensorflow.TFE_Py_RegisterForwardGradientFunction(_forward_gradient)
|
||||
|
||||
|
||||
class ForwardGradientAccumulator(object):
|
||||
"""Computes Jacobian-vector products using forward-mode autodiff.
|
||||
|
||||
|
@ -22,6 +22,7 @@ import weakref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import forwardprop
|
||||
@ -32,6 +33,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -74,7 +76,10 @@ def _grad(f, argnums=0):
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(params)
|
||||
primals_out = f(*params)
|
||||
return tape.gradient(primals_out, params[argnums])
|
||||
return tape.gradient(
|
||||
primals_out,
|
||||
params[argnums],
|
||||
unconnected_gradients=UnconnectedGradients.ZERO)
|
||||
|
||||
return _f
|
||||
|
||||
@ -93,8 +98,8 @@ def _test_gradients(testcase,
|
||||
atol=1e-6):
|
||||
"""Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients."""
|
||||
if order < 1:
|
||||
raise ValueError("`order` should be a positive integer, got '{}'."
|
||||
.format(order))
|
||||
raise ValueError(
|
||||
"`order` should be a positive integer, got '{}'.".format(order))
|
||||
if order > 1:
|
||||
_test_gradients(
|
||||
testcase=testcase,
|
||||
@ -117,6 +122,60 @@ def _test_gradients(testcase,
|
||||
|
||||
class ForwardpropTest(test.TestCase):
|
||||
|
||||
def testForwardGradientFunction(self):
|
||||
add_outputs = (constant_op.constant(4.),)
|
||||
vp, = forwardprop._forward_gradient(
|
||||
op_name="Add",
|
||||
attr_tuple=(),
|
||||
inputs=(constant_op.constant(1.), constant_op.constant(3.)),
|
||||
outputs=add_outputs,
|
||||
tangents=(
|
||||
constant_op.constant(1.),
|
||||
constant_op.constant(5.),
|
||||
))
|
||||
self.assertAllClose(1. + 5., self.evaluate(vp))
|
||||
|
||||
mul_outputs = (constant_op.constant([20.]),)
|
||||
vp, = forwardprop._forward_gradient(
|
||||
op_name="Mul",
|
||||
attr_tuple=(),
|
||||
inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
|
||||
outputs=mul_outputs,
|
||||
tangents=(
|
||||
constant_op.constant([2.]),
|
||||
constant_op.constant([3.]),
|
||||
))
|
||||
self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))
|
||||
|
||||
def testForwardGradientFunctionUsedByAccumulatorForOps(self):
|
||||
previous_fn = forwardprop._forward_gradient
|
||||
try:
|
||||
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||
x = constant_op.constant(1.)
|
||||
acc.watch(x, 2.)
|
||||
y = x + x
|
||||
pywrap_tensorflow.TFE_Py_RegisterForwardGradientFunction(
|
||||
lambda *args, **kwargs: [constant_op.constant(-15.)])
|
||||
z = x + x
|
||||
self.assertAllClose(4., acc.jvp(y))
|
||||
self.assertAllClose(-15., acc.jvp(z))
|
||||
finally:
|
||||
pywrap_tensorflow.TFE_Py_RegisterForwardGradientFunction(previous_fn)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testFunctionCacheLimited(self):
|
||||
# Every time this test is executed, it will create a slightly larger Tensor
|
||||
# and push it through Add's gradient. Since we check for new pyobjects after
|
||||
# the warmup, retracing each time without cleaning up old traces fails the
|
||||
# test. It works because of experimental_relax_shapes.
|
||||
execution_count = getattr(self, "_execution_count", 0)
|
||||
self._execution_count = execution_count + 1
|
||||
x = array_ops.zeros([execution_count])
|
||||
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||
acc.watch(x, array_ops.ones_like(x))
|
||||
y = x + x
|
||||
self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testMultipleWatchesAdd(self):
|
||||
x = constant_op.constant(-2.)
|
||||
@ -151,14 +210,14 @@ class ForwardpropTest(test.TestCase):
|
||||
self.assertIsNone(derived_tensor_weak())
|
||||
self.assertIsNone(derived_tensor_grad_weak())
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testJVPManual(self):
|
||||
primal, tangent = _jvp(math_ops.sin, (constant_op.constant(0.1),),
|
||||
(constant_op.constant(0.2),))
|
||||
self.assertAllClose(math_ops.sin(0.1), primal)
|
||||
self.assertAllClose(math_ops.cos(0.1) * 0.2, tangent)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testNumericHigherOrder(self):
|
||||
|
||||
def f(x):
|
||||
@ -169,7 +228,7 @@ class ForwardpropTest(test.TestCase):
|
||||
_test_gradients(
|
||||
self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testCustomGradient(self):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
@ -182,7 +241,7 @@ class ForwardpropTest(test.TestCase):
|
||||
|
||||
_test_gradients(self, f, [constant_op.constant([1., 2.])], order=3)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testCustomGradientRecomputeGrad(self):
|
||||
|
||||
@custom_gradient.recompute_grad
|
||||
@ -257,7 +316,7 @@ class ForwardpropTest(test.TestCase):
|
||||
tangents = constant_op.constant([3., 4., 5.])
|
||||
_hvp(fun, (primals,), (tangents,))
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testHVPCorrectness(self):
|
||||
|
||||
def fun(x):
|
||||
@ -285,6 +344,7 @@ class ForwardpropTest(test.TestCase):
|
||||
self.assertAllClose(backback_hvp, forwardback_hvp_function)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# TODO(allenl): Also test with 1.x-style graph mode.
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
@ -86,6 +86,14 @@ PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e);
|
||||
// This function is not thread-safe.
|
||||
PyObject* TFE_Py_RegisterGradientFunction(PyObject* e);
|
||||
|
||||
// Registers e as the forward_gradient_function. The registered function takes
|
||||
// (op_name, attrs, inputs, outputs, tangents) and returns the output
|
||||
// tangents. This function is used only for operations, not for custom gradients
|
||||
// or functional ops.
|
||||
//
|
||||
// This function is not thread-safe.
|
||||
PyObject* TFE_Py_RegisterForwardGradientFunction(PyObject* e);
|
||||
|
||||
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
|
||||
// `exception` if not nullptr, else using the class registered via
|
||||
// TFE_Py_RegisterExceptionClass), and returns -1.
|
||||
|
@ -713,6 +713,9 @@ PyObject* fallback_exception_class = nullptr;
|
||||
// Python function that returns input gradients given output gradients.
|
||||
PyObject* gradient_function = nullptr;
|
||||
|
||||
// Python function that returns output gradients given input gradients.
|
||||
PyObject* forward_gradient_function = nullptr;
|
||||
|
||||
PyTypeObject* resource_variable_type = nullptr;
|
||||
|
||||
tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
@ -821,6 +824,23 @@ PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_RegisterForwardGradientFunction(PyObject* e) {
|
||||
if (forward_gradient_function != nullptr) {
|
||||
Py_DECREF(forward_gradient_function);
|
||||
}
|
||||
if (!PyCallable_Check(e)) {
|
||||
forward_gradient_function = nullptr;
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"TFE_Py_RegisterForwardGradientFunction: "
|
||||
"Registered object should be function.");
|
||||
return nullptr;
|
||||
} else {
|
||||
Py_INCREF(e);
|
||||
forward_gradient_function = e;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
}
|
||||
|
||||
void RaiseFallbackException(const char* message) {
|
||||
if (fallback_exception_class != nullptr) {
|
||||
PyErr_SetString(fallback_exception_class, message);
|
||||
@ -1825,7 +1845,8 @@ void TapeSetRecordOperation(
|
||||
const std::vector<tensorflow::int64>& input_ids,
|
||||
const std::vector<tensorflow::DataType>& input_dtypes,
|
||||
const std::function<PyBackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
|
||||
const std::function<void(PyBackwardFunction*)>& backward_function_killer,
|
||||
const ForwardAccumulator::ForwardFunction* forward_function) {
|
||||
std::vector<PyTapeTensor> output_info;
|
||||
tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
|
||||
output_tensors, "expected a sequence of integer tensor ids"));
|
||||
@ -1882,7 +1903,7 @@ void TapeSetRecordOperation(
|
||||
for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
|
||||
tensorflow::Status status = accumulator->accumulator->Accumulate(
|
||||
op_type_str, input_info, output_info, input_ids, input_dtypes,
|
||||
backward_function_getter, backward_function_killer);
|
||||
forward_function, backward_function_getter, backward_function_killer);
|
||||
if (PyErr_Occurred()) return; // Don't swallow Python exceptions.
|
||||
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
|
||||
return;
|
||||
@ -1919,7 +1940,8 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
|
||||
[backward_function](PyBackwardFunction* py_backward_function) {
|
||||
Py_DECREF(backward_function);
|
||||
delete py_backward_function;
|
||||
});
|
||||
},
|
||||
nullptr /* No special-cased forward function */);
|
||||
}
|
||||
|
||||
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
|
||||
@ -2392,6 +2414,67 @@ PyObject* CopySequenceSettingIndicesToNull(
|
||||
return result;
|
||||
}
|
||||
|
||||
// Calls the registered forward_gradient_function, computing `output_tangents`
|
||||
// from `input_tangents`. `output_tangents` must not be null.
|
||||
//
|
||||
// `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
|
||||
// the forward function is being called.
|
||||
tensorflow::Status CallForwardGradientFunction(
|
||||
PyObject* op_name, PyObject* attrs, PyObject* inputs, PyObject* results,
|
||||
const std::vector<PyObject*>& input_tangents,
|
||||
std::vector<PyObject*>* output_tangents) {
|
||||
if (forward_gradient_function == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"No forward gradient function registered.");
|
||||
}
|
||||
tensorflow::Safe_PyObjectPtr py_input_tangents(
|
||||
PyTuple_New(input_tangents.size()));
|
||||
for (int i = 0; i < input_tangents.size(); ++i) {
|
||||
PyObject* element;
|
||||
if (input_tangents[i] == nullptr) {
|
||||
element = Py_None;
|
||||
} else {
|
||||
element = input_tangents[i];
|
||||
}
|
||||
Py_INCREF(element);
|
||||
PyTuple_SET_ITEM(py_input_tangents.get(), i, element);
|
||||
}
|
||||
// Normalize the input sequence to a tuple so it works with function
|
||||
// caching; otherwise it may be an opaque _InputList object.
|
||||
tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
|
||||
tensorflow::Safe_PyObjectPtr callback_args(
|
||||
Py_BuildValue("OOOOO", op_name, attrs, input_tuple.get(), results,
|
||||
py_input_tangents.get()));
|
||||
tensorflow::Safe_PyObjectPtr py_result(
|
||||
PyObject_CallObject(forward_gradient_function, callback_args.get()));
|
||||
if (py_result == nullptr || PyErr_Occurred()) {
|
||||
return tensorflow::errors::Internal(
|
||||
"forward gradient function threw exceptions");
|
||||
}
|
||||
if (py_result.get() == Py_None) {
|
||||
// No connected gradients.
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
tensorflow::Safe_PyObjectPtr fast_result(PySequence_Fast(
|
||||
py_result.get(), "expected a sequence of forward gradients"));
|
||||
if (fast_result == nullptr) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"forward gradient function did not return a sequence.");
|
||||
}
|
||||
int len = PySequence_Fast_GET_SIZE(fast_result.get());
|
||||
output_tangents->reserve(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_Fast_GET_ITEM(fast_result.get(), i);
|
||||
if (item == Py_None) {
|
||||
output_tangents->push_back(nullptr);
|
||||
} else {
|
||||
Py_INCREF(item);
|
||||
output_tangents->push_back(item);
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
PyObject* results, PyObject* name) {
|
||||
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
|
||||
@ -2450,6 +2533,14 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
op_inputs = inputs;
|
||||
}
|
||||
|
||||
ForwardAccumulator::ForwardFunction forward_function =
|
||||
[op_name, attrs, inputs, results](
|
||||
const std::vector<PyObject*>& input_tangents,
|
||||
std::vector<PyObject*>* output_tangents) {
|
||||
return CallForwardGradientFunction(op_name, attrs, inputs, results,
|
||||
input_tangents, output_tangents);
|
||||
};
|
||||
|
||||
PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
|
||||
|
||||
TapeSetRecordOperation(
|
||||
@ -2502,7 +2593,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
Py_DECREF(op_outputs);
|
||||
|
||||
delete backward_function;
|
||||
});
|
||||
},
|
||||
&forward_function);
|
||||
|
||||
Py_DECREF(num_inputs);
|
||||
if (op_outputs_tuple_created) Py_DECREF(op_outputs);
|
||||
|
@ -59,6 +59,7 @@ limitations under the License.
|
||||
%rename("%s") TFE_Py_InitEagerTensor;
|
||||
%rename("%s") TFE_Py_SetEagerTensorProfiler;
|
||||
%rename("%s") TFE_Py_RegisterExceptionClass;
|
||||
%rename("%s") TFE_Py_RegisterForwardGradientFunction;
|
||||
%rename("%s") TFE_Py_RegisterGradientFunction;
|
||||
%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
|
||||
%rename("%s") TFE_Py_RegisterResourceVariableType;
|
||||
|
Loading…
Reference in New Issue
Block a user