Fix another overzealous variable checking issue with custom_gradient
PiperOrigin-RevId: 267188272
This commit is contained in:
parent
b883c28ffe
commit
3706ccd34a
@ -1491,24 +1491,25 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
def f(x):
|
||||
|
||||
@custom_gradient.custom_gradient(primals=(x,))
|
||||
def g(unused_dz):
|
||||
def g(dz):
|
||||
|
||||
def h(unused_ddz):
|
||||
return 2.2
|
||||
|
||||
return x * 2.1, h
|
||||
return x * 2.1 * dz, h
|
||||
|
||||
return x + 1., g
|
||||
|
||||
with backprop.GradientTape(persistent=True) as t:
|
||||
with backprop.GradientTape(persistent=True) as tt:
|
||||
v = variables.Variable(1.)
|
||||
self.evaluate(v.initializer)
|
||||
w = variables.Variable(0.)
|
||||
self.evaluate([v.initializer, w.initializer])
|
||||
t.watch(v)
|
||||
tt.watch(v)
|
||||
output = f(v)
|
||||
output = f(v + w)
|
||||
self.assertAllClose(2., output)
|
||||
g = tt.gradient(output, v)
|
||||
g = tt.gradient(output, v, output_gradients=1. + w)
|
||||
self.assertAllClose(2.1, g)
|
||||
gg = t.gradient(g, v)
|
||||
self.assertAllClose(2.2, gg)
|
||||
|
@ -320,7 +320,7 @@ def _graph_mode_decorator(f, primals, *args, **kwargs):
|
||||
inputs = args
|
||||
else:
|
||||
primals = [ops.convert_to_tensor(x) for x in nest.flatten(primals)]
|
||||
inputs = primals
|
||||
inputs = primals + args
|
||||
variables_in_tape = frozenset([
|
||||
v.experimental_ref() for v in tape.watched_variables()
|
||||
]) - frozenset(v.experimental_ref() for v in inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user