Use a tf.function to more efficiently compute op jvps

Allows the unused backward computation to be pruned out.

Does not change custom_gradient or function forward-mode computations.

Some fiddling with the memory checking on the unit tests, since tf.function creates persistent symbolic Tensors the first time it's called. This means we need to do warmup runs and ignore Tensors allocated there.

Forward gradients still need some followups after this:
  - Functions should have a special-cased forward function so that they're efficient when executing eagerly.
  - Watching variables on an accumulator should be possible

After those the remaining case is custom gradients, which are probably fine to leave as they are for now (they work, they're just a bit less efficient than they could be if the user provided a jvp or told us the code was safe to wrap in a tf.function).

From //tensorflow/python/eager:benchmarks_test:

benchmark_forwardprop_in_defun_matmul_256_by_2096_CPU 487 examples/second no change
benchmark_forwardprop_matmul_256_by_2096_CPU          406 examples/second 1.6x speedup
benchmark_forwardprop_of_defun_matmul_256_by_2096_CPU 176 examples/second no change

benchmark_forwardprop_in_defun_matmul_100_by_784_CPU 2872 examples/second no change
benchmark_forwardprop_matmul_100_by_784_CPU          1766 examples/second 1.4x speedup
benchmark_forwardprop_of_defun_matmul_100_by_784_CPU  909 examples/second no change

PiperOrigin-RevId: 257832992
This commit is contained in:
Allen Lavoie 2019-07-12 11:02:54 -07:00 committed by TensorFlower Gardener
parent d329b403ed
commit acab6a2051
7 changed files with 280 additions and 21 deletions

View File

@ -212,6 +212,11 @@ class ForwardAccumulator {
// Tensor associated with `tensor_id` is deleted.
void DeleteGradient(int64 tensor_id);
// Describes a callback for special-cased and more efficient jvp computation.
typedef std::function<Status(const std::vector<Gradient*>&,
std::vector<Gradient*>*)>
ForwardFunction;
// Runs forward autodiff. Should be called whenever a new operation is
// available and the accumulator is active.
//
@ -222,6 +227,12 @@ class ForwardAccumulator {
// between calls to ShouldRecord and Accumulator), and its outputs
// (`output_tensors`).
//
// If provided, a non-null `forward_function` will be used instead of the
// backward function (`backward_function_getter` /
// `backward_function_deleter`) to compute jvps for this operation. If
// `forward_function` is null, a GradientTape is used on the backward function
// to compute the jvp, which will waste computation when executing eagerly.
//
// Unlike GradientTape::RecordOperation, Accumulate runs gradient computation
// immediately. It stores the results, which feed into Accumulate for future
// operations and may be fetched by calling FetchJVP. ForwardAccumulator
@ -237,6 +248,7 @@ class ForwardAccumulator {
const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
const ForwardFunction* forward_function,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter);
@ -930,6 +942,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
const ForwardAccumulator<Gradient, BackwardFunction,
TapeTensor>::ForwardFunction* forward_function,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
if (backward_tape_ != nullptr) {
@ -981,23 +995,36 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
}
}
// Avoid infinite recursion. Whichever forward function we run, it'll end up
// executing ops, and we don't want to watch those with this accumulator.
accumulating_ = true;
auto reset_accumulating =
gtl::MakeCleanup([this] { this->accumulating_ = false; });
std::vector<Gradient*> forward_grads;
TF_RETURN_IF_ERROR(
ForwardpropFromTape(output_tensors, backward_function_getter,
backward_function_deleter, in_grads, &forward_grads));
if (forward_function == nullptr) {
// We have no special-cased forward gradient. Fall back to running the
// backward function under a gradient tape.
TF_RETURN_IF_ERROR(ForwardpropFromTape(
output_tensors, backward_function_getter, backward_function_deleter,
in_grads, &forward_grads));
} else {
TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads));
}
for (int i = 0; i < forward_grads.size(); ++i) {
if (forward_grads[i] != nullptr) {
int64 tensor_id = output_tensors[i].GetID();
auto existing = accumulated_gradients_.find(tensor_id);
if (existing != accumulated_gradients_.end()) {
vspace_.DeleteGradient(existing->second);
// This is a somewhat odd case to be in, since it means we have two
// operations which supposedly both created the same Tensor. It comes up
// in recompute_grad, where the gradients have the same value. However,
// only the original gradient is connected to everything else, so we
// should still use that.
vspace_.DeleteGradient(forward_grads[i]);
} else {
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
}
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
}
}
return Status::OK();

View File

@ -250,6 +250,7 @@ cuda_py_test(
":forwardprop",
":test",
],
shard_count = 5,
xla_enable_strict_auto_jit = True,
)

