Forwardprop: fix variables inside functions, test with layers

We were trying to use the EagerTensor variable handle inside the function. Stops doing that, and also makes it not segfault if it were to do that.

A couple misc. fixes to test infrastructure to support the new tests, e.g. gradient checker didn't like unconnected gradients.

PiperOrigin-RevId: 266958407
This commit is contained in:
Allen Lavoie 2019-09-03 10:17:38 -07:00 committed by TensorFlower Gardener
parent 619ad608cb
commit 7403920848
4 changed files with 105 additions and 4 deletions

View File

@ -922,6 +922,11 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
for (const TapeTensor& output_tensor : output_tensors) {
// Ownership of `aid` transferred to CallBackwardFunction below.
Gradient* aid = vspace_.Ones(output_tensor);
if (TF_PREDICT_FALSE(aid == nullptr)) {
return tensorflow::errors::Internal(
"Failed to create ones tensor for tensor ", output_tensor.GetID(),
" with dtype ", output_tensor.GetDType());
}
forwardprop_aids.push_back(aid);
int64 aid_id = vspace_.TensorId(aid);
sources.push_back(aid_id);

View File

@ -187,7 +187,9 @@ class ForwardGradientAccumulator(object):
"floating (e.g. tf.float32), got %r", 5, t.dtype)
g = ops.convert_to_tensor(g, dtype=t.dtype)
if hasattr(t, "handle"):
t = t.handle
# Run convert_to_tensor to get the captured handle from whichever
# function we're running if necessary.
t = ops.convert_to_tensor(t.handle)
pywrap_tensorflow.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
def jvp(self, target):
@ -209,7 +211,7 @@ class ForwardGradientAccumulator(object):
raise ValueError("Called jvp() without first tracing anything.")
def _fetch_jvp(tensor):
if hasattr(tensor, "handle"):
tensor = tensor.handle
tensor = ops.convert_to_tensor(tensor.handle)
return pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP(
self._accumulator, tensor)
return nest.map_structure(_fetch_jvp, target)

View File

@ -31,8 +31,13 @@ from tensorflow.python.eager import forwardprop
from tensorflow.python.eager import forwardprop_util
from tensorflow.python.eager import tape as tape_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.layers import convolutional
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import normalization_v2
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradient_checker_v2
@ -59,7 +64,12 @@ def _jvp(f, primals, tangents):
with forwardprop.ForwardGradientAccumulator() as acc:
acc.watch(primals, tangents)
primals_out = f(*primals)
return primals_out, acc.jvp(primals_out)
tangents_out = acc.jvp(primals_out)
if primals_out is not None and tangents_out is None:
# TODO(allenl): Support UnconnectedGradients as an accumulator constructor
# argument.
return primals_out, array_ops.zeros_like(primals_out)
return primals_out, tangents_out
def _jacfwd(f, primals):
@ -288,6 +298,66 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
def testElementwiseNNOps(self, value, op_fn):
_test_gradients(self, op_fn, [constant_op.constant(value)], order=3)
@parameterized.named_parameters(
[("Dense", [[0.1]], functools.partial(core.Dense, 5)),
("Conv2D",
np.reshape(np.arange(start=-1., stop=1., step=2. / (1 * 2 * 4 * 4)),
[1, 2, 4, 4]),
functools.partial(convolutional.Conv2D, 2, 2),
1e-4),
("BatchNorm", [[0.1], [0.2], [-0.3]],
normalization_v2.BatchNormalization)])
def testKerasLayers(self, value, op_fn, atol=1e-6):
layer = op_fn()
input_value = constant_op.constant(value, dtype=dtypes.float32)
_test_gradients(
self, layer, [input_value], atol=atol,
# These are linear, so second-order is pretty boring.
order=2)
@parameterized.named_parameters(
[("Function", def_function.function),
("NoFunction", lambda f: f)])
def testVariablesHVP(self, decorator):
class _Model(module.Module):
def __init__(self):
self._first_dense = core.Dense(18)
self._conv = convolutional.Conv2D(2, 2)
self._norm = normalization_v2.BatchNormalization()
self._second_dense = core.Dense(1)
def __call__(self, x):
x = self._first_dense(x)
x = nn_ops.relu(x)
x = self._norm(x)
x = nn_ops.relu(self._conv(array_ops.reshape(x, [-1, 2, 3, 3])))
return self._second_dense(x)
model = _Model()
def _loss():
input_value = constant_op.constant([[-0.5, 1.], [0.5, -1.]])
target = constant_op.constant([[-1.], [2.]])
return math_ops.reduce_sum((model(input_value) - target) ** 2.)
@decorator
def _compute_hvps():
with backprop.GradientTape() as tape:
loss = _loss()
vector = tape.gradient(loss, model.trainable_variables)
variable_input_fn = lambda unused_variables: _loss()
forward_over_back_hvp = _hvp(
variable_input_fn, [model.trainable_variables], [vector])
with backprop.GradientTape(persistent=True) as tape:
tape.watch(model.trainable_variables)
loss = _loss()
first_grads = tape.gradient(loss, model.trainable_variables)
back_over_back_hvp = tape.gradient(
first_grads, model.trainable_variables, output_gradients=vector)
return forward_over_back_hvp, back_over_back_hvp
self.assertAllClose(*_compute_hvps(), rtol=1e-5, atol=1e-5)
def testPushPopAccumulatorState(self):
# Note that this example is somewhat contrived. push_forwardprop_state is
# probably only useful in practice for building functions that compute jvps
@ -571,6 +641,30 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
x2 = v + .1
self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testVariableWatchedFunction(self):
class _Model(module.Module):
def __init__(self):
self._v = None
@def_function.function
def compute_jvps(self):
if self._v is None:
self._v = variables.Variable([1., 2., 3.])
with forwardprop.ForwardGradientAccumulator() as acc:
acc.watch(self._v, constant_op.constant([.1, -.2, .3]))
x = self._v * 2.
x2 = self._v + .1
return acc.jvp((self._v, x, x2))
model = _Model()
v_jvp, x_jvp, x2_jvp = model.compute_jvps()
self.assertAllClose([.1, -.2, .3], v_jvp)
self.assertAllClose([.2, -.4, .6], x_jvp)
self.assertAllClose([.1, -.2, .3], x2_jvp)
# NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this
# test... could be something wrong with the test decorator, or some sort of
# nondeterminstic caching.

View File

@ -175,7 +175,7 @@ def _compute_theoretical_jacobian(f, y_shape, y_dtype, xs, param):
r_begin = i * x_val_size
r_end = r_begin + x_val_size
jacobian[r_begin:r_end, col] += v.flat
else:
elif grad is not None:
jacobian[:, col] = grad.ravel().view(jacobian.dtype)
# If the output is empty, run the gradients at least once and make sure