The assertAll* statements already evaluate the arguments. PiperOrigin-RevId: 320729457 Change-Id: Ie1564419eb5cf8f69d0e700c000074e248401dbc
959 lines
27 KiB
Python
959 lines
27 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from absl.testing import parameterized
|
|
|
|
from tensorflow.python.eager import backprop
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import function
|
|
from tensorflow.python.framework import config
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import gradients_impl
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import nn_grad
|
|
from tensorflow.python.ops import nn_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.util import nest
|
|
|
|
|
|
_COS_DERIVATIVES = [math_ops.cos,
|
|
lambda x: -math_ops.sin(x),
|
|
lambda x: -math_ops.cos(x),
|
|
math_ops.sin,
|
|
math_ops.cos]
|
|
|
|
|
|
class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
super(FunctionGradientsTest, self).setUp()
|
|
cpus = config.list_physical_devices('CPU')
|
|
# Set 4 virtual CPUs
|
|
config.set_logical_device_configuration(cpus[0], [
|
|
context.LogicalDeviceConfiguration(),
|
|
context.LogicalDeviceConfiguration(),
|
|
context.LogicalDeviceConfiguration(),
|
|
context.LogicalDeviceConfiguration()
|
|
])
|
|
|
|
def testGraphModeWithGradients(self):
|
|
v = resource_variable_ops.ResourceVariable(1.0, name='v')
|
|
|
|
@def_function.function
|
|
def step():
|
|
def inner():
|
|
return v * v
|
|
|
|
return backprop.implicit_grad(inner)()[0][0]
|
|
|
|
self.assertAllEqual(step(), 2.0)
|
|
|
|
def testGraphGradientVariable(self):
|
|
with ops.Graph().as_default(), self.cached_session():
|
|
v = variables.Variable(1.0)
|
|
|
|
@def_function.function
|
|
def f():
|
|
return 2.0 * v
|
|
|
|
node = f()
|
|
grads, = gradients_impl.gradients(node, v)
|
|
v.initializer.run()
|
|
self.assertAllEqual(grads, 2.0)
|
|
self.assertEqual(grads.shape, v.shape)
|
|
|
|
def testSymbolicHigherOrder(self):
|
|
@def_function.function
|
|
def f(x, order):
|
|
y = def_function.function(lambda: math_ops.cos(x))()
|
|
for _ in range(order):
|
|
y, = gradients_impl.gradients(y, [x])
|
|
return y
|
|
for order, expected in enumerate(_COS_DERIVATIVES):
|
|
self.assertAllClose(
|
|
expected(constant_op.constant(1.)),
|
|
f(constant_op.constant(1.), order))
|
|
|
|
@parameterized.parameters([dict(persistent=True),
|
|
dict(persistent=False)])
|
|
def testSymbolicHigherOrderUnderTape(self, persistent):
|
|
@def_function.function
|
|
def f(x, order):
|
|
with backprop.GradientTape(persistent=persistent) as tape:
|
|
tape.watch(x)
|
|
# Note that having a tape active, even if we don't use it, forces us
|
|
# down a different function call path. Symbolic gradients should work
|
|
# here too; correctness of tape gradients are tested elsewhere.
|
|
y = def_function.function(lambda: math_ops.cos(x))()
|
|
tape_dy = tape.gradient(y, x)
|
|
for _ in range(order):
|
|
y, = gradients_impl.gradients(y, [x])
|
|
if order > 0:
|
|
y1 = tape_dy
|
|
for _ in range(order - 1):
|
|
y1, = gradients_impl.gradients(y1, [x])
|
|
else:
|
|
y1 = y
|
|
return y, y1
|
|
for order, expected_f in enumerate(_COS_DERIVATIVES):
|
|
expected = self.evaluate(expected_f(constant_op.constant(1.)))
|
|
self.assertAllClose(
|
|
(expected, expected),
|
|
f(constant_op.constant(1.), order))
|
|
|
|
def testIteratedGradientsNested(self):
|
|
|
|
def _grad(f):
|
|
def _grad_function(primal):
|
|
with backprop.GradientTape() as tape:
|
|
tape.watch(primal)
|
|
primal_out = f(primal)
|
|
return tape.gradient(primal_out, primal)
|
|
return _grad_function
|
|
|
|
@def_function.function
|
|
def _forward(x):
|
|
return math_ops.cos(x)
|
|
|
|
f = _forward
|
|
traced_f = def_function.function(f)
|
|
one = constant_op.constant(1.)
|
|
for expected in _COS_DERIVATIVES:
|
|
self.assertAllClose(expected(one), f(one))
|
|
self.assertAllClose(expected(one), traced_f(one))
|
|
self.assertAllClose(expected(one), def_function.function(f)(one))
|
|
f = _grad(f)
|
|
traced_f = def_function.function(_grad(traced_f))
|
|
|
|
def testIteratedGradientsNestedWithVariable(self):
|
|
|
|
def _grad(f):
|
|
def _grad_function():
|
|
with backprop.GradientTape() as tape:
|
|
primal_out = f()
|
|
g, = tape.gradient(primal_out, tape.watched_variables())
|
|
return g
|
|
return _grad_function
|
|
|
|
v = variables.Variable(2.)
|
|
|
|
@def_function.function
|
|
def _forward():
|
|
return math_ops.cos(v)
|
|
|
|
f = _forward
|
|
|
|
two = constant_op.constant(2.)
|
|
|
|
for expected in _COS_DERIVATIVES:
|
|
self.assertAllClose(expected(two), f())
|
|
self.assertAllClose(expected(two), def_function.function(f)())
|
|
f = _grad(f)
|
|
|
|
def testIteratedGradientsPersistent(self):
|
|
|
|
@def_function.function
|
|
def _forward(z):
|
|
return math_ops.cos(z)
|
|
|
|
f = _forward
|
|
with backprop.GradientTape(persistent=True) as tape:
|
|
start = constant_op.constant(1.)
|
|
tape.watch(start)
|
|
x = f(start)
|
|
for expected in _COS_DERIVATIVES:
|
|
self.assertAllClose(expected(start), x)
|
|
x = tape.gradient(x, start)
|
|
|
|
def testHigherOrderWithVariable(self):
|
|
|
|
v = variables.Variable(1.)
|
|
|
|
@def_function.function
|
|
def _forward():
|
|
return math_ops.cos(v)
|
|
|
|
f = _forward
|
|
with backprop.GradientTape(persistent=True) as tape:
|
|
x = f()
|
|
for expected in _COS_DERIVATIVES:
|
|
self.assertAllClose(expected(constant_op.constant(1.)), x)
|
|
x, = tape.gradient(x, tape.watched_variables())
|
|
|
|
def testGradientsChained(self):
|
|
|
|
@def_function.function
|
|
def _forward(z):
|
|
return math_ops.cos(z)
|
|
|
|
f = _forward
|
|
x = constant_op.constant(1.)
|
|
with backprop.GradientTape() as t:
|
|
t.watch(x)
|
|
y = f(x)
|
|
with backprop.GradientTape() as tt:
|
|
doutputs = constant_op.constant(2.)
|
|
tt.watch(doutputs)
|
|
g = t.gradient(y, x, doutputs)
|
|
self.assertAllClose(-2. * math_ops.sin(x), g)
|
|
gg = tt.gradient(g, doutputs)
|
|
# We're taking gradients with respect to doutputs, which is just a linear
|
|
# function of the gradient.
|
|
self.assertAllClose(-math_ops.sin(x), gg)
|
|
|
|
def testSymGradGatherNd(self):
|
|
with ops.Graph().as_default(), self.cached_session():
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return array_ops.gather_nd(x, [[0]])
|
|
|
|
c = constant_op.constant([[2.]])
|
|
f_c = f(c)
|
|
g, = gradients_impl.gradients(f_c, c)
|
|
self.assertAllEqual(self.evaluate(g).values, [[1.0]])
|
|
|
|
def testNoSymGradNestedDefun(self):
|
|
|
|
@def_function.function
|
|
def outer():
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return array_ops.gather_nd(x, [[0]])
|
|
|
|
c = constant_op.constant([[2.]])
|
|
f_c = f(c)
|
|
g, = gradients_impl.gradients(f_c, c)
|
|
self.assertIsInstance(g, ops.IndexedSlices)
|
|
|
|
outer()
|
|
|
|
def testGraphFunctionWithGradients(self):
|
|
v = resource_variable_ops.ResourceVariable(1.0, name='v')
|
|
|
|
@def_function.function
|
|
def step():
|
|
def inner():
|
|
return v * v
|
|
|
|
return backprop.implicit_grad(inner)()[0][0]
|
|
|
|
step_op = step.get_concrete_function()
|
|
self.assertEqual(step_op.output_dtypes, dtypes.float32)
|
|
self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([]))
|
|
self.assertAllEqual(step_op(), 2.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes()
|
|
def testDefunCondGradient(self):
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return control_flow_ops.cond(x > 0.5, lambda: 2 * x, lambda: 3 * x)
|
|
|
|
with backprop.GradientTape() as t:
|
|
x = constant_op.constant(1.0)
|
|
t.watch(x)
|
|
y = f(x)
|
|
self.assertAllEqual(self.evaluate(t.gradient(y, x)), 2.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes()
|
|
def testGraphLoopGradient(self):
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return control_flow_ops.while_loop(lambda _, i: i < 2,
|
|
lambda x, i: (2*x, i + 1),
|
|
[x, 0])[0]
|
|
|
|
with backprop.GradientTape() as t:
|
|
x = constant_op.constant(1.0)
|
|
t.watch(x)
|
|
y = f(x)
|
|
self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0)
|
|
|
|
def testGraphLoopGradientInsideSession(self):
|
|
with ops.Graph().as_default():
|
|
n = constant_op.constant(2.0)
|
|
x = array_ops.placeholder(dtypes.float32, shape=None)
|
|
|
|
@def_function.function
|
|
def f():
|
|
c = lambda n: n < 10
|
|
b = lambda n: n * x
|
|
return control_flow_ops.while_loop(c, b, [n],
|
|
[tensor_shape.unknown_shape()])
|
|
|
|
l = f()
|
|
dx = gradients_impl.gradients(l, [x])[0]
|
|
|
|
with self.cached_session():
|
|
self.assertEqual(dx.eval(feed_dict={x: 2.0}), 24.0)
|
|
|
|
def testDefunDifferentiable(self):
|
|
v = resource_variable_ops.ResourceVariable(1.0)
|
|
|
|
@def_function.function
|
|
def f():
|
|
return v * v
|
|
|
|
self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
|
|
|
|
def testDefunCanBeDifferentiatedTwice(self):
|
|
v = resource_variable_ops.ResourceVariable(1.0)
|
|
|
|
@def_function.function
|
|
def f():
|
|
return v * v
|
|
|
|
self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
|
|
# Ensure that v is watched again.
|
|
self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
|
|
|
|
def testSymbolicGradientVariableNoneNotZerosLike(self):
|
|
with ops.Graph().as_default():
|
|
v = variables.Variable(1.0)
|
|
|
|
@def_function.function
|
|
def f(x, v):
|
|
v.read_value()
|
|
return x * x
|
|
|
|
x = constant_op.constant(1.0)
|
|
l = f(x, v)
|
|
_, dv = gradients_impl.gradients(l, [x, v])
|
|
with self.cached_session():
|
|
v.initializer.run()
|
|
self.assertEqual(dv, None)
|
|
|
|
def testDefunCallBackprop(self):
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return math_ops.add(x, x)
|
|
|
|
@def_function.function
|
|
def g(x):
|
|
return backprop.gradients_function(f, [0])(x)[0]
|
|
|
|
self.assertAllEqual(2, g(constant_op.constant(2.)))
|
|
|
|
@test_util.run_v1_only('b/120545219')
|
|
def testGraphModeEagerGradError(self):
|
|
with context.graph_mode():
|
|
def f():
|
|
x = variable_scope.get_variable(
|
|
'v', initializer=constant_op.constant(1.0))
|
|
return x * constant_op.constant(2.0)
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
'No trainable variables were accessed'):
|
|
backprop.implicit_val_and_grad(f)()
|
|
|
|
def testDefunCallBackpropUsingSameObjectForMultipleArguments(self):
|
|
|
|
@def_function.function
|
|
def g(x):
|
|
return backprop.gradients_function(math_ops.multiply, [0, 1])(x, x)
|
|
|
|
def np_g(x):
|
|
return [d.numpy() for d in g(x)]
|
|
|
|
x = constant_op.constant(1.)
|
|
self.assertAllEqual([1., 1.], np_g(x))
|
|
self.assertAllEqual([1., 1.], np_g(1.))
|
|
|
|
def testGradientTensorConversionWithDefun(self):
|
|
three = resource_variable_ops.ResourceVariable(3.0, name='v')
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return math_ops.add(x, three)
|
|
|
|
def g(x):
|
|
return f(x)
|
|
|
|
g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0]
|
|
self.assertAllEqual(g, 1.0)
|
|
|
|
def testGradient(self):
|
|
matmul = def_function.function(math_ops.matmul)
|
|
|
|
def sq(x):
|
|
return matmul(x, x, transpose_a=True)
|
|
|
|
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
|
grad_t, = backprop.gradients_function(sq, [0])(t)
|
|
self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
|
|
|
|
def testGradientInFunction(self):
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return backprop.gradients_function(lambda y: y * y, [0])(x)[0]
|
|
|
|
self.assertAllEqual(f(constant_op.constant(1.0)), 2.0)
|
|
|
|
def testGradientOfGatherWithDefun(self):
|
|
v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
|
|
|
|
def sum_gather():
|
|
return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
|
|
|
|
grad_fn = backprop.implicit_grad(sum_gather)
|
|
gradient = grad_fn()
|
|
defun_grad_fn = backprop.implicit_grad(def_function.function(sum_gather))
|
|
defun_gradient = defun_grad_fn()
|
|
self.assertEqual(len(gradient), len(defun_gradient))
|
|
|
|
gradient = gradient[0][0]
|
|
defun_gradient = defun_gradient[0][0]
|
|
self.assertAllEqual(gradient.values, defun_gradient.values)
|
|
self.assertAllEqual(gradient.indices, defun_gradient.indices)
|
|
self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape)
|
|
|
|
def testDifferentiableFunctionNoneOutputs(self):
|
|
|
|
@def_function.function
|
|
def my_function(x):
|
|
return x, None
|
|
|
|
def wrapper(x):
|
|
return my_function(x)[0]
|
|
|
|
g = backprop.gradients_function(wrapper, [0])(constant_op.constant(0.0))
|
|
self.assertAllEqual(g[0], 1.)
|
|
|
|
@def_function.function
|
|
def foo(a):
|
|
return None, a * a
|
|
|
|
x = constant_op.constant(5.0)
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
none, r = foo(x)
|
|
g = tp.gradient(r, x)
|
|
|
|
self.assertIs(none, None)
|
|
self.assertAllEqual(r, 25.0)
|
|
self.assertAllEqual(g, 2 * 5.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testNestedDifferentiableFunction(self):
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return a * math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
return inner_fn(x, 1.0)
|
|
|
|
x = constant_op.constant(5.0)
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
result = outer_fn(x)
|
|
grad = tp.gradient(result, x)
|
|
|
|
self.assertAllEqual(grad, 2 * 5.0 + 1.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testDeeplyNestedDifferentiableFunction(self):
|
|
@def_function.function
|
|
def inner_inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return inner_inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return a * inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
return middle_fn(x, 1.0)
|
|
|
|
x = constant_op.constant(5.0)
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
result = outer_fn(x)
|
|
grad = tp.gradient(result, x)
|
|
|
|
self.assertAllEqual(grad, 2 * 5.0 + 1.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testDeeplyNestedDifferentiableFunctionWithMultipleGradCalls(self):
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return math_ops.mul(a, inner_fn(a, b))
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
return middle_fn(x, 3.0)
|
|
|
|
x = constant_op.constant(5.0)
|
|
self.assertAllEqual(outer_fn(x), 5.0 * (5.0 + 3.0))
|
|
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
result = outer_fn(x)
|
|
grad = tp.gradient(result, x)
|
|
|
|
self.assertAllEqual(grad, 2 * 5.0 + 3.0)
|
|
self.assertAllEqual(outer_fn(x), 5.0 * (5.0 + 3.0))
|
|
self.assertAllEqual(middle_fn(3.0, x), 3.0 * (3.0 + 5.0))
|
|
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
result = outer_fn(x)
|
|
grad = tp.gradient(result, x)
|
|
|
|
self.assertAllEqual(grad, 2 * 5.0 + 3.0)
|
|
|
|
y = constant_op.constant(4.0)
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(y)
|
|
result = outer_fn(y)
|
|
grad = tp.gradient(result, y)
|
|
|
|
self.assertAllEqual(grad, 2 * 4.0 + 3.0)
|
|
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(y)
|
|
result = inner_fn(y, y)
|
|
grad = tp.gradient(result, y)
|
|
|
|
self.assertAllEqual(grad, 2.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testDeeplyNestedDifferentiableFunctionGradientTapeInDefun(self):
|
|
@def_function.function
|
|
def inner_inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return inner_inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return a * inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
result = middle_fn(x, 1.0)
|
|
grad = tp.gradient(result, x)
|
|
return grad
|
|
|
|
x = constant_op.constant(5.0)
|
|
grad = outer_fn(x)
|
|
self.assertAllEqual(grad, 2 * 5.0 + 1.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testDeeplyNestedDifferentiableFunctionGradientTapeInNestedDefun(self):
|
|
@def_function.function
|
|
def inner_inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return inner_inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return a * inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def almost_outer_fn(x):
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
result = middle_fn(x, 1.0)
|
|
grad = tp.gradient(result, x)
|
|
return grad
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
return almost_outer_fn(x)
|
|
|
|
x = constant_op.constant(5.0)
|
|
grad = outer_fn(x)
|
|
self.assertAllEqual(grad, 2 * 5.0 + 1.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testDeeplyNestedDifferentiableFunctionGradientTapeInMultNestedDefun(self):
|
|
@def_function.function
|
|
def inner_inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return inner_inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return a * inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def almost_outer_fn(x):
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
result = middle_fn(x, 1.0)
|
|
grad = tp.gradient(result, x)
|
|
return grad
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
return almost_outer_fn(x)
|
|
|
|
@def_function.function
|
|
def outer_outer_fn(x):
|
|
return outer_fn(x)
|
|
|
|
x = constant_op.constant(5.0)
|
|
grad = outer_outer_fn(x)
|
|
self.assertAllEqual(grad, 2 * 5.0 + 1.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testDeeplyNestedDifferentiableFunctionTFGradientInDefun(self):
|
|
@def_function.function
|
|
def inner_inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return inner_inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return a * inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
result = middle_fn(x, 1.0)
|
|
return gradients_impl.gradients(result, [x])[0]
|
|
|
|
x = constant_op.constant(5.0)
|
|
grad = outer_fn(x)
|
|
self.assertAllEqual(grad, 2 * 5.0 + 1.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testDeeplyNestedDifferentiableFunctionTFGradientInNestedDefun(self):
|
|
@def_function.function
|
|
def inner_inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return inner_inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return a * inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def almost_outer_fn(x):
|
|
result = middle_fn(x, 1.0)
|
|
return gradients_impl.gradients(result, [x])[0]
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
return almost_outer_fn(x)
|
|
|
|
x = constant_op.constant(5.0)
|
|
grad = outer_fn(x)
|
|
self.assertAllEqual(grad, 2 * 5.0 + 1.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testDeeplyNestedDifferentiableFunctionTFGradientInMultNestedDefun(self):
|
|
@def_function.function
|
|
def inner_inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return inner_inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return a * inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def almost_outer_fn(x):
|
|
result = middle_fn(x, 1.0)
|
|
return gradients_impl.gradients(result, [x])[0]
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
return almost_outer_fn(x)
|
|
|
|
@def_function.function
|
|
def outer_outer_fn(x):
|
|
return outer_fn(x)
|
|
|
|
x = constant_op.constant(5.0)
|
|
grad = outer_outer_fn(x)
|
|
self.assertAllEqual(grad, 2 * 5.0 + 1.0)
|
|
|
|
def testDeeplyNestedDifferentiableFunctionWithVariable(self):
|
|
var = variables.Variable(constant_op.constant(1.0))
|
|
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return a * inner_fn(a, b)
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
return middle_fn(x, var)
|
|
|
|
x = constant_op.constant(5.0)
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
result = outer_fn(x)
|
|
grad = tp.gradient(result, x)
|
|
|
|
self.assertAllEqual(grad, 2 * 5.0 + 1.0)
|
|
|
|
def testDeeplyNestedDifferentiableFunctionWithVariableMultipleGradCalls(self):
|
|
v = variables.Variable(constant_op.constant(3.0))
|
|
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return math_ops.mul(a, inner_fn(a, b))
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
return middle_fn(x, v)
|
|
|
|
x = constant_op.constant(5.0)
|
|
self.assertAllEqual(outer_fn(x), 5.0 * (5.0 + 3.0))
|
|
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
result = outer_fn(x)
|
|
grad = tp.gradient(result, x)
|
|
|
|
self.assertAllEqual(grad, 2 * 5.0 + 3.0)
|
|
self.assertAllEqual(outer_fn(x), 5.0 * (5.0 + 3.0))
|
|
self.assertAllEqual(middle_fn(v, x), 3.0 * (3.0 + 5.0))
|
|
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(x)
|
|
result = outer_fn(x)
|
|
grad = tp.gradient(result, x)
|
|
|
|
self.assertAllEqual(grad, 2 * 5.0 + 3.0)
|
|
|
|
y = constant_op.constant(4.0)
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(y)
|
|
result = outer_fn(y)
|
|
grad = tp.gradient(result, y)
|
|
|
|
self.assertAllEqual(grad, 2 * 4.0 + 3.0)
|
|
|
|
v.assign(constant_op.constant(1.5))
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(y)
|
|
result = outer_fn(y)
|
|
grad = tp.gradient(result, y)
|
|
|
|
self.assertAllEqual(grad, 2 * 4.0 + 1.5)
|
|
|
|
with backprop.GradientTape() as tp:
|
|
tp.watch(y)
|
|
result = inner_fn(y, v)
|
|
grad = tp.gradient(result, y)
|
|
|
|
self.assertAllEqual(grad, 1.0)
|
|
|
|
def testDeeplyNestedDifferentiableFunctionWithVariableMultipleTFGrads(self):
|
|
with context.graph_mode(), self.cached_session():
|
|
v = resource_variable_ops.ResourceVariable(3.0)
|
|
v.initializer.run()
|
|
|
|
@def_function.function
|
|
def inner_fn(a, b):
|
|
return math_ops.add(a, b)
|
|
|
|
@def_function.function
|
|
def middle_fn(a, b):
|
|
return math_ops.mul(a, inner_fn(a, b))
|
|
|
|
@def_function.function
|
|
def outer_fn(x):
|
|
return middle_fn(x, v)
|
|
|
|
x = constant_op.constant(5.0)
|
|
self.assertAllEqual(outer_fn(x), 5.0 * (5.0 + 3.0))
|
|
|
|
grad, = gradients_impl.gradients(outer_fn(x), x)
|
|
|
|
self.assertAllEqual(grad, 2 * 5.0 + 3.0)
|
|
self.assertAllEqual(outer_fn(x), 5.0 * (5.0 + 3.0))
|
|
self.assertAllEqual(middle_fn(v, x), 3.0 * (3.0 + 5.0))
|
|
|
|
grad, = gradients_impl.gradients(outer_fn(x), x)
|
|
|
|
self.assertAllEqual(grad, 2 * 5.0 + 3.0)
|
|
|
|
y = constant_op.constant(4.0)
|
|
grad, = gradients_impl.gradients(outer_fn(y), y)
|
|
self.assertAllEqual(grad, 2 * 4.0 + 3.0)
|
|
|
|
self.evaluate(v.assign(constant_op.constant(1.5)))
|
|
grad, = gradients_impl.gradients(outer_fn(y), y)
|
|
|
|
self.assertAllEqual(grad, 2 * 4.0 + 1.5)
|
|
|
|
grad, = gradients_impl.gradients(inner_fn(y, v), y)
|
|
self.assertAllEqual(grad, 1.0)
|
|
|
|
def testNestedDifferentiableFunctionNoneOutputs(self):
|
|
@def_function.function
|
|
def foo(a, b):
|
|
return None, a * math_ops.add(a, b), None, 2*a
|
|
|
|
@def_function.function
|
|
def bar(x):
|
|
return foo(x, 1.0)
|
|
|
|
x = constant_op.constant(5.0)
|
|
with backprop.GradientTape(persistent=True) as tp:
|
|
tp.watch(x)
|
|
none1, r1, none2, r2 = bar(x)
|
|
g1 = tp.gradient(r1, x)
|
|
g2 = tp.gradient(r2, x)
|
|
|
|
self.assertAllEqual(r1, 30.0)
|
|
self.assertAllEqual(r2, 10.0)
|
|
self.assertIs(none1, None)
|
|
self.assertIs(none2, None)
|
|
self.assertAllEqual(g1, 2 * 5.0 + 1.0)
|
|
self.assertAllEqual(g2, 2.0)
|
|
|
|
def testGradientWithKeywordArguments(self):
|
|
matmul = def_function.function(math_ops.matmul)
|
|
|
|
def sq(x):
|
|
return matmul(a=x, b=x, transpose_a=True)
|
|
|
|
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
|
grad_t, = backprop.gradients_function(sq, [0])(t)
|
|
self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
|
|
|
|
with backprop.GradientTape(persistent=True) as tape:
|
|
tape.watch(t)
|
|
one = matmul(t, b=t, transpose_a=True)
|
|
two = matmul(b=t, a=t, transpose_a=True)
|
|
three = matmul(a=t, b=t, transpose_a=True)
|
|
|
|
for output in [one, two, three]:
|
|
self.assertAllEqual(tape.gradient(output, t), [[6, 6], [14, 14]])
|
|
|
|
def testGradientInFunctionWithKeywordArguments(self):
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
return backprop.gradients_function(lambda y: y * y, [0])(x)[0]
|
|
|
|
self.assertAllEqual(f(x=constant_op.constant(1.0)), 2.0)
|
|
|
|
def testFunctionHasNoSecondOrderGradient(self):
|
|
|
|
# This test needs nn_grad imported. We could just disable the lint error,
|
|
# but this way if the test is deleted we'll know the import isn't needed.
|
|
_ = nn_grad
|
|
|
|
v = variables.Variable(1.)
|
|
|
|
@def_function.function
|
|
def f(labels, logits):
|
|
return def_function.function(
|
|
nn_ops.sparse_softmax_cross_entropy_with_logits)(
|
|
labels=labels, logits=logits + v)
|
|
|
|
@def_function.function
|
|
def f_grad():
|
|
with backprop.GradientTape() as tape:
|
|
logits = constant_op.constant([1., 2.])
|
|
tape.watch(logits)
|
|
out = f(constant_op.constant(1), logits)
|
|
return tape.gradient(out, logits)
|
|
# Mainly we want to check that the function builds despite
|
|
# sparse_softmax_cross_entropy_with_logits not having a second-order
|
|
# gradient defined.
|
|
self.assertAllEqual([2], f_grad().shape)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBackwardNone(self):
|
|
model = variables.Variable(1.0, name='model')
|
|
count = variables.Variable(0)
|
|
|
|
@function.defun
|
|
def forward_pass(value):
|
|
count.assign_add(1)
|
|
residuals = value - model
|
|
loss = 0.5 * math_ops.reduce_mean(math_ops.pow(residuals, 2))
|
|
# Note: count is an integer, so its doutput will be None
|
|
return loss, count
|
|
|
|
def reduce_fn(x):
|
|
if context.executing_eagerly():
|
|
with backprop.GradientTape() as t:
|
|
loss, count = forward_pass(x)
|
|
return t.gradient(loss, model), count
|
|
loss, count = forward_pass(x)
|
|
grad_only = gradients_impl.gradients(loss, model)
|
|
return grad_only, count
|
|
|
|
g, _ = reduce_fn(constant_op.constant([7.0]))
|
|
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllEqual(nest.flatten(self.evaluate(g)), [-6.0])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
ops.enable_eager_execution()
|
|
test.main()
|