Merge pull request #41628 from abhichou4:exp/accumulator

PiperOrigin-RevId: 324254130
Change-Id: Ib4d7b222ade81d359a1f7002fe2e7fcffec8e055
This commit is contained in:
TensorFlower Gardener 2020-07-31 12:15:11 -07:00
commit ce3315d20c
6 changed files with 69 additions and 19 deletions

View File

@ -177,12 +177,12 @@ class GradientTape {
template <typename Gradient>
class ForwardFunction
: public std::function<Status(const std::vector<Gradient*>&,
std::vector<Gradient*>*)> {
std::vector<Gradient*>*, bool)> {
public:
template <typename lambda_type>
explicit ForwardFunction(lambda_type lambda)
: std::function<Status(const std::vector<Gradient*>&,
std::vector<Gradient*>*)>(lambda) {}
std::vector<Gradient*>*, bool)>(lambda) {}
};
// Computes Jacobian-vector products using forward-mode automatic
@ -205,8 +205,9 @@ class ForwardAccumulator {
// Does not take ownership of `vspace`, which must outlive the
// ForwardAccumulator.
explicit ForwardAccumulator(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
: vspace_(vspace) {
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
bool use_batch)
: vspace_(vspace), use_batch_(use_batch) {
call_state_.emplace(nullptr, false);
}
@ -314,6 +315,9 @@ class ForwardAccumulator {
// available in language bindings (e.g. Python).
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
// Decides if tangents are vectorized or not
bool use_batch_;
struct AccumulatorCallState {
AccumulatorCallState(
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
@ -1062,7 +1066,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
output_tensors, backward_function_getter, backward_function_deleter,
in_grads, &forward_grads));
} else {
TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads));
TF_RETURN_IF_ERROR(
(*forward_function)(in_grads, &forward_grads, use_batch_));
}
for (int i = 0; i < forward_grads.size(); ++i) {
if (forward_grads[i] != nullptr) {

View File

@ -29,6 +29,7 @@ from tensorflow.python.eager import forwardprop_util
from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
@ -219,7 +220,7 @@ pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
@tf_export("autodiff.ForwardAccumulator", v1=[])
class ForwardAccumulator(object):
class ForwardAccumulator():
"""Computes Jacobian-vector products ("JVP"s) using forward-mode autodiff.
Compare to `tf.GradientTape` which computes vector-Jacobian products ("VJP"s)
@ -349,7 +350,7 @@ class ForwardAccumulator(object):
ValueError: If the same tensor or variable is specified multiple times in
`primals`.
"""
self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew()
self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(False)
self._recording = False
primal_ids = set()
for primal in nest.flatten(primals):
@ -451,3 +452,32 @@ class ForwardAccumulator(object):
return result
return nest.map_structure(_fetch_jvp, primals)
@classmethod
def _batch_accumulator(cls, primals, tangents):
"""Factory constructor to test accumulator on batches of tangents.
Args:
primals: A tensor or nested structure of tensors to watch.
tangents: A tensor or nested structure of tensors, with the same nesting
structure as `primals`, with each element being a vector with compatible
shape `[None] + primal.shape` of the corresponding primal element.
Returns:
A batch accumulator object.
"""
acc = super(ForwardAccumulator, cls).__new__(cls, primals, tangents)
acc._recording = False
acc._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(True)
primal_ids = set()
for primal, tangent in zip(nest.flatten(primals), nest.flatten(tangents)):
tangent.shape.assert_is_compatible_with(
tensor_shape.TensorShape([None]) + primal.shape)
if id(primal) in primal_ids:
raise ValueError(
"Tensor {} was specified as a primal multiple times. This may "
"indicate an error. If it was intended, please sum the "
"corresponding tangents.")
primal_ids.add(id(primal))
acc._watch(primals, tangents)
return acc

View File

@ -1018,7 +1018,7 @@ class HessianTests(test.TestCase, parameterized.TestCase):
self.assertAllClose(hess_value, hessian_pfor)
class JacobianTests(test.TestCase, parameterized.TestCase):
class BatchTests(test.TestCase, parameterized.TestCase):
@parameterized.parameters([(math_ops.sin, (2, 3), 5),
(math_ops.sin, (2, 3, 4), 10)])
@ -1029,6 +1029,18 @@ class JacobianTests(test.TestCase, parameterized.TestCase):
_jvp_batch(f, primals, tangent_batch)[1],
_jvp_batch_matmul(f, primals, *tangent_batch))
def testBatchCorrectness(self):
x = constant_op.constant(2.0)
y = constant_op.constant(5.0)
tangents = (
constant_op.constant([1., 0., 1.]),
constant_op.constant([0., 1., 1.]),
)
with forwardprop.ForwardAccumulator._batch_accumulator((x, y),
tangents) as acc:
z = x * y
self.assertAllClose(acc.jvp(z), constant_op.constant([5.0, 2.0, 7.0]))
if __name__ == "__main__":
# TODO(allenl): Also test with 1.x-style graph mode.

View File

@ -284,7 +284,7 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
// Creates a new forward accumulator. Does not add it to the active set.
PyObject* TFE_Py_ForwardAccumulatorNew();
PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch);
// Adds a ForwardAccumulator to the active set, meaning it will watch executed
// operations. It must not already be in the active set.

