From 3706ccd34aa161a1fb07f1efb31cb34bcd111424 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Wed, 4 Sep 2019 11:01:04 -0700 Subject: [PATCH] Fix another overzealous variable checking issue with custom_gradient PiperOrigin-RevId: 267188272 --- tensorflow/python/eager/backprop_test.py | 11 ++++++----- tensorflow/python/ops/custom_gradient.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 27f093de704..54ab48053e9 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -1491,24 +1491,25 @@ class BackpropTest(test.TestCase, parameterized.TestCase): def f(x): @custom_gradient.custom_gradient(primals=(x,)) - def g(unused_dz): + def g(dz): def h(unused_ddz): return 2.2 - return x * 2.1, h + return x * 2.1 * dz, h return x + 1., g with backprop.GradientTape(persistent=True) as t: with backprop.GradientTape(persistent=True) as tt: v = variables.Variable(1.) - self.evaluate(v.initializer) + w = variables.Variable(0.) + self.evaluate([v.initializer, w.initializer]) t.watch(v) tt.watch(v) - output = f(v) + output = f(v + w) self.assertAllClose(2., output) - g = tt.gradient(output, v) + g = tt.gradient(output, v, output_gradients=1. + w) self.assertAllClose(2.1, g) gg = t.gradient(g, v) self.assertAllClose(2.2, gg) diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index f4f85b0fa5f..56f5c20e324 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -320,7 +320,7 @@ def _graph_mode_decorator(f, primals, *args, **kwargs): inputs = args else: primals = [ops.convert_to_tensor(x) for x in nest.flatten(primals)] - inputs = primals + inputs = primals + args variables_in_tape = frozenset([ v.experimental_ref() for v in tape.watched_variables() ]) - frozenset(v.experimental_ref() for v in inputs)