Fix chain rule to not imply second differentiation (#3650)
Fixes part of #3629.
This commit is contained in:
parent
c5f94b10bb
commit
9f47ac6445
@ -1011,13 +1011,13 @@ function which computes gradients with respect to the ops' inputs given
|
||||
gradients with respect to the ops' outputs.
|
||||
|
||||
Mathematically, if an op computes \\(y = f(x)\\) the registered gradient op
|
||||
converts gradients \\(\partial / \partial y\\) with respect to \\(y\\) into
|
||||
gradients \\(\partial / \partial x\\) with respect to \\(x\\) via the chain
|
||||
rule:
|
||||
converts gradients \\(\partial L/ \partial y\\) of loss \\(L\\) with respect to
|
||||
\\(y\\) into gradients \\(\partial L/ \partial x\\) with respect to \\(x\\) via
|
||||
the chain rule:
|
||||
|
||||
$$\frac{\partial}{\partial x}
|
||||
= \frac{\partial}{\partial y} \frac{\partial y}{\partial x}
|
||||
= \frac{\partial}{\partial y} \frac{\partial f}{\partial x}.$$
|
||||
$$\frac{\partial L}{\partial x}
|
||||
= \frac{\partial L}{\partial y} \frac{\partial y}{\partial x}
|
||||
= \frac{\partial L}{\partial y} \frac{\partial f}{\partial x}.$$
|
||||
|
||||
In the case of `ZeroOut`, only one entry in the input affects the output, so the
|
||||
gradient with respect to the input is a sparse "one hot" tensor. This is
|
||||
|
Loading…
x
Reference in New Issue
Block a user