View File

@ -2419,7 +2419,8 @@ tensorflow::Status ParseTangentOutputs(
tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
PyObject* inputs, PyObject* results,
const std::vector<PyObject*>& input_tangents,
std::vector<PyObject*>* output_tangents) {
std::vector<PyObject*>* output_tangents,
bool use_batch) {
if (forward_gradient_function == nullptr) {
return tensorflow::errors::Internal(
"No forward gradient function registered.");
@ -2430,9 +2431,10 @@ tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
// Normalize the input sequence to a tuple so it works with function
// caching; otherwise it may be an opaque _InputList object.
tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
PyObject* to_batch = (use_batch) ? Py_True : Py_False;
tensorflow::Safe_PyObjectPtr callback_args(
Py_BuildValue("OOOOO", op_name, attrs, input_tuple.get(), results,
py_input_tangents.get()));
Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
py_input_tangents.get(), to_batch));
tensorflow::Safe_PyObjectPtr py_result(
PyObject_CallObject(forward_gradient_function, callback_args.get()));
if (py_result == nullptr || PyErr_Occurred()) {
@ -2555,7 +2557,8 @@ PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
} else {
tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
[forward_function](const std::vector<PyObject*>& input_tangents,
std::vector<PyObject*>* output_tangents) {
std::vector<PyObject*>* output_tangents,
bool use_batch = false) {
return CallOpSpecificJVPFunction(forward_function, input_tangents,
output_tangents);
});
@ -2797,7 +2800,7 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
return PyList_New(0);
}
PyObject* TFE_Py_ForwardAccumulatorNew() {
PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
TFE_Py_ForwardAccumulator* accumulator =
@ -2808,7 +2811,7 @@ PyObject* TFE_Py_ForwardAccumulatorNew() {
"ForwardAccumulator requires a PyVSpace to be registered."),
nullptr);
}
accumulator->accumulator = new ForwardAccumulator(*py_vspace);
accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
return reinterpret_cast<PyObject*>(accumulator);
}
@ -3166,9 +3169,9 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
[op_name, attrs, inputs, results](
const std::vector<PyObject*>& input_tangents,
std::vector<PyObject*>* output_tangents) {
std::vector<PyObject*>* output_tangents, bool use_batch) {
return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
output_tangents);
output_tangents, use_batch);
});
tensorflow::eager::ForwardFunction<PyObject>* forward_function;
if (c_op_name == "While" || c_op_name == "StatelessWhile" ||

View File

@ -730,8 +730,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
});
// TFE_Py_ForwardAccumulator logic.
m.def("TFE_Py_ForwardAccumulatorNew", []() {
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew());
m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
});
m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {