For tf.cond, this required that we don't create a default zeros output grad when the output grad for all branch functions is None. E.g. since LookupTable ops are marked non-differentiable the output gradient wrt the LookupTable resource tensor is always None. Right now we try to convert that to a zeros tensor which is not supported.
Also added support for tf.cond v2 to have branch functions with no outputs. This is necessary now that we may have grad If ops with no outputs.
In tf.while_loop, since a captured LookupTable resource is a loop output as well, due to the requirement for matching input and output signatures, gradients_util tries to create a default gradient for the LookupTable which is not supported. So in gradients_util we now check whether the resource is a differentiable resource before building the default grad. Hopefully we can avoid this once we have explicit captures in While.
PiperOrigin-RevId: 277099963
Change-Id: Ib1e87fe42213bd10294d63c6ed4e77859489f1ce
Need to special-case variant dtypes since they require using zeros_like
Fixes forwardprop of functions containing control flow
PiperOrigin-RevId: 272938367
Fixes higher-order gradients of function calls
When running a function under a tape, we build a forward function which outputs everything the backward function needs, and a backward function which accepts output gradients for all of the outputs of the forward function. This sometimes needs a few iterations to converge, but the resulting pair does not need to be regenerated if higher-order gradients are eventually requested.
When taking symbolic gradients of function call operations (tf.gradients), we just need to do a bit less caching than we were doing previously. When we mutate the forward-pass op with new side outputs, tf.gradients is smart enough to re-request the backward function when taking higher-order gradients, but previously we were caching too aggressively and so ignored this request.
PiperOrigin-RevId: 256268751