View File

@ -21,13 +21,83 @@ from __future__ import print_function
import functools
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import execute
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
# TODO(allenl): Special-case op gradients and tf.functions to avoid unnecessary
# evaluation of gradient functions.
# TODO(allenl): experimental_relax_shapes for gradients which rely on static
# shape information may be underspecialized. We may want hand-written forward
# implementations.
@def_function.function(experimental_relax_shapes=True)
def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
"""Computes a Jacobian-vector product for an op.
Note that this function would be wasteful if executed eagerly. It runs the
backward gradient function and throws away the result just to record its
operations on a GradientTape. These unused ops are pruned away when this
function is traced.
Args:
op_name: A string, the type of operation being executed.
attr_tuple: Attributes of the operation.
inputs: A flat list of input Tensors to the operation.
outputs: A flat list of output Tensors from the operation.
tangents: A flat list of Tensors, same shape as `inputs`.
Returns:
A flat list of tangents corresponding to `outputs`.
"""
float_inputs = []
float_indices = []
nontrivial_tangents = []
for input_index, tensor in enumerate(inputs):
if tensor.dtype.is_floating:
float_inputs.append(tensor)
float_indices.append(input_index)
nontrivial_tangents.append(tangents[input_index])
with backprop.GradientTape() as transpose_tape:
with backprop.GradientTape() as backfunc_tape:
backfunc_tape.watch(float_inputs)
execute.record_gradient(op_name, inputs, attr_tuple, outputs,
"forward_op_replay")
forwardprop_aids = []
float_outputs = []
nontrivial_output_indices = []
for output_index, output in enumerate(outputs):
if output.dtype.is_floating:
forwardprop_aids.append(
array_ops.ones_like(output, name="unused_forwardprop_aid"))
float_outputs.append(output)
nontrivial_output_indices.append(output_index)
transpose_tape.watch(forwardprop_aids)
grads = backfunc_tape.gradient(
float_outputs,
float_inputs,
forwardprop_aids,
unconnected_gradients=UnconnectedGradients.ZERO)
nontrivial_output_tangents = transpose_tape.gradient(
grads, forwardprop_aids, output_gradients=nontrivial_tangents)
output_tangents = [None] * len(outputs)
for index, tangent in zip(nontrivial_output_indices,
nontrivial_output_tangents):
output_tangents[index] = tangent
return output_tangents
pywrap_tensorflow.TFE_Py_RegisterForwardGradientFunction(_forward_gradient)
class ForwardGradientAccumulator(object):
"""Computes Jacobian-vector products using forward-mode autodiff.

View File

@ -22,6 +22,7 @@ import weakref
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import forwardprop
@ -32,6 +33,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import test
from tensorflow.python.util import nest
@ -74,7 +76,10 @@ def _grad(f, argnums=0):
with backprop.GradientTape() as tape:
tape.watch(params)
primals_out = f(*params)
return tape.gradient(primals_out, params[argnums])
return tape.gradient(
primals_out,
params[argnums],
unconnected_gradients=UnconnectedGradients.ZERO)
return _f
@ -93,8 +98,8 @@ def _test_gradients(testcase,
atol=1e-6):
"""Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients."""
if order < 1:
raise ValueError("`order` should be a positive integer, got '{}'."
.format(order))
raise ValueError(
"`order` should be a positive integer, got '{}'.".format(order))
if order > 1:
_test_gradients(
testcase=testcase,
@ -117,6 +122,60 @@ def _test_gradients(testcase,
class ForwardpropTest(test.TestCase):
def testForwardGradientFunction(self):
add_outputs = (constant_op.constant(4.),)
vp, = forwardprop._forward_gradient(
op_name="Add",
attr_tuple=(),
inputs=(constant_op.constant(1.), constant_op.constant(3.)),
outputs=add_outputs,
tangents=(
constant_op.constant(1.),
constant_op.constant(5.),
))
self.assertAllClose(1. + 5., self.evaluate(vp))
mul_outputs = (constant_op.constant([20.]),)
vp, = forwardprop._forward_gradient(
op_name="Mul",
attr_tuple=(),
inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
outputs=mul_outputs,
tangents=(
constant_op.constant([2.]),
constant_op.constant([3.]),
))
self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))
def testForwardGradientFunctionUsedByAccumulatorForOps(self):
previous_fn = forwardprop._forward_gradient
try:
with forwardprop.ForwardGradientAccumulator() as acc:
x = constant_op.constant(1.)
acc.watch(x, 2.)
y = x + x
pywrap_tensorflow.TFE_Py_RegisterForwardGradientFunction(
lambda *args, **kwargs: [constant_op.constant(-15.)])
z = x + x
self.assertAllClose(4., acc.jvp(y))
self.assertAllClose(-15., acc.jvp(z))
finally:
pywrap_tensorflow.TFE_Py_RegisterForwardGradientFunction(previous_fn)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFunctionCacheLimited(self):
# Every time this test is executed, it will create a slightly larger Tensor
# and push it through Add's gradient. Since we check for new pyobjects after
# the warmup, retracing each time without cleaning up old traces fails the
# test. It works because of experimental_relax_shapes.
execution_count = getattr(self, "_execution_count", 0)
self._execution_count = execution_count + 1
x = array_ops.zeros([execution_count])
with forwardprop.ForwardGradientAccumulator() as acc:
acc.watch(x, array_ops.ones_like(x))
y = x + x
self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testMultipleWatchesAdd(self):
x = constant_op.constant(-2.)
@ -151,14 +210,14 @@ class ForwardpropTest(test.TestCase):
self.assertIsNone(derived_tensor_weak())
self.assertIsNone(derived_tensor_grad_weak())
@test_util.assert_no_new_tensors
@test_util.assert_no_new_pyobjects_executing_eagerly
def testJVPManual(self):
primal, tangent = _jvp(math_ops.sin, (constant_op.constant(0.1),),
(constant_op.constant(0.2),))
self.assertAllClose(math_ops.sin(0.1), primal)
self.assertAllClose(math_ops.cos(0.1) * 0.2, tangent)
@test_util.assert_no_new_tensors
@test_util.assert_no_new_pyobjects_executing_eagerly
def testNumericHigherOrder(self):
def f(x):
@ -169,7 +228,7 @@ class ForwardpropTest(test.TestCase):
_test_gradients(
self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
@test_util.assert_no_new_tensors
@test_util.assert_no_new_pyobjects_executing_eagerly
def testCustomGradient(self):
@custom_gradient.custom_gradient
@ -182,7 +241,7 @@ class ForwardpropTest(test.TestCase):
_test_gradients(self, f, [constant_op.constant([1., 2.])], order=3)
@test_util.assert_no_new_tensors
@test_util.assert_no_new_pyobjects_executing_eagerly
def testCustomGradientRecomputeGrad(self):
@custom_gradient.recompute_grad
@ -257,7 +316,7 @@ class ForwardpropTest(test.TestCase):
tangents = constant_op.constant([3., 4., 5.])
_hvp(fun, (primals,), (tangents,))
@test_util.assert_no_new_tensors
@test_util.assert_no_new_pyobjects_executing_eagerly
def testHVPCorrectness(self):
def fun(x):
@ -285,6 +344,7 @@ class ForwardpropTest(test.TestCase):
self.assertAllClose(backback_hvp, forwardback_hvp_function)
if __name__ == '__main__':
if __name__ == "__main__":
# TODO(allenl): Also test with 1.x-style graph mode.
ops.enable_eager_execution()
test.main()

View File

@ -86,6 +86,14 @@ PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e);
// This function is not thread-safe.
PyObject* TFE_Py_RegisterGradientFunction(PyObject* e);
// Registers e as the forward_gradient_function. The registered function takes
// (op_name, attrs, inputs, outputs, tangents) and returns the output
// tangents. This function is used only for operations, not for custom gradients
// or functional ops.
//
// This function is not thread-safe.
PyObject* TFE_Py_RegisterForwardGradientFunction(PyObject* e);
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
// `exception` if not nullptr, else using the class registered via
// TFE_Py_RegisterExceptionClass), and returns -1.

View File

@ -713,6 +713,9 @@ PyObject* fallback_exception_class = nullptr;
// Python function that returns input gradients given output gradients.
PyObject* gradient_function = nullptr;
// Python function that returns output gradients given input gradients.
PyObject* forward_gradient_function = nullptr;
PyTypeObject* resource_variable_type = nullptr;
tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
@ -821,6 +824,23 @@ PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
}
}
PyObject* TFE_Py_RegisterForwardGradientFunction(PyObject* e) {
if (forward_gradient_function != nullptr) {
Py_DECREF(forward_gradient_function);
}
if (!PyCallable_Check(e)) {
forward_gradient_function = nullptr;
PyErr_SetString(PyExc_TypeError,
"TFE_Py_RegisterForwardGradientFunction: "
"Registered object should be function.");
return nullptr;
} else {
Py_INCREF(e);
forward_gradient_function = e;
Py_RETURN_NONE;
}
}
void RaiseFallbackException(const char* message) {
if (fallback_exception_class != nullptr) {
PyErr_SetString(fallback_exception_class, message);
@ -1825,7 +1845,8 @@ void TapeSetRecordOperation(
const std::vector<tensorflow::int64>& input_ids,
const std::vector<tensorflow::DataType>& input_dtypes,
const std::function<PyBackwardFunction*()>& backward_function_getter,
const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
const std::function<void(PyBackwardFunction*)>& backward_function_killer,
const ForwardAccumulator::ForwardFunction* forward_function) {
std::vector<PyTapeTensor> output_info;
tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
output_tensors, "expected a sequence of integer tensor ids"));
@ -1882,7 +1903,7 @@ void TapeSetRecordOperation(
for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
tensorflow::Status status = accumulator->accumulator->Accumulate(
op_type_str, input_info, output_info, input_ids, input_dtypes,
backward_function_getter, backward_function_killer);
forward_function, backward_function_getter, backward_function_killer);
if (PyErr_Occurred()) return; // Don't swallow Python exceptions.
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
return;
@ -1919,7 +1940,8 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
[backward_function](PyBackwardFunction* py_backward_function) {
Py_DECREF(backward_function);
delete py_backward_function;
});
},
nullptr /* No special-cased forward function */);
}
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
@ -2392,6 +2414,67 @@ PyObject* CopySequenceSettingIndicesToNull(
return result;
}
// Calls the registered forward_gradient_function, computing `output_tangents`
// from `input_tangents`. `output_tangents` must not be null.
//
// `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
// the forward function is being called.
tensorflow::Status CallForwardGradientFunction(
PyObject* op_name, PyObject* attrs, PyObject* inputs, PyObject* results,
const std::vector<PyObject*>& input_tangents,
std::vector<PyObject*>* output_tangents) {
if (forward_gradient_function == nullptr) {
return tensorflow::errors::Internal(
"No forward gradient function registered.");
}
tensorflow::Safe_PyObjectPtr py_input_tangents(
PyTuple_New(input_tangents.size()));
for (int i = 0; i < input_tangents.size(); ++i) {
PyObject* element;
if (input_tangents[i] == nullptr) {
element = Py_None;
} else {
element = input_tangents[i];
}
Py_INCREF(element);
PyTuple_SET_ITEM(py_input_tangents.get(), i, element);
}
// Normalize the input sequence to a tuple so it works with function
// caching; otherwise it may be an opaque _InputList object.
tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
tensorflow::Safe_PyObjectPtr callback_args(
Py_BuildValue("OOOOO", op_name, attrs, input_tuple.get(), results,
py_input_tangents.get()));
tensorflow::Safe_PyObjectPtr py_result(
PyObject_CallObject(forward_gradient_function, callback_args.get()));
if (py_result == nullptr || PyErr_Occurred()) {
return tensorflow::errors::Internal(
"forward gradient function threw exceptions");
}
if (py_result.get() == Py_None) {
// No connected gradients.
return tensorflow::Status::OK();
}
tensorflow::Safe_PyObjectPtr fast_result(PySequence_Fast(
py_result.get(), "expected a sequence of forward gradients"));
if (fast_result == nullptr) {
return tensorflow::errors::InvalidArgument(
"forward gradient function did not return a sequence.");
}
int len = PySequence_Fast_GET_SIZE(fast_result.get());
output_tangents->reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(fast_result.get(), i);
if (item == Py_None) {
output_tangents->push_back(nullptr);
} else {
Py_INCREF(item);
output_tangents->push_back(item);
}
}
return tensorflow::Status::OK();
}
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
PyObject* results, PyObject* name) {
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
@ -2450,6 +2533,14 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
op_inputs = inputs;
}
ForwardAccumulator::ForwardFunction forward_function =
[op_name, attrs, inputs, results](
const std::vector<PyObject*>& input_tangents,
std::vector<PyObject*>* output_tangents) {
return CallForwardGradientFunction(op_name, attrs, inputs, results,
input_tangents, output_tangents);
};
PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
TapeSetRecordOperation(
@ -2502,7 +2593,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
Py_DECREF(op_outputs);
delete backward_function;
});
},
&forward_function);
Py_DECREF(num_inputs);
if (op_outputs_tuple_created) Py_DECREF(op_outputs);

View File

@ -59,6 +59,7 @@ limitations under the License.
%rename("%s") TFE_Py_InitEagerTensor;
%rename("%s") TFE_Py_SetEagerTensorProfiler;
%rename("%s") TFE_Py_RegisterExceptionClass;
%rename("%s") TFE_Py_RegisterForwardGradientFunction;
%rename("%s") TFE_Py_RegisterGradientFunction;
%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
%rename("%s") TFE_Py_RegisterResourceVariableType;