Merge pull request #41628 from abhichou4:exp/accumulator
PiperOrigin-RevId: 324254130 Change-Id: Ib4d7b222ade81d359a1f7002fe2e7fcffec8e055
This commit is contained in:
commit
ce3315d20c
@ -177,12 +177,12 @@ class GradientTape {
|
||||
template <typename Gradient>
|
||||
class ForwardFunction
|
||||
: public std::function<Status(const std::vector<Gradient*>&,
|
||||
std::vector<Gradient*>*)> {
|
||||
std::vector<Gradient*>*, bool)> {
|
||||
public:
|
||||
template <typename lambda_type>
|
||||
explicit ForwardFunction(lambda_type lambda)
|
||||
: std::function<Status(const std::vector<Gradient*>&,
|
||||
std::vector<Gradient*>*)>(lambda) {}
|
||||
std::vector<Gradient*>*, bool)>(lambda) {}
|
||||
};
|
||||
|
||||
// Computes Jacobian-vector products using forward-mode automatic
|
||||
@ -205,8 +205,9 @@ class ForwardAccumulator {
|
||||
// Does not take ownership of `vspace`, which must outlive the
|
||||
// ForwardAccumulator.
|
||||
explicit ForwardAccumulator(
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
|
||||
: vspace_(vspace) {
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
|
||||
bool use_batch)
|
||||
: vspace_(vspace), use_batch_(use_batch) {
|
||||
call_state_.emplace(nullptr, false);
|
||||
}
|
||||
|
||||
@ -314,6 +315,9 @@ class ForwardAccumulator {
|
||||
// available in language bindings (e.g. Python).
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
|
||||
|
||||
// Decides if tangents are vectorized or not
|
||||
bool use_batch_;
|
||||
|
||||
struct AccumulatorCallState {
|
||||
AccumulatorCallState(
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
|
||||
@ -1062,7 +1066,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
output_tensors, backward_function_getter, backward_function_deleter,
|
||||
in_grads, &forward_grads));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads));
|
||||
TF_RETURN_IF_ERROR(
|
||||
(*forward_function)(in_grads, &forward_grads, use_batch_));
|
||||
}
|
||||
for (int i = 0; i < forward_grads.size(); ++i) {
|
||||
if (forward_grads[i] != nullptr) {
|
||||
|
||||
@ -29,6 +29,7 @@ from tensorflow.python.eager import forwardprop_util
|
||||
from tensorflow.python.eager import function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||
@ -219,7 +220,7 @@ pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
|
||||
|
||||
|
||||
@tf_export("autodiff.ForwardAccumulator", v1=[])
|
||||
class ForwardAccumulator(object):
|
||||
class ForwardAccumulator():
|
||||
"""Computes Jacobian-vector products ("JVP"s) using forward-mode autodiff.
|
||||
|
||||
Compare to `tf.GradientTape` which computes vector-Jacobian products ("VJP"s)
|
||||
@ -349,7 +350,7 @@ class ForwardAccumulator(object):
|
||||
ValueError: If the same tensor or variable is specified multiple times in
|
||||
`primals`.
|
||||
"""
|
||||
self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew()
|
||||
self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(False)
|
||||
self._recording = False
|
||||
primal_ids = set()
|
||||
for primal in nest.flatten(primals):
|
||||
@ -451,3 +452,32 @@ class ForwardAccumulator(object):
|
||||
return result
|
||||
|
||||
return nest.map_structure(_fetch_jvp, primals)
|
||||
|
||||
@classmethod
|
||||
def _batch_accumulator(cls, primals, tangents):
|
||||
"""Factory constructor to test accumulator on batches of tangents.
|
||||
|
||||
Args:
|
||||
primals: A tensor or nested structure of tensors to watch.
|
||||
tangents: A tensor or nested structure of tensors, with the same nesting
|
||||
structure as `primals`, with each element being a vector with compatible
|
||||
shape `[None] + primal.shape` of the corresponding primal element.
|
||||
|
||||
Returns:
|
||||
A batch accumulator object.
|
||||
"""
|
||||
acc = super(ForwardAccumulator, cls).__new__(cls, primals, tangents)
|
||||
acc._recording = False
|
||||
acc._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(True)
|
||||
primal_ids = set()
|
||||
for primal, tangent in zip(nest.flatten(primals), nest.flatten(tangents)):
|
||||
tangent.shape.assert_is_compatible_with(
|
||||
tensor_shape.TensorShape([None]) + primal.shape)
|
||||
if id(primal) in primal_ids:
|
||||
raise ValueError(
|
||||
"Tensor {} was specified as a primal multiple times. This may "
|
||||
"indicate an error. If it was intended, please sum the "
|
||||
"corresponding tangents.")
|
||||
primal_ids.add(id(primal))
|
||||
acc._watch(primals, tangents)
|
||||
return acc
|
||||
|
||||
@ -1018,7 +1018,7 @@ class HessianTests(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(hess_value, hessian_pfor)
|
||||
|
||||
|
||||
class JacobianTests(test.TestCase, parameterized.TestCase):
|
||||
class BatchTests(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([(math_ops.sin, (2, 3), 5),
|
||||
(math_ops.sin, (2, 3, 4), 10)])
|
||||
@ -1029,6 +1029,18 @@ class JacobianTests(test.TestCase, parameterized.TestCase):
|
||||
_jvp_batch(f, primals, tangent_batch)[1],
|
||||
_jvp_batch_matmul(f, primals, *tangent_batch))
|
||||
|
||||
def testBatchCorrectness(self):
|
||||
x = constant_op.constant(2.0)
|
||||
y = constant_op.constant(5.0)
|
||||
tangents = (
|
||||
constant_op.constant([1., 0., 1.]),
|
||||
constant_op.constant([0., 1., 1.]),
|
||||
)
|
||||
with forwardprop.ForwardAccumulator._batch_accumulator((x, y),
|
||||
tangents) as acc:
|
||||
z = x * y
|
||||
self.assertAllClose(acc.jvp(z), constant_op.constant([5.0, 2.0, 7.0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(allenl): Also test with 1.x-style graph mode.
|
||||
|
||||
@ -284,7 +284,7 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
|
||||
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
|
||||
|
||||
// Creates a new forward accumulator. Does not add it to the active set.
|
||||
PyObject* TFE_Py_ForwardAccumulatorNew();
|
||||
PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch);
|
||||
|
||||
// Adds a ForwardAccumulator to the active set, meaning it will watch executed
|
||||
// operations. It must not already be in the active set.
|
||||
|
||||
@ -2419,7 +2419,8 @@ tensorflow::Status ParseTangentOutputs(
|
||||
tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
|
||||
PyObject* inputs, PyObject* results,
|
||||
const std::vector<PyObject*>& input_tangents,
|
||||
std::vector<PyObject*>* output_tangents) {
|
||||
std::vector<PyObject*>* output_tangents,
|
||||
bool use_batch) {
|
||||
if (forward_gradient_function == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"No forward gradient function registered.");
|
||||
@ -2430,9 +2431,10 @@ tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
|
||||
// Normalize the input sequence to a tuple so it works with function
|
||||
// caching; otherwise it may be an opaque _InputList object.
|
||||
tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
|
||||
PyObject* to_batch = (use_batch) ? Py_True : Py_False;
|
||||
tensorflow::Safe_PyObjectPtr callback_args(
|
||||
Py_BuildValue("OOOOO", op_name, attrs, input_tuple.get(), results,
|
||||
py_input_tangents.get()));
|
||||
Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
|
||||
py_input_tangents.get(), to_batch));
|
||||
tensorflow::Safe_PyObjectPtr py_result(
|
||||
PyObject_CallObject(forward_gradient_function, callback_args.get()));
|
||||
if (py_result == nullptr || PyErr_Occurred()) {
|
||||
@ -2555,7 +2557,8 @@ PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
|
||||
} else {
|
||||
tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
|
||||
[forward_function](const std::vector<PyObject*>& input_tangents,
|
||||
std::vector<PyObject*>* output_tangents) {
|
||||
std::vector<PyObject*>* output_tangents,
|
||||
bool use_batch = false) {
|
||||
return CallOpSpecificJVPFunction(forward_function, input_tangents,
|
||||
output_tangents);
|
||||
});
|
||||
@ -2797,7 +2800,7 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
|
||||
return PyList_New(0);
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_ForwardAccumulatorNew() {
|
||||
PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
|
||||
TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
|
||||
TFE_Py_ForwardAccumulator* accumulator =
|
||||
@ -2808,7 +2811,7 @@ PyObject* TFE_Py_ForwardAccumulatorNew() {
|
||||
"ForwardAccumulator requires a PyVSpace to be registered."),
|
||||
nullptr);
|
||||
}
|
||||
accumulator->accumulator = new ForwardAccumulator(*py_vspace);
|
||||
accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
|
||||
return reinterpret_cast<PyObject*>(accumulator);
|
||||
}
|
||||
|
||||
@ -3166,9 +3169,9 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
|
||||
[op_name, attrs, inputs, results](
|
||||
const std::vector<PyObject*>& input_tangents,
|
||||
std::vector<PyObject*>* output_tangents) {
|
||||
std::vector<PyObject*>* output_tangents, bool use_batch) {
|
||||
return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
|
||||
output_tangents);
|
||||
output_tangents, use_batch);
|
||||
});
|
||||
tensorflow::eager::ForwardFunction<PyObject>* forward_function;
|
||||
if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
|
||||
|
||||
@ -730,8 +730,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
});
|
||||
|
||||
// TFE_Py_ForwardAccumulator logic.
|
||||
m.def("TFE_Py_ForwardAccumulatorNew", []() {
|
||||
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew());
|
||||
m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
|
||||
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
|
||||
});
|
||||
|
||||
m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user