Fix chain rule to not imply second differentiation (#3650)

Fixes part of #3629.
This commit is contained in:
Geoffrey Irving 2016-08-04 16:00:26 -07:00 committed by GitHub
parent c5f94b10bb
commit 9f47ac6445

View File

@ -1011,13 +1011,13 @@ function which computes gradients with respect to the ops' inputs given
gradients with respect to the ops' outputs.
Mathematically, if an op computes \\(y = f(x)\\) the registered gradient op
converts gradients \\(\partial / \partial y\\) with respect to \\(y\\) into
gradients \\(\partial / \partial x\\) with respect to \\(x\\) via the chain
rule:
converts gradients \\(\partial L/ \partial y\\) of loss \\(L\\) with respect to
\\(y\\) into gradients \\(\partial L/ \partial x\\) with respect to \\(x\\) via
the chain rule:
$$\frac{\partial}{\partial x}
= \frac{\partial}{\partial y} \frac{\partial y}{\partial x}
= \frac{\partial}{\partial y} \frac{\partial f}{\partial x}.$$
$$\frac{\partial L}{\partial x}
= \frac{\partial L}{\partial y} \frac{\partial y}{\partial x}
= \frac{\partial L}{\partial y} \frac{\partial f}{\partial x}.$$
In the case of `ZeroOut`, only one entry in the input affects the output, so the
gradient with respect to the input is a sparse "one hot" tensor. This is