Fix another overzealous variable checking issue with custom_gradient

PiperOrigin-RevId: 267188272
This commit is contained in:
Allen Lavoie 2019-09-04 11:01:04 -07:00 committed by TensorFlower Gardener
parent b883c28ffe
commit 3706ccd34a
2 changed files with 7 additions and 6 deletions

View File

@ -1491,24 +1491,25 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
def f(x):
@custom_gradient.custom_gradient(primals=(x,))
def g(unused_dz):
def g(dz):
def h(unused_ddz):
return 2.2
return x * 2.1, h
return x * 2.1 * dz, h
return x + 1., g
with backprop.GradientTape(persistent=True) as t:
with backprop.GradientTape(persistent=True) as tt:
v = variables.Variable(1.)
self.evaluate(v.initializer)
w = variables.Variable(0.)
self.evaluate([v.initializer, w.initializer])
t.watch(v)
tt.watch(v)
output = f(v)
output = f(v + w)
self.assertAllClose(2., output)
g = tt.gradient(output, v)
g = tt.gradient(output, v, output_gradients=1. + w)
self.assertAllClose(2.1, g)
gg = t.gradient(g, v)
self.assertAllClose(2.2, gg)

View File

@ -320,7 +320,7 @@ def _graph_mode_decorator(f, primals, *args, **kwargs):
inputs = args
else:
primals = [ops.convert_to_tensor(x) for x in nest.flatten(primals)]
inputs = primals
inputs = primals + args
variables_in_tape = frozenset([
v.experimental_ref() for v in tape.watched_variables()
]) - frozenset(v.experimental_ref() for v in inputs)