Commit Graph

5 Commits

Author SHA1 Message Date
Saurabh Saxena
6d7211299d Support taking gradients of tf.cond and tf.while_loop using LookupTable.
For tf.cond, this required that we don't create a default zeros output grad when the output grad for all branch functions is None. E.g. since LookupTable ops are marked non-differentiable the output gradient wrt the LookupTable resource tensor is always None. Right now we try to convert that to a zeros tensor which is not supported.
Also added support for tf.cond v2 to have branch functions with no outputs. This is necessary now that we may have grad If ops with no outputs.

In tf.while_loop, since a captured LookupTable resource is a loop output as well, due to the requirement for matching input and output signatures, gradients_util tries to create a default gradient for the LookupTable which is not supported. So in gradients_util we now check whether the resource is a differentiable resource before building the default grad. Hopefully we can avoid this once we have explicit captures in While.

PiperOrigin-RevId: 277099963
Change-Id: Ib1e87fe42213bd10294d63c6ed4e77859489f1ce
2019-10-28 11:06:28 -07:00
Allen Lavoie
6077ea44e3 Fix tape/accumulator variant handling
Need to special-case variant dtypes since they require using zeros_like

Fixes forwardprop of functions containing control flow

PiperOrigin-RevId: 272938367
2019-10-04 19:03:57 -07:00
Saurabh Saxena
c1a6d34eeb Remove code in default_gradient that returns float32 if resource handle_data is not set since that should never happen.
PiperOrigin-RevId: 272885112
2019-10-04 08:58:23 -07:00
Allen Lavoie
bd4feec252 Accept output gradients of side outputs when calling functions
Fixes higher-order gradients of function calls

When running a function under a tape, we build a forward function which outputs everything the backward function needs, and a backward function which accepts output gradients for all of the outputs of the forward function. This sometimes needs a few iterations to converge, but the resulting pair does not need to be regenerated if higher-order gradients are eventually requested.

When taking symbolic gradients of function call operations (tf.gradients), we just need to do a bit less caching than we were doing previously. When we mutate the forward-pass op with new side outputs, tf.gradients is smart enough to re-request the backward function when taking higher-order gradients, but previously we were caching too aggressively and so ignored this request.

PiperOrigin-RevId: 256268751
2019-07-02 18:46:00 -07:00
Saurabh Saxena
026d33ff03 Fixes default gradient of non-float32 ResourceVariables.
Uses the handle_data of a Resource variable to build its default gradient
instead of falling back to float32.

PiperOrigin-RevId: 250715790
2019-05-30 10:46:29 -07:00